# from comet_ml import Experiment
# experiment=Experiment(api_key="hSd9vTj0EfMu72569YnVEvtvj")

# from loader import *
from util import *
import numpy as np
import random
import tensorflow as tf
import traceback
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from collections import Counter
import os

from keras import backend as K
from keras.preprocessing.sequence import pad_sequences
from keras.preprocessing.text import Tokenizer
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger, TensorBoard
from keras.layers import Embedding, Input, LSTM, Dense, Bidirectional
from keras.models import Model
from keras.utils import multi_gpu_model, np_utils

from _layers import AttentionWithContext, Attention

config = tf.ConfigProto()
#REPRODUCIBLE
np.random.seed(42)
random.seed(12345)
os.environ['PYTHONHASHSEED'] = '0'
# config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
tf.set_random_seed(1234)

# Don't pre-allocate memory; allocate as-needed
config.gpu_options.allow_growth=True
config.gpu_options.allocator_type='BFC'

sess = tf.Session(graph=tf.get_default_graph(), config=config)
K.set_session(sess)

tbCallBack = TensorBoard(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True)

callbacks_list=[
    EarlyStopping(
        monitor='val_loss',
        patience=2,
        min_delta=0.001
    ),
    ModelCheckpoint(
        filepath='models/icd10Classification_attention_duplicated.h5',
        monitor='val_loss',
        save_best_only=True,
    ),
    CSVLogger(
        append=True,
        filename='logs/icd10Classification_attention_duplicated_{}.csv'.format(date_label),
    ),
    tbCallBack
]

latent_dim = 512
epochs = 30
batch_size = 1000

tokenizer=TokenizePreprocessor()
kerasTokenizer = Tokenizer()
dataLoader=prepareData()
corpus, labels =dataLoader.prepareDictionaries(unbalanced=True, oversampled=True)

num_labels=len(list(set(labels)))
print("Extracted {} data points with {} unique labels".format(len(corpus), num_labels))

#prepareing the texts for input in RNN
tokens=tokenizer.transform([x for x in corpus])
tmp=[item for item in list(set(flatten(tokens))) if item.strip()]
vocabulary={item.strip():i+1 for i,item in enumerate(tmp)}
index_to_word_dict={i+1:item.strip() for i,item in enumerate(tmp)}
kerasTokenizer.word_index=vocabulary
# saving
with open('models/icd10_tokenizer_duplicated.p', 'wb') as handle:
    pickle.dump(kerasTokenizer, handle)

source_word_sequence=kerasTokenizer.texts_to_sequences(corpus)
max_sequence = max([len(x) for x in source_word_sequence])
word_sequence = pad_sequences(source_word_sequence, maxlen=max_sequence, padding='post')

embedding_matrix=embedding_matrix(vocabulary)
embedding_layer = Embedding(
    embedding_matrix.shape[0],
    embedding_matrix.shape[1],
    weights=[embedding_matrix],
    input_length=max_sequence,
    trainable=True,
    mask_zero=True)

#preparing the labels as one hot encoding vector
encoder = LabelEncoder()
encoder.fit(labels)
with open('models/icd10_mappings_duplicated.p', 'wb') as handle:
    pickle.dump(encoder, handle)

encoded_Y = encoder.transform(labels)

# convert integers to dummy variables (i.e. one hot encoded)
labels_one_hot = np_utils.to_categorical(encoded_Y)

X_train, X_test, Y_train, Y_test = train_test_split(word_sequence, labels_one_hot, test_size=0.02, random_state=777, stratify=labels)
print("Prepared data: ", len(X_train), len(Y_train), len(X_test), len(Y_test))

try:
    # LAYERS
    print("Creating Model...")
    inputs = Input(shape=(max_sequence,))
    embedding = embedding_layer(inputs)
    decoder_LSTM = Bidirectional(LSTM(latent_dim, return_sequences=True))
    decoder_out = decoder_LSTM(embedding) #, initial_state=encoder_states)
    attention = Attention()(decoder_out)
    decoder_dense = Dense(num_labels, activation='softmax')
    decoder_out = decoder_dense(attention)

    #MODEL
    model = Model(inputs=inputs, outputs=decoder_out)
    adam = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
    model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
    model.summary()
    print("Traning Model...")
    model.fit(word_sequence, labels_one_hot,
          batch_size=batch_size,
          epochs=epochs,
          callbacks=callbacks_list,
          validation_data=[X_test, Y_test]
          # validation_split=0.25
    )

except Exception as e:
    print(e)
    traceback.print_exc()