import sys

import numpy as np
import hex3_old as hex3
import hex3_chen_old as hex3_chen
import copy
import pandas as pd


def gen_dataset(N, ttratio=1.0):
    # Smith, Noah A., and Roy W. Tromble. "Sampling uniformly from the unit simplex." Johns Hopkins University, Tech. Rep 29 (2004).
    dim = 3
    x = np.sort(np.random.rand(dim - 1, N * 20), axis=0)
    x = np.concatenate([np.zeros((1, N * 20)), x, np.ones((1, N * 20))], axis=0)
    alpha = x[1:] - x[:-1]

    # # Checking uniformity
    # ax = plt.axes(projection='3d')
    # ax.plot(alpha[0], alpha[1], alpha[2] ,'b.')
    # plt.show()

    alpha = alpha[:-1]

    parspan = {}
    # Defining disturbance box [center, variability]
    parspan['T0'] = [60, 10]  # C
    parspan['w0'] = [105, 25]  # kW/K
    parspan['wh1'] = [40, 10]  # kW/K
    parspan['wh2'] = [50, 10]  # kW/K
    parspan['wh3'] = [30, 10]  # kW/K
    parspan['Th1'] = [150, 30]  # C
    parspan['Th2'] = [150, 30]  # C
    parspan['Th3'] = [150, 30]  # C
    parspan['UA1'] = [65, 15]  # kW/K
    parspan['UA2'] = [80, 10]  # kW/K
    parspan['UA3'] = [95, 15]  # kW/K

    # Copied from transfer learning
    parspan['Ts'] = [0, 0]  # C
    parspan['h1'] = [0, 0]  # kW/K
    parspan['h2'] = [0, 0]  # kW/K
    parspan['h3'] = [0, 0]  # kW/K

    randmatrix = np.random.rand(len(parspan), 10 * N)
    parvec = {}
    for i, parname in enumerate(parspan.keys()):
        parvec[parname] = parspan[parname][0] + ttratio * (2 * randmatrix[i] - 1) * (parspan[parname][-1])

    par0 = [{key: value[i] for key, value in parvec.items()} for i in range(10 * N)]

    # Generating measurements, priors and targets
    u_span = []
    d_span = []
    J_span = []
    J_span_chen = []
    grad_span_chen = []
    grad_span = []

    hex3_chen.Ti_max = 135
    hex3.Ti_max = 135

    errors = 0
    finished = 0
    i = 0
    while finished < N:
        params = par0[i]

        u = alpha[:, i]
        # print(u)
        if u[0] > 0.75 or u[0] < 0.1:
            print('u0 too big/small')
            errors += 1
            i += 1
            continue
        if u[1] > 0.75 or u[1] < 0.1:
            print('u2 too big/small')
            errors += 1
            i += 1
            continue
        if np.sum(u) > 0.95:
            print(f'sum u too large')
            errors += 1
            i += 1
            continue

        # Calculate optimal output temp from optimal u
        try:
            cost = hex3.cost(u, copy.deepcopy(params))
            cost_chen = hex3_chen.cost(u, copy.deepcopy(params))

            if not cost['success'] or not cost_chen['success']:
                errors += 1
                i += 1
                print('Bad hex3 cost, errors: ', errors)
                continue
            gradient_chen = hex3_chen.grad(u, copy.deepcopy(params))
            grad = hex3.grad(u, copy.deepcopy(params))

            if not grad['success']:
                errors += 1
                i += 1
                print('Bad hex3 grad, errors: ', errors)
                continue

        except Exception as e:
            errors += 1
            import traceback
            print(f'Error "{e}"', file=sys.stderr)
            traceback.print_exc()
            i += 1
            print('Skipping idx', i)
            continue
        else:
            print('Success solutions: ', finished + 1)
            i += 1
            finished += 1

        # Save values
        u_span.append(np.array(u))

        d_span.append(params)
        J_span.append(-cost['J'][0])
        J_span_chen.append(-cost_chen['J'][0])
        grad_span.append(grad['grad'])

        grad_span_chen.append(gradient_chen['grad'])

    u_span = np.array(u_span)
    d_span = np.array(d_span, dtype=dict)
    J_span = np.array(J_span)
    J_span_chen = np.array(J_span_chen)
    grad_span_chen = np.array(grad_span_chen)
    grad_span = np.array(grad_span)

    return u_span, d_span, J_span, J_span_chen, grad_span_chen, grad_span


def save_data(name, u_span, d_span, J_span, J_span_chen, g_span_chen, g_span):
    u_headers = [f'u{i}' for i in range(u_span.shape[1])]
    g_headers_chen = [f'gc{i}' for i in range(g_span_chen.shape[1])]
    g_headers = [f'g{i}' for i in range(g_span_chen.shape[1])]
    J_header = 'J'
    J_chen_header = 'J_chen'

    u_span_pd = pd.DataFrame.from_dict({key: val for key, val in zip(u_headers, u_span.T)})

    J_span_pd = pd.DataFrame.from_dict({J_header: J_span})
    J_span_chen_pd = pd.DataFrame.from_dict({J_chen_header: J_span_chen})
    d_span_pd = pd.DataFrame.from_records(d_span)

    g_span_chen_pd = pd.DataFrame.from_dict({key: val for key, val in zip(g_headers_chen, g_span_chen.T)})

    g_span_pd = pd.DataFrame.from_dict({key: val for key, val in zip(g_headers, g_span.T)})

    frames = pd.concat([u_span_pd, d_span_pd, J_span_pd, J_span_chen_pd, g_span_pd, g_span_chen_pd], axis=1)

    frames.to_csv(name, index=False)


def load_data(name):
    frames = pd.read_csv(name)
    data = dict(u=frames.iloc[:, :2],
                d=frames.iloc[:, 2:-6],
                J=frames.iloc[:, -6:-5],
                J_chen=frames.iloc[:, -5:-4],
                g=frames.iloc[:, -4:-2],
                gc=frames.iloc[:, -2:])

    return data


if __name__ == '__main__':

    # training sets
    for samples in [100, 500, 1000, 2500]:
        print('Generating training data....')

        np.random.seed(2025)
        u_span, d_span, J_span, J_span_chen, gradient_span_chen, gradient_span = gen_dataset(samples, 1)
        save_data(f'.\\datasets\\train_constrained_gradient{samples}.csv', u_span, d_span, J_span, J_span_chen, gradient_span_chen,
                  gradient_span)
        print('Done')

    # Test set
    print('Generating test data....')
    np.random.seed(2028)
    u_span, d_span, J_span, J_span_chen, gradient_span_chen, gradient_span = gen_dataset(samples, 1.35)
    save_data(f'.\\datasets\\test_constrained_gradient{samples}.csv', u_span, d_span, J_span, J_span_chen, gradient_span_chen,
              gradient_span)
    print('Done')
