import math
import numpy as np
import pandas as pd
from numpy import vstack
from numpy import sqrt
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch import Tensor
from torch.nn import Linear
from torch.nn import Sigmoid
from torch.nn import Module
from torch.optim import SGD
from torch.nn import MSELoss
from torch.nn.init import xavier_uniform_
import torch
import os.path
import matplotlib.pyplot as plt
import matplotlib as mpl


class CSVDataset(Dataset):
    # load the dataset
    def __init__(self, path, numfeatures, normalize):
        # load the csv file as a dataframe
        df = pd.read_csv(path, header=None)
        if normalize:
            df = pd.DataFrame(MinMaxScaler().fit_transform(df))
        # store the inputs and outputs
        self.X = df.values[:, :numfeatures].astype('float32')
        self.y = df.values[:, numfeatures:].astype('float32')

    # number of rows in the dataset
    def __len__(self):
        return len(self.X)

    # get a row at an index
    def __getitem__(self, idx):
        return [self.X[idx], self.y[idx]]

    # get indexes for train and test rows
    def get_splits(self, n_test=0.33):
        # determine sizes
        test_size = round(n_test * len(self.X))
        train_size = len(self.X) - test_size
        # calculate the split
        return random_split(self, [train_size, test_size])


class MLP(Module):
    # define model elements
    def __init__(self, numfeatures):
        super(MLP, self).__init__()
        # input to first hidden layer
        self.hidden1 = Linear(numfeatures, 500)
        xavier_uniform_(self.hidden1.weight)
        self.act1 = Sigmoid()
        # second hidden layer
        self.hidden2 = Linear(500, 200)
        xavier_uniform_(self.hidden2.weight)
        self.act2 = Sigmoid()
        # third hidden layer
        self.hidden3 = Linear(200, 100)
        xavier_uniform_(self.hidden2.weight)
        self.act3 = Sigmoid()
        # fourth hidden layer and output
        self.hidden4 = Linear(100, 6)
        xavier_uniform_(self.hidden3.weight)

    # forward propagate input
    def forward(self, X):
        # input to first hidden layer
        X = self.hidden1(X)
        X = self.act1(X)
        # second hidden layer
        X = self.hidden2(X)
        X = self.act2(X)
        # third hidden layer
        X = self.hidden3(X)
        X = self.act3(X)
        # third hidden layer and output
        X = self.hidden4(X)
        return X


# prepare the dataset
def prepare_data(path, numfeatures):
    # load the dataset
    dataset = CSVDataset(path, numfeatures, normalize=True)
    # print(dataset)
    # scaler = MinMaxScaler(0, 1)
    # scaler.fit(dataset)
    # scaler.transform(dataset)
    # print(dataset)
    # calculate split
    train, test = dataset.get_splits()
    # prepare data loaders
    train_dl = DataLoader(train, batch_size=32, shuffle=True, pin_memory=True)
    test_dl = DataLoader(test, batch_size=1024, shuffle=False, pin_memory=True)
    return train_dl, test_dl


# train the model
def train_model(train_dl, model, device):
    # define the optimization
    criterion = MSELoss()
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    # enumerate epochs
    for epoch in range(3):
        # enumerate mini batches
        for i, (inputs, targets) in enumerate(train_dl):
            inputs, targets = inputs.to(device), targets.to(device)
            # clear the gradients
            optimizer.zero_grad()
            # compute the model output
            yhat = model(inputs)
            # calculate loss
            loss = criterion(yhat, targets)
            # credit assignment
            loss.backward()
            # update model weights
            optimizer.step()


# evaluate the model
def evaluate_model(test_dl, model, device):
    predictions, actuals = list(), list()
    for i, (inputs, targets) in enumerate(test_dl):
        inputs, targets = inputs.to(device), targets.to(device)
        # evaluate the model on the test set
        yhat = model(inputs)
        # retrieve numpy array
        yhat.cpu()
        yhat = yhat.detach().numpy()
        actual = targets.numpy()
        actual = actual.reshape((len(actual), 1))
        # store
        predictions.append(yhat)
        actuals.append(actual)
    predictions, actuals = vstack(predictions), vstack(actuals)
    # calculate mse
    mse = mean_squared_error(actuals, predictions)
    return mse


# make a class prediction for one row of data
def predict(row, model):
    # convert row to data
    row = Tensor([row])
    # make prediction
    yhat = model(row)
    # retrieve numpy array
    yhat = yhat.detach().numpy()
    return yhat


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
# prepare the data
path = 'Datasets\\XFEL_KW0_Results_2.csv'
numfeatures = 8
train_dl, test_dl = prepare_data(path, numfeatures)
# define the network
model = MLP(numfeatures)
model = model.to(device)
if (os.path.isfile("model.pt")):
    # load the model
    model.load_state_dict(torch.load("model.pt"))
    model.eval()
else:
    # train the model
    print(model)
    train_model(train_dl, model, device)
    torch.save(model.state_dict(), "model.pt")

# evaluate the model
mse = evaluate_model(test_dl, model, device)
print('MSE: %.3f, RMSE: %.3f' % (mse, sqrt(mse)))
# # make a single prediction (expect class=1)
# row = [0.00632, 18.00, 2.310, 0, 0.5380, 6.5750, 65.20, 4.0900, 1, 296.0, 15.30, 396.90, 4.98]
# yhat = predict(row, model)
# print('Predicted: %.3f' % yhat)

# START: Plotte dataset
# datasetShape = dataset.shape
# divider = 1000
# filler = np.arange(math.ceil(datasetShape[0] / divider))
# dataset = np.array_split(dataset, divider)
# # print(dataset)
# # print(filler.shape)
# # print(dataset[0][0].shape)
# # print(datasetShape, filler)
# numfeatures = datasetShape[1]
# gs = mpl.gridspec.GridSpec(math.ceil(math.sqrt(numfeatures)), math.ceil(math.sqrt(numfeatures)))
# fig = plt.figure()
# for i in range(numfeatures):
#     ax = fig.add_subplot(gs[i])
#     ax.scatter(filler, dataset[0][i])
#     # ax.tight_layout()
# plt.show()
# ENDE: Plotte dataset