﻿#!/usr/bin/env python

import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np
import matplotlib.pyplot as plt

CLASSES = 5
BACKGROUND_WEIGHT = 1
SINGLE_PEAK_WEIGHT = 4
VERTICAL_PEAK_WEIGHT = 8
START_DETECTION_WEIGHT = 30
END_DETECTION_WEIGHT = 30
INPUT_CHANNEL = 4
CHANNEL1 = 8
CHANNEL2 = 16
CHANNEL3 = 32
CHANNEL4 = 64
CHANNEL5 = 128
CHANNEL6 = 256
CHANNEL7 = 512

class UNET(chainer.Chain):

    def __init__(self):
        super(UNET, self).__init__(
            c0=L.ConvolutionND(1, INPUT_CHANNEL, CHANNEL1, 3, 1, 1),
            
            c1=L.ConvolutionND(1, CHANNEL1, CHANNEL2, 3, 1, 1),            
            c2=L.ConvolutionND(1, CHANNEL2, CHANNEL2, 3, 1, 1),            

            c3=L.ConvolutionND(1, CHANNEL2, CHANNEL3, 3, 1, 1),            
            c4=L.ConvolutionND(1, CHANNEL3, CHANNEL3, 3, 1, 1),            

            c5=L.ConvolutionND(1, CHANNEL3, CHANNEL4, 3, 1, 1),
            c6=L.ConvolutionND(1, CHANNEL4, CHANNEL4, 3, 1, 1),            

            c7=L.ConvolutionND(1, CHANNEL4, CHANNEL5, 3, 1, 1),
            c8=L.ConvolutionND(1, CHANNEL5, CHANNEL5, 3, 1, 1),            

            c9=L.ConvolutionND(1, CHANNEL5, CHANNEL6, 3, 1, 1),
            c10=L.ConvolutionND(1, CHANNEL6, CHANNEL6, 3, 1, 1),            

            c11=L.ConvolutionND(1, CHANNEL6, CHANNEL7, 3, 1, 1),
            c12=L.ConvolutionND(1, CHANNEL7, CHANNEL7, 3, 1, 1),            
            
            dc12=L.DeconvolutionND(1, CHANNEL7, CHANNEL6, 3, 1, 1),            
            dc11=L.ConvolutionND(1, CHANNEL7, CHANNEL6, 3, 1, 1),
            
            dc10=L.DeconvolutionND(1, CHANNEL6, CHANNEL5, 3, 1, 1),            
            dc9=L.ConvolutionND(1, CHANNEL6, CHANNEL5, 3, 1, 1),

            dc8=L.DeconvolutionND(1, CHANNEL5, CHANNEL4, 3, 1, 1),            
            dc7=L.ConvolutionND(1, CHANNEL5, CHANNEL4, 3, 1, 1),

            dc6=L.DeconvolutionND(1, CHANNEL4, CHANNEL3, 3, 1, 1),            
            dc5=L.ConvolutionND(1, CHANNEL4, CHANNEL3, 3, 1, 1),            

            dc4=L.DeconvolutionND(1, CHANNEL3, CHANNEL2, 3, 1, 1),            
            dc3=L.ConvolutionND(1, CHANNEL3, CHANNEL2, 3, 1, 1),   

            dc2=L.DeconvolutionND(1, CHANNEL2, CHANNEL1, 3, 1, 1),            
            dc1=L.ConvolutionND(1, CHANNEL2, CHANNEL1, 3, 1, 1),            
            
            dc0=L.ConvolutionND(1, CHANNEL1, CLASSES, 1, 1, 0),

            bnc0=L.BatchNormalization(CHANNEL1),
            bnc1=L.BatchNormalization(CHANNEL2),
            bnc2=L.BatchNormalization(CHANNEL2),
            bnc3=L.BatchNormalization(CHANNEL3),
            bnc4=L.BatchNormalization(CHANNEL3),
            bnc5=L.BatchNormalization(CHANNEL4),
            bnc6=L.BatchNormalization(CHANNEL4),
            bnc7=L.BatchNormalization(CHANNEL5),
            bnc8=L.BatchNormalization(CHANNEL5),
            bnc9=L.BatchNormalization(CHANNEL6),
            bnc10=L.BatchNormalization(CHANNEL6),
            bnc11=L.BatchNormalization(CHANNEL7),
            bnc12=L.BatchNormalization(CHANNEL7),

            bnd12=L.BatchNormalization(CHANNEL6),
            bnd11=L.BatchNormalization(CHANNEL6),               
            bnd10=L.BatchNormalization(CHANNEL5),
            bnd9=L.BatchNormalization(CHANNEL5),            
            bnd8=L.BatchNormalization(CHANNEL4),
            bnd7=L.BatchNormalization(CHANNEL4),            
            bnd6=L.BatchNormalization(CHANNEL3),
            bnd5=L.BatchNormalization(CHANNEL3),            
            bnd4=L.BatchNormalization(CHANNEL2),
            bnd3=L.BatchNormalization(CHANNEL2),
            bnd2=L.BatchNormalization(CHANNEL1),
            bnd1=L.BatchNormalization(CHANNEL1)
            
        )

    def calc(self, x):
        e0 = F.relu(self.bnc0(self.c0(x)))
        e1 = F.relu(self.bnc1(self.c1(F.max_pooling_nd(e0, 2))))
        e2 = F.relu(self.bnc2(self.c2(e1)))
        del e1

        e3 = F.relu(self.bnc3(self.c3(F.max_pooling_nd(e2, 2))))
        e4 = F.relu(self.bnc4(self.c4(e3)))
        del e3
        
        e5 = F.relu(self.bnc5(self.c5(F.max_pooling_nd(e4, 2))))
        e6 = F.relu(self.bnc6(self.c6(e5)))
        del e5
        
        e7 = F.relu(self.bnc7(self.c7(F.max_pooling_nd(e6, 2))))
        e8 = F.relu(self.bnc8(self.c8(e7)))
        del e7

        e9 = F.relu(self.bnc9(self.c9(F.max_pooling_nd(e8, 2))))
        e10 = F.relu(self.bnc10(self.c10(e9)))
        del e9
        
        e11 = F.relu(self.bnc11(self.c11(F.max_pooling_nd(e10, 2))))
        e12 = F.relu(self.bnc12(self.c12(e11)))
        del e11

        d12 = F.relu(self.bnd12(self.dc12(F.unpooling_nd(e12, 2, cover_all=False))))
        del e12
        d11 = F.relu(self.bnd11(self.dc11(F.concat([e10, d12]))))
        del d12, e10
        
        d10 = F.relu(self.bnd10(self.dc10(F.unpooling_nd(d11, 2, cover_all=False))))
        del d11
        d9 = F.relu(self.bnd9(self.dc9(F.concat([e8, d10]))))
        del d10, e8
        
        d8 = F.relu(self.bnd8(self.dc8(F.unpooling_nd(d9, 2, cover_all=False))))
        del d9
        d7 = F.relu(self.bnd7(self.dc7(F.concat([e6, d8]))))
        del d8, e6
        
        d6 = F.relu(self.bnd6(self.dc6(F.unpooling_nd(d7, 2, cover_all=False))))
        del d7
        d5 = F.relu(self.bnd5(self.dc5(F.concat([e4, d6]))))
        del d6, e4  
        
        d4 = F.relu(self.bnd4(self.dc4(F.unpooling_nd(d5, 2, cover_all=False))))
        del d5
        d3 = F.relu(self.bnd3(self.dc3(F.concat([e2, d4]))))
        del d4, e2  
        
        d2 = F.relu(self.bnd2(self.dc2(F.unpooling_nd(d3, 2, cover_all=False))))
        del d3
        d1 = F.relu(self.bnd1(self.dc1(F.concat([e0, d2]))))
        del d2, e0  
        d0 = self.dc0(d1)
        
        return d0

    def __call__(self, x, t):
        h = self.calc(x)
        weight = chainer.cuda.to_gpu(np.array([BACKGROUND_WEIGHT, 
                                               SINGLE_PEAK_WEIGHT, 
                                               VERTICAL_PEAK_WEIGHT, 
                                               START_DETECTION_WEIGHT,
                                               END_DETECTION_WEIGHT]).astype(np.float32))
        loss = F.softmax_cross_entropy(h, t, class_weight = weight)
        accuracy = F.accuracy(h,t)
        chainer.report({
                'loss': loss,
                'accuracy': accuracy
                }, self)
        return loss
    
    def predict(self, x):
        with chainer.using_config('train', False):
            h = self.calc(x)
            return h

