예제 #1
0
def main():
    args, cfg = parse_args()
    checkpoint_path = find_checkpoint(cfg.MODELS_PATH, args.checkpoint)
    add_new_file_output_to_logger(logs_path=Path(cfg.EXPS_PATH) /
                                  'evaluation_logs',
                                  prefix=f'{Path(checkpoint_path).stem}_',
                                  only_message=True)
    logger.info(vars(args))

    device = torch.device(f'cuda:{args.gpu}')
    net = load_model(args.model_type, checkpoint_path, verbose=True)
    predictor = Predictor(net, device, with_flip=args.use_flip)

    datasets_names = args.datasets.split(',')
    datasets_metrics = []
    for dataset_indx, dataset_name in enumerate(datasets_names):
        dataset = HDataset(cfg.get(f'{dataset_name.upper()}_PATH'),
                           split='test',
                           augmentator=HCompose(
                               [RESIZE_STRATEGIES[args.resize_strategy]]),
                           keep_background_prob=-1)

        dataset_metrics = MetricsHub(
            [N(), MSE(), fMSE(), PSNR(),
             AvgPredictTime()], name=dataset_name)

        evaluate_dataset(dataset, predictor, dataset_metrics)
        datasets_metrics.append(dataset_metrics)
        if dataset_indx == 0:
            logger.info(dataset_metrics.get_table_header())
        logger.info(dataset_metrics)

    if len(datasets_metrics) > 1:
        overall_metrics = sum(datasets_metrics, MetricsHub([], 'Overall'))
        logger.info('-' * len(str(overall_metrics)))
        logger.info(overall_metrics)
예제 #2
0
def train(model, cfg, model_cfg, start_epoch=0):
    cfg.batch_size = 16 if cfg.batch_size < 1 else cfg.batch_size
    cfg.val_batch_size = cfg.batch_size

    cfg.input_normalization = model_cfg.input_normalization
    crop_size = model_cfg.crop_size

    loss_cfg = edict()
    loss_cfg.pixel_loss = MaskWeightedMSE()
    loss_cfg.pixel_loss_weight = 1.0

    num_epochs = 180

    train_augmentator = HCompose([
        LongestMaxSizeIfLarger(1024),
        HorizontalFlip(),
        PadIfNeeded(min_height=crop_size[0],
                    min_width=crop_size[1],
                    border_mode=0),
        RandomCrop(*crop_size)
    ])

    val_augmentator = HCompose([
        LongestMaxSizeIfLarger(1024),
        PadIfNeeded(min_height=crop_size[0],
                    min_width=crop_size[1],
                    border_mode=0),
        RandomCrop(*crop_size)
    ])

    trainset = ComposeDataset([
        HDataset(cfg.HFLICKR_PATH, split='train'),
        HDataset(cfg.HDAY2NIGHT_PATH, split='train'),
        HDataset(cfg.HCOCO_PATH, split='train'),
        HDataset(cfg.HADOBE5K_PATH, split='train'),
    ],
                              augmentator=train_augmentator,
                              input_transform=model_cfg.input_transform)

    valset = ComposeDataset([
        HDataset(cfg.HFLICKR_PATH, split='test'),
        HDataset(cfg.HDAY2NIGHT_PATH, split='test'),
        HDataset(cfg.HCOCO_PATH, split='test'),
    ],
                            augmentator=val_augmentator,
                            input_transform=model_cfg.input_transform)

    optimizer_params = {'lr': 1e-3, 'betas': (0.9, 0.999), 'eps': 1e-8}

    lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR,
                           milestones=[160, 175],
                           gamma=0.1)
    trainer = SimpleHTrainer(
        model,
        cfg,
        model_cfg,
        loss_cfg,
        trainset,
        valset,
        optimizer='adam',
        optimizer_params=optimizer_params,
        lr_scheduler=lr_scheduler,
        metrics=[
            PSNRMetric('images', 'target_images'),
            DenormalizedPSNRMetric(
                'images',
                'target_images',
                mean=torch.tensor(cfg.input_normalization['mean'],
                                  dtype=torch.float32).view(1, 3, 1, 1),
                std=torch.tensor(cfg.input_normalization['std'],
                                 dtype=torch.float32).view(1, 3, 1, 1),
            ),
            DenormalizedMSEMetric(
                'images',
                'target_images',
                mean=torch.tensor(cfg.input_normalization['mean'],
                                  dtype=torch.float32).view(1, 3, 1, 1),
                std=torch.tensor(cfg.input_normalization['std'],
                                 dtype=torch.float32).view(1, 3, 1, 1),
            )
        ],
        checkpoint_interval=5,
        image_dump_interval=500)

    logger.info(f'Starting Epoch: {start_epoch}')
    logger.info(f'Total Epochs: {num_epochs}')
    for epoch in range(start_epoch, num_epochs):
        trainer.training(epoch)
        trainer.validation(epoch)
예제 #3
0
def main():
    args, cfg = parse_args()
    checkpoint_path = find_checkpoint(cfg.MODELS_PATH, args.checkpoint)
    add_new_file_output_to_logger(logs_path=Path(cfg.EXPS_PATH) /
                                  'evaluation_results',
                                  prefix=f'{Path(checkpoint_path).stem}_',
                                  only_message=True)
    logger.info(vars(args))

    device = torch.device(f'cuda:{args.gpu}')
    net = load_model(args.model_type, checkpoint_path, verbose=True)
    predictor = Predictor(net, device, with_flip=args.use_flip)

    fg_ratio_intervals = [(0.0, 0.05), (0.05, 0.15), (0.15, 1.0), (0.0, 1.00)]

    datasets_names = args.datasets.split(',')
    datasets_metrics = [[] for _ in fg_ratio_intervals]
    for dataset_indx, dataset_name in enumerate(datasets_names):
        dataset = HDataset(cfg.get(f'{dataset_name.upper()}_PATH'),
                           split='test',
                           augmentator=HCompose(
                               [RESIZE_STRATEGIES[args.resize_strategy]]),
                           keep_background_prob=-1)

        dataset_metrics = []
        for fg_ratio_min, fg_ratio_max in fg_ratio_intervals:
            dataset_metrics.append(
                MetricsHub(
                    [N(), MSE(), fMSE(), PSNR()],
                    name=
                    f'{dataset_name} ({fg_ratio_min:.0%}-{fg_ratio_max:.0%})',
                    name_width=28))

        for sample_i in trange(len(dataset),
                               desc=f'Testing on {dataset_name}'):
            sample = dataset.get_sample(sample_i)
            sample = dataset.augment_sample(sample)

            sample_mask = sample['object_mask']
            sample_fg_ratio = (sample_mask > 0.5).sum() / (
                sample_mask.shape[0] * sample_mask.shape[1])
            pred = predictor.predict(sample['image'],
                                     sample_mask,
                                     return_numpy=False)

            target_image = torch.as_tensor(sample['target_image'],
                                           dtype=torch.float32).to(
                                               predictor.device)
            sample_mask = torch.as_tensor(sample_mask, dtype=torch.float32).to(
                predictor.device)
            with torch.no_grad():
                for metrics_hub, (fg_ratio_min,
                                  fg_ratio_max) in zip(dataset_metrics,
                                                       fg_ratio_intervals):
                    if fg_ratio_min <= sample_fg_ratio <= fg_ratio_max:
                        metrics_hub.compute_and_add(pred, target_image,
                                                    sample_mask)

        for indx, metrics_hub in enumerate(dataset_metrics):
            datasets_metrics[indx].append(metrics_hub)
        if dataset_indx == 0:
            logger.info(dataset_metrics[-1].get_table_header())
        for metrics_hub in dataset_metrics:
            logger.info(metrics_hub)

    if len(datasets_metrics) > 1:
        overall_metrics = [
            sum(
                x,
                MetricsHub([],
                           f'Overall ({fg_ratio_min:.0%}-{fg_ratio_max:.0%})',
                           name_width=28))
            for x, (fg_ratio_min,
                    fg_ratio_max) in zip(datasets_metrics, fg_ratio_intervals)
        ]
        logger.info('-' * len(str(overall_metrics[-1])))
        for x in overall_metrics:
            logger.info(x)
def train(model, cfg, model_cfg, start_epoch=0):
    cfg.batch_size = 16 if cfg.batch_size < 1 else cfg.batch_size
    cfg.val_batch_size = cfg.batch_size
    cfg.input_normalization = model_cfg.input_normalization

    loss_cfg = edict()
    loss_cfg.pixel_loss = MaskWeightedMSE(min_area=100)
    loss_cfg.pixel_loss_weight = 1.0

    num_epochs = 120
    train_augmentator = HCompose(
        [RandomResizedCrop(256, 256, scale=(0.5, 1.0)),
         HorizontalFlip()])

    val_augmentator = HCompose([Resize(256, 256)])

    trainset = ComposeDataset([
        HDataset(cfg.HFLICKR_PATH, split='train'),
        HDataset(cfg.HDAY2NIGHT_PATH, split='train'),
        HDataset(cfg.HCOCO_PATH, split='train'),
        HDataset(cfg.HADOBE5K_PATH, split='train'),
    ],
                              augmentator=train_augmentator,
                              input_transform=model_cfg.input_transform,
                              keep_background_prob=0.05)

    valset = ComposeDataset([
        HDataset(cfg.HFLICKR_PATH, split='test'),
        HDataset(cfg.HDAY2NIGHT_PATH, split='test'),
        HDataset(cfg.HCOCO_PATH, split='test'),
    ],
                            augmentator=val_augmentator,
                            input_transform=model_cfg.input_transform,
                            keep_background_prob=-1)

    optimizer_params = {'lr': 1e-3, 'betas': (0.9, 0.999), 'eps': 1e-8}

    lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR,
                           milestones=[105, 115],
                           gamma=0.1)
    trainer = SimpleHTrainer(
        model,
        cfg,
        model_cfg,
        loss_cfg,
        trainset,
        valset,
        optimizer='adam',
        optimizer_params=optimizer_params,
        lr_scheduler=lr_scheduler,
        metrics=[
            DenormalizedPSNRMetric(
                'images',
                'target_images',
                mean=torch.tensor(cfg.input_normalization['mean'],
                                  dtype=torch.float32).view(1, 3, 1, 1),
                std=torch.tensor(cfg.input_normalization['std'],
                                 dtype=torch.float32).view(1, 3, 1, 1),
            ),
            DenormalizedMSEMetric(
                'images',
                'target_images',
                mean=torch.tensor(cfg.input_normalization['mean'],
                                  dtype=torch.float32).view(1, 3, 1, 1),
                std=torch.tensor(cfg.input_normalization['std'],
                                 dtype=torch.float32).view(1, 3, 1, 1),
            )
        ],
        checkpoint_interval=10,
        image_dump_interval=1000)

    logger.info(f'Starting Epoch: {start_epoch}')
    logger.info(f'Total Epochs: {num_epochs}')
    for epoch in range(start_epoch, num_epochs):
        trainer.training(epoch)
        trainer.validation(epoch)