class Validator: def __init__(self, config, val: DataLoader): self.config = config self.val_dataset = val self.metric_counter = MetricCounter() def validate(self): self._init_params() self.netG.load_state_dict(torch.load('best_G_fpn.h5')['model']) self.netG.train(True) self._validate() torch.cuda.empty_cache() print(self.metric_counter.loss_message()) def _validate(self): self.metric_counter.clear() epoch_size = config.get('val_batches_per_epoch') or len( self.train_dataset) tq = tqdm.tqdm(self.val_dataset, total=epoch_size) tq.set_description('Validation') i = 0 total_psnr = 0 total_ssim = 0 total_samples = 0 for data in tq: inputs, targets = self.model.get_input(data) inputs, targets = inputs.cuda(), targets.cuda() outputs = self.netG(inputs) curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics( inputs, outputs, targets) total_ssim += curr_ssim * len(inputs) total_psnr += curr_psnr * len(inputs) total_samples += len(inputs) print("Metrcis:", curr_psnr, curr_ssim) print("Totals:", total_ssim, total_psnr, total_samples) print("nan:", np.isnan(img_for_vis).any()) print("inf:", np.isinf(img_for_vis).any()) self.metric_counter.add_metrics(curr_psnr, curr_ssim) if not i: self.metric_counter.add_image(img_for_vis, tag='val') i += 1 if i > epoch_size: break self.metric_counter.write_to_tensorboard(i, validation=True) del inputs, targets, outputs print("PSNR", total_psnr / total_samples) print("SSIM", total_ssim / total_samples) tq.close() def _init_params(self): self.netG, netD = get_nets(self.config['model']) self.netG.cuda() self.model = get_model(self.config['model'])
class Trainer: def __init__(self, config, train: DataLoader, val: DataLoader): self.config = config self.train_dataset = train self.val_dataset = val self.adv_lambda = config['model']['adv_lambda'] self.metric_counter = MetricCounter(config['experiment_desc']) self.warmup_epochs = config['warmup_num'] gpu_id = self.config['gpu_id'] self.device = torch.device('cuda:{}'.format(gpu_id) if (torch.cuda.is_available() and gpu_id > 0) else "cpu") def train(self): self._init_params() for epoch in range(0, config['num_epochs']): if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0): self.netG.module.unfreeze() self.optimizer_G = self._get_optim(self.netG.parameters()) self.scheduler_G = self._get_scheduler(self.optimizer_G) self._run_epoch(epoch) self._validate(epoch) self.scheduler_G.step() self.scheduler_D.step() if self.metric_counter.update_best_model(): torch.save({ 'model': self.netG.state_dict() }, 'best_{}.h5'.format(self.config['experiment_desc'])) torch.save({ 'model': self.netG.state_dict() }, 'last_{}.h5'.format(self.config['experiment_desc'])) print(self.metric_counter.loss_message()) logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % ( self.config['experiment_desc'], epoch, self.metric_counter.loss_message())) def _run_epoch(self, epoch): self.metric_counter.clear() for param_group in self.optimizer_G.param_groups: lr = param_group['lr'] epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset) tq = tqdm.tqdm(self.train_dataset, total=epoch_size) tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) i = 0 for data in tq: inputs, targets = self.model.get_input(data) outputs = self.netG(inputs) loss_D = self._update_d(outputs, targets) self.optimizer_G.zero_grad() loss_content = self.criterionG(outputs, targets) loss_adv = self.adv_trainer.loss_g(outputs, targets) loss_G = loss_content + self.adv_lambda * loss_adv loss_G.backward() self.optimizer_G.step() self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D) curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) self.metric_counter.add_metrics(curr_psnr, curr_ssim) tq.set_postfix(loss=self.metric_counter.loss_message()) if not i: self.metric_counter.add_image(img_for_vis, tag='train') i += 1 if i > epoch_size: break tq.close() self.metric_counter.write_to_tensorboard(epoch) def _validate(self, epoch): self.metric_counter.clear() epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset) tq = tqdm.tqdm(self.val_dataset, total=epoch_size) tq.set_description('Validation') i = 0 for data in tq: inputs, targets = self.model.get_input(data) outputs = self.netG(inputs) loss_content = self.criterionG(outputs, targets) loss_adv = self.adv_trainer.loss_g(outputs, targets) loss_G = loss_content + self.adv_lambda * loss_adv self.metric_counter.add_losses(loss_G.item(), loss_content.item()) curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) self.metric_counter.add_metrics(curr_psnr, curr_ssim) if not i: self.metric_counter.add_image(img_for_vis, tag='val') i += 1 if i > epoch_size: break tq.close() self.metric_counter.write_to_tensorboard(epoch, validation=True) def _update_d(self, outputs, targets): if self.config['model']['d_name'] == 'no_gan': return 0 self.optimizer_D.zero_grad() loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets) loss_D.backward(retain_graph=True) self.optimizer_D.step() return loss_D.item() def _get_optim(self, params): if self.config['optimizer']['name'] == 'adam': optimizer = optim.Adam(params, lr=self.config['optimizer']['lr']) elif self.config['optimizer']['name'] == 'sgd': optimizer = optim.SGD(params, lr=self.config['optimizer']['lr']) elif self.config['optimizer']['name'] == 'adadelta': optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr']) else: raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name']) return optimizer def _get_scheduler(self, optimizer): if self.config['scheduler']['name'] == 'plateau': scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=self.config['scheduler']['patience'], factor=self.config['scheduler']['factor'], min_lr=self.config['scheduler']['min_lr']) elif self.config['optimizer']['name'] == 'sgdr': scheduler = WarmRestart(optimizer) elif self.config['scheduler']['name'] == 'linear': scheduler = LinearDecay(optimizer, min_lr=self.config['scheduler']['min_lr'], num_epochs=self.config['num_epochs'], start_epoch=self.config['scheduler']['start_epoch']) else: raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name']) return scheduler @staticmethod def _get_adversarial_trainer(d_name, net_d, criterion_d): if d_name == 'no_gan': return GANFactory.create_model('NoGAN') elif d_name == 'patch_gan' or d_name == 'multi_scale': return GANFactory.create_model('SingleGAN', net_d, criterion_d) elif d_name == 'double_gan': return GANFactory.create_model('DoubleGAN', net_d, criterion_d) else: raise ValueError("Discriminator Network [%s] not recognized." % d_name) def _init_params(self): self.criterionG, criterionD = get_loss(self.config['model']) self.netG, netD = get_nets(self.config['model']) self.netG.to(self.device) self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD) self.model = get_model(self.config['model']) self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters())) self.optimizer_D = self._get_optim(self.adv_trainer.get_params()) self.scheduler_G = self._get_scheduler(self.optimizer_G) self.scheduler_D = self._get_scheduler(self.optimizer_D)