Esempio n. 1
0
def main(data_path=args['data_path'], train_from=train_from):
    train_gen, test_gen = create_data_generator(data_path)
    valid_x, valid_y = create_valid_data()
    if train_from == 'trained_weights':
        model = load_model_from_trained_weights(imagedims=IMAGE_DIMS, nb_classes=len(train_gen.class_indices),
                                                weights=args['weight_path'],
                                                freeze_until=freeze_until)
    elif train_from == 'trained_model':
        model = load_model_from_trained_model()
    else:
        model = load_models(imagedims=IMAGE_DIMS, nb_classes=len(train_gen.class_indices))
    print('[INFO] compiling model...')
    model.compile(loss="categorical_crossentropy", optimizer=OPT, metrics=["accuracy"])

    checkpoint = ModelCheckpoint(filepath=args['save_model'], monitor='val_loss', verbose=0,
                                 save_best_only=True, save_weights_only=False,
                                 mode='auto', period=1)
    stop_early = EarlyStopping(monitor='val_loss', min_delta=.0, patience=40, verbose=0, mode='auto')
    if lr_finder_from == 'large_range_search':
        '''Exponential lr finder,
           USE THIS FOR A LARGE RANGE SEARCH
           Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate
        '''
        lr_finder = LRFinder(NUM_SAMPLES, BS, minimum_lr=1e-3, maximum_lr=10.,
                             lr_scale='exp',
                             validation_data=(valid_x, valid_y),  # use the validation data for losses
                             validation_sample_rate=5,
                             save_dir='weights/', verbose=True)
    elif lr_finder_from == 'close_range_search':
        '''LINEAR lr finder,
           USE THIS FOR A CLOSE RANGE SEARCH
           Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate
        '''
        lr_finder = LRFinder(NUM_SAMPLES, BS, minimum_lr=1e-5, maximum_lr=1e-2,
                             lr_scale='exp',
                             validation_data=(valid_x, valid_y),  # use the validation data for losses
                             validation_sample_rate=5,
                             save_dir='weights/', verbose=True)
    callbacks = [checkpoint, stop_early, lr_finder]
    H = model.fit_generator(train_gen,
                            validation_data=(valid_x, valid_y),
                            epochs=EPOCHS,
                            #steps_per_epoch=209,
                            callbacks=callbacks,
                            verbose=1
                            )
    lr_finder.plot_schedule(clip_beginning=10, clip_endding=5)
# Learning rate finder callback setup
num_samples = X_train.shape[0]

MOMENTUMS = [0.9, 0.95, 0.99]

for momentum in MOMENTUMS:
    K.clear_session()

    # Learning rate range obtained from `find_lr_schedule.py`
    # NOTE : Minimum is 10x smaller than the max found above !
    # NOTE : It is preferable to use the validation data here to get a correct value
    lr_finder = LRFinder(num_samples,
                         batch_size,
                         minimum_lr=1e-3,
                         maximum_lr=1,
                         validation_data=(X_test, Y_test),
                         validation_sample_rate=5,
                         lr_scale='linear',
                         save_dir='weights/momentum/momentum-%s/' %
                         str(momentum),
                         verbose=True)

    model = VGG16Net(img_rows, img_cols, img_channels, 10, 0)
    model.summary()

    # set the weight_decay here !
    # lr doesnt matter as it will be over written by the callback
    optimizer = SGD(lr=0.001288, momentum=momentum, nesterov=True)
    model.compile(loss='categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['accuracy'])
#
#         # Compute quantities required for featurewise normalization
#         # (std, mean, and principal components if ZCA whitening is applied).
#         datagen.fit(X_train)
#
#         # Fit the model on the batches generated by datagen.flow().
#         model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size, shuffle=True),
#                             steps_per_epoch=X_train.shape[0] // batch_size,
#                             validation_data=(X_test, Y_test),
#                             epochs=nb_epoch, verbose=1,
#                             callbacks=[lr_finder])

# from plot we see, the model isnt impacted by the weight_decay very much at all
# so we can use any of them.

save_dir = Path('./weights')
plt.figure()
for momentum in MOMENTUMS:
    directory = 'weights/momentum/momentum-%s/' % str(momentum)

    losses, lrs = LRFinder.restore_schedule_from_dir(directory, 10, 5)
    plt.plot(lrs, losses, label='momentum=%0.2f' % momentum)

plt.title("Momentum")
plt.xlabel("Learning rate")
plt.ylabel("Validation Loss")
plt.legend()
plt.show()
plt.savefig(str(save_dir / 'momentum_finder_from_file.png'))

# Reference: https://github.com/titu1994/keras-one-cycle

# def LRSearch(squeeze_scale_exp, small_filter_rate):
scale = 10**float(sys.argv[1])  # float(sys.argv[1])
small_filter_rate = float(sys.argv[2])  # float(sys.argv[2])
batch_size = 2048
minimum_lr = 1e-8
maximum_lr = 1e8
f = open('data.p', 'rb')
(X_train, y_train), (X_test, y_test) = pickle.load(f)  # cifar100.load_data()
num_samples = len(X_train)

lr_callback = LRFinder(
    num_samples,
    batch_size,
    minimum_lr,
    maximum_lr,
    # validation_data=(X_val, Y_val),
    lr_scale='exp',
    save_dir='lr_log')
op = tf.keras.optimizers.SGD(momentum=0.95)  # , decay=1e-6, momentum=0.9)
model = squeeze_net(small_filter_rate=small_filter_rate,
                    squeeze_scale=scale,
                    verbose=False)
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
model.compile(loss=loss, optimizer=op, metrics=['acc'])
oh = OneHotEncoder(sparse=False)
oh.fit(y_train)
history = model.fit(X_train / 255.,
                    oh.transform(y_train),
                    epochs=1,
                    batch_size=batch_size,
Esempio n. 5
0
#model.compile(loss='binary_crossentropy',
#	optimizer=optimizers.RMSprop(lr=2e-5),
#	metrics=['accuracy'])

from clr import LRFinder

num_samples = train_data.shape[0]
batch_size = 32
minimum_lr = 1e-6
maximum_lr = 1e-1

lr_callback = LRFinder(
    num_samples,
    batch_size,
    minimum_lr,
    maximum_lr,
    # validation_data=(X_val, Y_val),
    lr_scale='exp',
    save_dir='/home/esflores/KFT_ISIC2019')

model.compile(loss='binary_crossentropy',
              optimizer=optimizers.SGD(lr=0.1, momentum=0.9, nesterov=True))

model.fit(train_data,
          train_labels,
          epochs=1,
          batch_size=batch_size,
          callbacks=[lr_callback])

#train_labels[i] == 0 => melanoma
#train_labels[i] == 1 => nevi
Esempio n. 6
0
print("Channel Mean : ", mean)
print("Channel Std : ", std)

X_train = (X_train - mean) / (std)
X_test = (X_test - mean) / (std)

# Learning rate finder callback setup
num_samples = X_train.shape[0]

# Exponential lr finder
# USE THIS FOR A LARGE RANGE SEARCH
# Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate
lr_finder = LRFinder(num_samples, batch_size, minimum_lr=1e-5, maximum_lr=10.,
                     lr_scale='exp',
                     validation_data=(X_test, Y_test),  # use the validation data for losses
                     validation_sample_rate=5,
                     save_dir='weights/', verbose=True)

# Linear lr finder
# USE THIS FOR A CLOSE SEARCH
# Uncomment the validation_data flag to reduce speed but get a better idea of the learning rate
# lr_finder = LRFinder(num_samples, batch_size, minimum_lr=5e-4, maximum_lr=1e-2,
#                      lr_scale='linear',
#                      validation_data=(X_test, y_test),  # use the validation data for losses
#                      validation_sample_rate=5,
#                      save_dir='weights/', verbose=True)

# plot the previous values if present
LRFinder.plot_schedule_from_file('weights/', clip_beginning=10, clip_endding=5)
def main(data_path=args['data_path'], train_from=train_from):
    train_gen, test_gen = create_data_generator(data_path)
    valid_x, valid_y = create_valid_data()
    MOMENTUMS = [0.9, 0.95, 0.99]
    for momentum in MOMENTUMS:
        K.clear_session()
        # Learning rate range obtained from `find_lr_schedule.py`
        # NOTE : Minimum is 10x smaller than the max found above !
        # NOTE : It is preferable to use the validation data here to get a correct value
        lr_finder = LRFinder(NUM_SAMPLES,
                             BS,
                             minimum_lr=0.0001,
                             maximum_lr=0.001,
                             validation_data=(valid_x, valid_y),
                             validation_sample_rate=5,
                             lr_scale='linear',
                             save_dir='weights/momentum/momentum-%s/' %
                             str(momentum),
                             verbose=True)

        if train_from == 'trained_weights':
            model = load_model_from_trained_weights(
                imagedims=IMAGE_DIMS,
                nb_classes=len(train_gen.class_indices),
                weights=args['weight_path'],
                freeze_until=freeze_until)
        elif train_from == 'trained_model':
            model = load_model_from_trained_model()
        else:
            model = load_models(imagedims=IMAGE_DIMS,
                                nb_classes=len(train_gen.class_indices))
        print('[INFO] compiling model...')
        model.compile(loss="categorical_crossentropy",
                      optimizer=OPT,
                      metrics=["accuracy"])

        # set the weight_decay here !
        # lr doesnt matter as it will be over written by the callback
        optimizer = SGD(lr=0.001, momentum=momentum, nesterov=True)
        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])

        callbacks = [lr_finder]
        H = model.fit_generator(
            train_gen,
            validation_data=(valid_x, valid_y),
            epochs=EPOCHS,
            #steps_per_epoch=209,
            callbacks=callbacks,
            verbose=1)
    for momentum in MOMENTUMS:
        directory = 'weights/momentum/momentum-%s/' % str(momentum)

        losses, lrs = LRFinder.restore_schedule_from_dir(directory, 10, 5)
        plt.plot(lrs, losses, label='momentum=%0.2f' % momentum)

    plt.title("Momentum")
    plt.xlabel("Learning rate")
    plt.ylabel("Validation Loss")
    plt.legend()
    plt.show()
Esempio n. 8
0
#Call VGG16Net model
# input image dimensions
data_aug = True
img_rows, img_cols = 224, 224
# The CIFAR10 images are RGB.
img_channels = 3
nb_classes = 17
n_epochs = 1
n_batch = 8

num_sample = X_train.shape[0]

lrf = LRFinder(
    num_sample,
    n_batch,
    minimum_lr=1e-4,
    maximum_lr=1,
    lr_scale='exp',
    #validation_data = (X_test, Y_test),
    validation_sample_rate=1)

VGG16_model = VGG16Net(img_rows, img_cols, img_channels, nb_classes, 1e-6)
VGG16_model.summary()
VGG16_model.compile(SGD(lr=0.01, momentum=0.9, decay=0.00001, nesterov=False),
                    loss='categorical_crossentropy',
                    metrics=['accuracy'])
#clc = CyclicLR(base_lr= 0.01, max_lr=0.1, step_size = 2*X_train.shape[0]//n_batch, mode = 'triangular')

if not data_aug:
    print('Train without data augmentation!')
    VGG16_model.fit(X_train,
                    Y_train,
Esempio n. 9
0
def classifier_training_main(
        folders,
        val_folders,
        model_name,
        time,
        epochs,
        batch_size,
        opt,
        learn_rate,
        lropf=False,
        sd=False,
        es=False,
        clr=False,
        workers=1,
        test_dirs='',
        load_model=False,
        tb=False,
        intensity_cut=None,
        leakage=0.2,
        gpu_fraction=1,
        train_indexes=None,
        valid_indexes=None,
        clr_values=[5e-5, 5e-3, 4]):
    if 0 >= gpu_fraction or gpu_fraction > 1:
        pass
    ###################################
    # TensorFlow wizardry for GPU dynamic memory allocation
    else:
        config = tf.ConfigProto()
        # Don't pre-allocate memory; allocate as-needed
        config.gpu_options.allow_growth = True
        # Only allow a fraction of the GPU memory to be allocated
        config.gpu_options.per_process_gpu_memory_fraction = gpu_fraction
        # Create a session with the above options specified.
        K.tensorflow_backend.set_session(tf.Session(config=config))
    ###################################

    # remove semaphore warnings
    os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    # avoid validation deadlock problem
    mp.set_start_method('spawn', force=True)

    # hard coded parameters
    shuffle = True
    channels = 1
    if time:
        channels = 2

    # early stopping
    md_es = 0.01  # min delta
    p_es = 25  # patience

    # cycle learning rate CLR
    base_lr = clr_values[0]
    max_lr = clr_values[1]
    step_size = clr_values[2]

    # sgd
    lr = 0.01  # lr
    decay = 1e-4  # decay
    momentum = 0.9  # momentum
    nesterov = True

    # adam
    # default a_lr should be 0.001
    a_lr = learn_rate
    a_beta_1 = 0.9
    a_beta_2 = 0.999
    a_epsilon = None
    a_decay = 0
    amsgrad = True

    # adabound
    ab_lr = 1e-03
    ab_final_lr = 0.1
    ab_gamma = 1e-03
    ab_weight_decay = 0
    amsbound = True

    # reduce lr on plateau
    f_lrop = 0.1  # factor
    p_lrop = 15  # patience
    md_lrop = 0.005  # min delta
    cd_lrop = 5  # cool down
    mlr_lrop = a_lr / 100  # min lr

    # cuts
    #intensity_cut = 500
    #leakage2_intensity_cut = 0.2

    training_files = get_all_files(folders)
    validation_files = get_all_files(val_folders)

    # generators
    print('Building training generator...')
    feature = 'gammaness'  # hardcoded by now
    '''
    training_generator = DataGeneratorC(training_files,
                                        batch_size=batch_size,
                                        arrival_time=time,
                                        shuffle=shuffle,
                                        intensity=intensity_cut)
    train_idxs = training_generator.get_indexes()
    train_gammas = np.unique(train_idxs[:, 2], return_counts=True)[1][1]
    train_protons = np.unique(train_idxs[:, 2], return_counts=True)[1][0]
    '''
    training_generator = LSTGenerator(training_files,
                                      batch_size=batch_size,
                                      arrival_time=time,
                                      feature=feature,
                                      shuffle=shuffle,
                                      intensity=intensity_cut,
                                      leakage2_intensity=leakage,
                                      load_indexes=train_indexes)
    # get image size (rows and columns)
    img_rows = training_generator.img_rows
    img_cols = training_generator.img_cols
    print("IMG rows: {}, cols: {}".format(img_rows, img_cols))
    # create a folder to keep model & results
    now = datetime.datetime.now()
    root_dir = now.strftime(model_name + '_' + '%Y-%m-%d_%H-%M')
    mkdir(root_dir)
    models_dir = join(root_dir, "models")
    mkdir(models_dir)

    # save data info
    train_idxs = training_generator.get_all_info()
    train_gammas = np.unique(train_idxs['class'], return_counts=True)[1][1]
    train_protons = np.unique(train_idxs['class'], return_counts=True)[1][0]
    train_gamma_frac = training_generator.gamma_fraction()
    if train_indexes is None:
        train_idxs.to_pickle(join(root_dir, "train_indexes.pkl"))

    if len(val_folders) > 0:
        print('Building validation generator...')
        validation_generator = LSTGenerator(validation_files,
                                            batch_size=batch_size,
                                            arrival_time=time,
                                            feature=feature,
                                            shuffle=False,
                                            intensity=intensity_cut,
                                            leakage2_intensity=leakage,
                                            load_indexes=valid_indexes
                                            )
        valid_idxs = validation_generator.get_all_info()
        valid_gammas = np.unique(valid_idxs['class'], return_counts=True)[1][1]
        valid_protons = np.unique(valid_idxs['class'], return_counts=True)[1][0]
        valid_gamma_frac = validation_generator.gamma_fraction()
        if valid_indexes is None:
            valid_idxs.to_pickle(join(root_dir, "valid_indexes.pkl"))

    # class_weight = {0: 1., 1: train_protons/train_gammas}
    # print(class_weight)

    hype_print = '\n' + '======================================HYPERPARAMETERS======================================'

    hype_print += '\n' + 'Image rows: ' + str(img_rows) + ' Image cols: ' + str(img_cols)
    hype_print += '\n' + 'Folders:' + str(folders)
    hype_print += '\n' + 'Model: ' + str(model_name)
    hype_print += '\n' + 'Use arrival time: ' + str(time)
    hype_print += '\n' + 'Epochs:' + str(epochs)
    hype_print += '\n' + 'Batch size: ' + str(batch_size)
    hype_print += '\n' + 'Optimizer: ' + str(opt)
    hype_print += '\n' + 'Validation: ' + str(val_folders)
    hype_print += '\n' + 'Test dirs: ' + str(test_dirs)

    hype_print += '\n' + 'intensity_cut: ' + str(intensity_cut)
    hype_print += '\n' + 'leakage2_intensity_cut: ' + str(leakage)

    if clr:
        hype_print += '\n' + '--- Cycle Learning Rate ---'
        hype_print += '\n' + 'Base LR: ' + str(base_lr)
        hype_print += '\n' + 'Max LR: ' + str(max_lr)
        hype_print += '\n' + 'Step size: ' + str(step_size) + ' (' + str(step_size*len(training_generator)) + ')'
    if es:
        hype_print += '\n' + '--- Early stopping ---'
        hype_print += '\n' + 'Min delta: ' + str(md_es)
        hype_print += '\n' + 'Patience: ' + str(p_es)
        hype_print += '\n' + '----------------------'
    if opt == 'sgd':
        hype_print += '\n' + '--- SGD ---'
        hype_print += '\n' + 'Learning rate:' + str(lr)
        hype_print += '\n' + 'Decay: ' + str(decay)
        hype_print += '\n' + 'Momentum: ' + str(momentum)
        hype_print += '\n' + 'Nesterov: ' + str(nesterov)
        hype_print += '\n' + '-----------'
    elif opt == 'adam':
        hype_print += '\n' + '--- ADAM ---'
        hype_print += '\n' + 'lr: ' + str(a_lr)
        hype_print += '\n' + 'beta_1: ' + str(a_beta_1)
        hype_print += '\n' + 'beta_2: ' + str(a_beta_2)
        hype_print += '\n' + 'epsilon: ' + str(a_epsilon)
        hype_print += '\n' + 'decay: ' + str(a_decay)
        hype_print += '\n' + 'Amsgrad: ' + str(amsgrad)
        hype_print += '\n' + '------------'
    if lropf:
        hype_print += '\n' + '--- Reduce lr on plateau ---'
        hype_print += '\n' + 'lr decrease factor: ' + str(f_lrop)
        hype_print += '\n' + 'Patience: ' + str(p_lrop)
        hype_print += '\n' + 'Min delta: ' + str(md_lrop)
        hype_print += '\n' + 'Cool down:' + str(cd_lrop)
        hype_print += '\n' + 'Min lr: ' + str(mlr_lrop)
        hype_print += '\n' + '----------------------------'
    if sd:
        hype_print += '\n' + '--- Step decay ---'

    hype_print += '\n' + 'Workers: ' + str(workers)
    hype_print += '\n' + 'Shuffle: ' + str(shuffle)

    hype_print += '\n' + 'Number of training batches: ' + str(len(training_generator))
    hype_print += '\n' + 'Number of training gammas: ' + str(train_gammas)
    hype_print += '\n' + 'Number of training protons: ' + str(train_protons)
    hype_print += '\n' + 'Fraction of gamma in training set: ' + str(train_gamma_frac)
    if len(val_folders) > 0:
        hype_print += '\n' + 'Number of validation batches: ' + str(len(validation_generator))
        hype_print += '\n' + 'Number of validation gammas: ' + str(valid_gammas)
        hype_print += '\n' + 'Number of validation protons: ' + str(valid_protons)
        hype_print += '\n' + 'Fraction of gamma in validation set: ' + str(valid_gamma_frac)

    # keras.backend.set_image_data_format('channels_first')
    if load_model:
        model = keras.models.load_model(model_name)
        model_name = Path(model_name).name
    else:
        model, hype_print = select_classifier(model_name, hype_print, channels, img_rows, img_cols)
    #model = load_model('/home/pgrespan/nick_models/ResNetFSE_49_0.82375_0.80292.h5')

    hype_print += '\n' + '========================================================================================='

    # printing on screen hyperparameters
    print(hype_print)

    # writing hyperparameters on file
    f = open(root_dir + '/hyperparameters.txt', 'w')
    f.write(hype_print)
    f.close()

    model.summary()

    callbacks = []

    if len(val_folders) > 0:
        checkpoint = ModelCheckpoint(
            filepath=models_dir + '/' + model_name + '_{epoch:02d}_{acc:.5f}_{val_acc:.5f}.h5', monitor='val_acc',
            save_best_only=False)
    else:
        checkpoint = ModelCheckpoint(
            filepath=models_dir + '/' + model_name + '_{epoch:02d}_{acc:.5f}.h5', monitor='acc',
            save_best_only=True)

    callbacks.append(checkpoint)

    # tensorboard = keras.callbacks.TensorBoard(log_dir=root_dir + "/logs",
    #                                          histogram_freq=5,
    #                                          batch_size=batch_size,
    #                                          write_images=True,
    #                                          update_freq=batch_size * 100)

    history = LossHistoryC()

    csv_callback = keras.callbacks.CSVLogger(root_dir + '/epochs_log.csv', separator=',', append=False)

    callbacks.append(history)
    callbacks.append(csv_callback)

    # callbacks.append(tensorboard)

    # sgd
    optimizer = None
    if opt == 'sgd':
        sgd = optimizers.SGD(lr=lr, decay=decay, momentum=momentum, nesterov=nesterov)
        optimizer = sgd
    elif opt == 'adam':
        adam = optimizers.Adam(lr=a_lr, beta_1=a_beta_1, beta_2=a_beta_2, epsilon=a_epsilon, decay=a_decay,
                               amsgrad=amsgrad)
        optimizer = adam
    '''
    elif opt == 'adabound':
        adabound = AdaBound(lr=ab_lr, final_lr=ab_final_lr, gamma=ab_gamma, weight_decay=ab_weight_decay,
                            amsbound=False)
        optimizer = adabound
    '''
    # reduce lr on plateau
    if lropf:
        lrop = keras.callbacks.ReduceLROnPlateau(monitor='val_acc', factor=f_lrop, patience=p_lrop, verbose=1,
                                                 mode='auto',
                                                 min_delta=md_lrop, cooldown=cd_lrop, min_lr=mlr_lrop)
        callbacks.append(lrop)

    if sd:
        # learning rate schedule
        def step_decay(epoch):
            current = K.eval(model.optimizer.lr)
            lrate = current
            if epoch == 99:
                lrate = current / 10
                print('Reduced learning rate by a factor 10')
            return lrate

        stepd = LearningRateScheduler(step_decay)
        callbacks.append(stepd)

    if es:
        # early stopping
        early_stopping = EarlyStopping(monitor='val_acc', min_delta=md_es, patience=p_es, verbose=1, mode='max')
        callbacks.append(early_stopping)

    if tb:
        tb_path = os.path.join(root_dir, 'tb')
        if not os.path.exists(tb_path):
            os.mkdir(tb_path)
        #tb_path = os.path.join(tb_path, root_dir)
        # os.mkdir(tb_path)
        tensorboard = TensorBoard(log_dir=tb_path)
        callbacks.append(tensorboard)

    if clr:
        cyclelr = CyclicLR(
            base_lr=base_lr,
            max_lr=max_lr,
            step_size=step_size*len(training_generator)
        )
        callbacks.append(cyclelr)

    lrfinder=False
    if lrfinder:

        lr_callback = LRFinder(len(training_generator)*batch_size, batch_size,
                               1e-05, 1e01,
                               # validation_data=(X_val, Y_val),
                               lr_scale='exp', save_dir=join(root_dir,'clr'))
        callbacks.append(lr_callback)

    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

    if len(val_folders) > 0:
        model.fit(
            x=training_generator,
            validation_data=validation_generator,
            steps_per_epoch=len(training_generator),
            validation_steps=len(validation_generator),
            epochs=epochs,
            verbose=1,
            max_queue_size=10,
            use_multiprocessing=True,
            workers=workers,
            shuffle=False,
            callbacks=callbacks
        )
    else:
        model.fit(
            x=training_generator,
            steps_per_epoch=len(training_generator),
            epochs=epochs,
            verbose=1,
            max_queue_size=10,
            use_multiprocessing=True,
            workers=workers,
            shuffle=False,
            callbacks=callbacks
        )

    # save results
    train_history = root_dir + '/train-history'
    with open(train_history, 'wb') as file_pi:
        pickle.dump(history.dic, file_pi)

    # post training operations

    # training plots
    train_plots(train_history, False)

    if len(test_dirs) > 0:

        if len(val_folders) > 0:
            # get the best model on validation
            val_acc = history.dic['val_accuracy']
            m = val_acc.index(max(val_acc))  # get the index with the highest accuracy

            model_checkpoints = [join(root_dir, f) for f in listdir(root_dir) if
                                 (isfile(join(root_dir, f)) and f.startswith(
                                     model_name + '_' + '{:02d}'.format(m + 1)))]

            best = model_checkpoints[0]

            print('Best checkpoint: ', best)

        else:
            # get the best model
            acc = history.dic['accuracy']
            m = acc.index(max(acc))  # get the index with the highest accuracy

            model_checkpoints = [join(root_dir, f) for f in listdir(root_dir) if
                                 (isfile(join(root_dir, f)) and f.startswith(
                                     model_name + '_' + '{:02d}'.format(m + 1)))]

            best = model_checkpoints[0]

            print('Best checkpoint: ', best)

        # test plots & results if test data is provided
        if len(test_dirs) > 0:
            csv = tester(test_dirs, best, batch_size, time, workers)
            test_plots(csv)