import os
from typing import List

import pandas as pd

from pandas import DataFrame
from tqdm import tqdm

from app_context import AppContext
from util import LoggingMixin
from itertools import groupby
from preprocessing import DataPreparationUtil as pdu


class Clef18Task1Data(LoggingMixin):

    def __init__(self):
        LoggingMixin.__init__(self, self.__class__.__name__)

    def read_train_certifcates_by_id(self, id: str) -> DataFrame:
        if id == "it":
            return self.read_it_train_certificates()
        elif id == "hu":
            return self.read_hu_train_certificates()
        elif id == "fr":
            return self.read_fr_train_certificates()
        elif id == "all-con":
            return self.read_all_con_certificates()
        else:
            raise AssertionError("Unsupported language: " + id)

    def read_test_certifcates_by_lang(self, lang: str) -> DataFrame:
        if lang == "it":
            return self.read_it_test_certificates()
        elif lang == "hu":
            return self.read_hu_test_certificates()
        elif lang == "fr":
            return self.read_fr_test_certificates()

        else:
            raise AssertionError("Unsupported language: " + lang)

    def read_dictionary_by_id(self, id: str) -> DataFrame:
        if id == "it":
            return self.read_it_dictionary()
        elif id == "hu":
            return self.read_hu_dictionary()
        elif id == "fr":
            return self.read_fr_dictionary()
        elif id == "all-con":
            return self.read_all_con_dictionary()
        else:
            raise AssertionError("Unsupported language: " + id)

    # --------------------------------------------------------------------------------

    def read_it_train_certificates(self) -> DataFrame:
        base_folder = "data/train/IT/training/raw/corpus/"

        calculees_file = os.path.join(base_folder, "CausesCalculees_IT_1.csv")
        brutes_file = os.path.join(base_folder, "CausesBrutes_IT_1.csv")

        return self._read_certificates([calculees_file], [brutes_file], "it")

    def read_it_dictionary(self) -> DataFrame:
        base_folder = "data/train/IT/training/raw/dictionaries"
        dictionary_file = os.path.join(base_folder, "dictionary_IT.csv")
        return self._read_icd10_dictionary([dictionary_file], "iso-8859-1", "it")

    def read_it_test_certificates(self) -> DataFrame:
        brutes_file = "data/test/IT/test/raw/corpus/CausesBrutes_IT_2.csv"
        return self._read_test_data(brutes_file)

    # --------------------------------------------------------------------------------

    def read_hu_train_certificates(self) -> DataFrame:
        base_folder = "data/train/HU/training/raw/corpus/"

        calculees_file = os.path.join(base_folder, "CausesCalculees_HU_1.csv")
        brutes_file = os.path.join(base_folder, "CausesBrutes_HU_1.csv")

        return self._read_certificates([calculees_file], [brutes_file], "hu")

    def read_hu_dictionary(self) -> DataFrame:
        base_folder = "data/train/HU/training/raw/dictionaries"
        dictionary_file = os.path.join(base_folder, "Hungarian_dictionary_UTF8.csv")
        return self._read_icd10_dictionary([dictionary_file], "utf-8", "hu")

    def read_hu_test_certificates(self) -> DataFrame:
        brutes_file = "data/test/HU/test/raw/corpus/CausesBrutes_HU_2.csv"
        return self._read_test_data(brutes_file)

    # --------------------------------------------------------------------------------

    def read_fr_train_certificates(self) -> DataFrame:
        # FIXME: Load other training files from 2011-2015!
        base_folder = "data/train/FR/training/raw/corpus/"

        calculees_files = [
            os.path.join(base_folder, "CausesCalculees_FR_2006-2012.csv"),
            os.path.join(base_folder, "CausesCalculees_FR_2013.csv"),
            os.path.join(base_folder, "CausesCalculees_FR_2014.csv")
        ]

        brutes_files = [
            os.path.join(base_folder, "CausesBrutes_FR_2006-2012.csv"),
            os.path.join(base_folder, "CausesBrutes_FR_2013.csv"),
            os.path.join(base_folder, "CausesBrutes_FR_2014.csv")
        ]

        return self._read_certificates(calculees_files, brutes_files, "fr")

    def read_fr_dictionary(self) -> DataFrame:
        base_folder = "data/train/FR/training/aligned/dictionaries"
        dictionary_files = [
            os.path.join(base_folder, "Dictionnaire2006-2010.csv"),
            os.path.join(base_folder, "Dictionnaire2014.csv"),
            os.path.join(base_folder, "Dictionnaire2015.csv")

        ]
        return self._read_icd10_dictionary(dictionary_files, "utf-8", "fr")

    def read_fr_test_certificates(self) -> DataFrame:
        brutes_file = "data/test/FR/test/raw/corpus/CausesBrutes_FR_2.csv"
        return self._read_test_data(brutes_file)

    # --------------------------------------------------------------------------------

    def read_all_con_dictionary(self) -> DataFrame:
        return pd.concat([self.read_fr_dictionary(),
                          self.read_it_dictionary(),
                          self.read_hu_dictionary()])

    def read_all_con_certificates(self) -> DataFrame:
        all_certificates = pd.concat([self.read_fr_train_certificates(), self.read_it_train_certificates(), self.read_hu_train_certificates()])
        self.logger.info("Found %s death certificate lines", len(all_certificates))
        return all_certificates

    # --------------------------------------------------------------------------------

    def _read_certificates(self, calculees_files: List[str], brutus_files: List[str], language: str) -> DataFrame:
        calculees_data = []
        for calculees_file in calculees_files:
            self.logger.info("Reading calculees file from %s", calculees_file)
            calculees_data.append(pd.read_csv(calculees_file, sep=";", encoding="iso-8859-1", index_col=["YearCoded", "DocID", "LineID"],
                                              skipinitialspace=True))
            self.logger.info("Found %s death certificate entries", len(calculees_data[-1]))

        calculees_data = pd.concat(calculees_data)
        self.logger.info("Found %s death certificate lines in total", len(calculees_data))

        brutus_data = []
        for brutus_file in brutus_files:
            self.logger.info("Reading brutus file from %s", brutus_file)
            brutus_data.append(pd.read_csv(brutus_file, sep=";", encoding="iso-8859-1", index_col=["YearCoded", "DocID", "LineID"],
                                           skipinitialspace=True))
            self.logger.info("Found %s death certificate entries", len(brutus_data[-1]))

        brutus_data = pd.concat(brutus_data)

        joined_data = brutus_data.join(calculees_data, lsuffix="_b", rsuffix="_c")
        joined_data["ICD10"] = joined_data["ICD10"].astype(str)

        num_unchecked_data = len(joined_data)
        joined_data = joined_data.query("ICD10 != 'nan'")
        self.logger.info("Removed %s lines with ICD10 'nan'", num_unchecked_data - len(joined_data))

        joined_data = pdu.clean_text("RawText").fit_transform(joined_data)
        joined_data["Lang"] = language

        return joined_data[["RawText", "ICD10", "Lang"]]

    def _read_icd10_dictionary(self, dictionary_files: List[str], encoding: str, lang: str) -> DataFrame:
        dictionary_data = []
        for dict_file in dictionary_files:
            self.logger.info("Reading ICD10 dictionary from %s", dict_file)
            dictionary_data.append(pd.read_csv(dict_file, sep=";", encoding=encoding,
                                               skipinitialspace=True, error_bad_lines=False))
        dictionary_data = pd.concat(dictionary_data)

        num_dictionary_entries = len(dictionary_data)
        self.logger.info("Found %s dictionary entries", num_dictionary_entries)

        if "Standardized" not in dictionary_data.columns:
            dictionary_data["Standardized"] = None

        dictionary_data = dictionary_data[["Icd1", "Standardized", "DiagnosisText"]]
        dictionary_data = dictionary_data.drop_duplicates()

        dictionary_data.columns = ["ICD10", "Standardized", "DiagnosisText"]

        dictionary_data["ICD10"] = dictionary_data["ICD10"].astype(str)
        self.logger.info("Removed %s duplicates from dictionary", num_dictionary_entries - len(dictionary_data))

        dictionary_data["Lang"] = lang

        return dictionary_data

    def _read_test_data(self, file: str) -> DataFrame:
        self.logger.info("Reading test certificates from %s", file)
        test_data = pd.read_csv(file, sep=";", encoding="iso-8859-1", index_col=["YearCoded", "DocID", "LineID"],
                                skipinitialspace=True, error_bad_lines=False)
        self.logger.info("Found %s test certificate lines.", len(test_data))

        return test_data

    # --------------------------------------------------------------------------------

    def filter_single_code_lines(self, certificate_df: DataFrame) -> DataFrame:
        multi_code_lines = [key for key, group in groupby(certificate_df.index.values) if len(list(group)) > 1]
        self.logger.info("Start filtering %s lines with multiple codes", len(multi_code_lines))

        original_size = len(certificate_df)
        certificate_df = certificate_df.drop(multi_code_lines)
        self.logger.info("Filtered %s out of %s entries due to single code constraint", len(certificate_df), original_size)

        return certificate_df

    def add_masked_icd10_column(self, certificate_df: DataFrame, min_support: int, mask_code: str = "RARE-ICD10") -> DataFrame:
        code_frequency_distribution = certificate_df["ICD10"].value_counts()
        icd_masker = pdu.mask_icd10("ICD10", "ICD10_masked", code_frequency_distribution, min_support, mask_code)

        certificate_df = icd_masker.fit_transform(certificate_df)

        num_infrequent_codes = certificate_df["ICD10_masked"].value_counts()[mask_code]
        self.logger.info("Added masked icd10 code column. Found %s codes with support less than %s", num_infrequent_codes, min_support)

        return certificate_df

    def down_sample_by_icd10_frequency(self, certificate_df: DataFrame, max_freq: int):
        self.logger.info("Down sampled data set with %s entries", len(certificate_df))
        icd10_codes = certificate_df["ICD10"].unique()

        data_sets = []
        for code in tqdm(icd10_codes,desc="down-sample", total=len(icd10_codes)):
            entries_by_code = certificate_df.query("ICD10 == '%s'" % code)
            if len(entries_by_code) > max_freq:
                unique_texts = entries_by_code["RawText"].unique()

                unique_entries = []
                for text in unique_texts:
                    unique_entries.append(entries_by_code.query("RawText == \"%s\"" % text)[0:1])

                unique_entries.append(entries_by_code.sample(max(max_freq-len(unique_texts), 10)))
                entries_by_code = pd.concat(unique_entries)

            data_sets.append(entries_by_code)

        sampled_df = pd.concat(data_sets)
        sampled_df = sampled_df.sample(frac=1) # Reshuffle!
        self.logger.info("Down sampled data set contains %s entries", len(sampled_df))
        return sampled_df

    def extend_certificates_by_dictionaries(self, certificate_df: DataFrame, dictionary_df: DataFrame) -> DataFrame:
        self.logger.info("Start extending certificate data set with dictionary entries (original size: %s)", len(certificate_df))
        original_size = len(certificate_df)

        dict_icd10_codes = dictionary_df["ICD10"].unique()
        cert_icd10_codes = certificate_df["ICD10"].unique()

        unseen_icd10_codes = [dict_icd10 for dict_icd10 in dict_icd10_codes if dict_icd10 not in cert_icd10_codes]

        unseen_mask = dictionary_df["ICD10"].isin(unseen_icd10_codes)
        lines_with_unseen_codes = dictionary_df.loc[unseen_mask]

        new_rows = certificate_df.copy().iloc[0:0]
        for i, row in tqdm(lines_with_unseen_codes.iterrows(), desc="extend-cert", total=len(lines_with_unseen_codes)):
            new_rows = new_rows.append({"RawText": row["DiagnosisText"], "ICD10": row["ICD10"], "Lang": row["Lang"] }, ignore_index=True)

        extended_size = len(new_rows)
        certificate_df = pd.concat([certificate_df, new_rows])

        self.logger.info("Extended cert data set with %s from dictionary (%s in total)" % (extended_size - original_size, extended_size))
        return certificate_df

    def remove_duplicates_from_certificates(self, certificate_df: DataFrame):
        self.logger.info("Start removing duplicates from certificate data set (size: %s)", len(certificate_df))

        cleaned_cert_df = certificate_df.drop_duplicates(subset=["RawText", "ICD10"])
        self.logger.info("Removed %s duplicates from certificate data set (new size: %s)",
                         len(certificate_df) - len(cleaned_cert_df), len(cleaned_cert_df))

        return cleaned_cert_df

    def split_multi_code_lines(self, certificate_df: DataFrame) -> DataFrame:
        self.logger.info("Start splitting multi code lines in certificate data set (size: %s)", len(certificate_df))
        duplicate_ids = certificate_df.index.duplicated(keep=False)
        lines_with_multiple_codes = certificate_df[duplicate_ids].index.values

        mask_multicode = certificate_df.index.isin(lines_with_multiple_codes)
        singlecode_rows = certificate_df[~mask_multicode]
        multicode_rows = certificate_df[mask_multicode]

        last_index = -1
        text_pos = 0
        new_rows = singlecode_rows.copy().iloc[0:0]
        for index, row in tqdm(multicode_rows.iterrows(), desc="split-multi", total=len(multicode_rows)):
            split = str(row["RawText"]).split(",")

            if last_index != index:
                text_pos = 0
            else:
                text_pos = text_pos + 1

            if text_pos < len(split):
                row["RawText"] = split[text_pos]
                new_rows = new_rows.append(row)
            else:
                new_rows = new_rows.append(row)
            last_index = index

        result = pd.concat([singlecode_rows, new_rows])
        self.logger.info("Finished multi code line splitting. Adding %s new rows (new size: %s)", len(new_rows), len(result))
        return result

    def duplicate_less_frequent(self, certificate_df: DataFrame, min_freq: int):
        self.logger.info("Start duplicating less frequent ICD10 code entries (size: %s)", len(certificate_df))

        code_counts = certificate_df["ICD10"].value_counts()
        less_frequent_codes = set([code for code, freq in code_counts.iteritems() if freq < min_freq])

        less_frequent_mask = certificate_df["ICD10"].isin(less_frequent_codes)
        less_frequent_rows = certificate_df[less_frequent_mask]

        new_rows = certificate_df.copy().iloc[0:0]
        for code in tqdm(less_frequent_codes, desc="build-dup", total=len(less_frequent_codes)):
            num_dumplicates = min_freq - code_counts[code]
            duplicates = less_frequent_rows.query("ICD10 == '%s'" % code).sample(num_dumplicates, replace=True)
            new_rows = pd.concat([new_rows, duplicates])

        result = pd.concat([certificate_df, new_rows])
        self.logger.info("Added %s duplicates to data set to gurantee min frequency %s (new size: %s)",
                         len(new_rows), min_freq, len(result))

        return result

    def language_tag_data(self, data_set: DataFrame, text_column: str, lang_column: str) -> DataFrame:
        self.logger.info("Start word tagging column %s with language from %s (size: %s)", text_column, lang_column, len(data_set))

        word_tagger = pdu.tag_words_with_language(text_column, text_column, lang_column)
        data_set = word_tagger.fit_transform(data_set)

        self.logger.info("Finished word tagging data")
        return data_set

    def filter_nan_texts(self, certifcate_df: DataFrame) -> DataFrame:
        self.logger.info("Start filtering nan texts (size: %s)", len(certifcate_df))
        nan_mask = certifcate_df["RawText"].isin(["nan"])
        certifcate_df = certifcate_df[~nan_mask]
        self.logger.info("Finished filtering nan texts (size: %s)", len(certifcate_df))
        return certifcate_df

def check_multi_label_distribution(ds_name: str, certificate_df: DataFrame):
    print("Data set: ", ds_name)

    ids = []
    for i, row in certificate_df.iterrows():
        ids.append("{}##{}".format(i[0], i[1]))

    certificate_df["ID"] = ids
    value_counts = certificate_df["ID"].value_counts()
    print(value_counts)

    total_sum = value_counts.sum()
    sum_more_than_one = 0
    sum_more_than_two = 0
    for i, row in value_counts.iteritems():
        if row > 1:
            sum_more_than_one = sum_more_than_one + row
        if row > 2:
            sum_more_than_two = sum_more_than_two + row

    print(sum_more_than_one / total_sum)
    print(sum_more_than_two / total_sum)

    print(certificate_df["ICD10"].value_counts())
    print("\n\n\n")


def check_word_dictionary_overlap(cert_df: DataFrame, dict_df: DataFrame, dict_file: str):
    words = set()
    with open(dict_file, "r", encoding="utf8") as dict_reader:
        for line in dict_reader.readlines():
            words.add(line.strip().split(" ")[0])
        dict_reader.close()

    cert_words = set()
    for i, row in cert_df.iterrows():
        for word in str(row["RawText"]).lower().split(" "):
            cert_words.add(word)

    dict_words = set()
    for i, row in dict_df.iterrows():
        for word in str(row["DiagnosisText"]).lower().split(" "):
            dict_words.add(word)

    inter_cert_words = words.intersection(cert_words)
    print(len(inter_cert_words) / len(cert_words))

    inter_dict_words = words.intersection(dict_words)
    print(len(inter_dict_words) / len(dict_words))

def check_multi_code_lines(cert_df: DataFrame):
    duplicate_ids = cert_df.index.duplicated(keep=False)
    lines_with_multiple_codes = cert_df[duplicate_ids]

    num_multi_code_lines = len(lines_with_multiple_codes.index.unique())

    line_occurrences = cert_df.index.value_counts()


    correct_single = set()
    incorrect_single = set()

    correct_multi = set()
    incorrect_multi = set()

    for i, row in cert_df.iterrows():
        num_commas = str(row["RawText"]).count(",")
        num_icd10_codes = line_occurrences[i]

        if num_commas == (num_icd10_codes - 1) and num_icd10_codes == 1:
            correct_single.add(i)

        if num_commas != (num_icd10_codes - 1) and num_icd10_codes == 1:
            incorrect_single.add(i)

        if num_commas == (num_icd10_codes-1) and num_icd10_codes > 1:
            correct_multi.add(i)

        if num_commas != (num_icd10_codes - 1) and num_icd10_codes > 1:
            incorrect_multi.add(i)


    print("Number of multi code lines: ", num_multi_code_lines)
    print("-> Equal number of commas and codes: ", len(correct_multi))

    print("Multi line codes:")
    print("\tCorrect: ", len(correct_multi))
    print("\tIncorrect: ", len(incorrect_multi))

    print("Single line codes:")
    print("\tCorrect: ", len(correct_single))
    print("\tIncorrect: ", len(incorrect_single))

def check_label_distribution(cert_df: DataFrame):
    distribution = cert_df["ICD10"].value_counts()
    print(distribution)


if __name__ == "__main__":
    # Just for debugging / development purposes
    AppContext.initialize_by_app_name("Clef18Task1-Data")

    clef_task_data = Clef18Task1Data()

    #all_cert = clef_task_data.read_all_con_certificates()
    #check_label_distribution(all_cert)
    #clef_task_data.down_sample_by_icd10_frequency(all_cert, 4000)

    it_certificates = clef_task_data.read_it_train_certificates()
    it_dictionary = clef_task_data.read_it_dictionary()

    it_certificates = clef_task_data.extend_certificates_by_dictionaries(it_certificates, it_dictionary)
    it_certificates = clef_task_data.remove_duplicates_from_certificates(it_certificates)
    it_certificates = clef_task_data.split_multi_code_lines(it_certificates)
    it_certificates = clef_task_data.duplicate_less_frequent(it_certificates, 4)
    print("IT: ", len(it_certificates))

    hu_certificates = clef_task_data.read_hu_train_certificates()
    hu_dictionary = clef_task_data.read_hu_dictionary()
    hu_certificates = clef_task_data.extend_certificates_by_dictionaries(hu_certificates, hu_dictionary)
    hu_certificates = clef_task_data.remove_duplicates_from_certificates(hu_certificates)
    hu_certificates = clef_task_data.split_multi_code_lines(hu_certificates)
    hu_certificates = clef_task_data.duplicate_less_frequent(hu_certificates, 4)
    print("HU: ", len(hu_certificates))

    fr_certificates = clef_task_data.read_fr_train_certificates()
    fr_dictionary = clef_task_data.read_fr_dictionary()
    fr_certificates = clef_task_data.extend_certificates_by_dictionaries(fr_certificates, fr_dictionary)
    fr_certificates = clef_task_data.remove_duplicates_from_certificates(fr_certificates)
    fr_certificates = clef_task_data.split_multi_code_lines(fr_certificates)
    fr_certificates = clef_task_data.duplicate_less_frequent(fr_certificates, 4)
    print("FR: ", len(fr_certificates))

    #check_label_distribution(it_certificates)
    #it_certificates = clef_task_data.down_sample_by_icd10_frequency(it_certificates, 800)
    #check_label_distribution(it_certificates)

    #hu_certificates = clef_task_data.read_hu_train_certificates()
    #hu_dictionary = clef_task_data.read_hu_dictionary()

    #check_label_distribution(hu_certificates)
    #hu_certificates = clef_task_data.down_sample_by_icd10_frequency(hu_certificates, 2750)
    #check_label_distribution(hu_certificates)
    #check_word_dictionary_overlap(hu_certificates, hu_dictionary, "data/dictionary/hu-en.txt")

    # fr_certificates = clef_task_data.read_fr_train_certificates()
    # fr_dictionary = clef_task_data.read_fr_dictionary()
    # check_label_distribution(fr_certificates)
    # fr_certificates = clef_task_data.down_sample_by_icd10_frequency(fr_certificates, 2750)
    # check_label_distribution(fr_certificates)

    # check_word_dictionary_overlap(fr_certificates, fr_dictionary, "data/dictionary/fr-en.txt")

    # certificates = pdu.extract_icd10_chapter("ICD10", "ICD10_chapter").fit_transform(certificates)
    # certificates = pdu.extract_icd10_subchapter("ICD10", "ICD10_subchapter").fit_transform(certificates)
    # print(certificates["ICD10_chapter"].value_counts())
    # print(certificates["ICD10_subchapter"].value_counts())