import pathlib
import itertools
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# 1. Update and the fonts in the parent direcotry
# import os
# import sys
# sys.path.insert(0, os.path.dirname(os.getcwd()))
# import vistk
# vistk.FONTS_DIR = '../fonts/'
# vistk.update_font_properties()

# 2. Reload the vistk package
# from importlib import reload
# reload(vistk)

FONTS_DIR = 'fonts'
FONT_PATHS = {}
TICK_FORMATTERS = {}
TICK_LOCATORS = {}

PAPER_TEMPLATES = {
    'IEEE CompSoc': {'text': 7.17, 'column': 3.5},
    'SoCC': {'text': 6.5, 'column': 3.3},
    'SoCC-large': {'text': 7.3, 'column': 3.5},
    'SNU': {'text': 5.5, 'column': 5.5, 'single_figure': 3.85},
    'jupyter': {
        'small': 4, 'smallmedium': 6, 'medium': 8,
        'mediumlarge': 10, 'large': 16, 'full': 24
    }
}

COLORMAP_VMAX = {
    'Pastel1': 9, 'Pastel2': 8, 'Paired': 12, 'Accent': 8, 'Dark2': 8,
    'Set1': 9, 'Set2': 8, 'Set3': 12, 'tab10': 10, 'tab20': 20,
    'tab20b': 20, 'tab20c': 20
}

COLORMAP_GROUP_SIZE = {
    'tab20': 2, 'tab20b': 4, 'tab20c': 4
}


def get_colors(colormap, n, vmax=None, reverse=False, shuffle=True):
    colors = []
    cmap = plt.cm.get_cmap(colormap)

    if colormap in COLORMAP_VMAX:
        vmax = COLORMAP_VMAX[colormap]

    if vmax is None:
        vmax = n - 1

    norm = mpl.colors.Normalize(vmin=0, vmax=vmax)
    for i in range(n):
        if shuffle and colormap in COLORMAP_GROUP_SIZE:
            group_size = COLORMAP_GROUP_SIZE[colormap]
            i = ((i * group_size) % vmax) + int(i / (vmax / group_size))
        colors.append(cmap(norm(i)))

    if reverse:
        colors = list(reversed(colors))

    return colors


def cdf(a):
    a = np.array(a)
    x = np.sort(a)
    y = np.array(range(a.size)) / float(a.size)
    return x, y


def ccdf(a):
    a = np.array(a)
    x = np.sort(a)
    y = (np.full(a.shape, fill_value=a.size) - np.array(range(a.size))) / float(a.size)
    return x, y


def set_type1_font():
    try:
        plt.rc('pdf', fonttype=1)
        plt.rc('ps', fonttype=1)
    except ValueError as e:
        print(f'{e}, Use type 42 instead of type 1')
        plt.rc('pdf', fonttype=42)
        plt.rc('ps', fonttype=42)


def update_font_properties():
    for p in pathlib.Path(FONTS_DIR).glob('*.ttf'):
        name = p.name[:-4]
        FONT_PATHS[name] = str(p.absolute())


def update_tick_formatters():
    TICK_FORMATTERS['null'] = mpl.ticker.NullFormatter()
    TICK_FORMATTERS['scalar'] = mpl.ticker.ScalarFormatter()
    TICK_FORMATTERS['log'] = mpl.ticker.LogFormatter(base=10)
    TICK_FORMATTERS['eng'] = mpl.ticker.EngFormatter()


def update_tick_locators():
    TICK_LOCATORS['null'] = mpl.ticker.NullLocator()
    TICK_LOCATORS['log'] = mpl.ticker.LogLocator(base=10)
    TICK_LOCATORS['auto'] = mpl.ticker.AutoLocator()


def beautify(ax, title=None, xlabel=None, ylabel=None,
             font='Times New Roman', fontsize=8, linewidth=1.0,
             spineslinewidth=0.8, titlepad=None,
             xlabelpad=2.0, ylabelpad=2.0, xtickpad=2.0, ytickpad=2.0,
             xfreq=None, yfreq=None, xrot=None, yrot=None,
             xtick_horizontal_alignment='center',
             ytick_horizontal_alignment='right',
             xscale=None, yscale=None, basex=10, basey=10,
             xtick_major_formatter=None, xtick_major_locator=None,
             ytick_major_formatter=None, ytick_major_locator=None,
             xtick_minor_formatter=None, xtick_minor_locator=None,
             ytick_minor_formatter=None, ytick_minor_locator=None,
             xlim=None, ylim=None, xticks=None, yticks=None,
             xticklabels=None, yticklabels=None,
             grid=True, grid_line_style='-', grid_line_width=0.5,
             grid_line_color='silver', grid_line_alpha=0.5,
             legend=False, legend_font_size=6, legend_loc='best',
             legend_ncol=2, legend_labelspacing=0.3, legend_borderpad=0.3,
             legend_handlelength=1.5, legend_handletextpad=0.5,
             legend_columnspacing=0.5, legend_linewidth=0.5):

    fontprop = mpl.font_manager.FontProperties(
            fname=FONT_PATHS[font], size=fontsize)

    if title is not None:
        ax.set_title(f'{title}', fontproperties=fontprop, pad=titlepad)
    if xlabel is not None:
        ax.set_xlabel(f'{xlabel}', fontproperties=fontprop, labelpad=xlabelpad)
    if ylabel is not None:
        ax.set_ylabel(f'{ylabel}', fontproperties=fontprop, labelpad=ylabelpad)

    # https://github.com/matplotlib/matplotlib/issues/15845
    ax.tick_params(labelsize=fontsize, which='minor')

    # https://stackoverflow.com/questions/21512305/inconsistent-font-size-for-scientific-notation-in-axis
    ax.xaxis.get_offset_text().set_fontproperties(fontprop)
    ax.yaxis.get_offset_text().set_fontproperties(fontprop)

    for tick in ax.get_ymajorticklabels():
        tick.set_fontproperties(fontprop)
    for tick in ax.get_yminorticklabels():
        tick.set_fontproperties(fontprop)

    for tick in ax.get_xmajorticklabels():
        tick.set_fontproperties(fontprop)
    for tick in ax.get_xminorticklabels():
        tick.set_fontproperties(fontprop)

    for tick in ax.get_yticklabels():
        tick.set_fontproperties(fontprop)
        tick.set_rotation(yrot)
        tick.set_horizontalalignment(ytick_horizontal_alignment)

    for tick in ax.get_xticklabels():
        tick.set_fontproperties(fontprop)
        tick.set_rotation(xrot)
        tick.set_horizontalalignment(xtick_horizontal_alignment)

    ax.get_xaxis().set_tick_params(width=linewidth, pad=xtickpad)
    ax.get_yaxis().set_tick_params(width=linewidth, pad=ytickpad)

    for axis in ['top', 'bottom', 'left', 'right']:
        ax.spines[axis].set_linewidth(spineslinewidth)

    if xscale is not None:
        params = {} if xscale == 'linear' else {'base': basex}
        ax.set_xscale(xscale, **params)

    if yscale is not None:
        params = {} if yscale == 'linear' else {'base': basey}
        ax.set_yscale(yscale, **params)

    if xfreq is not None:
        loc = mpl.ticker.MultipleLocator(base=xfreq)
        ax.get_xaxis().set_major_locator(loc)
    if yfreq is not None:
        loc = mpl.ticker.MultipleLocator(base=yfreq)
        ax.get_yaxis().set_major_locator(loc)

    if xtick_major_formatter is not None:
        if type(xtick_major_formatter) is str:
            formatter = TICK_FORMATTERS[xtick_major_formatter]
        else:
            formatter = xtick_major_formatter
        ax.get_xaxis().set_major_formatter(formatter)

    if ytick_major_formatter is not None:
        if type(ytick_major_formatter) is str:
            formatter = TICK_FORMATTERS[ytick_major_formatter]
        else:
            formatter = ytick_major_formatter
        ax.get_yaxis().set_major_formatter(formatter)

    if xtick_major_locator is not None:
        if type(xtick_major_locator) is str:
            locator = TICK_LOCATORS[xtick_major_locator]
        else:
            locator = xtick_major_locator
        ax.get_xaxis().set_major_locator(locator)

    if ytick_major_locator is not None:
        if type(ytick_major_locator) is str:
            locator = TICK_LOCATORS[ytick_major_locator]
        else:
            locator = ytick_major_locator
        ax.get_yaxis().set_major_locator(locator)

    if xtick_minor_formatter is not None:
        if type(xtick_minor_formatter) is str:
            formatter = TICK_FORMATTERS[xtick_minor_formatter]
        else:
            formatter = xtick_minor_formatter
        ax.get_xaxis().set_minor_formatter(formatter)

    if ytick_minor_formatter is not None:
        if type(ytick_minor_formatter) is str:
            formatter = TICK_FORMATTERS[ytick_minor_formatter]
        else:
            formatter = ytick_minor_formatter
        ax.get_yaxis().set_minor_formatter(formatter)

    if xtick_minor_locator is not None:
        if type(xtick_minor_locator) is str:
            locator = TICK_LOCATORS[xtick_minor_locator]
        else:
            locator = xtick_minor_locator
        ax.get_xaxis().set_minor_locator(locator)

    if ytick_minor_locator is not None:
        if type(ytick_minor_locator) is str:
            locator = TICK_LOCATORS[ytick_minor_locator]
        else:
            locator = ytick_minor_locator
        ax.get_yaxis().set_minor_locator(locator)

    if grid:
        ax.set_axisbelow(True)
        ax.grid(True, linestyle=grid_line_style, linewidth=grid_line_width,
                color=grid_line_color, alpha=grid_line_alpha)

    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)

    if xticks is not None:
        ax.set_xticks(xticks)
    if yticks is not None:
        ax.set_yticks(yticks)

    if xticklabels is not None:
        ax.set_xticklabels(xticklabels)
    if yticklabels is not None:
        ax.set_yticklabels(yticklabels)

    if legend:
        legend_fontprop = mpl.font_manager.FontProperties(
                fname=FONT_PATHS[font], size=legend_font_size)
        legend = ax.legend(
                loc=legend_loc, prop=legend_fontprop, edgecolor='k',
                labelspacing=legend_labelspacing,
                borderpad=legend_borderpad,
                handlelength=legend_handlelength,
                handletextpad=legend_handletextpad,
                columnspacing=legend_columnspacing, ncol=legend_ncol)
        legend.get_frame().set_linewidth(legend_linewidth)


def central_legend(fig, handles, labels, font='Times New Roman', legend_font_size=6, legend_loc='best',
             legend_ncol=2, legend_labelspacing=0.3, legend_borderpad=0.3,
             legend_handlelength=1.5, legend_handletextpad=0.5,
             legend_columnspacing=0.5, legend_linewidth=0.5, bbox_to_anchor=()):
        legend_fontprop = mpl.font_manager.FontProperties(
                fname=FONT_PATHS[font], size=legend_font_size)
        legend = fig.legend(handles, labels,
                loc=legend_loc, prop=legend_fontprop, edgecolor='k',
                labelspacing=legend_labelspacing,
                borderpad=legend_borderpad,
                handlelength=legend_handlelength,
                handletextpad=legend_handletextpad,
                columnspacing=legend_columnspacing, ncol=legend_ncol, bbox_to_anchor=bbox_to_anchor)
        legend.get_frame().set_linewidth(legend_linewidth)

def create_subplots(nrows=1, ncols=1, template='IEEE CompSoc', width='text',
                    ratio=(4, 3), figsize=None, gridspec_kw={}, sharex=False, sharey=False):
    if figsize is None:
        w = PAPER_TEMPLATES[template][width] #7.3
        h = nrows * ((w / ncols) * (ratio[1] / ratio[0])) #1 * 3.3 / 1 * 2.8/4
    else:
        w, h = figsize
    if gridspec_kw is None:
        return plt.subplots(nrows=nrows, ncols=ncols, figsize=(w, h), sharex=sharex, sharey=sharey)
    else:
        return plt.subplots(nrows=nrows, ncols=ncols, figsize=(w, h), gridspec_kw=gridspec_kw, sharex=sharex, sharey=sharey)



def flatten(axes):
    if not isinstance(axes, np.ndarray):
        return axes, None
    indexes = list(itertools.product(*[list(range(n)) for n in axes.shape]))
    return [axes[i] for i in indexes], indexes


def show():
    plt.draw()
    plt.ioff()
    plt.show()


def savefig(fig, path, format='pdf', bbox_inches='tight', pad_inches=0):
    fig.savefig(path, format=format, bbox_inches=bbox_inches, pad_inches=pad_inches)


def get_default_boxplot_params():
    params = dict(
        whis=(5, 95),
        showfliers=False, showmeans=True, meanline=True, patch_artist=True,
        whiskerprops=dict(linestyle='-', linewidth=0.5),
        medianprops=dict(linestyle='-', linewidth=1.0, color='firebrick'),
        meanprops=dict(linestyle='-', linewidth=1.0, color='k'),
        boxprops=dict(edgecolor='k', linewidth=0.5, facecolor='#a1cdff'))

    return params


def boxplot(ax, data, **args):
    params = get_default_boxplot_params()
    params.update(args)
    bplot = ax.boxplot(data, **params)
    ax.plot([], [], '-', linewidth=1.0, color='k', label='Mean')
    ax.plot([], [], '-', linewidth=1.0, color='firebrick', label='Median')

    alpha = args['alpha'] if 'alpha' in args else 1.0
    for patch in ax.artists:
        r, g, b, a = patch.get_facecolor()
        patch.set_facecolor((r, g, b, alpha))

    return bplot


def get_default_hist_params():
    return dict(bins=10, histtype='stepfilled', edgecolor='k', linewidth=0.5)


def hist(df, groupby=None, ncols=4, template='IEEE CompSoc', place='text',
         ratio=(4, 3), color='#1f77b4', colormap='tab10',
         colormap_reverse=False, colormap_shuffle=False, **args):
    columns = list(df.columns)
    if groupby is not None:
        columns = [c for c in columns if c != groupby]

    ncols = min(ncols, len(columns))
    nrows = int(np.ceil(len(columns) / ncols))

    width = PAPER_TEMPLATES[template][place]
    height = nrows * ((width / ncols) * (ratio[1] / ratio[0]))

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(width, height))

    index = []
    flatten = []
    if hasattr(axes, 'shape'):
        for i, tp in enumerate(itertools.product(*[list(range(n)) for n in axes.shape])):
            if i >= len(columns):
                break
            index.append((columns[i], tp[0], tp[1]))
            flatten.append(axes[tp])
    else:
        flatten.append(axes)

    params = get_default_hist_params()
    params.update(args)

    if groupby is not None:
        unique_labels = df[groupby].unique()
        colors = get_colors(
            colormap, len(unique_labels),
            reverse=colormap_reverse, shuffle=colormap_shuffle)
        for i, label in enumerate(unique_labels):
            color = colors[i]
            for j, col_name in enumerate(columns):
                x = df[df[groupby] == label][col_name].to_numpy(copy=True)
                params['label'] = label
                params['color'] = color
                flatten[j].hist(x, **params)
    else:
        for i, col_name in enumerate(columns):
            x = df[col_name].to_numpy(copy=True)
            params['color'] = color
            flatten[i].hist(x, **params)

    return fig, axes, index


update_font_properties()
update_tick_formatters()
update_tick_locators()