#!/usr/bin/env python3
import os
import multiprocessing
import sklearn
import numpy as np
import pickle
import sklearn
import sklearn.svm
import sklearn.pipeline
import sklearn.ensemble
import sklearn.multioutput
import sklearn.linear_model
import sklearn.preprocessing
import sklearn.neural_network
import sklearn.model_selection
import sklearn.feature_selection

from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVR
from sklearn.ensemble import BaggingRegressor
from sklearn.ensemble import ExtraTreesRegressor

from migration_model import MigrationModel

class ETSModel(MigrationModel):
    """
    A simple implementation of SVR VM migration model.
    """
    def __init__(self, df, features=None, targets=None):
        super().__init__(df, features=features, targets=targets)

    def get_cv_result(self):
        return self.__cv_result

    def cross_validation(self, n_splits=10):
        if self.features is None:
            raise AssertionError('The feature list is None.')
        if self.targets is None:
            raise AssertionError('The target list is None.')

        X = self.df[self.features].to_numpy()
        Y = self.df[self.targets].to_numpy()
        result = dict()
        result['X'] = X
        result['Y'] = Y
        result['CV'] = list()

        kf = sklearn.model_selection.KFold(n_splits=n_splits)
        for train_index, test_index in kf.split(X):
            X_train = X[train_index]
            Y_train = Y[train_index]
            X_test = X[test_index]
            Y_test = Y[test_index]

            X_train = np.where(np.isnan(X_train), 0, X_train)
            Y_train = np.where(np.isnan(Y_train), 0, Y_train)
            X_test = np.where(np.isnan(X_test), 0, X_test)
            Y_test = np.where(np.isnan(Y_test), 0, Y_test)

            # QTF = sklearn.preprocessing.QuantileTransformer
            # X_scaler = QTF(output_distribution='normal', n_quantiles=X_train.shape[0])
            # Y_scaler = QTF(output_distribution='normal', n_quantiles=Y_train.shape[0])
            SS = sklearn.preprocessing.StandardScaler
            X_scaler = SS()
            Y_scaler = SS()
            X_scaler.fit(X_train)
            Y_scaler.fit(Y_train)

            MOR = sklearn.multioutput.MultiOutputRegressor
            regr = MOR(ExtraTreesRegressor())
            regr.fit(X_scaler.transform(X_train), Y_scaler.transform(Y_train))

            prediction = Y_scaler.inverse_transform(
                regr.predict(X_scaler.transform(X_test)))

            cv_result = dict()
            cv_result['X_train'] = X_train
            cv_result['Y_train'] = Y_train
            cv_result['X_test'] = X_test
            cv_result['Y_test'] = Y_test
            cv_result['X_scaler'] = X_scaler
            cv_result['Y_scaler'] = Y_scaler
            cv_result['prediction'] = prediction
            result['CV'].append(cv_result)

        self.__cv_result = result

        return result

    def fit(self):
        X = self.df[self.features]
        Y = self.df[self.targets]
        # X = self.preprocess_x(X)
        # Y = self.preprocess_y(Y)
        X = X.to_numpy()
        Y = Y.to_numpy()
        X = np.where(np.isnan(X), 0, X)
        Y = np.where(np.isnan(Y), 0, Y)
        # QTF = sklearn.preprocessing.QuantileTransformer
        # X_scaler = QTF(output_distribution='normal', n_quantiles=X.shape[0])
        # Y_scaler = QTF(output_distribution='normal', n_quantiles=Y.shape[0])
        SS = sklearn.preprocessing.StandardScaler
        X_scaler = SS()
        Y_scaler = SS()
        X_scaler.fit(X)
        Y_scaler.fit(Y)
        MOR = sklearn.multioutput.MultiOutputRegressor
        regr = MOR(ExtraTreesRegressor())
        regr.fit(X_scaler.transform(X), Y_scaler.transform(Y))
        self.__regr = regr
        self.__X_scaler = X_scaler
        self.__Y_scaler = Y_scaler

    def predict(self, feature):
        if self.__regr is None:
            raise AssertionError('Regressor does not exist.')
        if self.__X_scaler is None:
            raise AssertionError('X_scaler does not exist.')
        if self.__Y_scaler is None:
            raise AssertionError('Y_scaler does not exist.')


        #feature = self.preprocess_x(feature)
        a = np.array([[feature[f] for f in self.features]])
        p = self.__Y_scaler.inverse_transform(
            self.__regr.predict(self.__X_scaler.transform(a)))[0]

        result = dict(zip(self.targets, p))
        return result