Exemplo n.º 1
0
def test(gcp_bucket, dataset_id, model_id, batch_size, trained_thresholds_id, random_module_global_seed,
         numpy_random_global_seed, tf_random_global_seed, message):

    # seed global random generators if specified; global random seeds here must be int or default None (no seed given)
    if random_module_global_seed is not None:
        random.seed(random_module_global_seed)
    if numpy_random_global_seed is not None:
        np.random.seed(numpy_random_global_seed)
    if tf_random_global_seed is not None:
        tf_random.set_seed(tf_random_global_seed)

    start_dt = datetime.now()

    assert "gs://" in gcp_bucket

    # clean up the tmp directory
    try:
        shutil.rmtree(tmp_directory.as_posix())
    except FileNotFoundError:
        pass
    tmp_directory.mkdir()

    local_dataset_dir = Path(tmp_directory, 'datasets')
    local_model_dir = Path(tmp_directory, 'models')

    copy_folder_locally_if_missing(os.path.join(gcp_bucket, 'datasets', dataset_id), local_dataset_dir)
    local_folder_has_files(local_dataset_dir, dataset_id)

    copy_folder_locally_if_missing(os.path.join(gcp_bucket, 'models', model_id), local_model_dir)
    local_folder_has_files(local_model_dir, model_id)

    test_id = "{}_{}".format(model_id, dataset_id)
    test_dir = Path(tmp_directory, 'tests', test_id)
    test_dir.mkdir(parents=True)

    with Path(local_dataset_dir, dataset_id, 'config.yaml').open('r') as f:
        dataset_config = yaml.safe_load(f)['dataset_config']

    with Path(local_model_dir, model_id, 'config.yaml').open('r') as f:
        train_config = yaml.safe_load(f)['train_config']

    if trained_thresholds_id is not None:
        with Path(local_model_dir, model_id, trained_thresholds_id).open('r') as f:
            threshold_output_data = yaml.safe_load(f)

    target_size = dataset_config['target_size']

    test_generator = ImagesAndMasksGenerator(
        Path(local_dataset_dir, dataset_id, 'test').as_posix(),
        rescale=1. / 255,
        target_size=target_size,
        batch_size=batch_size,
        seed=None if 'test_data_shuffle_seed' not in train_config else train_config['test_data_shuffle_seed'])

    optimized_class_thresholds = {}
    if trained_thresholds_id is not None and 'thresholds_training_output' in threshold_output_data['metadata']:
        for i in range(len(test_generator.mask_filenames)):
            if ('x' in threshold_output_data['metadata']['thresholds_training_output'][str('class' + str(i))] and
                    threshold_output_data['metadata']['thresholds_training_output'][str('class' + str(i))]['success']):
                optimized_class_thresholds.update(
                    {str('class' + str(i)): threshold_output_data['metadata']['thresholds_training_output'][str('class' + str(i))]['x']}
                )
            else:
                AssertionError('Unsuccessfully trained threshold attempted to be loaded.')
    else:
        optimized_class_thresholds = None

    compiled_model = generate_compiled_segmentation_model(
        train_config['segmentation_model']['model_name'],
        train_config['segmentation_model']['model_parameters'],
        len(test_generator.mask_filenames),
        train_config['loss'],
        train_config['optimizer'],
        Path(local_model_dir, model_id, "model.hdf5").as_posix(),
        optimized_class_thresholds=optimized_class_thresholds)

    results = compiled_model.evaluate(test_generator)

    metric_names = [m.name for m in compiled_model.metrics]

    with Path(test_dir, str('metrics_' + test_datetime + '.csv')).open('w') as f:
        f.write(','.join(metric_names) + '\n')
        f.write(','.join(map(str, results)))

    metadata_sys = {
        'System_info': getSystemInfo(),
        'Lib_versions_info': getLibVersions()
    }

    metadata = {
        'message': message,
        'gcp_bucket': gcp_bucket,
        'dataset_id': dataset_id,
        'model_id': model_id,
        'trained_thresholds_id': trained_thresholds_id,
        'trained_class_thresholds_loaded': optimized_class_thresholds,  # global thresh used if None
        'default_global_threshold_for_reference': global_threshold,
        'batch_size': batch_size,
        'created_datetime': datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'),
        'git_hash': git.Repo(search_parent_directories=True).head.object.hexsha,
        'elapsed_minutes': round((datetime.now() - start_dt).total_seconds() / 60, 1),
        'dataset_config': dataset_config,
        'train_config': train_config,
        'random-module-global-seed': random_module_global_seed,
        'numpy_random_global_seed': numpy_random_global_seed,
        'tf_random_global_seed': tf_random_global_seed,
        'metadata_system': metadata_sys
    }

    with Path(test_dir, metadata_file_name).open('w') as f:
        yaml.safe_dump(metadata, f)

    os.system("gsutil -m cp -n -r '{}' '{}'".format(Path(tmp_directory, 'tests').as_posix(), gcp_bucket))

    print('\n Test Metadata:')
    print(metadata)
    print('\n')

    shutil.rmtree(tmp_directory.as_posix())
Exemplo n.º 2
0
def test(gcp_bucket, dataset_id, model_id, batch_size):

    start_dt = datetime.now()

    assert "gs://" in gcp_bucket

    # clean up the tmp directory
    try:
        shutil.rmtree(tmp_directory.as_posix())
    except FileNotFoundError:
        pass
    tmp_directory.mkdir()

    local_dataset_dir = Path(tmp_directory, 'datasets')
    local_model_dir = Path(tmp_directory, 'models')

    copy_folder_locally_if_missing(os.path.join(gcp_bucket, 'datasets', dataset_id), local_dataset_dir)

    copy_folder_locally_if_missing(os.path.join(gcp_bucket, 'models', model_id), local_model_dir)

    test_id = "{}_{}".format(model_id, dataset_id, datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'))
    test_dir = Path(tmp_directory, 'tests', test_id)
    test_dir.mkdir(parents=True)

    with Path(local_dataset_dir, dataset_id, 'config.yaml').open('r') as f:
        dataset_config = yaml.safe_load(f)['dataset_config']

    with Path(local_model_dir, model_id, 'config.yaml').open('r') as f:
        train_config = yaml.safe_load(f)['train_config']

    target_size = dataset_config['target_size']

    test_generator = ImagesAndMasksGenerator(
        Path(local_dataset_dir, dataset_id, 'test').as_posix(),
        rescale=1./255,
        target_size=target_size,
        batch_size=batch_size,
        shuffle=True,
        seed=None)

    model = Unet('vgg16', input_shape=(None, None, 1), classes=len(test_generator.mask_filenames), encoder_weights=None)

    crossentropy = binary_crossentropy if len(test_generator.mask_filenames) == 1 else categorical_crossentropy
    loss_fn = crossentropy

    model.compile(optimizer=Adam(),
                  loss=loss_fn,
                  metrics=[accuracy, iou_score, jaccard_loss, dice_loss, crossentropy])

    model.load_weights(Path(local_model_dir, model_id, "model.hdf5").as_posix())

    results = model.evaluate_generator(test_generator)

    metric_names = [loss_fn.__name__, 'accuracy', 'iou_score', 'jaccard_loss', 'dice_loss', 'crossentropy']
    with Path(test_dir, 'metrics.csv').open('w') as f:
        f.write(','.join(metric_names) + '\n')
        f.write(','.join(map(str, results)))

    metadata = {
        'gcp_bucket': gcp_bucket,
        'dataset_id': dataset_id,
        'model_id': model_id,
        'batch_size': batch_size,
        'created_datetime': datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'),
        'git_hash': git.Repo(search_parent_directories=True).head.object.hexsha,
        'elapsed_minutes': round((datetime.now() - start_dt).total_seconds() / 60, 1),
        'dataset_config': dataset_config,
        'train_config': train_config
    }

    with Path(test_dir, metadata_file_name).open('w') as f:
        yaml.safe_dump(metadata, f)

    os.system("gsutil -m cp -r '{}' '{}'".format(Path(tmp_directory, 'tests').as_posix(), gcp_bucket))

    shutil.rmtree(tmp_directory.as_posix())
def test(gcp_bucket, dataset_id, model_id, batch_size):

    start_dt = datetime.now()

    assert "gs://" in gcp_bucket

    # clean up the tmp directory
    try:
        shutil.rmtree(tmp_directory.as_posix())
    except FileNotFoundError:
        pass
    tmp_directory.mkdir()

    local_dataset_dir = Path(tmp_directory, 'datasets')
    local_model_dir = Path(tmp_directory, 'models')

    copy_folder_locally_if_missing(
        os.path.join(gcp_bucket, 'datasets', dataset_id), local_dataset_dir)

    copy_folder_locally_if_missing(
        os.path.join(gcp_bucket, 'models', model_id), local_model_dir)

    test_id = "{}_{}".format(model_id, dataset_id,
                             datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'))
    test_dir = Path(tmp_directory, 'tests', test_id)
    test_dir.mkdir(parents=True)

    with Path(local_dataset_dir, dataset_id, 'config.yaml').open('r') as f:
        dataset_config = yaml.safe_load(f)['dataset_config']

    with Path(local_model_dir, model_id, 'config.yaml').open('r') as f:
        train_config = yaml.safe_load(f)['train_config']

    target_size = dataset_config['target_size']

    test_generator = ImagesAndMasksGenerator(Path(local_dataset_dir,
                                                  dataset_id,
                                                  'test').as_posix(),
                                             rescale=1. / 255,
                                             target_size=target_size,
                                             batch_size=batch_size)

    compiled_model = generate_compiled_segmentation_model(
        train_config['segmentation_model']['model_name'],
        train_config['segmentation_model']['model_parameters'], 1,
        train_config['loss'], train_config['optimizer'],
        Path(local_model_dir, model_id, "model.hdf5").as_posix())

    sys.stdout.write(str(compiled_model.summary()))
    results = compiled_model.evaluate_generator(test_generator, verbose=1)

    metric_names = [compiled_model.loss.__name__
                    ] + [m.name for m in compiled_model.metrics]
    with Path(test_dir, 'metrics.csv').open('w') as f:
        f.write(','.join(metric_names) + '\n')
        f.write(','.join(map(str, results)))

    metadata = {
        'gcp_bucket':
        gcp_bucket,
        'dataset_id':
        dataset_id,
        'model_id':
        model_id,
        'batch_size':
        batch_size,
        'created_datetime':
        datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'),
        'git_hash':
        git.Repo(search_parent_directories=True).head.object.hexsha,
        'elapsed_minutes':
        round((datetime.now() - start_dt).total_seconds() / 60, 1),
        'dataset_config':
        dataset_config,
        'train_config':
        train_config
    }

    with Path(test_dir, metadata_file_name).open('w') as f:
        yaml.safe_dump(metadata, f)

    os.system("gsutil -m cp -r '{}' '{}'".format(
        Path(tmp_directory, 'tests').as_posix(), gcp_bucket))

    shutil.rmtree(tmp_directory.as_posix())
def train_segmentation_model_prediction_thresholds(
        gcp_bucket, dataset_directory, model_id, batch_size,
        optimizing_class_metric, dataset_downsample_factor,
        random_module_global_seed, numpy_random_global_seed,
        tf_random_global_seed, message):

    # seed global random generators if specified; global random seeds here must be int or default None (no seed given)
    if random_module_global_seed is not None:
        random.seed(random_module_global_seed)
    if numpy_random_global_seed is not None:
        np.random.seed(numpy_random_global_seed)
    if tf_random_global_seed is not None:
        tf_random.set_seed(tf_random_global_seed)

    start_dt = datetime.now()

    assert "gs://" in gcp_bucket

    # clean up the tmp directory
    try:
        shutil.rmtree(tmp_directory.as_posix())
    except FileNotFoundError:
        pass
    tmp_directory.mkdir()

    dataset_id = dataset_directory.split('/')[0]
    dataset_type = dataset_directory.split('/')[-1]

    local_dataset_dir = Path(tmp_directory, 'datasets')
    local_model_dir = Path(tmp_directory, 'models')

    copy_folder_locally_if_missing(
        os.path.join(gcp_bucket, 'datasets', dataset_directory),
        Path(local_dataset_dir, dataset_id))

    local_folder_has_files(local_dataset_dir, dataset_id)

    copy_file_locally_if_missing(
        os.path.join(gcp_bucket, 'datasets', dataset_id, 'config.yaml'),
        Path(local_dataset_dir, dataset_id, 'config.yaml'))

    copy_folder_locally_if_missing(
        os.path.join(gcp_bucket, 'models', model_id), local_model_dir)
    local_folder_has_files(local_model_dir, model_id)

    train_thresh_id = "{}_{}_{}".format(model_id, dataset_id,
                                        optimizing_class_metric)
    train_thresh_id_dir = Path(tmp_directory,
                               str('train_thresholds_' + train_thresh_id))
    train_thresh_id_dir.mkdir(parents=True)

    with Path(local_dataset_dir, dataset_id, 'config.yaml').open('r') as f:
        dataset_config = yaml.safe_load(f)['dataset_config']

    with Path(local_model_dir, model_id, 'config.yaml').open('r') as f:
        train_config = yaml.safe_load(f)['train_config']

    target_size = dataset_config['target_size']

    if 'validation' in dataset_type:
        gen_seed = None if 'validation_data_shuffle_seed' not in train_config else train_config[
            'validation_data_shuffle_seed']
    elif 'test' in dataset_type:
        gen_seed = None if 'test_data_shuffle_seed' not in train_config else train_config[
            'test_data_shuffle_seed']
    else:
        gen_seed = 1234

    train_threshold_dataset_generator = ImagesAndMasksGenerator(
        Path(local_dataset_dir, dataset_directory).as_posix(),
        rescale=1. / 255,
        target_size=target_size,
        batch_size=batch_size,
        shuffle=True,
        seed=gen_seed)

    trained_prediction_thresholds = {}
    training_thresholds_output = {}
    opt_config = []
    for i in range(len(train_threshold_dataset_generator.mask_filenames)):
        print('\n' +
              str('Training class' + str(i) + ' prediction threshold...'))
        training_threshold_output, opt_config = train_prediction_thresholds(
            i, optimizing_class_metric, train_config,
            train_threshold_dataset_generator, dataset_downsample_factor,
            Path(local_model_dir, model_id, "model.hdf5").as_posix())
        if not training_threshold_output.success:
            AssertionError(
                "Training prediction thresholds has failed. See function minimization command line output."
            )

        training_thresholds_output.update({
            str('class' + str(i)): {
                'x': float(training_threshold_output.x),
                'success': training_threshold_output.success,
                'status': training_threshold_output.status,
                'message': training_threshold_output.message,
                'nfev': training_threshold_output.nfev,
                'fun': float(training_threshold_output.fun)
            }
        })
        trained_prediction_thresholds.update(
            {str('class' + str(i)): float(training_threshold_output.x)})

    metadata = {
        'message':
        message,
        'gcp_bucket':
        gcp_bucket,
        'created_datetime':
        datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'),
        'num_classes':
        len(train_threshold_dataset_generator.mask_filenames),
        'target_size':
        target_size,
        'git_hash':
        git.Repo(search_parent_directories=True).head.object.hexsha,
        'elapsed_minutes':
        round((datetime.now() - start_dt).total_seconds() / 60, 1),
        'dataset_directory':
        dataset_directory,
        'model_id':
        model_id,
        'batch_size':
        batch_size,
        'dataset_config':
        dataset_config,
        'train_config':
        train_config,
        'thresholds_training_configuration':
        opt_config,
        'thresholds_training_output':
        training_thresholds_output,
        'thresholds_training_history':
        thresholds_training_history,
        'random-module-global-seed':
        random_module_global_seed,
        'numpy_random_global_seed':
        numpy_random_global_seed,
        'tf_random_global_seed':
        tf_random_global_seed
    }

    metadata_sys = {
        'System_info': getSystemInfo(),
        'Lib_versions_info': getLibVersions()
    }

    output_data = {
        'final_trained_prediction_thresholds': trained_prediction_thresholds,
        'metadata': metadata,
        'metadata_system': metadata_sys
    }

    with Path(train_thresh_id_dir, output_file_name).open('w') as f:
        yaml.safe_dump(output_data, f)

    # copy without overwrite
    os.system("gsutil -m cp -n -r '{}' '{}'".format(
        Path(train_thresh_id_dir, output_file_name).as_posix(),
        os.path.join(gcp_bucket, 'models', model_id)))

    print('\n Train Prediction Thresholds Results:')
    print(trained_prediction_thresholds)

    print('\n Train Prediction Thresholds Metadata:')
    print(metadata)
    print('\n')

    shutil.rmtree(tmp_directory.as_posix())
def train(gcp_bucket, config_file):

    start_dt = datetime.now()

    with Path(config_file).open('r') as f:
        train_config = yaml.safe_load(f)['train_config']

    assert "gs://" in gcp_bucket

    # clean up the tmp directory
    try:
        shutil.rmtree(tmp_directory.as_posix())
    except FileNotFoundError:
        pass
    tmp_directory.mkdir()

    local_dataset_dir = Path(tmp_directory, 'datasets')

    copy_folder_locally_if_missing(
        os.path.join(gcp_bucket, 'datasets', train_config['dataset_id']),
        local_dataset_dir)

    model_id = "{}_{}".format(
        train_config['model_id_prefix'],
        datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'))
    model_dir = Path(tmp_directory, 'models', model_id)
    model_dir.mkdir(parents=True)

    plots_dir = Path(model_dir, 'plots')
    plots_dir.mkdir(parents=True)

    logs_dir = Path(model_dir, 'logs')
    logs_dir.mkdir(parents=True)

    with Path(local_dataset_dir, train_config['dataset_id'],
              'config.yaml').open('r') as f:
        dataset_config = yaml.safe_load(f)['dataset_config']

    with Path(model_dir, 'config.yaml').open('w') as f:
        yaml.safe_dump({'train_config': train_config}, f)

    target_size = dataset_config['target_size']
    batch_size = train_config['batch_size']
    epochs = train_config['epochs']
    augmentation_type = train_config['data_augmentation']['augmentation_type']

    if augmentation_type == 'necstlab':  # necstlab's workflow
        train_generator = ImagesAndMasksGenerator(
            Path(local_dataset_dir, train_config['dataset_id'],
                 'train').as_posix(),
            rescale=1. / 255,
            target_size=target_size,
            batch_size=batch_size,
            shuffle=True,
            random_rotation=train_config['data_augmentation']
            ['necstlab_augmentation']['random_90-degree_rotations'],
            seed=train_config['training_data_shuffle_seed'])

        validation_generator = ImagesAndMasksGenerator(Path(
            local_dataset_dir, train_config['dataset_id'],
            'validation').as_posix(),
                                                       rescale=1. / 255,
                                                       target_size=target_size,
                                                       batch_size=batch_size)
    elif augmentation_type == 'bio':  # new workflow
        bio_augmentation = train_config['data_augmentation'][
            'bio_augmentation']
        augmentation_dict = dict(
            rotation_range=bio_augmentation['rotation_range'],
            width_shift_range=bio_augmentation['width_shift_range'],
            height_shift_range=bio_augmentation['height_shift_range'],
            shear_range=bio_augmentation['shear_range'],
            zoom_range=bio_augmentation['zoom_range'],
            horizontal_flip=bio_augmentation['horizontal_flip'],
            fill_mode=bio_augmentation['fill_mode'],
            cval=0)
        train_generator = trainGenerator(
            batch_size=batch_size,
            train_path=Path(local_dataset_dir, train_config['dataset_id'],
                            'train').as_posix(),
            image_folder='images',
            mask_folder='masks',
            aug_dict=augmentation_dict,
            target_size=target_size,
            seed=train_config['training_data_shuffle_seed'])

        validation_generator = trainGenerator(
            batch_size=batch_size,
            train_path=Path(local_dataset_dir, train_config['dataset_id'],
                            'validation').as_posix(),
            image_folder='images',
            mask_folder='masks',
            aug_dict=augmentation_dict,
            target_size=target_size,
            seed=train_config['training_data_shuffle_seed'])

    compiled_model = generate_compiled_segmentation_model(
        train_config['segmentation_model']['model_name'],
        train_config['segmentation_model']['model_parameters'], 1,
        train_config['loss'], train_config['optimizer'])

    model_checkpoint_callback = ModelCheckpoint(Path(model_dir,
                                                     'model.hdf5').as_posix(),
                                                monitor='loss',
                                                verbose=1,
                                                save_best_only=True)
    tensorboard_callback = TensorBoard(log_dir=logs_dir.as_posix(),
                                       batch_size=batch_size,
                                       write_graph=True,
                                       write_grads=False,
                                       write_images=True,
                                       update_freq='epoch')

    # n_sample_images = 20
    # train_image_and_mask_paths = sample_image_and_mask_paths(train_generator, n_sample_images)
    # validation_image_and_mask_paths = sample_image_and_mask_paths(validation_generator, n_sample_images)

    # tensorboard_image_callback = TensorBoardImage(
    #     log_dir=logs_dir.as_posix(),
    #     images_and_masks_paths=train_image_and_mask_paths + validation_image_and_mask_paths)

    csv_logger_callback = CSVLogger(Path(model_dir, 'metrics.csv').as_posix(),
                                    append=True)

    results = compiled_model.fit_generator(
        train_generator,
        steps_per_epoch=len(train_generator) if augmentation_type == 'necstlab'
        else train_config['data_augmentation']['bio_augmentation']
        ['steps_per_epoch'],
        epochs=epochs,
        validation_data=validation_generator,
        validation_steps=len(validation_generator) if augmentation_type
        == 'necstlab' else train_config['data_augmentation']
        ['bio_augmentation']['validation_steps'],
        callbacks=[
            model_checkpoint_callback, tensorboard_callback,
            csv_logger_callback
        ])

    metric_names = ['loss'] + [m.name for m in compiled_model.metrics]

    for metric_name in metric_names:

        fig, ax = plt.subplots()
        for split in ['train', 'validate']:

            key_name = metric_name
            if split == 'validate':
                key_name = 'val_' + key_name

            ax.plot(range(epochs), results.history[key_name], label=split)
        ax.set_xlabel('epochs')
        if metric_name == 'loss':
            ax.set_ylabel(compiled_model.loss.__name__)
        else:
            ax.set_ylabel(metric_name)
        ax.legend()
        if metric_name == 'loss':
            fig.savefig(
                Path(plots_dir,
                     compiled_model.loss.__name__ + '.png').as_posix())
        else:
            fig.savefig(Path(plots_dir, metric_name + '.png').as_posix())

    # mosaic plot
    fig2, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 6))
    counter_m = 0
    counter_n = 0
    for metric_name in metric_names:

        for split in ['train', 'validate']:

            key_name = metric_name
            if split == 'validate':
                key_name = 'val_' + key_name

            axes[counter_m, counter_n].plot(range(epochs),
                                            results.history[key_name],
                                            label=split)
        axes[counter_m, counter_n].set_xlabel('epochs')
        if metric_name == 'loss':
            axes[counter_m, counter_n].set_ylabel(compiled_model.loss.__name__)
        else:
            axes[counter_m, counter_n].set_ylabel(metric_name)
        axes[counter_m, counter_n].legend()

        counter_n += 1
        if counter_n == 3:  # 3 plots per row
            counter_m += 1
            counter_n = 0

    fig2.tight_layout()
    fig2.delaxes(axes[1][2])
    fig2.savefig(Path(plots_dir, 'metrics_mosaic.png').as_posix())

    metadata = {
        'gcp_bucket':
        gcp_bucket,
        'created_datetime':
        datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'),
        'num_classes':
        1,
        'target_size':
        target_size,
        'git_hash':
        git.Repo(search_parent_directories=True).head.object.hexsha,
        'original_config_filename':
        config_file,
        'elapsed_minutes':
        round((datetime.now() - start_dt).total_seconds() / 60, 1),
        'dataset_config':
        dataset_config,
        'train_config':
        train_config
    }

    with Path(model_dir, metadata_file_name).open('w') as f:
        yaml.safe_dump(metadata, f)

    os.system("gsutil -m cp -r '{}' '{}'".format(
        Path(tmp_directory, 'models').as_posix(), gcp_bucket))

    shutil.rmtree(tmp_directory.as_posix())
Exemplo n.º 6
0
def train(gcp_bucket, config_file):

    start_dt = datetime.now()

    with Path(config_file).open('r') as f:
        train_config = yaml.safe_load(f)['train_config']

    assert "gs://" in gcp_bucket

    # clean up the tmp directory
    try:
        shutil.rmtree(tmp_directory.as_posix())
    except FileNotFoundError:
        pass
    tmp_directory.mkdir()

    local_dataset_dir = Path(tmp_directory, 'datasets')

    copy_folder_locally_if_missing(
        os.path.join(gcp_bucket, 'datasets', train_config['dataset_id']),
        local_dataset_dir)

    model_id = "{}_{}".format(
        train_config['model_id_prefix'],
        datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'))
    model_dir = Path(tmp_directory, 'models', model_id)
    model_dir.mkdir(parents=True)

    plots_dir = Path(model_dir, 'plots')
    plots_dir.mkdir(parents=True)

    logs_dir = Path(model_dir, 'logs')
    logs_dir.mkdir(parents=True)

    with Path(local_dataset_dir, train_config['dataset_id'],
              'config.yaml').open('r') as f:
        dataset_config = yaml.safe_load(f)['dataset_config']

    with Path(model_dir, 'config.yaml').open('w') as f:
        yaml.safe_dump({'train_config': train_config}, f)

    target_size = dataset_config['target_size']
    batch_size = train_config['batch_size']
    epochs = train_config['epochs']

    train_generator = ImagesAndMasksGenerator(
        Path(local_dataset_dir, train_config['dataset_id'],
             'train').as_posix(),
        rescale=1. / 255,
        target_size=target_size,
        batch_size=batch_size,
        shuffle=True,
        random_rotation=train_config['data_augmentation']
        ['random_90-degree_rotations'],
        seed=None if 'training_data_shuffle_seed' not in train_config else
        train_config['training_data_shuffle_seed'])

    validation_generator = ImagesAndMasksGenerator(
        Path(local_dataset_dir, train_config['dataset_id'],
             'validation').as_posix(),
        rescale=1. / 255,
        target_size=target_size,
        batch_size=batch_size,
        seed=None if 'validation_data_shuffle_seed' not in train_config else
        train_config['validation_data_shuffle_seed'])

    compiled_model = generate_compiled_segmentation_model(
        train_config['segmentation_model']['model_name'],
        train_config['segmentation_model']['model_parameters'],
        len(train_generator.mask_filenames), train_config['loss'],
        train_config['optimizer'])

    model_checkpoint_callback = ModelCheckpoint(Path(model_dir,
                                                     'model.hdf5').as_posix(),
                                                monitor='loss',
                                                verbose=1,
                                                save_best_only=True)
    # profile_batch = 0 is needed until insufficinet privileges issue resolved with CUPTI
    #   (_https://github.com/tensorflow/tensorflow/issues/35860)
    tensorboard_callback = TensorBoard(log_dir=logs_dir.as_posix(),
                                       write_graph=True,
                                       write_grads=False,
                                       write_images=True,
                                       update_freq='epoch',
                                       profile_batch=0)

    n_sample_images = 20
    train_image_and_mask_paths = sample_image_and_mask_paths(
        train_generator, n_sample_images)
    validation_image_and_mask_paths = sample_image_and_mask_paths(
        validation_generator, n_sample_images)

    csv_logger_callback = CSVLogger(Path(model_dir, 'metrics.csv').as_posix(),
                                    append=True)

    results = compiled_model.fit(train_generator,
                                 steps_per_epoch=len(train_generator),
                                 epochs=epochs,
                                 validation_data=validation_generator,
                                 validation_steps=len(validation_generator),
                                 callbacks=[
                                     model_checkpoint_callback,
                                     tensorboard_callback, csv_logger_callback
                                 ])

    # individual plots
    metric_names = ['loss'] + [m.name for m in compiled_model.metrics]
    for metric_name in metric_names:
        fig, ax = plt.subplots()
        for split in ['train', 'validate']:
            key_name = metric_name
            if split == 'validate':
                key_name = 'val_' + key_name
            ax.plot(range(epochs), results.history[key_name], label=split)
        ax.set_xlabel('epochs')
        if metric_name == 'loss' and hasattr(compiled_model.loss, '__name__'):
            ax.set_ylabel(compiled_model.loss.__name__)
        elif metric_name == 'loss' and hasattr(compiled_model.loss, 'name'):
            ax.set_ylabel(compiled_model.loss.name)
        else:
            ax.set_ylabel(metric_name)
        ax.legend()
        if metric_name == 'loss' and hasattr(compiled_model.loss, '__name__'):
            fig.savefig(
                Path(plots_dir,
                     compiled_model.loss.__name__ + '.png').as_posix())
        elif metric_name == 'loss' and hasattr(compiled_model.loss, 'name'):
            fig.savefig(
                Path(plots_dir, compiled_model.loss.name + '.png').as_posix())
        else:
            fig.savefig(Path(plots_dir, metric_name + '.png').as_posix())
    plt.close()

    # mosaic of subplot
    if len(train_generator.mask_filenames) == 1:
        num_rows = 2
    else:  # 1 row for all classes, 1 row for each of n classes
        num_rows = len(train_generator.mask_filenames) + 1
    num_cols = np.ceil(len(metric_names) / num_rows).astype(int)
    fig2, axes = plt.subplots(nrows=num_rows,
                              ncols=num_cols,
                              figsize=(num_cols * 3.25, num_rows * 3.25))
    counter_m = 0
    counter_n = 0
    for metric_name in metric_names:
        for split in ['train', 'validate']:
            key_name = metric_name
            if split == 'validate':
                key_name = 'val_' + key_name
            axes[counter_m, counter_n].plot(range(epochs),
                                            results.history[key_name],
                                            label=split)
        axes[counter_m, counter_n].set_xlabel('epochs')
        if metric_name == 'loss' and hasattr(compiled_model.loss, '__name__'):
            axes[counter_m, counter_n].set_ylabel(compiled_model.loss.__name__)
        elif metric_name == 'loss' and hasattr(compiled_model.loss, 'name'):
            axes[counter_m, counter_n].set_ylabel(compiled_model.loss.name)
        else:
            axes[counter_m, counter_n].set_ylabel(metric_name)
        axes[counter_m, counter_n].legend()
        counter_n += 1
        if counter_n == num_cols:  # plots per row
            counter_m += 1
            counter_n = 0
    fig2.tight_layout()
    fig2.savefig(Path(plots_dir, 'metrics_mosaic.png').as_posix())
    plt.close()

    metadata = {
        'gcp_bucket':
        gcp_bucket,
        'created_datetime':
        datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'),
        'num_classes':
        len(train_generator.mask_filenames),
        'target_size':
        target_size,
        'git_hash':
        git.Repo(search_parent_directories=True).head.object.hexsha,
        'original_config_filename':
        config_file,
        'elapsed_minutes':
        round((datetime.now() - start_dt).total_seconds() / 60, 1),
        'dataset_config':
        dataset_config,
        'global_threshold_for_metrics':
        global_threshold,
    }

    with Path(model_dir, metadata_file_name).open('w') as f:
        yaml.safe_dump(metadata, f)

    os.system("gsutil -m cp -r '{}' '{}'".format(
        Path(tmp_directory, 'models').as_posix(), gcp_bucket))

    print('\n Train/Val Metadata:')
    print(metadata)
    print('\n')

    shutil.rmtree(tmp_directory.as_posix())
def train(gcp_bucket, config_file, random_module_global_seed,
          numpy_random_global_seed, tf_random_global_seed, pretrained_model_id,
          message, metric_modelcheckpoint):

    # seed global random generators if specified; global random seeds here must be int or default None (no seed given)
    if random_module_global_seed is not None:
        random.seed(random_module_global_seed)
    if numpy_random_global_seed is not None:
        np.random.seed(numpy_random_global_seed)
    if tf_random_global_seed is not None:
        tf_random.set_seed(tf_random_global_seed)

    start_dt = datetime.now()

    with Path(config_file).open('r') as f:
        train_config = yaml.safe_load(f)['train_config']

    assert "gs://" in gcp_bucket

    # clean up the tmp directory
    try:
        shutil.rmtree(tmp_directory.as_posix())
    except FileNotFoundError:
        pass
    tmp_directory.mkdir()

    local_dataset_dir = Path(tmp_directory, 'datasets')

    copy_folder_locally_if_missing(
        os.path.join(gcp_bucket, 'datasets', train_config['dataset_id']),
        local_dataset_dir)

    local_folder_has_files(local_dataset_dir, train_config['dataset_id'])

    model_id = "{}_{}".format(
        train_config['model_id_prefix'],
        datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'))
    model_dir = Path(tmp_directory, 'models', model_id)
    model_dir.mkdir(parents=True)

    plots_dir = Path(model_dir, 'plots')
    plots_dir.mkdir(parents=True)

    logs_dir = Path(model_dir, 'logs')
    logs_dir.mkdir(parents=True)

    with Path(local_dataset_dir, train_config['dataset_id'],
              'config.yaml').open('r') as f:
        dataset_config = yaml.safe_load(f)['dataset_config']

    with Path(model_dir, 'config.yaml').open('w') as f:
        yaml.safe_dump({'train_config': train_config}, f)

    target_size = dataset_config['target_size']
    batch_size = train_config['batch_size']
    epochs = train_config['epochs']

    train_generator = ImagesAndMasksGenerator(
        Path(local_dataset_dir, train_config['dataset_id'],
             'train').as_posix(),
        rescale=1. / 255,
        target_size=target_size,
        batch_size=batch_size,
        shuffle=True,
        random_rotation=train_config['data_augmentation']
        ['random_90-degree_rotations'],
        seed=None if 'training_data_shuffle_seed' not in train_config else
        train_config['training_data_shuffle_seed'])

    validation_generator = ImagesAndMasksGenerator(
        Path(local_dataset_dir, train_config['dataset_id'],
             'validation').as_posix(),
        rescale=1. / 255,
        target_size=target_size,
        batch_size=batch_size,
        seed=None if 'validation_data_shuffle_seed' not in train_config else
        train_config['validation_data_shuffle_seed'])

    if pretrained_model_id is not None:
        # load pretrained metadata
        local_pretrained_model_dir = Path(tmp_directory, 'pretrained_models')
        copy_folder_locally_if_missing(
            os.path.join(gcp_bucket, 'models', pretrained_model_id),
            local_pretrained_model_dir)

        local_folder_has_files(local_pretrained_model_dir, pretrained_model_id)

        path_pretrained_model = Path(local_pretrained_model_dir,
                                     pretrained_model_id,
                                     "model.hdf5").as_posix()

        with Path(local_pretrained_model_dir, pretrained_model_id,
                  'config.yaml').open('r') as f:
            pretrained_model_config = yaml.safe_load(f)['train_config']

        with Path(local_pretrained_model_dir, pretrained_model_id,
                  'metadata.yaml').open('r') as f:
            pretrained_model_metadata = yaml.safe_load(f)

        pretrained_info = {
            'pretrained_model_id': pretrained_model_id,
            'pretrained_config': pretrained_model_config,
            'pretrained_metadata': pretrained_model_metadata
        }

        check_pretrained_model_compatibility(pretrained_model_config,
                                             pretrained_model_metadata,
                                             train_config, dataset_config,
                                             train_generator)

    else:
        path_pretrained_model = None
        pretrained_info = None

    compiled_model = generate_compiled_segmentation_model(
        train_config['segmentation_model']['model_name'],
        train_config['segmentation_model']['model_parameters'],
        len(train_generator.mask_filenames), train_config['loss'],
        train_config['optimizer'], path_pretrained_model)

    model_checkpoint_callback = ModelCheckpoint(Path(model_dir,
                                                     'model.hdf5').as_posix(),
                                                monitor=metric_modelcheckpoint,
                                                verbose=1,
                                                save_best_only=True)
    # profile_batch = 0 is needed until insufficinet privileges issue resolved with CUPTI
    #   (_https://github.com/tensorflow/tensorflow/issues/35860)
    tensorboard_callback = TensorBoard(log_dir=logs_dir.as_posix(),
                                       write_graph=True,
                                       write_grads=False,
                                       write_images=True,
                                       update_freq='epoch',
                                       profile_batch=0)

    n_sample_images = 20
    train_image_and_mask_paths = sample_image_and_mask_paths(
        train_generator, n_sample_images)
    validation_image_and_mask_paths = sample_image_and_mask_paths(
        validation_generator, n_sample_images)

    csv_logger_callback = CSVLogger(Path(model_dir, 'metrics.csv').as_posix(),
                                    append=True)
    time_callback = timecallback(
    )  # model_dir, plots_dir, 'metrics_epochtime.csv')

    results = compiled_model.fit(train_generator,
                                 steps_per_epoch=len(train_generator),
                                 epochs=epochs,
                                 validation_data=validation_generator,
                                 validation_steps=len(validation_generator),
                                 callbacks=[
                                     model_checkpoint_callback,
                                     tensorboard_callback, time_callback,
                                     csv_logger_callback
                                 ])

    metric_names = ['epoch_time_in_sec', 'total_elapsed_time_in_sec'
                    ] + [m.name for m in compiled_model.metrics]

    # define number of columns and rows for the mosaic plot
    if len(train_generator.mask_filenames) == 1:
        num_rows = 2
    else:  # 1 row for all classes, 1 row for each of n classes
        num_rows = len(train_generator.mask_filenames) + 1
    num_cols = np.ceil(len(metric_names) / num_rows).astype(int)

    # generate individual plots
    generate_plots(metric_names,
                   range(epochs),
                   results.history,
                   plots_dir,
                   num_rows=1,
                   num_cols=1)

    # generate mosaic plot
    generate_plots(metric_names,
                   range(epochs),
                   results.history,
                   plots_dir,
                   num_rows=num_rows,
                   num_cols=num_cols)

    metadata_sys = {
        'System_info': getSystemInfo(),
        'Lib_versions_info': getLibVersions()
    }

    metadata = {
        'message':
        message,
        'gcp_bucket':
        gcp_bucket,
        'created_datetime':
        datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'),
        'num_classes':
        len(train_generator.mask_filenames),
        'target_size':
        target_size,
        'git_hash':
        git.Repo(search_parent_directories=True).head.object.hexsha,
        'original_config_filename':
        config_file,
        'elapsed_minutes':
        round((datetime.now() - start_dt).total_seconds() / 60, 1),
        'dataset_config':
        dataset_config,
        'global_threshold_for_metrics':
        global_threshold,
        'random-module-global-seed':
        random_module_global_seed,
        'numpy_random_global_seed':
        numpy_random_global_seed,
        'tf_random_global_seed':
        tf_random_global_seed,
        'metric_modelcheckpoint':
        metric_modelcheckpoint,
        'pretrained_model_info':
        pretrained_info,
        'metadata_system':
        metadata_sys
    }

    with Path(model_dir, metadata_file_name).open('w') as f:
        yaml.safe_dump(metadata, f)

    os.system("gsutil -m cp -r '{}' '{}'".format(
        Path(tmp_directory, 'models').as_posix(), gcp_bucket))

    print('\n Train/Val Metadata:')
    print(metadata)
    print('\n')

    shutil.rmtree(tmp_directory.as_posix())
def train(gcp_bucket, config_file):

    start_dt = datetime.now()

    with Path(config_file).open('r') as f:
        train_config = yaml.safe_load(f)['train_config']

    assert "gs://" in gcp_bucket

    # clean up the tmp directory
    try:
        shutil.rmtree(tmp_directory.as_posix())
    except FileNotFoundError:
        pass
    tmp_directory.mkdir()

    local_dataset_dir = Path(tmp_directory, 'datasets')

    copy_folder_locally_if_missing(os.path.join(gcp_bucket, 'datasets', train_config['dataset_id']),
                                   local_dataset_dir)

    model_id = "{}_{}".format(train_config['model_id_prefix'], datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'))
    model_dir = Path(tmp_directory, 'models', model_id)
    model_dir.mkdir(parents=True)

    plots_dir = Path(model_dir, 'plots')
    plots_dir.mkdir(parents=True)

    logs_dir = Path(model_dir, 'logs')
    logs_dir.mkdir(parents=True)

    with Path(local_dataset_dir, train_config['dataset_id'], 'config.yaml').open('r') as f:
        dataset_config = yaml.safe_load(f)['dataset_config']

    with Path(model_dir, 'config.yaml').open('w') as f:
        yaml.safe_dump({'train_config': dataset_config}, f)

    target_size = dataset_config['target_size']
    batch_size = train_config['batch_size']
    epochs = train_config['epochs']

    train_generator = ImagesAndMasksGenerator(
        Path(local_dataset_dir, train_config['dataset_id'], 'train').as_posix(),
        rescale=1./255,
        target_size=target_size,
        batch_size=batch_size,
        shuffle=True,
        random_rotation=train_config['augmentation']['random_90-degree_rotations'],
        seed=None)

    validation_generator = ImagesAndMasksGenerator(
        Path(local_dataset_dir, train_config['dataset_id'],
             'validation').as_posix(),
        rescale=1./255,
        target_size=target_size,
        batch_size=batch_size,
        shuffle=True,
        seed=None)

    model = Unet('vgg16', input_shape=(None, None, 1), classes=len(train_generator.mask_filenames), encoder_weights=None)

    crossentropy = binary_crossentropy if len(train_generator.mask_filenames) == 1 else categorical_crossentropy
    loss_fn = crossentropy

    model.compile(optimizer=Adam(),
                  loss=loss_fn,
                  metrics=[accuracy, iou_score, jaccard_loss, dice_loss, crossentropy])

    model_checkpoint_callback = ModelCheckpoint(Path(model_dir, 'model.hdf5').as_posix(),
                                                monitor='loss', verbose=1, save_best_only=True)
    tensorboard_callback = TensorBoard(log_dir=logs_dir.as_posix(), batch_size=batch_size, write_graph=True,
                                       write_grads=False, write_images=True, update_freq='epoch')

    n_sample_images = 20
    train_image_and_mask_paths = sample_image_and_mask_paths(train_generator, n_sample_images)
    validation_image_and_mask_paths = sample_image_and_mask_paths(validation_generator, n_sample_images)

    tensorboard_image_callback = TensorBoardImage(
        log_dir=logs_dir.as_posix(),
        images_and_masks_paths=train_image_and_mask_paths + validation_image_and_mask_paths)

    csv_logger_callback = CSVLogger(Path(model_dir, 'metrics.csv').as_posix(), append=True)

    results = model.fit_generator(
        train_generator,
        steps_per_epoch=len(train_generator),
        epochs=epochs,
        validation_data=validation_generator,
        validation_steps=len(validation_generator),
        callbacks=[model_checkpoint_callback, tensorboard_callback, tensorboard_image_callback, csv_logger_callback])

    metric_names = ['loss', 'accuracy', 'iou_score', 'jaccard_loss', 'dice_loss']
    for metric_name in metric_names:

        fig, ax = plt.subplots()
        for split in ['train', 'validate']:

            key_name = metric_name
            if split == 'validate':
                key_name = 'val_' + key_name

            ax.plot(range(epochs), results.history[key_name], label=split)
        ax.set_xlabel('epochs')
        if metric_name == 'loss':
            ax.set_ylabel(loss_fn.__name__)
        else:
            ax.set_ylabel(metric_name)
        ax.legend()
        if metric_name == 'loss':
            fig.savefig(Path(plots_dir, loss_fn.__name__ + '.png').as_posix())
        else:
            fig.savefig(Path(plots_dir, metric_name + '.png').as_posix())

    metadata = {
        'gcp_bucket': gcp_bucket,
        'created_datetime': datetime.now(pytz.UTC).strftime('%Y%m%dT%H%M%SZ'),
        'num_classes': len(train_generator.mask_filenames),
        'target_size': target_size,
        'git_hash': git.Repo(search_parent_directories=True).head.object.hexsha,
        'original_config_filename': config_file,
        'elapsed_minutes': round((datetime.now() - start_dt).total_seconds() / 60, 1)
    }

    with Path(model_dir, metadata_file_name).open('w') as f:
        yaml.safe_dump(metadata, f)

    os.system("gsutil -m cp -r '{}' '{}'".format(Path(tmp_directory, 'models').as_posix(), gcp_bucket))

    shutil.rmtree(tmp_directory.as_posix())