results.write("-" * len(header) + "\n")
results.close()


EXP_DIR = os.getcwd()
base_path = os.path.join(EXP_DIR, "pars.pkl")
base_path1 = os.path.join(EXP_DIR, "best_pars.pkl")
n_iter = 0

if os.path.isfile(base_path):
    with open("pars.pkl", "rb") as tp:
        n_iter, best_pars = cp.load(tp)
        m.parameters.data[...] = best_pars


for i, info in enumerate(m.powerfit((X, Z), (TX, TZ), stop, pause)):

    if info["n_iter"] % n_report != 0:
        continue

    passed = time.time() - start
    losses.append((info["loss"], info["val_loss"]))

    info.update(
        {
            "time": passed,
            "mae_train": f_mae(X, train_labels),
            "rmse_train": f_rmse(X, train_labels),
            "mae_test": f_mae(TX, test_labels),
            "rmse_test": f_rmse(TX, test_labels),
        }
results = open('result.txt', 'a')
results.write(header + '\n')
results.write('-' * len(header) + '\n')
results.close()

EXP_DIR = os.getcwd()
base_path = os.path.join(EXP_DIR, "pars.pkl")
base_path1 = os.path.join(EXP_DIR, "best_pars.pkl")
n_iter = 0

if os.path.isfile(base_path):
    with open('pars.pkl', 'rb') as tp:
        n_iter, best_pars = cp.load(tp)
        m.parameters.data[...] = best_pars

for i, info in enumerate(m.powerfit((X, Z), (TX, TZ), stop, pause)):

    if info['n_iter'] % n_report != 0:
        continue

    passed = time.time() - start
    losses.append((info['loss'], info['val_loss']))

    info.update({
        'time': passed,
        'mae_train': f_mae(X, train_labels),
        'rmse_train': f_rmse(X, train_labels),
        'mae_test': f_mae(TX, test_labels),
        'rmse_test': f_rmse(TX, test_labels)
    })
def do_one_eval(X, Z, TX, TZ, test_labels, train_labels, step_rate, momentum, decay, c_wd, counter, opt):
    seed = 3453
    np.random.seed(seed)
    max_passes = 200
    batch_size = 25
    max_iter = 25000000
    n_report = X.shape[0] / batch_size
    weights = []
    optimizer = 'gd', {'step_rate': step_rate, 'momentum': momentum, 'decay': decay}


    stop = climin.stops.AfterNIterations(max_iter)
    pause = climin.stops.ModuloNIterations(n_report)
    # This defines our NN. Since BayOpt does not support categorical data, we just
    # use a fixed hidden layer length and transfer functions.
    m = Mlp(2100, [400, 100], 1, X, Z, hidden_transfers=['tanh', 'tanh'], out_transfer='identity', loss='squared',
            optimizer=optimizer, batch_size=batch_size, max_iter=max_iter)

    #climin.initialize.randomize_normal(m.parameters.data, 0, 1e-3)

    # Transform the test data
    #TX = m.transformedData(TX)
    TX = np.array([TX for _ in range(10)]).mean(axis=0)
    losses = []
    print 'max iter', max_iter

    m.init_weights()

    for layer in m.mlp.layers:
        weights.append(m.parameters[layer.weights])


    weight_decay = ((weights[0]**2).sum()
                        + (weights[1]**2).sum()
                        + (weights[2]**2).sum())

    weight_decay /= m.exprs['inpt'].shape[0]
    m.exprs['true_loss'] = m.exprs['loss']
    c_wd = c_wd
    m.exprs['loss'] = m.exprs['loss'] + c_wd * weight_decay

    mae = T.abs_((m.exprs['output'] * np.std(train_labels) + np.mean(train_labels))- m.exprs['target']).mean()
    f_mae = m.function(['inpt', 'target'], mae)

    rmse = T.sqrt(T.square((m.exprs['output'] * np.std(train_labels) + np.mean(train_labels))- m.exprs['target']).mean())
    f_rmse = m.function(['inpt', 'target'], rmse)

    start = time.time()
    # Set up a nice printout.
    keys = '#', 'seconds', 'loss', 'val loss', 'mae_train', 'rmse_train', 'mae_test', 'rmse_test'
    max_len = max(len(i) for i in keys)
    header = '\t'.join(i for i in keys)
    print header
    print '-' * len(header)
    results = open('result.txt', 'a')
    results.write(header + '\n')
    results.write('-' * len(header) + '\n')
    results.write("%f %f %f %f %s" %(step_rate, momentum, decay, c_wd, opt))
    results.write('\n')
    results.close()

    EXP_DIR = os.getcwd()
    base_path = os.path.join(EXP_DIR, "pars_hp_"+opt+str(counter)+".pkl")
    n_iter = 0

    if os.path.isfile(base_path):
        with open("pars_hp_"+opt+str(counter)+".pkl", 'rb') as tp:
            print 'am here'
            n_iter, best_pars = dill.load(tp)
            m.parameters.data[...] = best_pars

    for i, info in enumerate(m.powerfit((X, Z), (TX, TZ), stop, pause)):
        if info['n_iter'] % n_report != 0:
            continue
        passed = time.time() - start
        if math.isnan(info['loss']) == True:
            info.update({'mae_test': f_mae(TX, test_labels)})
            n_iter = info['n_iter']
            break
        losses.append((info['loss'], info['val_loss']))
        info.update({
            'time': passed,
            'mae_train': f_mae(X, train_labels),
            'rmse_train': f_rmse(X, train_labels),
            'mae_test': f_mae(TX, test_labels),
            'rmse_test': f_rmse(TX, test_labels)

        })
        info['n_iter'] += n_iter
        row = '%(n_iter)i\t%(time)g\t%(loss)f\t%(val_loss)f\t%(mae_train)g\t%(rmse_train)g\t%(mae_test)g\t%(rmse_test)g' % info
        results = open('result.txt','a')
        print row
        results.write(row + '\n')
        results.close()
        with open("pars_hp_"+opt+str(counter)+".pkl", 'wb') as fp:
            dill.dump((info['n_iter'], info['best_pars']), fp)
        with open("apsis_pars_"+opt+str(counter)+".pkl", 'rb') as fp:
            LAss, opt, step_rate, momentum, decay, c_wd, counter, n_iter1, result1 = dill.load(fp)
        n_iter1 = info['n_iter']
        result1 = info['mae_test']
        with open("apsis_pars_"+opt+str(counter)+".pkl", 'wb') as fp:
            dill.dump((LAss, opt, step_rate, momentum, decay, c_wd, counter, n_iter1, result1), fp)


    return info['mae_test'], info['n_iter']