#!/usr/bin/env python
import numpy

import chainer
import chainer.functions as F
import chainer.links as L
import chainerx
MODE = 0
Z_ORDER = 3 

def add_noise(device, h, sigma=0.2):
    if chainer.config.train:
        xp = device.xp
        if device.xp is chainerx:
            fallback_device = device.fallback_device
            with chainer.using_device(fallback_device):
                randn = device.send(fallback_device.xp.random.randn(*h.shape))
        else:
            randn = xp.random.randn(*h.shape)
        return h + sigma * randn
    else:
        return h


class Generator(chainer.Chain):

    def __init__(self, t_len=301, sigma=4 ,wscale=0.5):
        super(Generator, self).__init__()
        self.t_len=t_len
        self.hidden_order=30
        self.z_order=Z_ORDER
        self.sigma=sigma
        self.t=None
        with self.init_scope():
            w = chainer.initializers.Normal(wscale)
            if(MODE == 0):
                self.l0 = L.Linear(None,self.hidden_order ,initialW=w)
                self.l1 = L.Linear(None,self.hidden_order ,initialW=w)
                self.l2 = L.Linear(None,self.hidden_order ,initialW=w)
                self.l3 = L.Linear(None,self.hidden_order ,initialW=w)
                self.l4 = L.Linear(None,self.hidden_order ,initialW=w)
                self.ll = L.Linear(None,1,initialW=w )

            if(MODE == 1):    
                self.l0 = L.Linear(None,self.hidden_order ,initialW=w)
                self.l1 = L.Linear(None,self.hidden_order ,initialW=w)
                self.l2 = L.Linear(None,self.hidden_order ,initialW=w)
                self.l3 = L.Linear(None,self.hidden_order ,initialW=w)
                self.ll = L.Linear(None,1,initialW=w )
                tw = chainer.initializers.Normal(0.35)
                self.l10 = L.Linear(None,self.hidden_order ,initialW=tw)
                self.l11 = L.Linear(None,self.hidden_order ,initialW=tw)
                self.l12 = L.Linear(None,self.hidden_order ,initialW=tw)
                self.l1l = L.Linear(None,1,initialW=tw )
 
    def make_hidden(self, batchsize):
        dtype = chainer.get_dtype()
        return numpy.random.randn(batchsize, self.z_order).astype(dtype)

    def forward_time(self,t,z):
        zz=F.reshape(F.transpose(F.broadcast_to(z,(t.data.shape[0],z.data.shape[0],z.shape[1]) ) ,  (1,0,2)   ),(-1,z.shape[1]))
        tt=F.reshape(F.broadcast_to(t,(z.data.shape[0],t.data.shape[0]) ),(-1,1))
        x=h=F.concat([tt,zz])
        if(MODE==0):
            h= self.l0(h)
            h= F.elu(h)
            h=F.concat([h,x])
            h= self.l1(h)
            h= F.leaky_relu(h)
            h=F.concat([h,x])
            h= self.l2(h)
            h= F.leaky_relu(h)
            h=F.concat([h,x])
            h= self.l3(h)
            h= F.leaky_relu(h)
            h=F.concat([h,x])
            h= self.l4(h)
            h= F.elu(h)
            h= self.ll(h) +tt
        if(MODE==1):
            h= self.l0(h)
            h= F.elu(h)
            h=F.concat([h,x])
            h= self.l1(h)
            h= F.elu(h)
            h=F.concat([h,x])
            h= self.l2(h)
            h= F.elu(h)
            h=F.concat([h,x])
            h= self.l3(h)
            h= F.elu(h)
            h= self.ll(h) +tt
            
        return h
    def forward_inten(self,y,t,z):
        if(MODE==0):
            return y
        if(MODE==1):
            zz=F.reshape(F.transpose(F.broadcast_to(z,(t.data.shape[0],z.data.shape[0],z.shape[1]) ) ,  (1,0,2)   ),(-1,z.shape[1]))
            h=F.concat([y,zz])
    
            h= self.l10(h)
            h= F.elu(h)
            h= self.l11(h)
            h= F.leaky_relu(h)
            h=F.concat([y,h])
            h= self.l12(h)
            h= F.leaky_relu(h)
            
            h= self.ll(h)
            return h
        
    def forward(self, z):
        if(self.t is None):
            self.t=chainer.Variable(data=self.device.xp.asarray(
                    numpy.linspace(-self.sigma,self.sigma,self.t_len,dtype=chainer.get_dtype())
                    ))
        h=F.exp( -1*(self.forward_time(self.t,z))**2)
        h=self.forward_inten(h,self.t,z)
        return F.reshape(h,(z.data.shape[0],1,self.t.data.shape[0]))

    def MSE(self, z, real):
        h=self.forward(z)
        return F.mean_squared_error(h, real)


class Discriminator(chainer.Chain):

    def __init__(self,  t_len=301,inten_sigma=0.01):
        self.hidden_order=5
        self.sigma=inten_sigma
        super(Discriminator, self).__init__()
        with self.init_scope():
            self.c0 = L.Convolution1D(None,5,ksize=16,stride=2)
            self.c1 = L.Convolution1D(None,5,ksize=16,stride=2)
            self.l0 = L.Linear(None,self.hidden_order)
            self.l1 = L.Linear(None,self.hidden_order)
            self.l2 = L.Linear(None,self.hidden_order)
            self.ll = L.Linear(None,2 )

    def forward(self, x):
        device = self.device
        x = add_noise(device, x,sigma=self.sigma) 
        h = x
        if(False):
            h = F.leaky_relu(add_noise(device, self.c0(h)))
            h = F.leaky_relu(add_noise(device, self.c1(h)))
            F.concat([F.reshape(h, (h.data.shape[0],-1)),F.reshape(x, (x.data.shape[0],-1))])
            h = F.leaky_relu(add_noise(device, self.l0(h)))
            h = F.leaky_relu(add_noise(device, self.l1(h)))
            h = F.leaky_relu(add_noise(device, self.l2(h)))
        else:
            h = F.leaky_relu( self.c0(h))
            h = F.leaky_relu( self.c1(h))
            F.concat([F.reshape(h, (h.data.shape[0],-1)),F.reshape(x, (x.data.shape[0],-1))])
            h = F.leaky_relu( self.l0(h))
            h = F.leaky_relu( self.l1(h))
            h = F.leaky_relu( self.l2(h))
        return self.ll(h)
