#!/usr/bin/env python3

import pandas as pd
import yaml
import pickle
import pprint
from linear import LinearModel
from svr import SVRModel
from ets import ETSModel
from dt import DTModel
import time

def load_dataset(path):
    return pd.read_pickle(path)

def load_config(path):
    with open(path, 'r') as f:
        configs = yaml.load(f, Loader=yaml.FullLoader)

    return configs


def build(df, config, path):
    features = config['features']
    targets = config['targets']
    techniques = config['techniques']
    supported_models = config['supported_models']


    MODEL = {}
    for technique in techniques:
        df_new = df.query(f'technique == "{technique}"').copy()
        print(technique, len(df_new))
        model = ETSModel(df=df_new)
        model.set_features(features)
        model.set_targets(targets)
        model.fit()
        model.df = None
        
        MODEL[technique] = model

    with open(path, 'wb') as f:
        pickle.dump(MODEL, f, pickle.HIGHEST_PROTOCOL)

def remove_outliers(df, out_ranges):
    before_N = len(df)
    for out in out_ranges:
        df = df[(df[out] < df[out].describe(percentiles=[out_ranges[out]['max']])[out_ranges[out]['max_per']]) & (df[out] > df[out].describe(percentiles=[out_ranges[out]['min']])[out_ranges[out]['min_per']])]
    after_N = len(df)
    print('Removing Outliers | Before : {}, After : {}'.format(before_N, after_N))

    return df

def load_model(path):
    with open(path, 'rb') as f:
        model = pickle.load(f)
    return model

def predict(df, model, techniques):
    prediction_results = {}
    for technique in techniques:
        prediction_results[technique] = model[technique].predict(df)

    return prediction_results

def cross_validation(df, config, path, n_splits=10):
    features = config['features']
    targets = config['targets']
    techniques = config['techniques']
    supported_models = config['supported_models']

    CV_RESULT = {}
    for technique in techniques:
        df_new = df.query(f'technique == "{technique}"').copy()
        model = ETSModel(df=df_new)
        model.set_features(features)
        model.set_targets(targets)
        result = model.cross_validation()
        
        CV_RESULT[technique] = result

    with open(path, 'wb') as f:
        pickle.dump(CV_RESULT, f, pickle.HIGHEST_PROTOCOL)

if __name__ == "__main__":
    synthetic_df = load_dataset('../datasets/df_synthetic.pkl')
    realapp_df = load_dataset('../datasets/df_realapp.pkl')
    df = pd.concat([synthetic_df, realapp_df])
    
    additional_df = load_dataset('../datasets/additional_dataset.pkl')
    config = load_config('../model_config.yaml')
    df = df[config['columns']]
    additional_df = additional_df[config['columns']]
    df = pd.concat([df, additional_df])
    dltc_df = df[df['migration_technique'].str.contains('DLTC')]
    not_dltc_df = df[~df['migration_technique'].str.contains('DLTC')]
    dltc_df = dltc_df[~dltc_df['xbzrle_cache_size'].isin([1073741824,4294967296])]
    print(dltc_df['xbzrle_cache_size'].unique())
    df = pd.concat([dltc_df, not_dltc_df])
    # outlier = {}
    # for feature in config['targets']:
    #     outlier[feature] = {}
    #     outlier[feature]['min'] = .001
    #     outlier[feature]['min_per'] = '0.1%'
    #     if feature in ['total_time_ms', 'downtime_ms', 'relative_performance_migration_src']:
    #         outlier[feature]['max'] = .95
    #         outlier[feature]['max_per'] = '95%'
    #     else:
    #         outlier[feature]['max'] = .995
    #         outlier[feature]['max_per'] = '99.5%'
    # df = remove_outliers(df, outlier)
    df['technique'] = df['technique'].replace({"DLTC": 'DCdlt', 'DTC': 'DCZlib', 'ZSTD': 'DCZstd', 'DLTC-DTC': 'DCdlt-DCZlib', 'DTC-NOWAIT': 'DCZlib-Async', 'DLTC-DTC-NOWAIT': 'DCdlt-DCZlib-Async'})
    df.to_csv('../datasets/dataset.csv', index=False)
    
    # build(df, config, f'model_objects/ets_{int(time.time())}.pkl')
    # cross_validation(df, config, f'model_objects/ets_svr_{int(time.time())}.pkl')

    # df = load_dataset('../datasets/df_synthetic.pkl')
    # df = df.iloc[1]
    # print(df)

    # model = load_model('model_objects/test.pkl')
    # prediction_results = predict(df[config['features']], model, config['techniques'])
    # pprint.pprint(prediction_results[df['technique']])
    # print(df['technique'])
    # print(df[config['targets']])
