# -*- coding: utf-8 -*-
"""
Created on Tue Sep 11 19:29:15 2018

@author: Kanazawa
"""
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
from chainer import serializers
import argparse
import numpy as np
import os
import matplotlib.pyplot as plt
import shutil
import pandas as pd
import winsound
import copy

import Net.unet
import Common.FileReader as rd
import Common.Utilities as ut
import Common.FileCreater as ct

NPZ = "/LearnedModels/snapshot"
TYPE = "fake"
# TYPE = "real"
if("fake" == TYPE):
    MAX_HEIGHT = 0
else:
    MAX_HEIGHT = 1000

INPUT_SIZE = 1024
INPUT_CHANNEL = 4
WIDTH = 5.0

BASELINE = 0
SINGLE_PEAK = 1
VERTICAL_PEAK = 2
START_DETECTION = 3
END_DETECTION = 4 

MIN_DELTA_T = 5
MIN_DELTA_I = 1

def get_args():
    parser = argparse.ArgumentParser(description='LCMS Test')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    return parser.parse_args()

def GetStartAndEnd(modifiedLabels):
    startIndexes = []
    startIndexes = np.array(startIndexes)
    startIndexes = np.append(startIndexes, np.where(START_DETECTION == np.array(modifiedLabels))[0], -1)
    startIndexes = np.sort(startIndexes)
    startAvgs = []
    endAvgs = []
    if(0 == len(startIndexes)): 
        return startAvgs, endAvgs
    
    startDiff = np.diff(np.append(startIndexes, startIndexes[-1]), n = 1)
    startBoundaries = startIndexes[startDiff != 1]
    startBoundaries = np.append(startBoundaries, -1)
    startBoundaries = np.sort(startBoundaries)

    for i in range(len(startBoundaries) - 1):
        target = startIndexes[np.logical_and(startBoundaries[i] < startIndexes, 
                                             startIndexes <= startBoundaries[i + 1])]
        startAvgs.append([int(round(np.average(target))), "s"])

    endIndexes = []
    endIndexes = np.array(endIndexes)
    endIndexes = np.append(endIndexes, np.where(END_DETECTION == np.array(modifiedLabels))[0], -1)
    endIndexes = np.sort(endIndexes)        
    endIndexes = np.where(END_DETECTION == np.array(modifiedLabels))[0]
    if(0 == len(endIndexes)): 
        return startAvgs, endAvgs
    endDiff = np.diff(np.append(endIndexes, endIndexes[-1]), n = 1)
    endBoundaries = endIndexes[endDiff != 1]
    endBoundaries = np.append(endBoundaries, -1)
    endBoundaries = np.sort(endBoundaries)
    for i in range(len(endBoundaries) - 1):
        target = endIndexes[np.logical_and(endBoundaries[i] < endIndexes, 
                                           endIndexes <= endBoundaries[i + 1])]
        endAvgs.append([int(round(np.average(target))), "e"])
    return startAvgs, endAvgs

def GetInvaledIndex(modifiedLabels, mixes):
    delete_ix =[]
    for i in range(len(mixes) - 1):
        target = np.array(modifiedLabels[mixes[i][0]:mixes[i + 1][0] + 1])
        baseline = len(np.where(BASELINE == target)[0])
        singlePeak = len(np.where(SINGLE_PEAK == target)[0])
        verticalPeak = len(np.where(VERTICAL_PEAK == target)[0])
        if("s" == mixes[i][1] and "s" == mixes[i + 1][1]):
            if( 0 == singlePeak + verticalPeak + baseline or
                baseline / (singlePeak + verticalPeak + baseline) >= 0.5):
                delete_ix.append(i)
            else:
                delete_ix.append(i + 1)
        elif("e" == mixes[i][1] and "e" == mixes[i + 1][1]):
            if( 0 == singlePeak + verticalPeak + baseline or
                baseline / (singlePeak + verticalPeak + baseline) >= 0.5):
                delete_ix.append(i + 1)
            else:
                delete_ix.append(i)
    return delete_ix

def Labels2PeakTablesInnerFunction(labels):
    peaks = pd.DataFrame(index=[], columns = ["kinds", "clusterIndex", "start", "end"])

    startAvgs, endAvgs = GetStartAndEnd(labels)
    mixes = sorted(startAvgs + endAvgs)
    modifiedLabels = copy.copy(labels)
    delete_ix = GetInvaledIndex(modifiedLabels, mixes)
    mixes = [mix for i, mix in enumerate(mixes) if i not in delete_ix]
    pairs = []                

    for i in range(len(mixes) - 1):
        if("s" == mixes[i][1] and "e" == mixes[i + 1][1]):
            pairs.append([mixes[i][0], mixes[i + 1][0]])

    for pair in pairs:
        start = pair[0]
        end = pair[1]
        target = np.array(labels[start:end])
        baseline = len(np.where(BASELINE == target)[0])
        singlePeak = len(np.where(SINGLE_PEAK == target)[0])
        verticalPeak = len(np.where(VERTICAL_PEAK == target)[0])
        if(0 != singlePeak + verticalPeak + baseline and 
            singlePeak + verticalPeak / (singlePeak + verticalPeak + baseline) >= 0.5):
            if singlePeak > verticalPeak:
                for i in range(start, end):
                    if BASELINE == modifiedLabels[i]:
                        modifiedLabels[i] = SINGLE_PEAK 
            else:
                for i in range(start, end):
                    if BASELINE == modifiedLabels[i]:
                        modifiedLabels[i] = VERTICAL_PEAK

    for i in range(len(modifiedLabels) - 1):
        if(BASELINE == modifiedLabels[i] and (SINGLE_PEAK == modifiedLabels[i + 1] or VERTICAL_PEAK == modifiedLabels[i + 1])):
            modifiedLabels[i + 1] =  START_DETECTION
        if((SINGLE_PEAK == modifiedLabels[i] or VERTICAL_PEAK == modifiedLabels[i]) and BASELINE == modifiedLabels[i + 1]):
            modifiedLabels[i] =  END_DETECTION

    startAvgs, endAvgs = GetStartAndEnd(modifiedLabels)
    if 0 == len(startAvgs) or 0 == len(endAvgs):
        return peaks
    
    mixes = sorted(startAvgs + endAvgs)

    delete_ix = GetInvaledIndex(modifiedLabels, mixes)
    mixes = [mix for i, mix in enumerate(mixes) if i not in delete_ix]

    pairs = []                
    for i in range(len(mixes) - 1):
        if("s" == mixes[i][1] and "e" == mixes[i + 1][1]):
            pairs.append([mixes[i][0], mixes[i + 1][0]])
            
    clusterFlag = False
    clusterIndex = -1
    preEndIndex = -1
    for pair in pairs:
        startIndex = pair[0]
        endIndex = pair[1]
        target = np.array(modifiedLabels[startIndex:endIndex + 1])
        baseline = len(np.where(BASELINE == target)[0])
        singlePeak = len(np.where(SINGLE_PEAK == target)[0])
        verticalPeak = len(np.where(VERTICAL_PEAK == target)[0])
        if(-1 != preEndIndex):
            valley = np.array(modifiedLabels[preEndIndex:startIndex + 1])
            if 0 != len(np.where(BASELINE == valley)[0]):
                clusterFlag = False
        
        if( 0 == singlePeak + verticalPeak + baseline or
            baseline / (singlePeak + verticalPeak + baseline) >= 0.5):
            clusterFlag = False
            clusterIndex = -1
            preEndIndex = -1
            continue
        elif(singlePeak >= verticalPeak):
            clusterFlag = False
            peaks = peaks.append(pd.Series(["single", "single", startIndex, endIndex],
                                           index = peaks.columns), ignore_index=True)
        elif(singlePeak < verticalPeak):
            assert(START_DETECTION == modifiedLabels[startIndex])
            while(0 != startIndex and
                  START_DETECTION == modifiedLabels[startIndex - 1]):
                startIndex -= 1
            assert(END_DETECTION == modifiedLabels[endIndex])
            while(len(modifiedLabels) - 1 != endIndex and
                  END_DETECTION == modifiedLabels[endIndex + 1]):
                endIndex += 1
            if(len(modifiedLabels) - 1 != endIndex):
                endIndex += 1     
                           
            if(True == clusterFlag):
                peaks = peaks.append(pd.Series(["vertical", clusterIndex, startIndex, endIndex],
                               index = peaks.columns), ignore_index=True)
            else:
                clusterIndex += 1
                peaks = peaks.append(pd.Series(["vertical", clusterIndex, startIndex, endIndex],
                               index = peaks.columns), ignore_index=True)
                clusterFlag = True
        else:
            assert(False)
        preEndIndex = endIndex
        
    peaks = peaks.sort_values("start")
    verticalPeaks = peaks["vertical" == peaks["kinds"]]
    for i in range(verticalPeaks["clusterIndex"].nunique()):
        clusterPeaks = verticalPeaks[i == verticalPeaks["clusterIndex"]]
        for j in range(len(clusterPeaks) - 1):
            leftIndex = clusterPeaks["end"].iloc[j]
            rightIndex = clusterPeaks["start"].iloc[j + 1]
            if(rightIndex == leftIndex):
                continue
            assert(leftIndex <= rightIndex)
            target = modifiedLabels[leftIndex:(rightIndex + 1)]
            singlePeak = len(np.where(SINGLE_PEAK == target)[0])
            verticalPeak = len(np.where(VERTICAL_PEAK == target)[0])
            assert(singlePeak + verticalPeak <= rightIndex - leftIndex)
            if((singlePeak + verticalPeak) / (rightIndex - leftIndex) >= 0.5):
                peaks = peaks.append(pd.Series(["vertical", clusterPeaks["clusterIndex"].iloc[j], leftIndex, rightIndex],
                               index = peaks.columns), ignore_index=True)        
    return peaks.sort_values("start")
    
def Labels2PeakTables(upSamplingTimesInfo,
                      chromato,
                      labels):      
    peaks = Labels2PeakTablesInnerFunction(labels)
    
    peakTablesDf = pd.DataFrame(index=[], columns = ["rTime", "startTime", "startInt", "endTime", "endInt", "compounName"])        
    singlePeaks = peaks["single" == peaks["kinds"]]
    for singlePeak in singlePeaks.itertuples():
        baselineStart = singlePeak.start
        baselineEnd = singlePeak.end
        baseline = [round(i) for i in np.linspace(chromato[baselineStart],
                                             chromato[baselineEnd], 
                                             num = baselineEnd - baselineStart + 1)]
        modifiedPeak = chromato[baselineStart:(baselineEnd + 1)] - baseline
        if(max(modifiedPeak) < MAX_HEIGHT):
            continue
        rTime = round(upSamplingTimesInfo[baselineStart + np.argmax(modifiedPeak)], MIN_DELTA_T)
        sTime = round(upSamplingTimesInfo[baselineStart], MIN_DELTA_T)
        sInt = round(chromato[baselineStart], MIN_DELTA_I)
        eTime = round(upSamplingTimesInfo[baselineEnd], MIN_DELTA_T)
        eInt = round(chromato[baselineEnd], MIN_DELTA_I)
        assert(sTime <= rTime)
        assert(rTime <= eTime)
        tmp = pd.Series([str(rTime), 
                           str(sTime), 
                           str(sInt),
                           str(eTime), 
                           str(eInt),
                           str("N/A")], index = peakTablesDf.columns)
        peakTablesDf = peakTablesDf.append(tmp, ignore_index=True)  

    verticalPeaks = peaks["vertical" == peaks["kinds"]]
    for i in range(verticalPeaks["clusterIndex"].nunique()):
        clusterPeaks = verticalPeaks[i == verticalPeaks["clusterIndex"]]
        baselineStart = clusterPeaks["start"].min()
        baselineEnd = clusterPeaks["end"].max()
        baseline = [round(i) for i in np.linspace(chromato[baselineStart],
                                             chromato[baselineEnd], 
                                             num = baselineEnd - baselineStart + 1)]
        modifiedPeak = chromato[baselineStart:(baselineEnd + 1)] - baseline
        for clusterPeak in clusterPeaks.itertuples():
            peakStart = clusterPeak.start
            peakEnd = clusterPeak.end
            modifiedPeakStart = peakStart - baselineStart
            modifiedPeakEnd = peakEnd - baselineStart
            if(max(modifiedPeak[modifiedPeakStart:(modifiedPeakEnd + 1)]) < MAX_HEIGHT):
                continue
            rTime = round(upSamplingTimesInfo[peakStart + np.argmax(modifiedPeak[modifiedPeakStart:(modifiedPeakEnd + 1)])], MIN_DELTA_T)
            sTime = round(upSamplingTimesInfo[peakStart], MIN_DELTA_T)
            sInt = round(baseline[modifiedPeakStart], MIN_DELTA_I)
            eTime = round(upSamplingTimesInfo[peakEnd], MIN_DELTA_T)
            eInt = round(baseline[modifiedPeakEnd], MIN_DELTA_I)
            assert(sTime <= rTime)
            assert(rTime <= eTime)
            tmp = pd.Series([str(rTime), 
                               str(sTime), 
                               str(sInt),
                               str(eTime), 
                               str(eInt),
                               str("N/A")], index = peakTablesDf.columns)
            peakTablesDf = peakTablesDf.append(tmp, ignore_index=True) 
    
    return peakTablesDf.sort_values("rTime").values.tolist()
    
def SavePredictedTxtDataInnerFunction(chromatoPath, 
                                      resultPath,
                                      cnn,
                                      outputF):
    with open(chromatoPath) as f:
        lines = f.readlines()
    compound_lines = [line for line in lines if 'compound' in line]

    for compoundID in compound_lines:            
        compoundID = compoundID.replace(" (input)\n", "")
        inputTime, chromato = np.array(rd.ReadChromatos(chromatoPath, compoundID))
        chromatoDict = {"time":inputTime, "intensity":chromato}
        
        chromatos = []   
        upSamplingTimesInfo, orgUpsamplingInts = ut.UpSamplingAndNormalizationForUnet(WIDTH, 
                                                                 INPUT_SIZE, 
                                                                 chromatoDict,  
                                                                 chromatos)
        
        predict = cnn.predict(chromatos[0].reshape(1, INPUT_CHANNEL, INPUT_SIZE).astype(np.float32))
        label = predict2label(predict.data[0])
        
        peakTables = Labels2PeakTables(upSamplingTimesInfo[0], orgUpsamplingInts[0], label)
        
        onlyID = compoundID.replace("compound# ", "")
        onlyID = onlyID.strip(" ")
        assert(str.isdecimal(onlyID))
        ct.CreatePeakTableTextFile(outputF, 
                                   peakTables, 
                                   onlyID,
                                  "compound# ")
                
def SavePredictedTxtData(filePath, resultPath, cnn):
    FolderNames = os.listdir(filePath)
    for folderName in FolderNames:    
        folderPath = filePath + '/' + folderName
        fileNames = os.listdir(folderPath)
        print('Please enter output file name.')
        filename = input()
        outputPath = resultPath + "/" + filename
        print(outputPath)
        if not os.path.isdir(outputPath):
            os.makedirs(outputPath)

        for fileName in fileNames:
            chromatoPath = folderPath + '/' + fileName
            
            with open(outputPath + "/" + filename + "_output.txt", "w") as outputF:
                shutil.copyfile(chromatoPath, outputPath + "/" + filename + "_input.txt")
                SavePredictedTxtDataInnerFunction(chromatoPath,
                                                    outputPath,
                                                    cnn,
                                                    outputF)

def predict2label(predict):
    channels = np.array(predict)
    return np.argmax(channels, axis = 0)

def main():
    args = get_args()

    print('GPU: {}'.format(args.gpu))
    print('')

    cnn = Net.unet.UNET()
    serializers.load_npz(os.path.dirname(os.path.abspath(__file__)) + NPZ, cnn, path = "updater/model:main/")

    filePath = os.path.dirname(os.path.abspath(__file__)) + r'\Test\u-net'
    resultPath = os.path.dirname(os.path.abspath(__file__)) + r'\Results'

    SavePredictedTxtData(filePath, resultPath, cnn)

if __name__ == '__main__':
   main()