# from comet_ml import Experiment
# experiment = Experiment(api_key="hSd9vTj0EfMu72569YnVEvtvj")

from loader import *
from _layers import AttentionWithContext, Attention, AttentionDecoder
from keras.models import Model, load_model as keras_load_model
from keras.layers import Input, LSTM, Dense, Embedding, GRU, Activation, dot, concatenate, Bidirectional, TimeDistributed
from keras.utils import multi_gpu_model
from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
import tensorflow as tf
import tqdm

import pickle
from sklearn.metrics import classification_report

#REPRODUCIBLE
np.random.seed(42)
import random
random.seed(12345)
import os
os.environ['PYTHONHASHSEED'] = '0'

import tensorflow as tf
# config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config = tf.ConfigProto()
from keras import backend as K
tf.set_random_seed(1234)
#REPRODUCIBLE


###################################
# TensorFlow wizardry
# Don't pre-allocate memory; allocate as-needed
config.gpu_options.allow_growth = True
config.gpu_options.allocator_type = 'BFC'
sess = tf.Session(graph=tf.get_default_graph(), config=config)
K.set_session(sess)

# LOAD ICD 10 CLASSIFICATION MODEL
try:
    icd10_model = keras_load_model('models/icd10Classification_attention_extended.h5',
              custom_objects={'Attention':Attention})
except OSError:
    from classificationICD10 import *
    icd10_model = keras_load_model('models/icd10Classification_attention_extended.h5')

with open('models/icd10_tokenizer_extended.p', 'rb') as handle:
    icd10Tokenizer = pickle.load(handle)

with open('models/icd10_mappings_extended.p', 'rb') as handle:
    encoded_Y = pickle.load(handle)
# LOAD ICD 10 CLASSIFICATION MODEL

callbacks_list = [
    EarlyStopping(
        monitor='val_loss',
        patience=2,
        # min_delta=0.001
    ),
    ModelCheckpoint(
        filepath='models/s2s_att_extended.h5',
        monitor='val_loss',
        save_best_only=True,
    ),
    CSVLogger(
        append=False,
        filename='logs/s2s_att_extended_{}.csv'.format(date_label)
    )
]

latent_dim = 256
batch_size = 400
epochs = 1

train_data_generator = KerasBatchGenerator(batch_size,
                                           source_train,
                                           source_max_sequence_tokenizer,
                                           source_kerasTokenizer,
                                           target_train,
                                           target_max_sequence_tokenizer,
                                           target_kerasTokenizer
                                           )

validation_data_generator = KerasBatchGenerator(batch_size,
                                           source_val,
                                           source_max_sequence_tokenizer,
                                           source_kerasTokenizer,
                                           target_val,
                                           target_max_sequence_tokenizer,
                                           target_kerasTokenizer
                                           )

print("Lets train some stuff!")
# Define an input sequence and process it.
encoder_input = Input(shape=(source_max_sequence_tokenizer, ), name='encoder_input')
x_encoder = source_embedding_layer(encoder_input)
encoder_out, state_h, state_c = LSTM(latent_dim, return_sequences=True, unroll=True, return_state=True, name='encoder_lstm')(x_encoder)
encoder_states = [state_h, state_c]

# Set up the decoder, using `encoder_states` as initial state.
decoder_input = Input(shape=(target_max_sequence_tokenizer, ), name='decoder_input')
x_decode = target_embedding_layer(decoder_input)
decoder_LSTM = LSTM(latent_dim, return_sequences=True, return_state = True, unroll=True, name='decoder_lstm')
decoder, state_h_decode , state_c_decode = decoder_LSTM(x_decode, initial_state=encoder_states)

# Equation (7) with 'dot' score from Section 3.1 in the paper.
# Note that we reuse Softmax-activation layer instead of writing tensor calculation
attention = dot([encoder_out, decoder], name='attention_dot' ,axes=[2, 2])
attention = Activation('softmax', name='attention_activation')(attention)
context = dot([attention, encoder_out], name='context_dot' ,axes=[1,1])
decoder_combined_context = concatenate([context, decoder])
print(decoder_combined_context)

decoder_dense = Dense(len(target_vocab)+1, activation='softmax', name='dense_output')
decoder_out = decoder_dense(decoder_combined_context) # equation (6) of the paper

# MODEL
model = Model([encoder_input, decoder_input], decoder_out)
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
# model.fit([source_train, target_train],
#             target_train_onehot,
#             batch_size=batch_size,
#             callbacks=callbacks_list,
#             epochs=epochs,
#             validation_split=0.1
#             # validation_data=([source_val, target_val], target_val_onehot)
#       )

model.fit_generator(
    generator=train_data_generator.generate_data(),
    steps_per_epoch=int(len(source_train)/batch_size)+1,
    epochs=epochs,
    callbacks=callbacks_list,
    validation_data=validation_data_generator.generate_data(),
    validation_steps=int(len(source_val)/batch_size)+1,
    # use_multiprocessing=True,
    # workers=10
)

# INFERENCE MODELS
# Encoder inference model
encoder_model_inf = Model(encoder_input, encoder_states)
encoder_model_inf.summary()

# Decoder inference model
decoder_state_input_h = Input(shape=(256,), name='inference_decoder_input_h')
decoder_state_input_c = Input(shape=(256,), name='inference_decoder_input_c')
decoder_input_states = [decoder_state_input_h, decoder_state_input_c]

decoder, decoder_h, decoder_c = decoder_LSTM(x_decode, initial_state=decoder_input_states)
decoder_states = [decoder_h , decoder_c]

attention = dot([encoder_out, decoder], axes=[2, 2])
attention = Activation('softmax')(attention)
context = dot([attention, encoder_out], axes=[1,1])

# print(context, decoder)
decoder_combined_context = concatenate([context, decoder])
# print('decoder_combined_context\t', decoder_combined_context)

decoder_out = decoder_dense(decoder_combined_context)
decoder_model_inf = Model(inputs=[decoder_input] + decoder_input_states,
                          outputs=[decoder_out] + decoder_states )
decoder_model_inf.summary()

def decode_seq(inp_seq):

    states_val = encoder_model_inf.predict(inp_seq)
    print('states_val\t', states_val)
    input('inference encoder prediction\t')

    target_seq = np.zeros((1, target_max_sequence_tokenizer))
    target_seq[0, 0] = target_vocab['sos']

    translated_sent = []
    translated_index = []
    stop_condition = False

    while not stop_condition:

        decoder_out, decoder_h, decoder_c = decoder_model_inf.predict(x=[target_seq] + states_val)
        max_val_index = np.argmax(decoder_out[0, -1, :])

        try:
            sampled_fra_char = target_index_to_word_dict[max_val_index]
        except KeyError:
            # stop_condition = True
            sampled_fra_char = 'eos'

        translated_sent.append(sampled_fra_char)
        translated_index.append(max_val_index)

        if ((sampled_fra_char == 'eos') or (len(translated_sent) > target_max_sequence_tokenizer)):
            stop_condition = True

        target_seq = np.zeros((1, target_max_sequence_tokenizer))
        target_seq[0, 0] = max_val_index
        states_val = [decoder_h, decoder_c]

    return translated_sent[:-1], translated_index[:-1]

y_true = []
y_pred = []


source_val = source_kerasTokenizer.texts_to_sequences(source_val)
source_val = pad_sequences(source_val, maxlen=source_max_sequence_tokenizer, padding='post')

for seq_index in tqdm.tqdm(range(len(source_val))):
# for seq_index in range(10):
    inp_seq = source_val[seq_index:seq_index+1]
    translated_sent, translated_index= decode_seq(inp_seq)

    # PREDICT ICD10
    source_word_sequence = icd10Tokenizer.texts_to_sequences([" ".join(translated_sent)])
    word_sequence = pad_sequences(source_word_sequence, maxlen=icd10_model.layers[0].input_shape[1], padding='post')
    icd10_code_index = icd10_model.predict(word_sequence)
    # print(icd10_code_index, type(icd10_code_index))
    max_val_index = np.argmax(icd10_code_index, axis=1)[0]
    # print(max_val_index)
    icd10_label = encoded_Y.inverse_transform(max_val_index)

    # print('-')
    # target_index = np.trim_zeros(target_val[seq_index], 'b')[1:-1]
    # print('Target indexes:', target_index)
    # print('Decoded indexes:', translated_index)
    #
    # print('Target sentence:', " ".join([target_index_to_word_dict[x] for x in target_index]))
    # print('Decoded sentence:', " ".join(translated_sent))
    #
    # print('Target ICD-10:', labels_val[seq_index])
    # print('Predict ICD-10:', icd10_label)

    y_true.append(labels_val[seq_index])
    y_pred.append(icd10_label)

report = classification_report(y_true, y_pred)
report_df = report_to_df(report)
report_df.to_csv('logs/classification_report_extended.csv')
print(report_df)