learning_rate = theano.shared(
    np.array(LEARNING_RATE_SCHEDULE[0], dtype=theano.config.floatX))

idx = T.lscalar('idx')

givens = {
    l0.input_var: xs_shared[0][idx * BATCH_SIZE:(idx + 1) * BATCH_SIZE],
    l0_45.input_var: xs_shared[1][idx * BATCH_SIZE:(idx + 1) * BATCH_SIZE],
    l6.target_var: y_shared[idx * BATCH_SIZE:(idx + 1) * BATCH_SIZE],
}

# updates = layers.gen_updates(train_loss, all_parameters, learning_rate=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
updates_nonorm = layers.gen_updates_nesterov_momentum_no_bias_decay(
    train_loss_nonorm,
    all_parameters,
    all_bias_parameters,
    learning_rate=learning_rate,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY)
updates = layers.gen_updates_nesterov_momentum_no_bias_decay(
    train_loss,
    all_parameters,
    all_bias_parameters,
    learning_rate=learning_rate,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY)

train_nonorm = theano.function([idx],
                               train_loss_nonorm,
                               givens=givens,
                               updates=updates_nonorm)
xs_shared = [theano.shared(np.zeros((1,1,1,1), dtype=theano.config.floatX)) for _ in xrange(num_input_representations)]
y_shared = theano.shared(np.zeros((1,1), dtype=theano.config.floatX))

learning_rate = theano.shared(np.array(LEARNING_RATE_SCHEDULE[0], dtype=theano.config.floatX))

idx = T.lscalar('idx')

givens = {
    l0.input_var: xs_shared[0][idx*BATCH_SIZE:(idx+1)*BATCH_SIZE],
    l0_45.input_var: xs_shared[1][idx*BATCH_SIZE:(idx+1)*BATCH_SIZE],
    l6.target_var: y_shared[idx*BATCH_SIZE:(idx+1)*BATCH_SIZE],
}

# updates = layers.gen_updates(train_loss, all_parameters, learning_rate=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
updates_nonorm = layers.gen_updates_nesterov_momentum_no_bias_decay(train_loss_nonorm, all_parameters, all_bias_parameters, learning_rate=learning_rate, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
updates = layers.gen_updates_nesterov_momentum_no_bias_decay(train_loss, all_parameters, all_bias_parameters, learning_rate=learning_rate, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

train_nonorm = theano.function([idx], train_loss_nonorm, givens=givens, updates=updates_nonorm)
train_norm = theano.function([idx], train_loss, givens=givens, updates=updates)
compute_loss = theano.function([idx], valid_loss, givens=givens) # dropout_active=False
compute_output = theano.function([idx], l6.predictions(dropout_active=False), givens=givens, on_unused_input='ignore') # not using the labels, so theano complains
compute_features = theano.function([idx], l4.output(dropout_active=False), givens=givens, on_unused_input='ignore')


print "Train model"
start_time = time.time()
prev_time = start_time

num_batches_valid = x_valid.shape[0] // BATCH_SIZE
losses_train = []