from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow import keras
from helpers import build_generator_dataframe, get_file_id
WORKDIR = os.path.abspath(sys.argv[2])
sys.stdout.write('Project directory: %s\n' % WORKDIR)
SRC = os.path.join(WORKDIR, 'src')
DATA = os.path.join(WORKDIR, 'data')
RESULTS = os.path.join(WORKDIR, 'results')
TRAIN_MULTIBAND = os.path.join(DATA, 'train_multiband')
TRAIN_MULTIBAND_AUGMENT = os.path.join(DATA, 'train_multiband_augment')
TEST_MULTIBAND = os.path.join(DATA, 'test_multiband')

lens_df = pd.read_csv(os.path.join(RESULTS, 'lens_id_labels.csv'), index_col=0)
dataframe_for_generator = build_generator_dataframe(lens_df, TRAIN_MULTIBAND)
# Extract data proportions for loss weighting
n_lens_clean = len(lens_df[lens_df['is_lens'] == True])
n_nolens_clean = len(lens_df[lens_df['is_lens'] == False])
equal_class_coeff = np.array([n_lens_clean / n_nolens_clean, 1])
natural_class_coeff = np.array([1000 * n_lens_clean / n_nolens_clean, 1])

batch_size = 32
epochs = 3
IMG_HEIGHT = 200
IMG_WIDTH = 200
data_bias = 'none'

train_df, val_df = train_test_split(dataframe_for_generator,
                                    test_size=0.6,
                                    random_state=42)
Ejemplo n.º 2
0
def main():
    if len(sys.argv) == 2:
        config_file = 'config_lastro.ini'
        model_name = sys.argv[1]
    elif len(sys.argv) == 3:
        config_file = sys.argv[1]
        model_name = sys.argv[2]
    if not os.path.isfile(config_file):
        sys.exit('ERROR:\tThe config file %s was not found.' % config_file)
    # Avoid using GPU to evaluate models.
    sys.stdout.write('\nNot using GPU.\n')
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    # Import configuration file
    config = configparser.ConfigParser()
    config.read(config_file)
    # Extract parameters from model name
    if 'train_multiband_bin' in model_name:
        datadir = 'train_multiband_bin'
    elif 'train_multiband_noclip_bin' in model_name:
        datadir = 'train_multiband_noclip_bin'
    else:
        datadir = 'train_multiband_noclip_bin'

    # Extract bands from filename
    bands = []
    if 'VIS0' in model_name:
        bands.append(False)
    elif 'VIS1' in model_name:
        bands.append(True)
    if 'NIR000' in model_name:
        [bands.append(False) for i in range(3)]
    elif 'NIR111' in model_name:
        [bands.append(True) for i in range(3)]
    bands = list(np.array(bands).reshape(-1))
    print("The bands are: ", bands)
    # Extract split ratio from filename
    for param in model_name.split('_'):
        if 'ratio' in param:
            ratio = float(param.replace('ratio', ''))

    # Paths
    WORKDIR = config['general']['workdir']
    sys.stdout.write('Project directory: %s\n' % WORKDIR)
    DATA = os.path.join(WORKDIR, 'data')
    RESULTS = os.path.join(WORKDIR, 'results')
    TRAIN_MULTIBAND = os.path.join(DATA, datadir)

    image_catalog = pd.read_csv(os.path.join(
        DATA, 'catalog/image_catalog2.0train.csv'),
                                comment='#',
                                index_col=0)
    print('The shape of the image catalog: ' + str(image_catalog.shape) + "\n")

    lens_df = pd.read_csv(os.path.join(RESULTS, 'lens_id_labels_old.csv'),
                          index_col=0)
    lens_df_new = pd.read_csv(os.path.join(RESULTS, 'lens_id_labels.csv'),
                              index_col=0)
    dataframe_for_generator = build_generator_dataframe(
        lens_df, TRAIN_MULTIBAND)
    # Split the TRAIN_MULTIBAND set into train and validation sets. Set test_size below!
    train_df, val_df = train_test_split(
        dataframe_for_generator,
        test_size=config['trainparams'].getfloat('test_fraction'),
        random_state=42)
    total_train = len(train_df)
    total_val = len(val_df)
    print(train_df)

    sys.exit(0)
    print("The number of objects in the whole training sample is: ",
          total_train)
    print("The number of objects in the whole validation sample is: ",
          total_val)
    test_fraction = float(config["trainparams"]["test_fraction"])
    print("The test fraction is: ", test_fraction)
    if config['trainparams']['subsample_train'] == 'total':
        subsample_train = total_train
        subsample_val = total_val
    else:
        try:
            subsample_train = int(config['trainparams']['subsample_train'])
            subsample_val = int(subsample_train * test_fraction /
                                (1. - test_fraction))
        except:
            raise ValueError('subsample_train should be \'total\' or int.')

    print("The number of objects in the training subsample is: ",
          subsample_train)
    print("The number of objects in the validation subsample is: ",
          subsample_val)

    augment_train_data = bool(int(config['trainparams']['augment_train_data']))
    # Create Tiff Image Data Generator objects for train and validation
    image_data_gen_train = TiffImageDataGenerator(featurewise_center=False,
                                                  rotation_range=0,
                                                  fill_mode='wrap',
                                                  horizontal_flip=True,
                                                  vertical_flip=True,
                                                  preprocessing_function=None,
                                                  data_format='channels_last',
                                                  dtype='float32')
    image_data_gen_val = TiffImageDataGenerator(dtype='float32')

    # Create generators for Images and Labels
    roc_val_data_gen = image_data_gen_val.prop_image_generator_dataframe(
        val_df,
        directory=TRAIN_MULTIBAND,
        x_col='filenames',
        y_col='labels',
        batch_size=subsample_val,
        validation=True,
        ratio=ratio,
        bands=bands,
        binary=True)

    # Obtain model from the saving directory
    model_name_base = os.path.basename(model_name)
    model = tf.keras.models.load_model(model_name)
    model.summary()
    history_path = model_name.replace('h5', 'history')
    # Checkpoints dir
    save_dir = os.path.join(RESULTS, 'checkpoints/lastro_cnn/')
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    filepath = os.path.join(save_dir, model_name_base)

    # Plots
    # History
    if os.path.isfile(history_path):
        with open(history_path, 'rb') as file_pi:
            history = pickle.load(file_pi)
        fig, ax1 = plt.subplots(1, 1, figsize=(10, 5))
        ax2 = ax1.twinx()
        ax1.plot(
            range(len(history['loss'])),
            history['val_loss'],
            label='Validation loss',
            #               marker='o',
            c='b',
            lw=3)
        ax1.plot(
            range(len(history['loss'])),
            history['loss'],
            label='Training loss',
            #               marker='o',
            c='r',
            lw=3)
        ax2.set_ylim([0.5, 1])
        ax2.plot(
            range(len(history['loss'])),
            history['val_acc'],
            label='Validation accuracy',
            #               marker='^',
            c='b',
            ls='--',
            fillstyle='none',
            lw=3)
        ax2.plot(
            range(len(history['loss'])),
            history['acc'],
            label='Training accuracy',
            #               marker='^',
            c='r',
            ls='--',
            fillstyle='none',
            lw=3)
        ax1.set_xlabel('Epoch')
        ax1.legend(loc=(-0.1, 1))
        ax2.legend(loc=(0.9, 1))
        ax1.set_ylabel('Loss')
        ax2.set_ylabel('Accuracy')
        plt.gcf()
        plt.savefig(os.path.join(
            RESULTS, 'plots/' +
            os.path.basename(history_path).replace('.history', '.png')),
                    dpi=200)

    # Roc curve
    images_val, labels_true = next(roc_val_data_gen)
    print(labels_true)
    labels_score = model.predict(images_val,
                                 batch_size=1,
                                 verbose=2,
                                 workers=16,
                                 use_multiprocessing=True)
    fpr, tpr, thresholds = roc_curve(np.ravel(labels_true),
                                     np.ravel(labels_score))
    scores = model.evaluate(images_val,
                            labels_true,
                            batch_size=True,
                            verbose=1,
                            workers=16,
                            use_multiprocessing=True)
    scores_dict = {
        metric: value
        for metric, value in zip(model.metrics_names, scores)
    }
    print(scores)
    print(model.metrics_names)
    acc = scores_dict['acc']
    auc = scores_dict['auc']
    np.savetxt(os.path.join(RESULTS,
                            model_name_base.replace('h5', 'FPRvsTPR.dat')),
               np.array([fpr, tpr]).T,
               header='auc=%.3f\nacc=%.3f' % (auc, acc))
    plt.figure(2)
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.plot([0, 1], [0, 1])
    plt.legend()

    plt.plot(fpr,
             tpr,
             label='Validation\nAUC=%.3f\nACC=%.3f' % (auc, acc),
             lw=3)
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.plot([0, 1], [0, 1], lw=3)
    plt.legend()
    plt.savefig(os.path.join(
        RESULTS, 'plots/ROCsklearn_' +
        os.path.basename(model_name).replace('.h5', '.png')),
                dpi=200)
Ejemplo n.º 3
0
def main():
    """Fit the model.

    If checkpoint does not exist in checkpoints or in the results directory, a new model is created and fitted
    according to parameters set in the config file. If a checkpoint or end model (in results dir) is found, it is loaded
    and the training is resumed according to values defined in the config file. By default, a checkpoint is preferred
    over an end model since it is assumed to be more recent (in case of manual stopping of training)."""

    if len(sys.argv) != 2:
        config_file = 'config_lesta_df.ini'
    else:
        config_file = sys.argv[1]
    if not os.path.isfile(config_file):
        sys.exit('ERROR:\tThe config file %s was not found.' % config_file)

    config = configparser.ConfigParser()
    config.read(config_file)
    print("\nConfiguration file:\n")
    for section in config.sections():
        print("Section: %s" % section)
        for options in config.options(section):
            print("  %s = %s" % (options, config.get(section, options)))
    if not bool(config['general'].getboolean('use_gpu')):
        sys.stdout.write('\nNot using GPU.\n')
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    # Paths
    WORKDIR = config['general']['workdir']
    sys.stdout.write('Project directory: %s\n' % WORKDIR)
    SRC = os.path.join(WORKDIR, 'src')
    DATA = os.path.join(WORKDIR, 'data')
    RESULTS = os.path.join(WORKDIR, 'results')
    TRAIN_MULTIBAND = config['general']['train_multiband']
    TEST_MULTIBAND = os.path.join(DATA, 'test_multiband')
    # Catalog
    lens_df = pd.read_csv(os.path.join(RESULTS, 'lens_id_labels.csv'),
                          index_col=0)
    dataframe_for_generator = build_generator_dataframe(
        lens_df, TRAIN_MULTIBAND)
    # Extract data proportions for loss weighting
    n_lens_clean = len(lens_df[lens_df['is_lens'] == True])
    n_nolens_clean = len(lens_df[lens_df['is_lens'] == False])
    equal_class_coeff = np.array([n_lens_clean / n_nolens_clean, 1])
    natural_class_coeff = np.array([1000 * n_lens_clean / n_nolens_clean, 1])
    # Training parameters
    batch_size = config['trainparams'].getint('batch_size')
    epochs = config['trainparams'].getint('epochs')
    data_bias = config['trainparams']['data_bias']
    test_fraction = config['trainparams'].getfloat('test_fraction')
    augment_train_data = bool(int(config['trainparams']['augment_train_data']))
    kernel_size_1 = int(config['trainparams']['kernel_size_1'])
    kernel_size_2 = int(config['trainparams']['kernel_size_2'])
    dropout_config = config['trainparams'][
        'dropout_kind']  #Import dropout and check values are valid.
    if dropout_config == 'dropout':
        dropout_kind = Dropout
    elif dropout_config == 'spatialdropout':
        dropout_kind = SpatialDropout2D
    else:
        raise NotImplementedError(
            'dropout_kind must be \'dropout\' or \'spatialdropout\'\nPlease check config file.'
        )
    pool_size = int(config['trainparams']['pool_size'])

    bands = [
        config['bands'].getboolean('VIS0'), config['bands'].getboolean('NIR1'),
        config['bands'].getboolean('NIR2'), config['bands'].getboolean('NIR3')
    ]
    print("The bands are: ", bands)
    binary = bool(int(config['general']['binary']))
    ratio = float(config['trainparams']['lens_nolens_ratio'])
    # Split catalog in train and test (validation) sets. We used fixed state 42.
    train_df, val_df = train_test_split(dataframe_for_generator,
                                        test_size=test_fraction,
                                        random_state=42)
    total_train = len(train_df)
    total_val = len(val_df)
    print("The number of objects in the whole training sample is: ",
          total_train)
    print("The number of objects in the whole validation sample is: ",
          total_val)
    print("The test fraction is: ", test_fraction)
    if config['trainparams'][
            'subsample_train'] == 'total':  #Import subsample size and check values are as expected.
        subsample_train = total_train
        subsample_val = total_val
    else:
        try:
            subsample_train = int(config['trainparams']['subsample_train'])
            subsample_val = int(subsample_train * test_fraction /
                                (1. - test_fraction))
        except:
            raise ValueError('subsample_train should be \'total\' or int.')
    print("The number of objects in the training subsample is: ",
          subsample_train)
    print("The number of objects in the validation subsample is: ",
          subsample_val)
    train_steps_per_epoch = int(subsample_train // batch_size)
    val_steps_per_epoch = int(subsample_val // batch_size)
    print("The number of training steps is: ", train_steps_per_epoch)
    print("The number of validation steps is: ", val_steps_per_epoch)

    # Create TiffImageDataGenerator objects to inherit random transformations from Keras' class.
    image_data_gen_train = TiffImageDataGenerator(featurewise_center=False,
                                                  rotation_range=0,
                                                  fill_mode='wrap',
                                                  horizontal_flip=True,
                                                  vertical_flip=True,
                                                  preprocessing_function=None,
                                                  data_format='channels_last',
                                                  dtype='float32')
    image_data_gen_val = TiffImageDataGenerator(dtype='float32')
    # Create Generator objects from the initialized TiffImageDataGenerators.
    # To train
    train_data_gen = image_data_gen_train.prop_image_generator_dataframe(
        train_df,
        directory=TRAIN_MULTIBAND,
        x_col='filenames',
        y_col='labels',
        batch_size=batch_size,
        validation=not (augment_train_data),
        ratio=ratio,
        bands=bands,
        binary=binary)
    # To validate
    val_data_gen = image_data_gen_val.prop_image_generator_dataframe(
        val_df,
        directory=TRAIN_MULTIBAND,
        x_col='filenames',
        y_col='labels',
        batch_size=batch_size,
        validation=True,
        ratio=ratio,
        bands=bands,
        binary=binary)
    # To predict/evaluate
    roc_val_data_gen = image_data_gen_val.prop_image_generator_dataframe(
        val_df,
        directory=TRAIN_MULTIBAND,
        x_col='filenames',
        y_col='labels',
        batch_size=batch_size,
        validation=True,
        ratio=ratio,
        bands=bands,
        binary=binary)
    # To safely obtain image size
    temp_data_gen = image_data_gen_train.image_generator_dataframe(
        train_df,
        directory=TRAIN_MULTIBAND,
        x_col='filenames',
        y_col='labels',
        batch_size=1,
        validation=True,
        bands=bands,
        binary=binary)

    # Obtain image size
    image, _ = next(temp_data_gen)
    input_shape = image[0].shape
    # Define correct bias to initialize (use if not forcing generator to load equal proportions of data)
    output_bias = tf.keras.initializers.Constant(
        np.log(n_lens_clean / n_nolens_clean))

    # Path to save checkpoints
    model_type = 'lastro_cnn'
    save_dir = os.path.join(RESULTS, 'checkpoints/%s/' % model_type)
    model_name = '%s_Tr%i_Te%i_bs%i_ep%.03d_aug%i_VIS%i_NIR%i%i%i_DB%s_ratio%.01f_ks%i%i_ps%i_%s_%s.h5' % (
        model_type, subsample_train, subsample_val, batch_size, epochs,
        int(augment_train_data), bands[0], bands[1], bands[2], bands[3],
        data_bias, ratio, kernel_size_1, kernel_size_2, pool_size,
        dropout_kind.__name__, os.path.basename(TRAIN_MULTIBAND))
    # Create path of checkpoints if necessary
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    # Checkpoint file path
    checkpoint_filepath = os.path.join(save_dir, model_name)
    # Final model file path
    end_model_name = os.path.join(RESULTS, model_name)
    print("The model name is: ", model_name)
    history_path = os.path.join(RESULTS, model_name.replace('h5', 'history'))

    # Callbacks
    # Checkpoint callback to save every epoch for better resuming training
    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_filepath,
        save_best_only=False,
        verbose=1,
        monitor='val_acc',
        save_freq='epoch')
    # Checkpoint best model
    cp_best_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_filepath.replace('.h5', '_BEST.h5'),
        save_best_only=True,
        verbose=1,
        monitor='val_acc')
    # Early stopping callback (currently using a very high patience to avoid it triggering)
    es_callback = tf.keras.callbacks.EarlyStopping(
        monitor='val_acc',
        #min_delta=0.1,
        patience=30,
        verbose=1,
        mode='auto',
        baseline=None,
        restore_best_weights=True)
    # Learning rate reducer callback.
    lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(factor=np.sqrt(0.1),
                                                      cooldown=0,
                                                      patience=20,
                                                      min_lr=0.5e-6,
                                                      monitor='val_acc',
                                                      verbose=1,
                                                      mode='auto')
    # Callback to save log to csv. It is probably better to use this than save-resume history.
    logger_callback = tf.keras.callbacks.CSVLogger(checkpoint_filepath.replace(
        '.h5', '.log'),
                                                   separator=',',
                                                   append=True)
    # Callback to resume history if history_path exists.
    history_callback = ResumeHistory(history_path)
    # Callback to use Tensorboard
    log_dir = os.path.join(
        RESULTS,
        "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir,
                                                          histogram_freq=1,
                                                          profile_batch=0)

    # Define metrics for the model.
    metrics = [
        keras.metrics.TruePositives(name='tp'),
        keras.metrics.FalsePositives(name='fp'),
        keras.metrics.TrueNegatives(name='tn'),
        keras.metrics.FalseNegatives(name='fn'),
        keras.metrics.BinaryAccuracy(name='acc'),
        keras.metrics.AUC(name='auc')
    ]
    # If there are no checkpoints or final models saved, compile a new one.
    if not os.path.isfile(checkpoint_filepath) and not os.path.isfile(
            end_model_name):
        model = build_lastro_model(kernel_size_1, kernel_size_2, pool_size,
                                   input_shape, dropout_kind)
        model.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=metrics)
    elif not os.path.isfile(checkpoint_filepath) and os.path.isfile(
            end_model_name):
        print('Loading existing model from result.')
        model = tf.keras.models.load_model(end_model_name)
        epochs = int(config['trainparams']['new_epochs'])
        learning_rate = config['trainparams']['learning_rate']
        change_learning_rate(model, learning_rate)
    elif os.path.isfile(checkpoint_filepath):
        print('Loading existing model from checkpoint.')
        model = tf.keras.models.load_model(checkpoint_filepath)
        epochs = int(config['trainparams']['new_epochs'])
        learning_rate = config['trainparams']['learning_rate']
        change_learning_rate(model, learning_rate)
    model.summary()
    # Define class weights for unevenly distributed (biased) dataset.
    if data_bias == 'natural':
        sys.stdout.write(
            'Using natural data bias: 1000x more non lenses than lenses.\n')
        class_coeff = natural_class_coeff
    elif data_bias == 'none':
        sys.stdout.write(
            'Using no data bias (simulate equal proportion among classes).\n')
        class_coeff = equal_class_coeff
    elif data_bias == 'raw':
        sys.stdout.write('Using the raw bias (no weights applied).\n')
        class_coeff = [1, 1]
    else:
        raise NotImplementedError(
            'data_bias must be either natural, none or raw.')
    class_weights = {0: class_coeff[0], 1: class_coeff[1]}
    sys.stdout.write('Using loss weights: %s\n' % class_weights)

    # Fit the model and save the history callback.
    # Use multiprocessing True when using > 1 workers. (Seems to cause problems)
    # Expect tf update to use threadsafe_iter class.
    history = model.fit_generator(
        train_data_gen,
        steps_per_epoch=subsample_train // batch_size,
        epochs=epochs,
        validation_data=val_data_gen,
        validation_steps=subsample_val // batch_size,
        callbacks=[
            cp_callback, es_callback, lr_reducer, cp_best_callback,
            history_callback, logger_callback, tensorboard_callback
        ],
        class_weight=class_weights,
        #    use_multiprocessing=True,
        verbose=1,
        #    workers=16
    )
    model.save(end_model_name)
    # If training finishes normally (Is not stopped by user), save final model.
    # Save complete history if the training was resumed.
    if history_callback.use_history_file_flag:
        with open(history_path, 'wb') as file_pi:
            pickle.dump(history_callback.complete_history, file_pi)
    else:
        with open(history_path, 'wb') as file_pi:
            pickle.dump(history.history, file_pi)

    # Score trained model.
    scores = model.evaluate_generator(val_data_gen,
                                      verbose=2,
                                      steps=val_steps_per_epoch)
    images_val, labels_true = next(roc_val_data_gen)
    labels_score = model.predict(images_val, batch_size=batch_size, verbose=2)
    fpr, tpr, thresholds = roc_curve(np.ravel(labels_true),
                                     np.ravel(labels_score))
    auc = history.history['val_auc'][-1]
    acc = history.history['val_acc'][-1]
    # Save TPR and FPR metrics to plot ROC.
    np.savetxt(os.path.join(RESULTS, model_name.replace('h5', 'FPRvsTPR.dat')),
               np.array([fpr, tpr]).T,
               header='auc=%.3f\nacc=%.3f' % (auc, acc))
Ejemplo n.º 4
0
def main():
    if len(sys.argv) == 2:
        config_file = 'config_lesta_df.ini'
        model_name = sys.argv[1]
    elif len(sys.argv) == 3:
        config_file = sys.argv[1]
        model_name = sys.argv[2]
    else:
        sys.exit(
            'ERROR:\tUnexpected number of arguments.\nUSAGE:\t%s [CONFIG_FILE] MODEL_FILENAME'
            % sys.argv[0])
    if not os.path.isfile(config_file):
        sys.exit('ERROR:\tThe config file %s was not found.' % config_file)
    if not os.path.isfile(model_name):
        sys.exit('ERROR:\tThe model file %s was not found.' % model_name)

    # Import configuration file
    config = configparser.ConfigParser()
    config.read(config_file)
    # Extract parameters from model name
    if 'train_multiband_bin' in model_name:
        datadir = 'train_multiband_bin'
    elif 'train_multiband_noclip_bin' in model_name:
        datadir = 'train_multiband_noclip_bin'
    else:
        datadir = 'train_multiband_noclip_bin'

    # Extract bands from filename
    bands = [
        bool(int(config['bands']['VIS0'])),
        bool(int(config['bands']['NIR1'])),
        bool(int(config['bands']['NIR2'])),
        bool(int(config['bands']['NIR3']))
    ]
    print("The bands are: ", bands)
    # Paths
    WORKDIR = config['general']['workdir']
    sys.stdout.write('Project directory: %s\n' % WORKDIR)
    DATA = os.path.join(WORKDIR, 'data')
    RESULTS = os.path.join(WORKDIR, 'results')
    TRAIN_MULTIBAND = config['general']['train_multiband']
    TEST_MULTIBAND = TRAIN_MULTIBAND.replace('train', 'test')
    image_catalog = pd.read_csv(os.path.join(
        DATA, 'catalog/image_catalog2.0train.csv'),
                                comment='#',
                                index_col=0)
    print('The shape of the image catalog: ' + str(image_catalog.shape) + "\n")

    lens_df = pd.read_csv(os.path.join(RESULTS, 'lens_id_labels.csv'),
                          index_col=0)
    dataframe_for_generator = build_generator_dataframe(
        lens_df, TRAIN_MULTIBAND)
    print(dataframe_for_generator['filenames'])
    # Split the TRAIN_MULTIBAND set into train and validation sets. Set test_size below!
    train_df, val_df = train_test_split(
        dataframe_for_generator,
        test_size=config['trainparams'].getfloat('test_fraction'),
        random_state=42)
    total_train = len(train_df)
    total_val = len(val_df)
    print("The number of objects in the whole training sample is: ",
          total_train)
    print("The number of objects in the whole validation sample is: ",
          total_val)
    test_fraction = float(config["trainparams"]["test_fraction"])
    print("The test fraction is: ", test_fraction)
    if config['trainparams']['subsample_train'] == 'total':
        subsample_train = total_train
        subsample_val = total_val
    else:
        try:
            subsample_train = int(config['trainparams']['subsample_train'])
            subsample_val = int(subsample_train * test_fraction /
                                (1. - test_fraction))
        except:
            raise ValueError('subsample_train should be \'total\' or int.')

    print("The number of objects in the training subsample is: ",
          subsample_train)
    print("The number of objects in the validation subsample is: ",
          subsample_val)
    # Create Tiff Image Data Generator objects for train and validation
    image_data_gen_val = TiffImageDataGenerator(dtype='float32')

    # Create generators for Images and Labels
    prediction_ids = []
    test_data_gen = image_data_gen_val.generator_from_directory(
        directory=TEST_MULTIBAND,
        id_logger=prediction_ids,
        batch_size=1,
        bands=bands,
        binary=True)

    # Obtain model from the saving directory
    model_name_base = os.path.basename(model_name)
    model = tf.keras.models.load_model(model_name)
    model.summary()
    len_gen = len(os.listdir(TRAIN_MULTIBAND))
    print(len_gen)
    predictions = model.predict_generator(test_data_gen,
                                          verbose=1,
                                          steps=len_gen)
    int_prediction_ids = np.array([get_file_id(fn) for fn in prediction_ids],
                                  dtype=int)
    np.savetxt(os.path.join(RESULTS,
                            model_name_base.replace('.h5', 'predictions.dat')),
               np.array([
                   np.squeeze(int_prediction_ids).astype(int),
                   np.squeeze(predictions)
               ]).T,
               fmt='%i %.5f')