#max_iter = max_passes * X.shape[ 0] / batch_size
max_iter = 95000000
n_report = X.shape[0] / batch_size

stop = climin.stops.AfterNIterations(max_iter)
pause = climin.stops.ModuloNIterations(n_report)

optimizer = 'gd', {'step_rate': 0.001, 'momentum': 0}

typ = 'plain'
if typ == 'plain':
    m = Mlp(2099, [800, 800],
            15,
            X,
            Z,
            hidden_transfers=['tanh', 'tanh'],
            out_transfer='identity',
            loss='squared',
            optimizer=optimizer,
            batch_size=batch_size,
            max_iter=max_iter)
elif typ == 'fd':
    m = FastDropoutNetwork(2099, [800, 800],
                           14,
                           X,
                           Z,
                           TX,
                           TZ,
                           hidden_transfers=['tanh', 'tanh'],
                           out_transfer='identity',
                           loss='squared',
                           p_dropout_inpt=.1,
batch_size = 25
#max_iter = max_passes * X.shape[ 0] / batch_size
max_iter = 75000000
n_report = X.shape[0] / batch_size


stop = climin.stops.AfterNIterations(max_iter)
pause = climin.stops.ModuloNIterations(n_report)



optimizer = 'gd', {'step_rate': 0.01, 'momentum': 0}

typ = 'fd'
if typ == 'plain':
    m = Mlp(X.shape[1], [100, 100], 1, X, Z,
            hidden_transfers=['tanh', 'tanh'], out_transfer='identity', loss='squared', optimizer=optimizer, batch_size=batch_size, max_iter=max_iter)
elif typ == 'fd':
    m = FastDropoutNetwork(X.shape[1], [100, 100], 1, X, Z,
            hidden_transfers=['tanh', 'tanh'], out_transfer='identity', loss='squared', optimizer=optimizer, batch_size=batch_size,
            p_dropout_inpt=.1,
            p_dropout_hiddens=.5,
            max_iter=max_iter)


m.init_weights()
#Transform the test data
#TX = m.transformedData(TX)
TX = np.array([TX for _ in range(10)]).mean(axis=0)
print TX.shape

losses = []
stop = climin.stops.AfterNIterations(max_iter)
pause = climin.stops.ModuloNIterations(n_report)


optimizer = "gd", {"step_rate": 0.001, "momentum": 0}

typ = "plain"
if typ == "plain":
    m = Mlp(
        2099,
        [400, 400],
        15,
        X,
        Z,
        hidden_transfers=["tanh", "tanh"],
        out_transfer="identity",
        loss="squared",
        optimizer=optimizer,
        batch_size=batch_size,
        max_iter=max_iter,
    )
elif typ == "fd":
    m = FastDropoutNetwork(
        2099,
        [800, 800],
        14,
        X,
        Z,
        TX,
        TZ,
def do_one_eval(X, Z, TX, TZ, test_labels, train_labels, step_rate, momentum, decay, c_wd, counter, opt):
    seed = 3453
    np.random.seed(seed)
    max_passes = 200
    batch_size = 25
    max_iter = 25000000
    n_report = X.shape[0] / batch_size
    weights = []
    optimizer = 'gd', {'step_rate': step_rate, 'momentum': momentum, 'decay': decay}


    stop = climin.stops.AfterNIterations(max_iter)
    pause = climin.stops.ModuloNIterations(n_report)
    # This defines our NN. Since BayOpt does not support categorical data, we just
    # use a fixed hidden layer length and transfer functions.
    m = Mlp(2100, [400, 100], 1, X, Z, hidden_transfers=['tanh', 'tanh'], out_transfer='identity', loss='squared',
            optimizer=optimizer, batch_size=batch_size, max_iter=max_iter)

    #climin.initialize.randomize_normal(m.parameters.data, 0, 1e-3)

    # Transform the test data
    #TX = m.transformedData(TX)
    TX = np.array([TX for _ in range(10)]).mean(axis=0)
    losses = []
    print 'max iter', max_iter

    m.init_weights()

    for layer in m.mlp.layers:
        weights.append(m.parameters[layer.weights])


    weight_decay = ((weights[0]**2).sum()
                        + (weights[1]**2).sum()
                        + (weights[2]**2).sum())

    weight_decay /= m.exprs['inpt'].shape[0]
    m.exprs['true_loss'] = m.exprs['loss']
    c_wd = c_wd
    m.exprs['loss'] = m.exprs['loss'] + c_wd * weight_decay

    mae = T.abs_((m.exprs['output'] * np.std(train_labels) + np.mean(train_labels))- m.exprs['target']).mean()
    f_mae = m.function(['inpt', 'target'], mae)

    rmse = T.sqrt(T.square((m.exprs['output'] * np.std(train_labels) + np.mean(train_labels))- m.exprs['target']).mean())
    f_rmse = m.function(['inpt', 'target'], rmse)

    start = time.time()
    # Set up a nice printout.
    keys = '#', 'seconds', 'loss', 'val loss', 'mae_train', 'rmse_train', 'mae_test', 'rmse_test'
    max_len = max(len(i) for i in keys)
    header = '\t'.join(i for i in keys)
    print header
    print '-' * len(header)
    results = open('result.txt', 'a')
    results.write(header + '\n')
    results.write('-' * len(header) + '\n')
    results.write("%f %f %f %f %s" %(step_rate, momentum, decay, c_wd, opt))
    results.write('\n')
    results.close()

    EXP_DIR = os.getcwd()
    base_path = os.path.join(EXP_DIR, "pars_hp_"+opt+str(counter)+".pkl")
    n_iter = 0

    if os.path.isfile(base_path):
        with open("pars_hp_"+opt+str(counter)+".pkl", 'rb') as tp:
            print 'am here'
            n_iter, best_pars = dill.load(tp)
            m.parameters.data[...] = best_pars

    for i, info in enumerate(m.powerfit((X, Z), (TX, TZ), stop, pause)):
        if info['n_iter'] % n_report != 0:
            continue
        passed = time.time() - start
        if math.isnan(info['loss']) == True:
            info.update({'mae_test': f_mae(TX, test_labels)})
            n_iter = info['n_iter']
            break
        losses.append((info['loss'], info['val_loss']))
        info.update({
            'time': passed,
            'mae_train': f_mae(X, train_labels),
            'rmse_train': f_rmse(X, train_labels),
            'mae_test': f_mae(TX, test_labels),
            'rmse_test': f_rmse(TX, test_labels)

        })
        info['n_iter'] += n_iter
        row = '%(n_iter)i\t%(time)g\t%(loss)f\t%(val_loss)f\t%(mae_train)g\t%(rmse_train)g\t%(mae_test)g\t%(rmse_test)g' % info
        results = open('result.txt','a')
        print row
        results.write(row + '\n')
        results.close()
        with open("pars_hp_"+opt+str(counter)+".pkl", 'wb') as fp:
            dill.dump((info['n_iter'], info['best_pars']), fp)
        with open("apsis_pars_"+opt+str(counter)+".pkl", 'rb') as fp:
            LAss, opt, step_rate, momentum, decay, c_wd, counter, n_iter1, result1 = dill.load(fp)
        n_iter1 = info['n_iter']
        result1 = info['mae_test']
        with open("apsis_pars_"+opt+str(counter)+".pkl", 'wb') as fp:
            dill.dump((LAss, opt, step_rate, momentum, decay, c_wd, counter, n_iter1, result1), fp)


    return info['mae_test'], info['n_iter']