# from comet_ml import Experiment # experiment = Experiment(api_key="hSd9vTj0EfMu72569YnVEvtvj") from loader import * from _layers import AttentionWithContext, Attention, AttentionDecoder from keras.models import Model, load_model as keras_load_model from keras.layers import Input, LSTM, Dense, Embedding, GRU, Activation, dot, concatenate, Bidirectional, TimeDistributed from keras.utils import multi_gpu_model from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger import tensorflow as tf import tqdm import pickle from sklearn.metrics import classification_report #REPRODUCIBLE np.random.seed(42) import random random.seed(12345) import os os.environ['PYTHONHASHSEED'] = '0' import tensorflow as tf # config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) config = tf.ConfigProto() from keras import backend as K tf.set_random_seed(1234) #REPRODUCIBLE ################################### # TensorFlow wizardry # Don't pre-allocate memory; allocate as-needed config.gpu_options.allow_growth = True config.gpu_options.allocator_type = 'BFC' sess = tf.Session(graph=tf.get_default_graph(), config=config) K.set_session(sess) # LOAD ICD 10 CLASSIFICATION MODEL try: icd10_model = keras_load_model('models/icd10Classification_attention_extended.h5', custom_objects={'Attention':Attention}) except OSError: from classificationICD10 import * icd10_model = keras_load_model('models/icd10Classification_attention_extended.h5') with open('models/icd10_tokenizer_extended.p', 'rb') as handle: icd10Tokenizer = pickle.load(handle) with open('models/icd10_mappings_extended.p', 'rb') as handle: encoded_Y = pickle.load(handle) # LOAD ICD 10 CLASSIFICATION MODEL callbacks_list = [ EarlyStopping( monitor='val_loss', patience=2, # min_delta=0.001 ), ModelCheckpoint( filepath='models/s2s_att_extended.h5', monitor='val_loss', save_best_only=True, ), CSVLogger( append=False, filename='logs/s2s_att_extended_{}.csv'.format(date_label) ) ] latent_dim = 256 batch_size = 400 epochs = 1 train_data_generator = KerasBatchGenerator(batch_size, source_train, source_max_sequence_tokenizer, source_kerasTokenizer, target_train, target_max_sequence_tokenizer, target_kerasTokenizer ) validation_data_generator = KerasBatchGenerator(batch_size, source_val, source_max_sequence_tokenizer, source_kerasTokenizer, target_val, target_max_sequence_tokenizer, target_kerasTokenizer ) print("Lets train some stuff!") # Define an input sequence and process it. encoder_input = Input(shape=(source_max_sequence_tokenizer, ), name='encoder_input') x_encoder = source_embedding_layer(encoder_input) encoder_out, state_h, state_c = LSTM(latent_dim, return_sequences=True, unroll=True, return_state=True, name='encoder_lstm')(x_encoder) encoder_states = [state_h, state_c] # Set up the decoder, using `encoder_states` as initial state. decoder_input = Input(shape=(target_max_sequence_tokenizer, ), name='decoder_input') x_decode = target_embedding_layer(decoder_input) decoder_LSTM = LSTM(latent_dim, return_sequences=True, return_state = True, unroll=True, name='decoder_lstm') decoder, state_h_decode , state_c_decode = decoder_LSTM(x_decode, initial_state=encoder_states) # Equation (7) with 'dot' score from Section 3.1 in the paper. # Note that we reuse Softmax-activation layer instead of writing tensor calculation attention = dot([encoder_out, decoder], name='attention_dot' ,axes=[2, 2]) attention = Activation('softmax', name='attention_activation')(attention) context = dot([attention, encoder_out], name='context_dot' ,axes=[1,1]) decoder_combined_context = concatenate([context, decoder]) print(decoder_combined_context) decoder_dense = Dense(len(target_vocab)+1, activation='softmax', name='dense_output') decoder_out = decoder_dense(decoder_combined_context) # equation (6) of the paper # MODEL model = Model([encoder_input, decoder_input], decoder_out) model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy']) model.summary() # model.fit([source_train, target_train], # target_train_onehot, # batch_size=batch_size, # callbacks=callbacks_list, # epochs=epochs, # validation_split=0.1 # # validation_data=([source_val, target_val], target_val_onehot) # ) model.fit_generator( generator=train_data_generator.generate_data(), steps_per_epoch=int(len(source_train)/batch_size)+1, epochs=epochs, callbacks=callbacks_list, validation_data=validation_data_generator.generate_data(), validation_steps=int(len(source_val)/batch_size)+1, # use_multiprocessing=True, # workers=10 ) # INFERENCE MODELS # Encoder inference model encoder_model_inf = Model(encoder_input, encoder_states) encoder_model_inf.summary() # Decoder inference model decoder_state_input_h = Input(shape=(256,), name='inference_decoder_input_h') decoder_state_input_c = Input(shape=(256,), name='inference_decoder_input_c') decoder_input_states = [decoder_state_input_h, decoder_state_input_c] decoder, decoder_h, decoder_c = decoder_LSTM(x_decode, initial_state=decoder_input_states) decoder_states = [decoder_h , decoder_c] attention = dot([encoder_out, decoder], axes=[2, 2]) attention = Activation('softmax')(attention) context = dot([attention, encoder_out], axes=[1,1]) # print(context, decoder) decoder_combined_context = concatenate([context, decoder]) # print('decoder_combined_context\t', decoder_combined_context) decoder_out = decoder_dense(decoder_combined_context) decoder_model_inf = Model(inputs=[decoder_input] + decoder_input_states, outputs=[decoder_out] + decoder_states ) decoder_model_inf.summary() def decode_seq(inp_seq): states_val = encoder_model_inf.predict(inp_seq) print('states_val\t', states_val) input('inference encoder prediction\t') target_seq = np.zeros((1, target_max_sequence_tokenizer)) target_seq[0, 0] = target_vocab['sos'] translated_sent = [] translated_index = [] stop_condition = False while not stop_condition: decoder_out, decoder_h, decoder_c = decoder_model_inf.predict(x=[target_seq] + states_val) max_val_index = np.argmax(decoder_out[0, -1, :]) try: sampled_fra_char = target_index_to_word_dict[max_val_index] except KeyError: # stop_condition = True sampled_fra_char = 'eos' translated_sent.append(sampled_fra_char) translated_index.append(max_val_index) if ((sampled_fra_char == 'eos') or (len(translated_sent) > target_max_sequence_tokenizer)): stop_condition = True target_seq = np.zeros((1, target_max_sequence_tokenizer)) target_seq[0, 0] = max_val_index states_val = [decoder_h, decoder_c] return translated_sent[:-1], translated_index[:-1] y_true = [] y_pred = [] source_val = source_kerasTokenizer.texts_to_sequences(source_val) source_val = pad_sequences(source_val, maxlen=source_max_sequence_tokenizer, padding='post') for seq_index in tqdm.tqdm(range(len(source_val))): # for seq_index in range(10): inp_seq = source_val[seq_index:seq_index+1] translated_sent, translated_index= decode_seq(inp_seq) # PREDICT ICD10 source_word_sequence = icd10Tokenizer.texts_to_sequences([" ".join(translated_sent)]) word_sequence = pad_sequences(source_word_sequence, maxlen=icd10_model.layers[0].input_shape[1], padding='post') icd10_code_index = icd10_model.predict(word_sequence) # print(icd10_code_index, type(icd10_code_index)) max_val_index = np.argmax(icd10_code_index, axis=1)[0] # print(max_val_index) icd10_label = encoded_Y.inverse_transform(max_val_index) # print('-') # target_index = np.trim_zeros(target_val[seq_index], 'b')[1:-1] # print('Target indexes:', target_index) # print('Decoded indexes:', translated_index) # # print('Target sentence:', " ".join([target_index_to_word_dict[x] for x in target_index])) # print('Decoded sentence:', " ".join(translated_sent)) # # print('Target ICD-10:', labels_val[seq_index]) # print('Predict ICD-10:', icd10_label) y_true.append(labels_val[seq_index]) y_pred.append(icd10_label) report = classification_report(y_true, y_pred) report_df = report_to_df(report) report_df.to_csv('logs/classification_report_extended.csv') print(report_df)