Example #1
0
def cross_validate_inmemory(model_name, **kwargs):
    """
    StateFarm competition:
    Training set has 26 unique drivers. We do 26 fold CV where
    a driver is alternatively singled out to be the validation set

    Load the whole train data in memory for faster operations

    args: model (keras model)
          **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    nb_classes = kwargs["nb_classes"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    prob = kwargs["prob"]
    do_plot = kwargs["do_plot"]
    data_file = kwargs["data_file"]
    semi_super_file = kwargs["semi_super_file"]
    pretr_weights_file = kwargs["pretr_weights_file"]
    normalisation_style = kwargs["normalisation_style"]
    weak_labels = kwargs["weak_labels"]
    objective = kwargs["objective"]
    experiment = kwargs["experiment"]
    start_fold = kwargs["start_fold"]

    # Load env variables in (in .env file at the root of the project)
    load_dotenv(find_dotenv())

    # Load env variables
    model_dir = os.path.expanduser(os.environ.get("MODEL_DIR"))
    data_dir = os.path.expanduser(os.environ.get("DATA_DIR"))

    # Output path where we store experiment log and weights
    model_dir = os.path.join(model_dir, model_name)
    # Create if it does not exist
    general_utils.create_dir(model_dir)
    # Automatically determine experiment name
    list_exp = glob.glob(model_dir + "/*")
    # Create the experiment dir and weights dir
    if experiment:
        exp_dir = os.path.join(model_dir, experiment)
    else:
        exp_dir = os.path.join(model_dir, "Experiment_%s" % len(list_exp))
    general_utils.create_dir(exp_dir)

    # Compile model.
    # opt = RMSprop(lr=5E-6, rho=0.9, epsilon=1e-06)
    opt = SGD(lr=5e-4, decay=1e-6, momentum=0.9, nesterov=True)
    # opt = Adam(lr=1E-5, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    # Batch generator
    DataAug = batch_utils.AugDataGenerator(data_file,
                                           batch_size=batch_size,
                                           prob=prob,
                                           dset="train",
                                           maxproc=4,
                                           num_cached=60,
                                           random_augm=False,
                                           hdf5_file_semi=semi_super_file)
    DataAug.add_transform("h_flip")
    # DataAug.add_transform("v_flip")
    # DataAug.add_transform("fixed_rot", angle=40)
    DataAug.add_transform("random_rot", angle=40)
    # DataAug.add_transform("fixed_tr", tr_x=40, tr_y=40)
    DataAug.add_transform("random_tr", tr_x=40, tr_y=40)
    # DataAug.add_transform("fixed_blur", kernel_size=5)
    DataAug.add_transform("random_blur", kernel_size=5)
    # DataAug.add_transform("fixed_erode", kernel_size=4)
    DataAug.add_transform("random_erode", kernel_size=3)
    # DataAug.add_transform("fixed_dilate", kernel_size=4)
    DataAug.add_transform("random_dilate", kernel_size=3)
    # DataAug.add_transform("fixed_crop", pos_x=10, pos_y=10, crop_size_x=200, crop_size_y=200)
    DataAug.add_transform("random_crop", min_crop_size=140, max_crop_size=160)
    # DataAug.add_transform("hist_equal")
    # DataAug.add_transform("random_occlusion", occ_size_x=100, occ_size_y=100)

    epoch_size = n_batch_per_epoch * batch_size

    general_utils.pretty_print("Load all data...")

    with h5py.File(data_file, "r") as hf:
        X = hf["train_data"][:, :, :, :]
        y = hf["train_label"][:].astype(np.uint8)
        y = np_utils.to_categorical(y, nb_classes=nb_classes)  # Format for keras

        try:
            for fold in range(start_fold, 8):
                # for fold in np.random.permutation(26):

                min_valid_loss = 100

                # Save losses
                list_train_loss = []
                list_valid_loss = []

                # Load valid data in memory for fast error evaluation
                idx_valid = hf["valid_fold%s" % fold][:]
                idx_train = hf["train_fold%s" % fold][:]
                X_valid = X[idx_valid]
                y_valid = y[idx_valid]

                # Normalise
                X_valid = normalisation(X_valid, normalisation_style)

                # Compile model
                general_utils.pretty_print("Compiling...")
                model = models.load(model_name,
                                    nb_classes,
                                    X_valid.shape[-3:],
                                    pretr_weights_file=pretr_weights_file)
                model.compile(optimizer=opt, loss=objective)

                # Save architecture
                json_string = model.to_json()
                with open(os.path.join(data_dir, '%s_archi.json' % model.name), 'w') as f:
                    f.write(json_string)

                for e in range(nb_epoch):
                    # Initialize progbar and batch counter
                    progbar = generic_utils.Progbar(epoch_size)
                    batch_counter = 1
                    l_train_loss = []
                    start = time.time()

                    for X_train, y_train in DataAug.gen_batch_inmemory(X, y, idx_train=idx_train):
                        if do_plot:
                            general_utils.plot_batch(X_train, np.argmax(y_train, 1), batch_size)

                        # Normalise
                        X_train = normalisation(X_train, normalisation_style)

                        train_loss = model.train_on_batch(X_train, y_train)
                        l_train_loss.append(train_loss)
                        batch_counter += 1
                        progbar.add(batch_size, values=[("train loss", train_loss)])
                        if batch_counter >= n_batch_per_epoch:
                            break
                    print("")
                    print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))
                    y_valid_pred = model.predict(X_valid, verbose=0, batch_size=16)
                    train_loss = float(np.mean(l_train_loss))  # use float to make it json saveable
                    valid_loss = log_loss(y_valid, y_valid_pred)
                    print("Train loss:", train_loss, "valid loss:", valid_loss)
                    list_train_loss.append(train_loss)
                    list_valid_loss.append(valid_loss)

                    # Record experimental data in a dict
                    d_log = {}
                    d_log["fold"] = fold
                    d_log["nb_classes"] = nb_classes
                    d_log["batch_size"] = batch_size
                    d_log["n_batch_per_epoch"] = n_batch_per_epoch
                    d_log["nb_epoch"] = nb_epoch
                    d_log["epoch_size"] = epoch_size
                    d_log["prob"] = prob
                    d_log["optimizer"] = opt.get_config()
                    d_log["augmentator_config"] = DataAug.get_config()
                    d_log["train_loss"] = list_train_loss
                    d_log["valid_loss"] = list_valid_loss

                    json_file = os.path.join(exp_dir, 'experiment_log_fold%s.json' % fold)
                    general_utils.save_exp_log(json_file, d_log)

                    # Only save the best epoch
                    if valid_loss < min_valid_loss:
                        min_valid_loss = valid_loss
                        trained_weights_path = os.path.join(exp_dir, '%s_weights_fold%s.h5' % (model.name, fold))
                        model.save_weights(trained_weights_path, overwrite=True)

        except KeyboardInterrupt:
            pass
Example #2
0
def cross_validate_inmemory(model_name, **kwargs):
    """
    StateFarm competition:
    Training set has 26 unique drivers. We do 26 fold CV where
    a driver is alternatively singled out to be the validation set

    Load the whole train data in memory for faster operations

    args: model (keras model)
          **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    nb_classes = kwargs["nb_classes"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    prob = kwargs["prob"]
    do_plot = kwargs["do_plot"]
    data_file = kwargs["data_file"]
    semi_super_file = kwargs["semi_super_file"]
    pretr_weights_file = kwargs["pretr_weights_file"]
    normalisation_style = kwargs["normalisation_style"]
    weak_labels = kwargs["weak_labels"]
    objective = kwargs["objective"]
    experiment = kwargs["experiment"]
    start_fold = kwargs["start_fold"]

    # Load env variables in (in .env file at the root of the project)
    load_dotenv(find_dotenv())

    # Load env variables
    model_dir = os.path.expanduser(os.environ.get("MODEL_DIR"))
    data_dir = os.path.expanduser(os.environ.get("DATA_DIR"))

    # Output path where we store experiment log and weights
    model_dir = os.path.join(model_dir, model_name)
    # Create if it does not exist
    general_utils.create_dir(model_dir)
    # Automatically determine experiment name
    list_exp = glob.glob(model_dir + "/*")
    # Create the experiment dir and weights dir
    if experiment:
        exp_dir = os.path.join(model_dir, experiment)
    else:
        exp_dir = os.path.join(model_dir, "Experiment_%s" % len(list_exp))
    general_utils.create_dir(exp_dir)

    # Compile model.
    # opt = RMSprop(lr=5E-6, rho=0.9, epsilon=1e-06)
    opt = SGD(lr=5e-4, decay=1e-6, momentum=0.9, nesterov=True)
    # opt = Adam(lr=1E-5, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    # Batch generator
    DataAug = batch_utils.AugDataGenerator(data_file,
                                           batch_size=batch_size,
                                           prob=prob,
                                           dset="train",
                                           maxproc=4,
                                           num_cached=60,
                                           random_augm=False,
                                           hdf5_file_semi=semi_super_file)
    DataAug.add_transform("h_flip")
    # DataAug.add_transform("v_flip")
    # DataAug.add_transform("fixed_rot", angle=40)
    DataAug.add_transform("random_rot", angle=40)
    # DataAug.add_transform("fixed_tr", tr_x=40, tr_y=40)
    DataAug.add_transform("random_tr", tr_x=40, tr_y=40)
    # DataAug.add_transform("fixed_blur", kernel_size=5)
    DataAug.add_transform("random_blur", kernel_size=5)
    # DataAug.add_transform("fixed_erode", kernel_size=4)
    DataAug.add_transform("random_erode", kernel_size=3)
    # DataAug.add_transform("fixed_dilate", kernel_size=4)
    DataAug.add_transform("random_dilate", kernel_size=3)
    # DataAug.add_transform("fixed_crop", pos_x=10, pos_y=10, crop_size_x=200, crop_size_y=200)
    DataAug.add_transform("random_crop", min_crop_size=140, max_crop_size=160)
    # DataAug.add_transform("hist_equal")
    # DataAug.add_transform("random_occlusion", occ_size_x=100, occ_size_y=100)

    epoch_size = n_batch_per_epoch * batch_size

    general_utils.pretty_print("Load all data...")

    with h5py.File(data_file, "r") as hf:
        X = hf["train_data"][:, :, :, :]
        y = hf["train_label"][:].astype(np.uint8)
        y = np_utils.to_categorical(y,
                                    nb_classes=nb_classes)  # Format for keras

        try:
            for fold in range(start_fold, 8):
                # for fold in np.random.permutation(26):

                min_valid_loss = 100

                # Save losses
                list_train_loss = []
                list_valid_loss = []

                # Load valid data in memory for fast error evaluation
                idx_valid = hf["valid_fold%s" % fold][:]
                idx_train = hf["train_fold%s" % fold][:]
                X_valid = X[idx_valid]
                y_valid = y[idx_valid]

                # Normalise
                X_valid = normalisation(X_valid, normalisation_style)

                # Compile model
                general_utils.pretty_print("Compiling...")
                model = models.load(model_name,
                                    nb_classes,
                                    X_valid.shape[-3:],
                                    pretr_weights_file=pretr_weights_file)
                model.compile(optimizer=opt, loss=objective)

                # Save architecture
                json_string = model.to_json()
                with open(os.path.join(data_dir, '%s_archi.json' % model.name),
                          'w') as f:
                    f.write(json_string)

                for e in range(nb_epoch):
                    # Initialize progbar and batch counter
                    progbar = generic_utils.Progbar(epoch_size)
                    batch_counter = 1
                    l_train_loss = []
                    start = time.time()

                    for X_train, y_train in DataAug.gen_batch_inmemory(
                            X, y, idx_train=idx_train):
                        if do_plot:
                            general_utils.plot_batch(X_train,
                                                     np.argmax(y_train, 1),
                                                     batch_size)

                        # Normalise
                        X_train = normalisation(X_train, normalisation_style)

                        train_loss = model.train_on_batch(X_train, y_train)
                        l_train_loss.append(train_loss)
                        batch_counter += 1
                        progbar.add(batch_size,
                                    values=[("train loss", train_loss)])
                        if batch_counter >= n_batch_per_epoch:
                            break
                    print("")
                    print('Epoch %s/%s, Time: %s' %
                          (e + 1, nb_epoch, time.time() - start))
                    y_valid_pred = model.predict(X_valid,
                                                 verbose=0,
                                                 batch_size=16)
                    train_loss = float(np.mean(
                        l_train_loss))  # use float to make it json saveable
                    valid_loss = log_loss(y_valid, y_valid_pred)
                    print("Train loss:", train_loss, "valid loss:", valid_loss)
                    list_train_loss.append(train_loss)
                    list_valid_loss.append(valid_loss)

                    # Record experimental data in a dict
                    d_log = {}
                    d_log["fold"] = fold
                    d_log["nb_classes"] = nb_classes
                    d_log["batch_size"] = batch_size
                    d_log["n_batch_per_epoch"] = n_batch_per_epoch
                    d_log["nb_epoch"] = nb_epoch
                    d_log["epoch_size"] = epoch_size
                    d_log["prob"] = prob
                    d_log["optimizer"] = opt.get_config()
                    d_log["augmentator_config"] = DataAug.get_config()
                    d_log["train_loss"] = list_train_loss
                    d_log["valid_loss"] = list_valid_loss

                    json_file = os.path.join(
                        exp_dir, 'experiment_log_fold%s.json' % fold)
                    general_utils.save_exp_log(json_file, d_log)

                    # Only save the best epoch
                    if valid_loss < min_valid_loss:
                        min_valid_loss = valid_loss
                        trained_weights_path = os.path.join(
                            exp_dir,
                            '%s_weights_fold%s.h5' % (model.name, fold))
                        model.save_weights(trained_weights_path,
                                           overwrite=True)

        except KeyboardInterrupt:
            pass
Example #3
0
def cross_validate_inmemory(**kwargs):
    """
    StateFarm competition:
    Training set has 26 unique drivers. We do 26 fold CV where
    a driver is alternatively singled out to be the validation set

    Load the whole train data in memory for faster operations

    args: model (keras model)
          **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    nb_classes = kwargs["nb_classes"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    prob = kwargs["prob"]
    data_file = kwargs["data_file"]
    semi_super_file = kwargs["semi_super_file"]
    list_folds = kwargs["list_folds"]
    weak_labels = kwargs["weak_labels"]
    experiment = kwargs["experiment"]

    # Load env variables in (in .env file at the root of the project)
    load_dotenv(find_dotenv())

    # Load env variables
    model_dir = os.path.expanduser(os.environ.get("MODEL_DIR"))
    data_dir = os.path.expanduser(os.environ.get("DATA_DIR"))

    mean_values = np.load("../../data/external/resnet_mean_values.npy")

    # Output path where we store experiment log and weights
    model_dir = os.path.join(model_dir, "ResNet")
    # Create if it does not exist
    general_utils.create_dir(model_dir)
    # Automatically determine experiment name
    list_exp = glob.glob(model_dir + "/*")
    # Create the experiment dir and weights dir
    if experiment:
        exp_dir = exp_dir = os.path.join(model_dir, experiment)
    else:
        exp_dir = os.path.join(model_dir, "Experiment_%s" % len(list_exp))
    general_utils.create_dir(exp_dir)

    # Batch generator
    DataAug = batch_utils.AugDataGenerator(data_file,
                                           batch_size=batch_size,
                                           prob=prob,
                                           dset="train",
                                           maxproc=4,
                                           num_cached=60,
                                           random_augm=False,
                                           hdf5_file_semi=semi_super_file)
    DataAug.add_transform("h_flip")
    DataAug.add_transform("random_rot", angle=40)
    DataAug.add_transform("random_tr", tr_x=40, tr_y=40)
    DataAug.add_transform("random_blur", kernel_size=5)
    DataAug.add_transform("random_crop", min_crop_size=140, max_crop_size=160)

    epoch_size = n_batch_per_epoch * batch_size

    general_utils.pretty_print("Load all data...")

    with h5py.File(data_file, "r") as hf:
        X = hf["train_data"][:, :, :, :]
        y = hf["train_label"][:].astype(np.int32)

        try:
            for fold in list_folds:

                min_valid_loss = 100

                # Save losses
                list_train_loss = []
                list_valid_loss = []

                # Load valid data in memory for fast error evaluation
                idx_valid = hf["valid_fold%s" % fold][:]
                idx_train = hf["train_fold%s" % fold][:]
                X_valid = X[idx_valid]
                y_valid = y[idx_valid]

                # Normalise
                X_valid = X_valid[:, ::-1, :, :]
                X_valid = (X_valid - mean_values).astype(np.float32)

                # Define model
                input_var = T.tensor4('inputs')
                target_var = T.matrix('targets')

                network = build_model(input_var, usage="train")

                prediction = lasagne.layers.get_output(network)
                loss = lasagne.objectives.categorical_crossentropy(
                    prediction, target_var)
                loss = loss.mean()

                params = lasagne.layers.get_all_params(network, trainable=True)

                updates = lasagne.updates.nesterov_momentum(loss,
                                                            params,
                                                            learning_rate=5E-4,
                                                            momentum=0.9)
                train_fn = theano.function([input_var, target_var],
                                           loss,
                                           updates=updates)

                test_prediction = lasagne.layers.get_output(network,
                                                            deterministic=True)
                test_loss = lasagne.objectives.categorical_crossentropy(
                    test_prediction, target_var)
                test_loss = test_loss.mean()

                val_fn = theano.function([input_var, target_var], test_loss)

                # Loop over epochs
                for e in range(nb_epoch):
                    # Initialize progbar and batch counter
                    progbar = generic_utils.Progbar(epoch_size)
                    batch_counter = 1
                    l_train_loss = []
                    l_valid_loss = []
                    start = time.time()

                    for X_train, y_train in DataAug.gen_batch_inmemory(
                            X, y, idx_train=idx_train):
                        if True:
                            general_utils.plot_batch(X_train,
                                                     np.argmax(y_train, 1),
                                                     batch_size)

                        # Normalise
                        X_train - X_train[:, ::-1, :, :]
                        X_train = (X_train - mean_values).astype(np.float32)
                        # Train
                        train_loss = train_fn(X_train,
                                              y_train.astype(np.float32))

                        l_train_loss.append(train_loss)
                        batch_counter += 1
                        progbar.add(batch_size,
                                    values=[("train loss", train_loss)])
                        if batch_counter >= n_batch_per_epoch:
                            break
                    print("")
                    print('Epoch %s/%s, Time: %s' %
                          (e + 1, nb_epoch, time.time() - start))

                    # Split image list into num_chunks chunks
                    chunk_size = batch_size
                    num_imgs = X_valid.shape[0]
                    num_chunks = num_imgs / chunk_size
                    list_chunks = np.array_split(np.arange(num_imgs),
                                                 num_chunks)

                    # Loop over chunks
                    for chunk_idx in list_chunks:
                        X_b, y_b = X_valid[chunk_idx].astype(
                            np.float32), y_valid[chunk_idx]
                        y_b = np_utils.to_categorical(
                            y_b, nb_classes=nb_classes).astype(np.float32)
                        valid_loss = val_fn(X_b, y_b)
                        l_valid_loss.append(valid_loss)

                    train_loss = float(np.mean(
                        l_train_loss))  # use float to make it json saveable
                    valid_loss = float(np.mean(
                        l_valid_loss))  # use float to make it json saveable
                    print("Train loss:", train_loss, "valid loss:", valid_loss)
                    list_train_loss.append(train_loss)
                    list_valid_loss.append(valid_loss)

                    # Record experimental data in a dict
                    d_log = {}
                    d_log["fold"] = fold
                    d_log["nb_classes"] = nb_classes
                    d_log["batch_size"] = batch_size
                    d_log["n_batch_per_epoch"] = n_batch_per_epoch
                    d_log["nb_epoch"] = nb_epoch
                    d_log["epoch_size"] = epoch_size
                    d_log["prob"] = prob
                    d_log["augmentator_config"] = DataAug.get_config()
                    d_log["train_loss"] = list_train_loss
                    d_log["valid_loss"] = list_valid_loss

                    json_file = os.path.join(
                        exp_dir, 'experiment_log_fold%s.json' % fold)
                    general_utils.save_exp_log(json_file, d_log)

                    # Only save the best epoch
                    if valid_loss < min_valid_loss:
                        min_valid_loss = valid_loss
                        trained_weights_path = os.path.join(
                            exp_dir, 'resnet_weights_fold%s.pickle' % fold)
                        model = {
                            'values':
                            lasagne.layers.get_all_param_values(network),
                            'mean_image': mean_values
                        }
                        pickle.dump(model,
                                    open(trained_weights_path, 'wb'),
                                    protocol=-1)

        except KeyboardInterrupt:
            pass
Example #4
0
def cross_validate_inmemory(**kwargs):
    """
    StateFarm competition:
    Training set has 26 unique drivers. We do 26 fold CV where
    a driver is alternatively singled out to be the validation set

    Load the whole train data in memory for faster operations

    args: model (keras model)
          **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    nb_classes = kwargs["nb_classes"]
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    prob = kwargs["prob"]
    data_file = kwargs["data_file"]
    semi_super_file = kwargs["semi_super_file"]
    list_folds = kwargs["list_folds"]
    weak_labels = kwargs["weak_labels"]
    experiment = kwargs["experiment"]

    # Load env variables in (in .env file at the root of the project)
    load_dotenv(find_dotenv())

    # Load env variables
    model_dir = os.path.expanduser(os.environ.get("MODEL_DIR"))
    data_dir = os.path.expanduser(os.environ.get("DATA_DIR"))

    mean_values = np.load("../../data/external/resnet_mean_values.npy")

    # Output path where we store experiment log and weights
    model_dir = os.path.join(model_dir, "ResNet")
    # Create if it does not exist
    general_utils.create_dir(model_dir)
    # Automatically determine experiment name
    list_exp = glob.glob(model_dir + "/*")
    # Create the experiment dir and weights dir
    if experiment:
        exp_dir = exp_dir = os.path.join(model_dir, experiment)
    else:
        exp_dir = os.path.join(model_dir, "Experiment_%s" % len(list_exp))
    general_utils.create_dir(exp_dir)

    # Batch generator
    DataAug = batch_utils.AugDataGenerator(data_file,
                                           batch_size=batch_size,
                                           prob=prob,
                                           dset="train",
                                           maxproc=4,
                                           num_cached=60,
                                           random_augm=False,
                                           hdf5_file_semi=semi_super_file)
    DataAug.add_transform("h_flip")
    DataAug.add_transform("random_rot", angle=40)
    DataAug.add_transform("random_tr", tr_x=40, tr_y=40)
    DataAug.add_transform("random_blur", kernel_size=5)
    DataAug.add_transform("random_crop", min_crop_size=140, max_crop_size=160)

    epoch_size = n_batch_per_epoch * batch_size

    general_utils.pretty_print("Load all data...")

    with h5py.File(data_file, "r") as hf:
        X = hf["train_data"][:, :, :, :]
        y = hf["train_label"][:].astype(np.int32)

        try:
            for fold in list_folds:

                min_valid_loss = 100

                # Save losses
                list_train_loss = []
                list_valid_loss = []

                # Load valid data in memory for fast error evaluation
                idx_valid = hf["valid_fold%s" % fold][:]
                idx_train = hf["train_fold%s" % fold][:]
                X_valid = X[idx_valid]
                y_valid = y[idx_valid]

                # Normalise
                X_valid = X_valid[:, ::-1, :, :]
                X_valid = (X_valid - mean_values).astype(np.float32)

                # Define model
                input_var = T.tensor4('inputs')
                target_var = T.matrix('targets')

                network = build_model(input_var, usage="train")

                prediction = lasagne.layers.get_output(network)
                loss = lasagne.objectives.categorical_crossentropy(prediction, target_var)
                loss = loss.mean()

                params = lasagne.layers.get_all_params(network, trainable=True)

                updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=5E-4, momentum=0.9)
                train_fn = theano.function([input_var, target_var], loss, updates=updates)

                test_prediction = lasagne.layers.get_output(network, deterministic=True)
                test_loss = lasagne.objectives.categorical_crossentropy(test_prediction, target_var)
                test_loss = test_loss.mean()

                val_fn = theano.function([input_var, target_var], test_loss)

                # Loop over epochs
                for e in range(nb_epoch):
                    # Initialize progbar and batch counter
                    progbar = generic_utils.Progbar(epoch_size)
                    batch_counter = 1
                    l_train_loss = []
                    l_valid_loss = []
                    start = time.time()

                    for X_train, y_train in DataAug.gen_batch_inmemory(X, y, idx_train=idx_train):
                        if True:
                            general_utils.plot_batch(X_train, np.argmax(y_train, 1), batch_size)

                        # Normalise
                        X_train - X_train[:, ::-1, :, :]
                        X_train = (X_train - mean_values).astype(np.float32)
                        # Train
                        train_loss = train_fn(X_train, y_train.astype(np.float32))

                        l_train_loss.append(train_loss)
                        batch_counter += 1
                        progbar.add(batch_size, values=[("train loss", train_loss)])
                        if batch_counter >= n_batch_per_epoch:
                            break
                    print("")
                    print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

                    # Split image list into num_chunks chunks
                    chunk_size = batch_size
                    num_imgs = X_valid.shape[0]
                    num_chunks = num_imgs / chunk_size
                    list_chunks = np.array_split(np.arange(num_imgs), num_chunks)

                    # Loop over chunks
                    for chunk_idx in list_chunks:
                        X_b, y_b = X_valid[chunk_idx].astype(np.float32), y_valid[chunk_idx]
                        y_b = np_utils.to_categorical(y_b, nb_classes=nb_classes).astype(np.float32)
                        valid_loss = val_fn(X_b, y_b)
                        l_valid_loss.append(valid_loss)

                    train_loss = float(np.mean(l_train_loss))  # use float to make it json saveable
                    valid_loss = float(np.mean(l_valid_loss))  # use float to make it json saveable
                    print("Train loss:", train_loss, "valid loss:", valid_loss)
                    list_train_loss.append(train_loss)
                    list_valid_loss.append(valid_loss)

                    # Record experimental data in a dict
                    d_log = {}
                    d_log["fold"] = fold
                    d_log["nb_classes"] = nb_classes
                    d_log["batch_size"] = batch_size
                    d_log["n_batch_per_epoch"] = n_batch_per_epoch
                    d_log["nb_epoch"] = nb_epoch
                    d_log["epoch_size"] = epoch_size
                    d_log["prob"] = prob
                    d_log["augmentator_config"] = DataAug.get_config()
                    d_log["train_loss"] = list_train_loss
                    d_log["valid_loss"] = list_valid_loss

                    json_file = os.path.join(exp_dir, 'experiment_log_fold%s.json' % fold)
                    general_utils.save_exp_log(json_file, d_log)

                    # Only save the best epoch
                    if valid_loss < min_valid_loss:
                        min_valid_loss = valid_loss
                        trained_weights_path = os.path.join(exp_dir, 'resnet_weights_fold%s.pickle' % fold)
                        model = {
                            'values': lasagne.layers.get_all_param_values(network),
                            'mean_image': mean_values
                        }
                        pickle.dump(model, open(trained_weights_path, 'wb'), protocol=-1)

        except KeyboardInterrupt:
            pass
def train(**kwargs):
    """
    Train model

    args: **kwargs (dict) keyword arguments that specify the model hyperparameters
    """

    # Roll out the parameters
    batch_size = kwargs["batch_size"]
    n_batch_per_epoch = kwargs["n_batch_per_epoch"]
    nb_epoch = kwargs["nb_epoch"]
    data_file = kwargs["data_file"]
    nb_neighbors = kwargs["nb_neighbors"]
    model_name = kwargs["model_name"]
    training_mode = kwargs["training_mode"]
    epoch_size = n_batch_per_epoch * batch_size
    img_size = int(os.path.basename(data_file).split("_")[1])

    # Setup directories to save model, architecture etc
    general_utils.setup_logging(model_name)

    # Create a batch generator for the color data
    DataGen = batch_utils.DataGenerator(data_file,
                                        batch_size=batch_size,
                                        dset="training")
    c, h, w = DataGen.get_config()["data_shape"][1:]

    # Load the array of quantized ab value
    q_ab = np.load("../../data/processed/pts_in_hull.npy")
    nb_q = q_ab.shape[0]
    # Fit a NN to q_ab
    nn_finder = nn.NearestNeighbors(n_neighbors=nb_neighbors, algorithm='ball_tree').fit(q_ab)

    # Load the color prior factor that encourages rare colors
    prior_factor = np.load("../../data/processed/CelebA_%s_prior_factor.npy" % img_size)

    # Load and rescale data
    if training_mode == "in_memory":
        with h5py.File(data_file, "r") as hf:
            X_train = hf["training_lab_data"][:]

    # Remove possible previous figures to avoid confusion
    for f in glob.glob("../../figures/*.png"):
        os.remove(f)

    try:

        # Create optimizers
        opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        # Load colorizer model
        color_model = models.load(model_name, nb_q, (1, h, w), batch_size)
        color_model.compile(loss='categorical_crossentropy_color', optimizer=opt)

        color_model.summary()
        from keras.utils.visualize_util import plot
        plot(color_model, to_file='../../figures/colorful.png', show_shapes=True, show_layer_names=True)

        # Actual training loop
        for epoch in range(nb_epoch):

            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 1
            start = time.time()

            # Choose Batch Generation mode
            if training_mode == "in_memory":
                BatchGen = DataGen.gen_batch_in_memory(X_train, nn_finder, nb_q, prior_factor)
            else:
                BatchGen = DataGen.gen_batch(nn_finder, nb_q, prior_factor)

            for batch in BatchGen:

                X_batch_black, X_batch_color, Y_batch = batch

                train_loss = color_model.train_on_batch(X_batch_black / 100., Y_batch)

                batch_counter += 1
                progbar.add(batch_size, values=[("loss", train_loss)])

                if batch_counter >= n_batch_per_epoch:
                    break

            print("")
            print('Epoch %s/%s, Time: %s' % (epoch + 1, nb_epoch, time.time() - start))

            # Plot some data with original, b and w and colorized versions side by side
            general_utils.plot_batch(color_model, q_ab, X_batch_black, X_batch_color,
                                     batch_size, h, w, nb_q, epoch)

            # Save weights every 5 epoch
            if epoch % 5 == 0:
                weights_path = os.path.join('../../models/%s/%s_weights_epoch%s.h5' %
                                            (model_name, model_name, epoch))
                color_model.save_weights(weights_path, overwrite=True)

    except KeyboardInterrupt:
        pass