def train(model, cfg, model_cfg, start_epoch=0): cfg.batch_size = 16 if cfg.batch_size < 1 else cfg.batch_size cfg.val_batch_size = cfg.batch_size cfg.input_normalization = model_cfg.input_normalization crop_size = model_cfg.crop_size loss_cfg = edict() loss_cfg.pixel_loss = MaskWeightedMSE() loss_cfg.pixel_loss_weight = 1.0 num_epochs = 180 train_augmentator = HCompose([ LongestMaxSizeIfLarger(1024), HorizontalFlip(), PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), RandomCrop(*crop_size) ]) val_augmentator = HCompose([ LongestMaxSizeIfLarger(1024), PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), RandomCrop(*crop_size) ]) trainset = ComposeDataset([ HDataset(cfg.HFLICKR_PATH, split='train'), HDataset(cfg.HDAY2NIGHT_PATH, split='train'), HDataset(cfg.HCOCO_PATH, split='train'), HDataset(cfg.HADOBE5K_PATH, split='train'), ], augmentator=train_augmentator, input_transform=model_cfg.input_transform) valset = ComposeDataset([ HDataset(cfg.HFLICKR_PATH, split='test'), HDataset(cfg.HDAY2NIGHT_PATH, split='test'), HDataset(cfg.HCOCO_PATH, split='test'), ], augmentator=val_augmentator, input_transform=model_cfg.input_transform) optimizer_params = {'lr': 1e-3, 'betas': (0.9, 0.999), 'eps': 1e-8} lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, milestones=[160, 175], gamma=0.1) trainer = SimpleHTrainer( model, cfg, model_cfg, loss_cfg, trainset, valset, optimizer='adam', optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, metrics=[ PSNRMetric('images', 'target_images'), DenormalizedPSNRMetric( 'images', 'target_images', mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1), std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1), ), DenormalizedMSEMetric( 'images', 'target_images', mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1), std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1), ) ], checkpoint_interval=5, image_dump_interval=500) logger.info(f'Starting Epoch: {start_epoch}') logger.info(f'Total Epochs: {num_epochs}') for epoch in range(start_epoch, num_epochs): trainer.training(epoch) trainer.validation(epoch)
def train(model, cfg, model_cfg, start_epoch=0): cfg.batch_size = 16 if cfg.batch_size < 1 else cfg.batch_size cfg.val_batch_size = cfg.batch_size cfg.input_normalization = model_cfg.input_normalization loss_cfg = edict() loss_cfg.pixel_loss = MaskWeightedMSE(min_area=100) loss_cfg.pixel_loss_weight = 1.0 num_epochs = 120 train_augmentator = HCompose( [RandomResizedCrop(256, 256, scale=(0.5, 1.0)), HorizontalFlip()]) val_augmentator = HCompose([Resize(256, 256)]) trainset = ComposeDataset([ HDataset(cfg.HFLICKR_PATH, split='train'), HDataset(cfg.HDAY2NIGHT_PATH, split='train'), HDataset(cfg.HCOCO_PATH, split='train'), HDataset(cfg.HADOBE5K_PATH, split='train'), ], augmentator=train_augmentator, input_transform=model_cfg.input_transform, keep_background_prob=0.05) valset = ComposeDataset([ HDataset(cfg.HFLICKR_PATH, split='test'), HDataset(cfg.HDAY2NIGHT_PATH, split='test'), HDataset(cfg.HCOCO_PATH, split='test'), ], augmentator=val_augmentator, input_transform=model_cfg.input_transform, keep_background_prob=-1) optimizer_params = {'lr': 1e-3, 'betas': (0.9, 0.999), 'eps': 1e-8} lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, milestones=[105, 115], gamma=0.1) trainer = SimpleHTrainer( model, cfg, model_cfg, loss_cfg, trainset, valset, optimizer='adam', optimizer_params=optimizer_params, lr_scheduler=lr_scheduler, metrics=[ DenormalizedPSNRMetric( 'images', 'target_images', mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1), std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1), ), DenormalizedMSEMetric( 'images', 'target_images', mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1), std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1), ) ], checkpoint_interval=10, image_dump_interval=1000) logger.info(f'Starting Epoch: {start_epoch}') logger.info(f'Total Epochs: {num_epochs}') for epoch in range(start_epoch, num_epochs): trainer.training(epoch) trainer.validation(epoch)