def _valid_epoch(self, epoch): """ Validate after training an epoch :return: A log that contains information about validation Note: The validation metrics in log must have the key 'val_metrics'. """ self.generator.eval() self.discriminator.eval() total_val_loss = 0 total_val_metrics = np.zeros(len(self.metrics)) with torch.no_grad(): for batch_idx, sample in enumerate(self.valid_data_loader): blurred = sample['blurred'].to(self.device) sharp = sample['sharp'].to(self.device) deblurred = self.generator(blurred) deblurred_discriminator_out = self.discriminator(deblurred) content_loss_lambda = self.config['others'][ 'content_loss_lambda'] kwargs = { 'deblurred_discriminator_out': deblurred_discriminator_out } adversarial_loss_g = self.adversarial_loss('G', **kwargs) content_loss_g = self.content_loss(deblurred, sharp) * content_loss_lambda loss_g = adversarial_loss_g + content_loss_g self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.writer.add_scalar('adversarial_loss_g', adversarial_loss_g.item()) self.writer.add_scalar('content_loss_g', content_loss_g.item()) self.writer.add_scalar('loss_g', loss_g.item()) total_val_loss += loss_g.item() total_val_metrics += self._eval_metrics( denormalize(deblurred), denormalize(sharp)) # add histogram of model parameters to the tensorboard for name, p in self.generator.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return { 'val_loss': total_val_loss / len(self.valid_data_loader), 'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist() }
def main(blurred_dir, deblurred_dir, resume): # load checkpoint checkpoint = torch.load(resume) config = checkpoint['config'] # setup data_loader instances data_loader = CustomDataLoader(data_dir=blurred_dir) # build model architecture generator_class = getattr(module_arch, config['generator']['type']) generator = generator_class(**config['generator']['args']) # prepare model for deblurring device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') generator.to(device) generator.load_state_dict(checkpoint['generator']) generator.eval() # start to deblur with torch.no_grad(): for batch_idx, sample in enumerate(tqdm(data_loader, ascii=True)): blurred = sample['blurred'].to(device) image_name = sample['image_name'][0] deblurred = generator(blurred) deblurred_img = to_pil_image( denormalize(deblurred).squeeze().cpu()) deblurred_img.save( os.path.join(deblurred_dir, 'deblurred ' + image_name))
def fix_image(img): if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False): img = img.unsqueeze(dim=1) # Normalize so spectrogram is easier to view. img = (img - img.mean()) / img.std() if img.shape[1] > 3: img = img[:, :3, :, :] if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False): img = (img + 1) / 2 if opt_get(self.opt, ['logger', 'reverse_imagenet_norm'], False): img = denormalize(img) return img
def main(blurred_image, resume): # load checkpoint checkpoint = torch.load(resume) config = checkpoint['config'] # setup data_loader instances #data_loader = CustomDataLoader(data_dir=blurred_dir) # build model architecture generator_class = getattr(module_arch, config['generator']['type']) generator = generator_class(**config['generator']['args']) # prepare model for deblurring device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') generator.to(device) if config['n_gpu'] > 1: generator = torch.nn.DataParallel(generator) generator.load_state_dict(checkpoint['generator']) generator.eval() # start to deblur with torch.no_grad(): blurred = Image.open(blurred_image).convert('RGB') h = blurred.size[1] w = blurred.size[0] new_h = h - h % 4 + 4 if h % 4 != 0 else h new_w = w - w % 4 + 4 if w % 4 != 0 else w blurred = transforms.Resize([new_h, new_w], Image.BICUBIC)(blurred) transform = transforms.Compose([ transforms.ToTensor(), # convert to tensor transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) blurred = transform(blurred) blurred.unsqueeze_(0) print(blurred.shape) blurred = blurred.to(device) deblurred = generator(blurred) deblurred_img = to_pil_image(denormalize(deblurred).squeeze().cpu()) deblurred_img.save("./deblurred.png")
def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Current training epoch. :return: A log that contains all information you want to save. Note: If you have additional information to record, for example: > additional_log = {"x": x, "y": y} merge it with log before return. i.e. > log = {**log, **additional_log} > return log The metrics in log must have the key 'metrics'. """ # set models to train mode self.generator.train() self.discriminator.train() total_generator_loss = 0 total_discriminator_loss = 0 total_metrics = np.zeros(len(self.metrics)) for batch_idx, sample in enumerate(self.data_loader): self.writer.set_step((epoch - 1) * len(self.data_loader) + batch_idx) # get data and send them to GPU blurred = sample['blurred'].to(self.device) sharp = sample['sharp'].to(self.device) # get G's output deblurred = self.generator(blurred) # denormalize with torch.no_grad(): denormalized_blurred = denormalize(blurred) denormalized_sharp = denormalize(sharp) denormalized_deblurred = denormalize(deblurred) if batch_idx % 100 == 0: # save blurred, sharp and deblurred image self.writer.add_image('blurred', make_grid(denormalized_blurred.cpu())) self.writer.add_image('sharp', make_grid(denormalized_sharp.cpu())) self.writer.add_image('deblurred', make_grid(denormalized_deblurred.cpu())) # get D's output sharp_discriminator_out = self.discriminator(sharp) deblurred_discriminator_out = self.discriminator(deblurred) # set critic_updates if self.config['loss']['adversarial'] == 'wgan_gp_loss': critic_updates = 5 else: critic_updates = 1 # train discriminator discriminator_loss = 0 for i in range(critic_updates): self.discriminator_optimizer.zero_grad() # train discriminator on real and fake if self.config['loss']['adversarial'] == 'wgan_gp_loss': gp_lambda = self.config['others']['gp_lambda'] alpha = random.random() interpolates = alpha * sharp + (1 - alpha) * deblurred interpolates_discriminator_out = self.discriminator( interpolates) kwargs = { 'gp_lambda': gp_lambda, 'interpolates': interpolates, 'interpolates_discriminator_out': interpolates_discriminator_out, 'sharp_discriminator_out': sharp_discriminator_out, 'deblurred_discriminator_out': deblurred_discriminator_out } wgan_loss_d, gp_d = self.adversarial_loss('D', **kwargs) discriminator_loss_per_update = wgan_loss_d + gp_d self.writer.add_scalar('wgan_loss_d', wgan_loss_d.item()) self.writer.add_scalar('gp_d', gp_d.item()) elif self.config['loss']['adversarial'] == 'gan_loss': kwargs = { 'sharp_discriminator_out': sharp_discriminator_out, 'deblurred_discriminator_out': deblurred_discriminator_out } gan_loss_d = self.adversarial_loss('D', **kwargs) discriminator_loss_per_update = gan_loss_d self.writer.add_scalar('gan_loss_d', gan_loss_d.item()) else: # add other loss if you like raise NotImplementedError discriminator_loss_per_update.backward(retain_graph=True) self.discriminator_optimizer.step() discriminator_loss += discriminator_loss_per_update.item() discriminator_loss /= critic_updates self.writer.add_scalar('discriminator_loss', discriminator_loss) total_discriminator_loss += discriminator_loss # train generator self.generator_optimizer.zero_grad() content_loss_lambda = self.config['others']['content_loss_lambda'] kwargs = { 'deblurred_discriminator_out': deblurred_discriminator_out } adversarial_loss_g = self.adversarial_loss('G', **kwargs) content_loss_g = self.content_loss(deblurred, sharp) * content_loss_lambda generator_loss = adversarial_loss_g + content_loss_g self.writer.add_scalar('adversarial_loss_g', adversarial_loss_g.item()) self.writer.add_scalar('content_loss_g', content_loss_g.item()) self.writer.add_scalar('generator_loss', generator_loss.item()) generator_loss.backward() self.generator_optimizer.step() total_generator_loss += generator_loss.item() # calculate the metrics total_metrics += self._eval_metrics(denormalized_deblurred, denormalized_sharp) if self.verbosity >= 2 and batch_idx % self.log_step == 0: self.logger.info( 'Train Epoch: {} [{}/{} ({:.0f}%)] generator_loss: {:.6f} discriminator_loss: {:.6f}' .format( epoch, batch_idx * self.data_loader.batch_size, self.data_loader.n_samples, 100.0 * batch_idx / len(self.data_loader), generator_loss.item( ), # it's a tensor, so we call .item() method discriminator_loss # just a num )) log = { 'generator_loss': total_generator_loss / len(self.data_loader), 'discriminator_loss': total_discriminator_loss / len(self.data_loader), 'metrics': (total_metrics / len(self.data_loader)).tolist() } if self.do_validation: val_log = self._valid_epoch(epoch) log = {**log, **val_log} self.generator_lr_scheduler.step() self.discriminator_lr_scheduler.step() return log
def main(resume): # load checkpoint checkpoint = torch.load(resume) config = checkpoint['config'] # setup data_loader instances data_loader_class = getattr(module_data, config['data_loader']['type']) data_loader_config_args = { "data_dir": config['data_loader']['args']['data_dir'], 'batch_size': 16, # use large batch_size 'shuffle': False, # do not shuffle 'validation_split': 0.0, # do not split, just use the full dataset 'num_workers': 16 # use large num_workers } data_loader = data_loader_class(**data_loader_config_args) # build model architecture generator_class = getattr(module_arch, config['generator']['type']) generator = generator_class(**config['generator']['args']) discriminator_class = getattr(module_arch, config['discriminator']['type']) discriminator = discriminator_class(**config['discriminator']['args']) # get function handles of loss and metrics loss_fn = {k: getattr(module_loss, v) for k, v in config['loss'].items()} metric_fns = [getattr(module_metric, met) for met in config['metrics']] # prepare model for testing device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') generator = generator.to(device) discriminator = discriminator.to(device) if config['n_gpu'] > 1: generator = torch.nn.DataParallel(generator) discriminator = torch.nn.DataParallel(discriminator) generator.load_state_dict(checkpoint['generator']) discriminator.load_state_dict(checkpoint['discriminator']) generator.eval() discriminator.eval() total_loss = 0.0 total_metrics = np.zeros(len(metric_fns)) with torch.no_grad(): for batch_idx, sample in enumerate(tqdm(data_loader, ascii=True)): blurred = sample['blurred'].to(device) sharp = sample['sharp'].to(device) deblurred = generator(blurred) deblurred_discriminator_out = discriminator(deblurred) denormalized_deblurred = denormalize(deblurred) denormalized_sharp = denormalize(sharp) # computing loss, metrics on test set content_loss_lambda = config['others']['content_loss_lambda'] adversarial_loss_fn = loss_fn['adversarial'] content_loss_fn = loss_fn['content'] kwargs = { 'deblurred_discriminator_out': deblurred_discriminator_out } loss = adversarial_loss_fn('G', **kwargs) + content_loss_fn( deblurred, sharp) * content_loss_lambda total_loss += loss.item() for i, metric in enumerate(metric_fns): total_metrics[i] += metric(denormalized_deblurred, denormalized_sharp) n_samples = len(data_loader) log = {'loss': total_loss / n_samples} log.update({ met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns) }) print(log)