# from comet_ml import Experiment
# experiment = Experiment(api_key="hSd9vTj0EfMu72569YnVEvtvj")

from loader import *
from _layers import AttentionWithContext, Attention
from keras.models import Model, load_model as keras_load_model
from keras.layers import Input, LSTM, Dense, Embedding, GRU
from keras.utils import multi_gpu_model
from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
import tensorflow as tf
import tqdm

import pickle
from sklearn.metrics import classification_report
from seq2seq_base import run_pipeline_prediction

#REPRODUCIBLE
np.random.seed(42)
import random
random.seed(12345)
import os
os.environ['PYTHONHASHSEED'] = '0'

import tensorflow as tf
# config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config = tf.ConfigProto()
from keras import backend as K
tf.set_random_seed(1234)
#REPRODUCIBLE


###################################
# TensorFlow wizardry
# config = tf.ConfigProto()

# Don't pre-allocate memory; allocate as-needed
config.gpu_options.allow_growth = True
config.gpu_options.allocator_type = 'BFC'

sess = tf.Session(graph=tf.get_default_graph(), config=config)
K.set_session(sess)

# LOAD ICD 10 CLASSIFICATION MODEL
try:
    icd10_model = keras_load_model('models/icd10Classification_attention_extended.h5',
              custom_objects={'Attention':Attention})
except OSError:
    from classificationICD10 import *
    icd10_model = keras_load_model('models/icd10Classification_attention_extended.h5')

with open('models/icd10_tokenizer.p', 'rb') as handle:
    icd10Tokenizer = pickle.load(handle)

with open('models/icd10_mappings.p', 'rb') as handle:
    encoded_Y = pickle.load(handle)
# LOAD ICD 10 CLASSIFICATION MODEL

callbacks_list = [
    EarlyStopping(
        monitor='val_loss',
        patience=2,
        # min_delta=0.001
    ),
    ModelCheckpoint(
        filepath='models/s2s_extended.h5',
        monitor='val_loss',
        save_best_only=True,
    ),
    CSVLogger(
        append=False,
        filename='logs/s2s_extended_{}.csv'.format(date_label)
    )
]

latent_dim = 256
batch_size = 700
epochs = 30

train_data_generator = KerasBatchGenerator(batch_size,
                                           source_train,
                                           source_max_sequence_tokenizer,
                                           source_kerasTokenizer,
                                           target_train,
                                           target_max_sequence_tokenizer,
                                           target_kerasTokenizer
                                           )

validation_data_generator = KerasBatchGenerator(batch_size,
                                           source_val,
                                           source_max_sequence_tokenizer,
                                           source_kerasTokenizer,
                                           target_val,
                                           target_max_sequence_tokenizer,
                                           target_kerasTokenizer
                                           )

print("Lets train some stuff!")
# Define an input sequence and process it.
encoder_input = Input(shape=(source_max_sequence_tokenizer, ))
x = source_embedding_layer(encoder_input)
x, state_h, state_c = LSTM(latent_dim, return_state=True)(x)
encoder_states = [state_h, state_c]

# Set up the decoder, using `encoder_states` as initial state.
decoder_input = Input(shape=(target_max_sequence_tokenizer, ))
x = target_embedding_layer(decoder_input)
decoder_LSTM = LSTM(latent_dim, return_sequences=True, return_state = True)
decoder_out, _ , _ = decoder_LSTM(x, initial_state=encoder_states)
decoder_dense = Dense(len(target_vocab)+1, activation='softmax')
decoder_out = decoder_dense(decoder_out)


# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_input, decoder_input], decoder_out)

# Compile & run training
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
# model.fit([source_train, target_train],
#             target_train_onehot,
#             batch_size=batch_size,
#             callbacks=callbacks_list,
#             epochs=epochs,
#             validation_split=0.1
#             # validation_data=([source_val, target_val], target_val_onehot)
#       )

model.fit_generator(
    generator=train_data_generator.generate_data(),
    steps_per_epoch=int(len(source_train)/batch_size)+1,
    epochs=epochs,
    callbacks=callbacks_list,
    validation_data=validation_data_generator.generate_data(),
    validation_steps=int(len(source_val)/batch_size)+1,
    # use_multiprocessing=True,
    # workers=10
)

# INFERENCE MODELS
# Encoder inference model
encoder_model_inf = Model(encoder_input, encoder_states)

# Decoder inference model
decoder_state_input_h = Input(shape=(256,))
decoder_state_input_c = Input(shape=(256,))
decoder_input_states = [decoder_state_input_h, decoder_state_input_c]

decoder_out, decoder_h, decoder_c = decoder_LSTM(x, initial_state=decoder_input_states)
decoder_states = [decoder_h , decoder_c]
decoder_out = decoder_dense(decoder_out)

decoder_model_inf = Model(inputs=[decoder_input] + decoder_input_states,
                          outputs=[decoder_out] + decoder_states )

def decode_seq(inp_seq):
    states_val = encoder_model_inf.predict(inp_seq)

    target_seq = np.zeros((1, target_max_sequence_tokenizer))
    target_seq[0, 0] = target_vocab['sos']

    translated_sent = []
    translated_index = []
    stop_condition = False

    while not stop_condition:
        decoder_out, decoder_h, decoder_c = decoder_model_inf.predict(x=[target_seq] + states_val)
        max_val_index = np.argmax(decoder_out[0, -1, :])
        try:
            sampled_fra_char = target_index_to_word_dict[max_val_index]
        except KeyError:
            # stop_condition = True
            sampled_fra_char = 'eos'

        translated_sent.append(sampled_fra_char)
        translated_index.append(max_val_index)

        if ((sampled_fra_char == 'eos') or (len(translated_sent) > target_max_sequence_tokenizer)):
            stop_condition = True

        target_seq = np.zeros((1, target_max_sequence_tokenizer))
        target_seq[0, 0] = max_val_index
        states_val = [decoder_h, decoder_c]

    return translated_sent[:-1], translated_index[:-1]


max_icd10_length = icd10_model.layers[0].input_shape[1]

run_pipeline_prediction(source_val, decode_seq, icd10_model, encoded_Y, labels_val,
                        source_kerasTokenizer, source_max_sequence_tokenizer,
                        icd10Tokenizer, max_icd10_length,
                        target_index_to_word_dict, target_val,
                        'logs/seq2seq')




# y_true = []
# y_pred = []
#
#
# source_val = source_kerasTokenizer.texts_to_sequences(source_val)
# source_val = pad_sequences(source_val, maxlen=source_max_sequence_tokenizer, padding='post')
#
# for seq_index in tqdm.tqdm(range(len(source_val))):
# # for seq_index in range(10):
#     inp_seq = source_val[seq_index:seq_index+1]
#     translated_sent, translated_index= decode_seq(inp_seq)
#
#     # PREDICT ICD10
#     source_word_sequence = icd10Tokenizer.texts_to_sequences([" ".join(translated_sent)])
#     word_sequence = pad_sequences(source_word_sequence, maxlen=icd10_model.layers[0].input_shape[1], padding='post')
#     icd10_code_index = icd10_model.predict(word_sequence)
#     # print(icd10_code_index, type(icd10_code_index))
#     max_val_index = np.argmax(icd10_code_index, axis=1)[0]
#     # print(max_val_index)
#     icd10_label = encoded_Y.inverse_transform(max_val_index)
#
#     # print('-')
#     # target_index = np.trim_zeros(target_val[seq_index], 'b')[1:-1]
#     # print('Target indexes:', target_index)
#     # print('Decoded indexes:', translated_index)
#     #
#     # print('Target sentence:', " ".join([target_index_to_word_dict[x] for x in target_index]))
#     # print('Decoded sentence:', " ".join(translated_sent))
#     #
#     # print('Target ICD-10:', labels_val[seq_index])
#     # print('Predict ICD-10:', icd10_label)
#
#     y_true.append(labels_val[seq_index])
#     y_pred.append(icd10_label)
#
# report = classification_report(y_true, y_pred)
# report_df = report_to_df(report)
# report_df.to_csv('logs/classification_report_extended.csv')
# print(report_df)