Ejemplo n.º 1
0
def find_best_gan(weights_dir_path: Path):
    params = make_params(Path('params.yaml'), Models.GAN)

    _, val_ds = make_training_data(params['load']['dataset'],
                                   params['load']['input_shape'],
                                   params['load']['batch_size'],
                                   params['load']['validation_split'], {
                                       'equalize_hist': False,
                                       'artificial_lr': False
                                   }, params['load']['limit_per_scene'])

    test_ds = make_test_data(params['load']['dataset'],
                             params['load']['input_shape'],
                             params['load']['batch_size'])

    val_ds = make_validation_data(params['load']['dataset'],
                                  params['load']['input_shape'],
                                  params['load']['batch_size'],
                                  params['load']['validation_split'],
                                  params['load']['limit_per_scene'])

    model = make_model(Models.GAN,
                       params['load']['input_shape'],
                       use_lr_masks=False)

    best_ssim, best_weights_file_path = find_best_weights(
        model, weights_dir_path, val_ds)

    print(
        f'Best SSIM: {best_ssim}, for weights file: {best_weights_file_path}.')
Ejemplo n.º 2
0
def test(model_type: Models, weights_path: Path, output_dir: Path):
    params = make_params(Path('params.yaml'), model_type)
    name = extract_name_from_weights_path(weights_path)

    model = make_model(model_type,
                       params['load']['input_shape'],
                       use_lr_masks=False)
    model.load_weights(weights_path)

    output_dir = output_dir / f'test-{name}'
    output_dir.mkdir(parents=True, exist_ok=True)

    test_ds = make_test_data(params['load']['dataset'],
                             params['load']['input_shape'],
                             params['load']['batch_size'])

    lr_sets_labels = ['real']
    lr_sets = [test_ds.to_lr_array()]

    lr_prediction_labels = ['pred']
    print('Running test inference:')
    lr_preds = model.predict(test_ds, verbose=1)
    lr_sets_labels += lr_prediction_labels
    lr_sets.append(lr_preds)

    print('Creating artificial datasets.')
    artificial_dss = make_artificial_datasets(params['load'])
    lr_artificial_labels = [mode.name for mode in InterpolationMode]
    lr_artificials = [ads.to_lr_array() for ads in artificial_dss]
    lr_sets_labels += lr_artificial_labels
    lr_sets += lr_artificials

    print('Creating metrics heatmaps.')
    figs = make_heatmaps(lr_sets, lr_sets_labels)
    for fig_key in figs:
        figs[fig_key].savefig(output_dir / f'{fig_key}_heatmap-{name}.png',
                              dpi=300)

    batch, sample_in_batch = 0, 0
    hr_img = test_ds[batch][SampleEl.HR][sample_in_batch]
    lr_img = test_ds[batch][SampleEl.LR][sample_in_batch]
    pred_img = lr_preds[batch * params['load']['batch_size'] + sample_in_batch]

    print('Creating comparison figures.')
    fig = make_comparison_fig(hr_img, lr_img, pred_img, add_resized_lr=True)
    fig.savefig(output_dir / f'test_preview-{name}.png', dpi=300)

    lr_margin = 30
    fig = make_comparison_fig(crop_border(hr_img, lr_margin * 3),
                              crop_border(lr_img, lr_margin),
                              crop_border(pred_img, lr_margin),
                              add_resized_lr=True)
    fig.savefig(output_dir / f'test_preview_zoomed-{name}.png', dpi=300)
    print('Figures generation done, saving.')
Ejemplo n.º 3
0
def train(model_type: Models, training_name: str):
    params = make_params(Path('params.yaml'), model_type)

    model = make_model(model_type,
                       params['load']['input_shape'],
                       name=f'{model_type.name}-{training_name}',
                       use_lr_masks=params['train']['use_lr_masks'])

    train_ds, val_ds = make_training_data(params['load']['dataset'],
                                          params['load']['input_shape'],
                                          params['load']['batch_size'],
                                          params['load']['validation_split'],
                                          params['load']['preprocess'],
                                          params['load']['limit_per_scene'])

    if type(model) is not Gan:
        model.get_functional().summary()

    training = Training(model, params['train']['lr'], params['train']['loss'])
    training.make_callbacks(params['train']['callbacks'])

    training.train(train_ds, val_ds, params['train']['epochs'])
Ejemplo n.º 4
0
def main():
    enable_gpu_if_possible()

    parser = argparse.ArgumentParser(description='Export Sentinel-2 dataset.')

    model_selection = parser.add_mutually_exclusive_group()
    model_selection.add_argument('-s', '--simple', action='store_true',
                                 help='Export using simple conv net.')
    model_selection.add_argument('-u', '--unet', action='store_true',
                                 help='Export using Unet net.')
    model_selection.add_argument('-g', '--gan', action='store_true',
                                 help='Export using GAN net.')

    parser.add_argument('-n', '--noise', action='store_true',
                        help='Add noise to HR before augmentation.')
    parser.add_argument('-m', '--match_hist', action='store_true',
                        help='Match HR hist with real LR before augmentation.')
    parser.add_argument('-d', '--demo', action='store_true',
                        help='Don\'t export dataset, demo inference.')
    parser.add_argument('weights_path')

    args = parser.parse_args()

    if args.simple:
        model_type = Models.SIMPLE_CONV
    elif args.unet:
        model_type = Models.UNET
    elif args.gan:
        model_type = Models.GAN

    params = make_params(Path('params.yaml'), model_type)
    weights_path = Path(args.weights_path)
    add_noise = args.noise
    match_hist = args.match_hist

    if args.weights_path != '':
        model = make_model(
            model_type,
            params['load']['input_shape'],
            use_lr_masks=False)
        model.load_weights(args.weights_path)
    else:
        model = model_transform_to_lr_bicubic

    if args.demo:
        demo_hr_path = Path(
            'data/'
            'proba-v_registered_b/'
            'train/'
            'NIR/'
            'imgset0648/'
            'HR000.png')
        demo_export(model, demo_hr_path, add_noise)
    else:
        unregisterd_proba_path = Path('data/proba-v')
        registerd_proba_path = Path('data/proba-v_registered_b')
        if args.weights_path != '':
            suffix = str(weights_path.stem)
        else:
            suffix = 'bicubic'
        transform_proba_dataset_3xlrs(
            model, unregisterd_proba_path, registerd_proba_path, suffix, add_noise, match_hist)
Ejemplo n.º 5
0
def main():
    enable_gpu_if_possible()

    parser = argparse.ArgumentParser(description='Export Sentinel-2 dataset.')

    model_selection = parser.add_mutually_exclusive_group()
    model_selection.add_argument('-s',
                                 '--simple',
                                 action='store_true',
                                 help='Export using simple conv net.')
    model_selection.add_argument('-u',
                                 '--unet',
                                 action='store_true',
                                 help='Export using Unet net.')
    model_selection.add_argument('-g',
                                 '--gan',
                                 action='store_true',
                                 help='Export using GAN net.')

    parser.add_argument('-r',
                        '--random_translations',
                        action='store_true',
                        help='Generate random LR translations instead of using'
                        ' transaltions file.')
    parser.add_argument('-d',
                        '--demo',
                        action='store_true',
                        help='Don\'t export dataset, demo inference.')
    parser.add_argument('weights_path')

    args = parser.parse_args()

    if args.simple:
        model_type = Models.SIMPLE_CONV
    elif args.unet:
        model_type = Models.UNET
    elif args.gan:
        model_type = Models.GAN

    params = make_params(Path('params.yaml'), model_type)
    weights_path = Path(args.weights_path)

    model = make_model(model_type,
                       params['load']['input_shape'],
                       use_lr_masks=False)
    model.load_weights(args.weights_path)

    if args.demo:
        demo_hr_path = Path(
            'data/'
            'sentinel-2_artificial/'
            'S2B_MSIL1C_20200806T105619_N0209_R094_T30TWP_20200806T121751/'
            '08280x10440/'
            'b8/'
            'hr.png')
        demo_export(model, demo_hr_path)
    else:
        sentinel_root_path = Path('data/sentinel-2_artificial/')
        suffix = str(weights_path.stem)
        transform_sentinel_dataset_3xlrs(model, sentinel_root_path, suffix,
                                         args.random_translations)