from model_zoo import bn_feature_net_multires_61x61 as the_model

import os
import datetime
import numpy as np

batch_size = 256
n_classes = 3
n_epoch = 25

model = the_model(n_channels=2, n_features=3, reg=1e-5)
dataset = "HeLa_all_stdnorm_61x61"
direc_save = "/home/nquach/DeepCell2/trained_networks/"
direc_data = "/home/nquach/training_data_npz/"
optimizer = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
lr_sched = rate_scheduler(lr=0.01, decay=0.95)
expt = "bn_multires_stdnorm"

for iterate in xrange(5):
    train_model_sample(model=model,
                       dataset=dataset,
                       optimizer=optimizer,
                       expt=expt,
                       it=iterate,
                       batch_size=batch_size,
                       n_epoch=n_epoch,
                       direc_save=direc_save,
                       direc_data=direc_data,
                       lr_sched=lr_sched,
                       rotate=True,
                       flip=True,
Y_test = np_utils.to_categorical(Y_test, nb_classes)

model = bn_feature_net_61x61(n_channels = 2, reg = 1e-5)

# let's train the model using SGD + momentum (how original).
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
# rmsprop = RMSprop(lr = 0.001, rho = 0.95, epsilon = 1e-8)

model.compile(loss='categorical_crossentropy',
			  optimizer=sgd,
			  metrics=['accuracy'])

print('Using real-time data augmentation.')

# this will do preprocessing and realtime data augmentation
datagen = ImageDataGenerator(
	rotate = True,  # randomly rotate images by 90 degrees
	shear_range = 0, # randomly shear images in the range (radians , -shear_range to shear_range)
	horizontal_flip=True,  # randomly flip images
	vertical_flip=True)  # randomly flip images

# fit the model on the batches generated by datagen.flow()
loss_history = model.fit_generator(datagen.sample_flow(train_dict, batch_size=batch_size),
					samples_per_epoch=len(train_dict["labels"]),
					nb_epoch=nb_epoch,
					validation_data=(X_test, Y_test),
					callbacks = [ModelCheckpoint(file_name_save, monitor = 'val_loss', verbose = 0, save_best_only = True, mode = 'auto'),
						LearningRateScheduler(rate_scheduler(lr = 0.01, decay = 0.95))])

model.save_weights(file_name_save)
np.savez(file_name_save_loss, loss_history = loss_history)