def _init_tensorboard(self): r"""Initialize the tensorboard.""" # Logging frequency: self.cfg.logging_iter self.meters = {} names = ['optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch'] for name in names: self.meters[name] = Meter(name) # Logging frequency: self.cfg.snapshot_save_iter names = ['FID_a', 'best_FID_a', 'FID_b', 'best_FID_b'] self.metric_meters = {} for name in names: self.metric_meters[name] = Meter(name) # Logging frequency: self.cfg.image_display_iter self.image_meter = Meter('images')
def _write_loss_meters(self): r"""Write all loss values to tensorboard.""" for update, losses in self.losses.items(): # update is 'gen_update' or 'dis_update'. assert update == 'gen_update' or update == 'dis_update' for loss_name, loss in losses.items(): full_loss_name = update + '/' + loss_name if full_loss_name not in self.meters.keys(): # Create a new meter if it doesn't exist. self.meters[full_loss_name] = Meter(full_loss_name) self.meters[full_loss_name].write(loss.item())
def _init_tensorboard(self): r"""Initialize the tensorboard. Different algorithms might require different performance metrics. Hence, custom tensorboard initialization might be necessary. """ # Logging frequency: self.cfg.logging_iter self.meters = {} names = [ 'optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch' ] for name in names: self.meters[name] = Meter(name) # Logging frequency: self.cfg.snapshot_save_iter names = ['FID', 'best_FID'] self.metric_meters = {} for name in names: self.metric_meters[name] = Meter(name) # Logging frequency: self.cfg.image_display_iter self.image_meter = Meter('images')
def _init_tensorboard(self): r"""Initialize the tensorboard. For the SPADE model, we will record regular and FID, which is the average FID. """ self.regular_fid_meter = Meter('FID/regular') if self.cfg.trainer.model_average: self.average_fid_meter = Meter('FID/average') self.image_meter = Meter('images') self.meters = {} names = [ 'optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch' ] for name in names: self.meters[name] = Meter(name)
class Trainer(BaseTrainer): r"""Initialize SPADE trainer. Args: cfg (Config): Global configuration. net_G (obj): Generator network. net_D (obj): Discriminator network. opt_G (obj): Optimizer for the generator network. opt_D (obj): Optimizer for the discriminator network. sch_G (obj): Scheduler for the generator optimizer. sch_D (obj): Scheduler for the discriminator optimizer. train_data_loader (obj): Train data loader. val_data_loader (obj): Validation data loader. """ def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader): super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader) if cfg.data.type == 'imaginaire.datasets.paired_videos': self.video_mode = True else: self.video_mode = False def _init_loss(self, cfg): r"""Initialize loss terms. Args: cfg (obj): Global configuration. """ self.criteria['GAN'] = GANLoss(cfg.trainer.gan_mode) self.weights['GAN'] = cfg.trainer.loss_weight.gan # Setup the perceptual loss. Note that perceptual loss can run in # fp16 mode for additional speed. We find that running on fp16 mode # leads to improve training speed while maintaining the same accuracy. if hasattr(cfg.trainer, 'perceptual_loss'): self.criteria['Perceptual'] = \ PerceptualLoss( cfg=cfg, network=cfg.trainer.perceptual_loss.mode, layers=cfg.trainer.perceptual_loss.layers, weights=cfg.trainer.perceptual_loss.weights) self.weights['Perceptual'] = cfg.trainer.loss_weight.perceptual # Setup the feature matching loss. self.criteria['FeatureMatching'] = FeatureMatchingLoss() self.weights['FeatureMatching'] = \ cfg.trainer.loss_weight.feature_matching # Setup the Gaussian KL divergence loss. self.criteria['GaussianKL'] = GaussianKLLoss() self.weights['GaussianKL'] = cfg.trainer.loss_weight.kl def _init_tensorboard(self): r"""Initialize the tensorboard. For the SPADE model, we will record regular and FID, which is the average FID. """ self.regular_fid_meter = Meter('FID/regular') if self.cfg.trainer.model_average: self.average_fid_meter = Meter('FID/average') self.image_meter = Meter('images') self.meters = {} names = [ 'optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch' ] for name in names: self.meters[name] = Meter(name) def _start_of_iteration(self, data, current_iteration): r"""Model specific custom start of iteration process. We will do two things. First, put all the data to GPU. Second, we will resize the input so that it becomes multiple of the factor for bug-free convolutional operations. This factor is given by the yaml file. E.g., base = getattr(self.net_G, 'base', 32) Args: data (dict): The current batch. current_iteration (int): The iteration number of the current batch. """ if len(data['label'].size()) == 5: label_image_raw = data['images'][:, 0:-1, :, :, :] label_image = label_image_raw.reshape([ label_image_raw.size(0), -1, label_image_raw.size(3), label_image_raw.size(4) ]) images = data['images'][:, -1, :, :, :] label_label = data['label'].reshape([ data['label'].size(0), -1, data['label'].size(3), data['label'].size(4) ]) label = torch.cat([label_label, label_image], 1) data['label'] = label data['images'] = images data = to_device(data, 'cuda') data = self._resize_data(data) return data def gen_forward(self, data): r"""Compute the loss for SPADE generator. Args: data (dict): Training data at the current iteration. """ net_G_output = self.net_G(data) net_D_output = self.net_D(data, net_G_output) self._time_before_loss() output_fake = self._get_outputs(net_D_output, real=False) self.gen_losses['GAN'] = \ self.criteria['GAN'](output_fake, True, dis_update=False) self.gen_losses['FeatureMatching'] = self.criteria['FeatureMatching']( net_D_output['fake_features'], net_D_output['real_features']) if self.net_G_module.use_style_encoder: self.gen_losses['GaussianKL'] = \ self.criteria['GaussianKL'](net_G_output['mu'], net_G_output['logvar']) else: self.gen_losses['GaussianKL'] = \ self.gen_losses['GAN'].new_tensor([0]) if hasattr(self.cfg.trainer, 'perceptual_loss'): self.gen_losses['Perceptual'] = self.criteria['Perceptual']( net_G_output['fake_images'], data['images']) total_loss = self.gen_losses['GAN'].new_tensor([0]) for key in self.criteria: total_loss += self.gen_losses[key] * self.weights[key] self.gen_losses['total'] = total_loss return total_loss def dis_forward(self, data): r"""Compute the loss for SPADE discriminator. Args: data (dict): Training data at the current iteration. """ with torch.no_grad(): net_G_output = self.net_G(data) net_G_output['fake_images'] = net_G_output['fake_images'].detach() net_D_output = self.net_D(data, net_G_output) self._time_before_loss() output_fake = self._get_outputs(net_D_output, real=False) output_real = self._get_outputs(net_D_output, real=True) fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True) true_loss = self.criteria['GAN'](output_real, True, dis_update=True) self.dis_losses['GAN/fake'] = fake_loss self.dis_losses['GAN/true'] = true_loss self.dis_losses['GAN'] = fake_loss + true_loss total_loss = self.dis_losses['GAN'] * self.weights['GAN'] self.dis_losses['total'] = total_loss return total_loss def _get_visualizations(self, data): r"""Compute visualization image. We will first recalculate the batch statistics for the moving average model. Args: data (dict): The current batch. """ self.recalculate_model_average_batch_norm_statistics( self.train_data_loader) with torch.no_grad(): label_lengths = self.train_data_loader.dataset.get_label_lengths() labels = split_labels(data['label'], label_lengths) # Get visualization of the segmentation mask. segmap = tensor2label(labels['seg_maps'], label_lengths['seg_maps'], output_normalized_tensor=True) segmap = torch.cat([x.unsqueeze(0) for x in segmap], 0) net_G_output = self.net_G(data, random_style=True) vis_images = [data['images'], segmap, net_G_output['fake_images']] if self.cfg.trainer.model_average: net_G_model_average_output = \ self.net_G.module.averaged_model(data, random_style=True) vis_images.append(net_G_model_average_output['fake_images']) return vis_images def recalculate_model_average_batch_norm_statistics(self, data_loader): r"""Update the statistics in the moving average model. Args: data_loader (pytorch data loader): Data loader for estimating the statistics. """ if not self.cfg.trainer.model_average: return model_average_iteration = \ self.cfg.trainer.model_average_batch_norm_estimation_iteration if model_average_iteration == 0: return with torch.no_grad(): # Accumulate bn stats.. self.net_G.module.averaged_model.train() # Reset running stats. self.net_G.module.averaged_model.apply(reset_batch_norm) for cal_it, cal_data in enumerate(data_loader): if cal_it >= model_average_iteration: print('Done with {} iterations of updating batch norm ' 'statistics'.format(model_average_iteration)) break # cal_data = to_device(cal_data, 'cuda') cal_data = self._start_of_iteration(cal_data, 0) # Averaging over all batches self.net_G.module.averaged_model.apply( calibrate_batch_norm_momentum) self.net_G.module.averaged_model(cal_data) def write_metrics(self): r"""If moving average model presents, we have two meters one for regular FID and one for average FID. If no moving average model, we just report average FID. """ if self.cfg.trainer.model_average: regular_fid, average_fid = self._compute_fid() self.regular_fid_meter.write(regular_fid) self.average_fid_meter.write(average_fid) meters = [self.regular_fid_meter, self.average_fid_meter] else: regular_fid = self._compute_fid() self.regular_fid_meter.write(regular_fid) meters = [self.regular_fid_meter] for meter in meters: meter.flush(self.current_iteration) def _compute_fid(self): r"""We will compute FID for the regular model using the eval mode. For the moving average model, we will use the eval mode. """ self.net_G.eval() net_G_for_evaluation = \ functools.partial(self.net_G, random_style=True) regular_fid_path = self._get_save_path('regular_fid', 'npy') preprocess = \ functools.partial(self._start_of_iteration, current_iteration=0) regular_fid_value = compute_fid(regular_fid_path, self.val_data_loader, net_G_for_evaluation, preprocess=preprocess) print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format( self.current_epoch, self.current_iteration, regular_fid_value)) if self.cfg.trainer.model_average: avg_net_G_for_evaluation = \ functools.partial(self.net_G.module.averaged_model, random_style=True) fid_path = self._get_save_path('average_fid', 'npy') fid_value = compute_fid(fid_path, self.val_data_loader, avg_net_G_for_evaluation, preprocess=preprocess) print('Epoch {:05}, Iteration {:09}, FID {}'.format( self.current_epoch, self.current_iteration, fid_value)) self.net_G.float() return regular_fid_value, fid_value else: self.net_G.float() return regular_fid_value def _resize_data(self, data): r"""Resize input label maps and images so that it can be properly generated by the generator. Args: data (dict): Input dictionary contains 'label' and 'image fields. """ base = getattr(self.net_G, 'base', 32) sy = math.floor(data['label'].size()[2] * 1.0 // base) * base sx = math.floor(data['label'].size()[3] * 1.0 // base) * base data['label'] = F.interpolate(data['label'], size=[sy, sx], mode='nearest') if 'images' in data.keys(): data['images'] = F.interpolate(data['images'], size=[sy, sx], mode='bicubic') return data
class BaseTrainer(object): r"""Base trainer. We expect that all trainers inherit this class. Args: cfg (obj): Global configuration. net_G (obj): Generator network. net_D (obj): Discriminator network. opt_G (obj): Optimizer for the generator network. opt_D (obj): Optimizer for the discriminator network. sch_G (obj): Scheduler for the generator optimizer. sch_D (obj): Scheduler for the discriminator optimizer. train_data_loader (obj): Train data loader. val_data_loader (obj): Validation data loader. """ def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader): super(BaseTrainer, self).__init__() print('Setup trainer.') # Initialize models and data loaders. self.cfg = cfg self.net_G = net_G if cfg.trainer.model_average: # Two wrappers (DDP + model average). self.net_G_module = self.net_G.module.module elif not cfg.trainer.distribute: self.net_G_module = self.net_G else: # One wrapper (DDP) self.net_G_module = self.net_G.module self.val_data_loader = val_data_loader self.is_inference = train_data_loader is None self.net_D = net_D self.opt_G = opt_G self.opt_D = opt_D self.sch_G = sch_G self.sch_D = sch_D self.train_data_loader = train_data_loader # Initialize loss functions. # All loss names have weights. Some have criterion modules. # Mapping from loss names to criterion modules. self.criteria = nn.ModuleDict() # Mapping from loss names to loss weights. self.weights = dict() self.losses = dict(gen_update=dict(), dis_update=dict()) self.gen_losses = self.losses['gen_update'] self.dis_losses = self.losses['dis_update'] self._init_loss(cfg) for loss_name, loss_weight in self.weights.items(): print("Loss {:<20} Weight {}".format(loss_name, loss_weight)) if loss_name in self.criteria.keys() and \ self.criteria[loss_name] is not None: self.criteria[loss_name].to('cuda') if self.is_inference: # The initialization steps below can be skipped during inference. return # Initialize logging attributes. self.current_iteration = 0 self.current_epoch = 0 self.start_iteration_time = None self.start_epoch_time = None self.elapsed_iteration_time = 0 self.time_iteration = -1 self.time_epoch = -1 self.best_fid = None if getattr(self.cfg, 'speed_benchmark', False): self.accu_gen_forw_iter_time = 0 self.accu_gen_loss_iter_time = 0 self.accu_gen_back_iter_time = 0 self.accu_gen_step_iter_time = 0 self.accu_gen_avg_iter_time = 0 self.accu_dis_forw_iter_time = 0 self.accu_dis_loss_iter_time = 0 self.accu_dis_back_iter_time = 0 self.accu_dis_step_iter_time = 0 # Initialize tensorboard and hparams. self._init_tensorboard() self._init_hparams() def _init_tensorboard(self): r"""Initialize the tensorboard. Different algorithms might require different performance metrics. Hence, custom tensorboard initialization might be necessary. """ # Logging frequency: self.cfg.logging_iter self.meters = {} names = [ 'optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch' ] for name in names: self.meters[name] = Meter(name) # Logging frequency: self.cfg.snapshot_save_iter names = ['FID', 'best_FID'] self.metric_meters = {} for name in names: self.metric_meters[name] = Meter(name) # Logging frequency: self.cfg.image_display_iter self.image_meter = Meter('images') def _init_hparams(self): r"""Initialize a dictionary of hyperparameters that we want to monitor in the HParams dashboard in tensorBoard. """ self.hparam_dict = {} def _write_tensorboard(self): r"""Write values to tensorboard. By default, we will log the time used per iteration, time used per epoch, generator learning rate, and discriminator learning rate. We will log all the losses as well as custom meters. """ # Logs that are shared by all models. # self._write_to_meters({'time/iteration': self.time_iteration, # 'time/epoch': self.time_epoch, # 'optim/gen_lr': self.sch_G.get_last_lr()[0], # 'optim/dis_lr': self.sch_D.get_last_lr()[0]}, # self.meters) self._write_to_meters( { 'time/iteration': self.time_iteration, 'time/epoch': self.time_epoch, 'optim/gen_lr': self.sch_G.get_lr()[0], 'optim/dis_lr': self.sch_D.get_lr()[0] }, self.meters) print("self.sch_G.get_lr(): {}".format(self.sch_G.get_lr())) # Logs for loss values. Different models have different losses. self._write_loss_meters() # Other custom logs. self._write_custom_meters() # Write all logs to tensorboard. self._flush_meters(self.meters) def _write_loss_meters(self): r"""Write all loss values to tensorboard.""" for update, losses in self.losses.items(): # update is 'gen_update' or 'dis_update'. assert update == 'gen_update' or update == 'dis_update' for loss_name, loss in losses.items(): full_loss_name = update + '/' + loss_name if full_loss_name not in self.meters.keys(): # Create a new meter if it doesn't exist. self.meters[full_loss_name] = Meter(full_loss_name) self.meters[full_loss_name].write(loss.item()) def _write_custom_meters(self): r"""Dummy member function to be overloaded by the child class. In the child class, you can write down whatever you want to track. """ pass @staticmethod def _write_to_meters(data, meters): r"""Write values to meters.""" for key, value in data.items(): meters[key].write(value) def _flush_meters(self, meters): r"""Flush all meters using the current iteration.""" for meter in meters.values(): meter.flush(self.current_iteration) def _pre_save_checkpoint(self): r"""Implement the things you want to do before saving a checkpoint. For example, you can compute the K-mean features (pix2pixHD) before saving the model weights to a checkpoint. """ pass def save_checkpoint(self, current_epoch, current_iteration): r"""Save network weights, optimizer parameters, scheduler parameters to a checkpoint. """ self._pre_save_checkpoint() _save_checkpoint(self.cfg, self.net_G, self.net_D, self.opt_G, self.opt_D, self.sch_G, self.sch_D, current_epoch, current_iteration) def load_checkpoint(self, cfg, checkpoint_path, resume=None): r"""Load network weights, optimizer parameters, scheduler parameters from a checkpoint. Args: cfg (obj): Global configuration. checkpoint_path (str): Path to the checkpoint. resume (bool or None): If not ``None``, will determine whether or not to load optimizers in addition to network weights. """ if os.path.exists(checkpoint_path): # If checkpoint_path exists, we will load its weights to # initialize our network. if resume is None: resume = False elif os.path.exists(os.path.join(cfg.logdir, 'latest_checkpoint.txt')): # This is for resuming the training from the previously saved # checkpoint. fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt') with open(fn, 'r') as f: line = f.read().splitlines() checkpoint_path = os.path.join(cfg.logdir, line[0].split(' ')[-1]) if resume is None: resume = True else: # checkpoint not found and not specified. We will train # everything from scratch. current_epoch = 0 current_iteration = 0 print('No checkpoint found.') return current_epoch, current_iteration # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) current_epoch = 0 current_iteration = 0 if resume: self.net_G.load_state_dict(checkpoint['net_G']) if not self.is_inference: self.net_D.load_state_dict(checkpoint['net_D']) if 'opt_G' in checkpoint: self.opt_G.load_state_dict(checkpoint['opt_G']) self.opt_D.load_state_dict(checkpoint['opt_D']) self.sch_G.load_state_dict(checkpoint['sch_G']) self.sch_D.load_state_dict(checkpoint['sch_D']) current_epoch = checkpoint['current_epoch'] current_iteration = checkpoint['current_iteration'] print('Load from: {}'.format(checkpoint_path)) else: print('Load network weights only.') else: self.net_G.load_state_dict(checkpoint['net_G']) print('Load generator weights only.') print('Done with loading the checkpoint.') return current_epoch, current_iteration def start_of_epoch(self, current_epoch): r"""Things to do before an epoch. Args: current_epoch (int): Current number of epoch. """ self._start_of_epoch(current_epoch) self.current_epoch = current_epoch self.start_epoch_time = time.time() def start_of_iteration(self, data, current_iteration): r"""Things to do before an iteration. Args: data (dict): Data used for the current iteration. current_iteration (int): Current number of iteration. """ data = self._start_of_iteration(data, current_iteration) data = to_cuda(data) self.current_iteration = current_iteration if not self.is_inference: self.net_D.train() self.net_G.train() # torch.cuda.synchronize() self.start_iteration_time = time.time() return data def end_of_iteration(self, data, current_epoch, current_iteration): r"""Things to do after an iteration. Args: data (dict): Data used for the current iteration. current_epoch (int): Current number of epoch. current_iteration (int): Current number of iteration. """ self.current_iteration = current_iteration self.current_epoch = current_epoch # Update the learning rate policy for the generator if operating in the # iteration mode. if self.cfg.gen_opt.lr_policy.iteration_mode: self.sch_G.step() # Update the learning rate policy for the discriminator if operating # in the iteration mode. if self.cfg.dis_opt.lr_policy.iteration_mode: self.sch_D.step() # Accumulate time # torch.cuda.synchronize() self.elapsed_iteration_time += time.time() - self.start_iteration_time # Logging. if current_iteration % self.cfg.logging_iter == 0: ave_t = self.elapsed_iteration_time / self.cfg.logging_iter self.time_iteration = ave_t print('Iteration: {}, average iter time: ' '{:6f}.'.format(current_iteration, ave_t)) self.elapsed_iteration_time = 0 if getattr(self.cfg, 'speed_benchmark', False): # Below code block only needed when analyzing computation # bottleneck. print('\tGenerator FWD time {:6f}'.format( self.accu_gen_forw_iter_time / self.cfg.logging_iter)) print('\tGenerator LOS time {:6f}'.format( self.accu_gen_loss_iter_time / self.cfg.logging_iter)) print('\tGenerator BCK time {:6f}'.format( self.accu_gen_back_iter_time / self.cfg.logging_iter)) print('\tGenerator STP time {:6f}'.format( self.accu_gen_step_iter_time / self.cfg.logging_iter)) print('\tGenerator AVG time {:6f}'.format( self.accu_gen_avg_iter_time / self.cfg.logging_iter)) print('\tDiscriminator FWD time {:6f}'.format( self.accu_dis_forw_iter_time / self.cfg.logging_iter)) print('\tDiscriminator LOS time {:6f}'.format( self.accu_dis_loss_iter_time / self.cfg.logging_iter)) print('\tDiscriminator BCK time {:6f}'.format( self.accu_dis_back_iter_time / self.cfg.logging_iter)) print('\tDiscriminator STP time {:6f}'.format( self.accu_dis_step_iter_time / self.cfg.logging_iter)) print('{:6f}'.format(ave_t)) self.accu_gen_forw_iter_time = 0 self.accu_gen_loss_iter_time = 0 self.accu_gen_back_iter_time = 0 self.accu_gen_step_iter_time = 0 self.accu_gen_avg_iter_time = 0 self.accu_dis_forw_iter_time = 0 self.accu_dis_loss_iter_time = 0 self.accu_dis_back_iter_time = 0 self.accu_dis_step_iter_time = 0 self._end_of_iteration(data, current_epoch, current_iteration) # Save everything to the checkpoint. if current_iteration >= self.cfg.snapshot_save_start_iter and \ current_iteration % self.cfg.snapshot_save_iter == 0: self.save_image(self._get_save_path('images', 'jpg'), data) self.save_checkpoint(current_epoch, current_iteration) self.write_metrics() # Compute image to be saved. elif current_iteration % self.cfg.image_save_iter == 0: self.save_image(self._get_save_path('images', 'jpg'), data) elif current_iteration % self.cfg.image_display_iter == 0: image_path = os.path.join(self.cfg.logdir, 'images', 'current.jpg') self.save_image(image_path, data) if current_iteration % self.cfg.logging_iter == 0: self._write_tensorboard() print("gen loss: {}".format(self.gen_losses)) print("dis loss: {}".format(self.dis_losses)) def end_of_epoch(self, data, current_epoch, current_iteration): r"""Things to do after an epoch. Args: data (dict): Data used for the current iteration. current_epoch (int): Current number of epoch. current_iteration (int): Current number of iteration. """ # Update the learning rate policy for the generator if operating in the # epoch mode. self.current_iteration = current_iteration self.current_epoch = current_epoch if not self.cfg.gen_opt.lr_policy.iteration_mode: self.sch_G.step() # Update the learning rate policy for the discriminator if operating # in the epoch mode. if not self.cfg.dis_opt.lr_policy.iteration_mode: self.sch_D.step() elapsed_epoch_time = time.time() - self.start_epoch_time # Logging. print('Epoch: {}, total time: {:6f}.'.format(current_epoch, elapsed_epoch_time)) self.time_epoch = elapsed_epoch_time self._end_of_epoch(data, current_epoch, current_iteration) # Save everything to the checkpoint. if current_epoch >= self.cfg.snapshot_save_start_epoch and \ current_epoch % self.cfg.snapshot_save_epoch == 0: self.save_image(self._get_save_path('images', 'jpg'), data) self.save_checkpoint(current_epoch, current_iteration) self.write_metrics() def pre_process(self, data): r"""Custom data pre-processing function. Utilize this function if you need to preprocess your data before sending it to the generator and discriminator. Args: data (dict): Data used for the current iteration. """ def recalculate_model_average_batch_norm_statistics(self, data_loader): r"""Update the statistics in the moving average model. Args: data_loader (torch.utils.data.DataLoader): Data loader for estimating the statistics. """ if not self.cfg.trainer.model_average: return model_average_iteration = \ self.cfg.trainer.model_average_batch_norm_estimation_iteration if model_average_iteration == 0: return with torch.no_grad(): # Accumulate bn stats.. self.net_G.module.averaged_model.train() # Reset running stats. self.net_G.module.averaged_model.apply(reset_batch_norm) for cal_it, cal_data in enumerate(data_loader): if cal_it >= model_average_iteration: print('Done with {} iterations of updating batch norm ' 'statistics'.format(model_average_iteration)) break cal_data = to_device(cal_data, 'cuda') # Averaging over all batches self.net_G.module.averaged_model.apply( calibrate_batch_norm_momentum) self.net_G.module.averaged_model(cal_data) def save_image(self, path, data): r"""Compute visualization images and save them to the disk. Args: path (str): Location of the file. data (dict): Data used for the current iteration. """ self.net_G.eval() vis_images = self._get_visualizations(data) if is_master() and vis_images is not None: vis_images = torch.cat(vis_images, dim=3).float() vis_images = (vis_images + 1) / 2 print('Save output images to {}'.format(path)) vis_images.clamp_(0, 1) os.makedirs(os.path.dirname(path), exist_ok=True) image_grid = torchvision.utils.make_grid(vis_images, nrow=1, padding=0, normalize=False) if self.cfg.trainer.image_to_tensorboard: self.image_meter.write_image(image_grid, self.current_iteration) torchvision.utils.save_image(image_grid, path, nrow=1) def write_metrics(self): r"""Write metrics to the tensorboard.""" cur_fid = self._compute_fid() if cur_fid is not None: if self.best_fid is not None: self.best_fid = min(self.best_fid, cur_fid) else: self.best_fid = cur_fid metric_dict = {'FID': cur_fid, 'best_FID': self.best_fid} self._write_to_meters(metric_dict, self.metric_meters) self._flush_meters(self.metric_meters) if self.cfg.trainer.hparam_to_tensorboard: add_hparams(self.hparam_dict, metric_dict) def _get_save_path(self, subdir, ext): r"""Get the image save path. Args: subdir (str): Sub-directory under the main directory for saving the outputs. ext (str): Filename extension for the image (e.g., jpg, png, ...). Return: (str): image filename to be used to save the visualization results. """ subdir_path = os.path.join(self.cfg.logdir, subdir) if not os.path.exists(subdir_path): os.makedirs(subdir_path, exist_ok=True) return os.path.join( subdir_path, 'epoch_{:05}_iteration_{:09}.{}'.format(self.current_epoch, self.current_iteration, ext)) def _get_outputs(self, net_D_output, real=True): r"""Return output values. Note that when the gan mode is relativistic. It will do the difference before returning. Args: net_D_output (dict): real_outputs (tensor): Real output values. fake_outputs (tensor): Fake output values. real (bool): Return real or fake. """ def _get_difference(a, b): r"""Get difference between two lists of tensors or two tensors. Args: a: list of tensors or tensor b: list of tensors or tensor """ out = list() for x, y in zip(a, b): if isinstance(x, list): res = _get_difference(x, y) else: res = x - y out.append(res) return out if real: if self.cfg.trainer.gan_relativistic: return _get_difference(net_D_output['real_outputs'], net_D_output['fake_outputs']) else: return net_D_output['real_outputs'] else: if self.cfg.trainer.gan_relativistic: return _get_difference(net_D_output['fake_outputs'], net_D_output['real_outputs']) else: return net_D_output['fake_outputs'] def _start_of_epoch(self, current_epoch): r"""Operations to do before starting an epoch. Args: current_epoch (int): Current number of epoch. """ pass def _start_of_iteration(self, data, current_iteration): r"""Operations to do before starting an iteration. Args: data (dict): Data used for the current iteration. current_iteration (int): Current epoch number. Returns: (dict): Data used for the current iteration. They might be processed by the custom _start_of_iteration function. """ return data def _end_of_iteration(self, data, current_epoch, current_iteration): r"""Operations to do after an iteration. Args: data (dict): Data used for the current iteration. current_epoch (int): Current number of epoch. current_iteration (int): Current epoch number. """ pass def _end_of_epoch(self, data, current_epoch, current_iteration): r"""Operations to do after an epoch. Args: data (dict): Data used for the current iteration. current_epoch (int): Current number of epoch. current_iteration (int): Current epoch number. """ pass def _get_visualizations(self, data): r"""Compute visualization outputs. Args: data (dict): Data used for the current iteration. """ return None def _compute_fid(self): r"""FID computation function to be overloaded.""" return None def _init_loss(self, cfg): r"""Every trainer should implement its own init loss function.""" raise NotImplementedError def gen_update(self, data): r"""Update the generator. Args: data (dict): Data used for the current iteration. """ self.opt_G.zero_grad() # Set requires_grad flags. requires_grad(self.net_G_module, True) requires_grad(self.net_D, False) # Compute the loss. self._time_before_forward() total_loss = self.gen_forward(data) if total_loss is None: return # Backpropagate the loss. self._time_before_backward() with amp.scale_loss(total_loss, self.opt_G, loss_id=0) as scaled_loss: scaled_loss.backward() # Optionally clip gradient norm. if hasattr(self.cfg.gen_opt, 'clip_grad_norm'): nn.utils.clip_grad_norm_(amp.master_params(self.opt_G), self.cfg.gen_opt.clip_grad_norm) # Perform an optimizer step. self._time_before_step() self.opt_G.step() # Update model average. self._time_before_model_avg() if self.cfg.trainer.model_average: self.net_G.module.update_average() self._detach_losses() self._time_before_leave_gen() def gen_forward(self, data): r"""Every trainer should implement its own generator forward.""" raise NotImplementedError def dis_update(self, data): r"""Update the discriminator. Args: data (dict): Data used for the current iteration. """ self.opt_D.zero_grad() # Set requires_grad flags. requires_grad(self.net_G_module, False) requires_grad(self.net_D, True) # Compute the loss. self._time_before_forward() total_loss = self.dis_forward(data) if total_loss is None: return # Backpropagate the loss. self._time_before_backward() with amp.scale_loss(total_loss, self.opt_D, loss_id=1) as scaled_loss: scaled_loss.backward() # Perform an optimizer step. self._time_before_step() self.opt_D.step() self._detach_losses() self._time_before_leave_dis() def dis_forward(self, data): r"""Every trainer should implement its own discriminator forward.""" raise NotImplementedError def test(self, data_loader, output_dir, inference_args): r"""Compute results images for a batch of input data and save the results in the specified folder. Args: data_loader (torch.utils.data.DataLoader): PyTorch dataloader. output_dir (str): Target location for saving the output image. """ if self.cfg.trainer.model_average: net_G = self.net_G.module.averaged_model else: net_G = self.net_G.module net_G.eval() print('# of samples %d' % len(data_loader)) for it, data in enumerate(tqdm(data_loader)): data = self.start_of_iteration(data, current_iteration=-1) with torch.no_grad(): output_images, file_names = \ net_G.inference(data, **vars(inference_args)) for output_image, file_name in zip(output_images, file_names): fullname = os.path.join(output_dir, file_name + '.jpg') output_image = tensor2pilimage(output_image.clamp_(-1, 1), minus1to1_normalized=True) save_pilimage_in_jpeg(fullname, output_image) def _get_total_loss(self, gen_forward): r"""Return the total loss to be backpropagated. Args: gen_forward (bool): If ``True``, backpropagates the generator loss, otherwise the discriminator loss. """ losses = self.gen_losses if gen_forward else self.dis_losses total_loss = torch.tensor(0., device=torch.device('cuda')) # Iterates over all possible losses. for loss_name in self.weights: # If it is for the current model (gen/dis). if loss_name in losses: # Multiply it with the corresponding weight # and add it to the total loss. total_loss += losses[loss_name] * self.weights[loss_name] losses['total'] = total_loss # logging purpose return total_loss def _detach_losses(self): r"""Detach all logging variables to prevent potential memory leak.""" for loss_name in self.gen_losses: self.gen_losses[loss_name] = self.gen_losses[loss_name].detach() for loss_name in self.dis_losses: self.dis_losses[loss_name] = self.dis_losses[loss_name].detach() def _time_before_forward(self): r""" Record time before applying forward. """ if getattr(self.cfg, 'speed_benchmark', False): torch.cuda.synchronize() self.forw_time = time.time() def _time_before_loss(self): r""" Record time before computing loss. """ if getattr(self.cfg, 'speed_benchmark', False): torch.cuda.synchronize() self.loss_time = time.time() def _time_before_backward(self): r""" Record time before applying backward. """ if getattr(self.cfg, 'speed_benchmark', False): torch.cuda.synchronize() self.back_time = time.time() def _time_before_step(self): r""" Record time before updating the weights """ if getattr(self.cfg, 'speed_benchmark', False): torch.cuda.synchronize() self.step_time = time.time() def _time_before_model_avg(self): r""" Record time before applying model average. """ if getattr(self.cfg, 'speed_benchmark', False): torch.cuda.synchronize() self.avg_time = time.time() def _time_before_leave_gen(self): r""" Record forward, backward, loss, and model average time for the generator update. """ if getattr(self.cfg, 'speed_benchmark', False): torch.cuda.synchronize() end_time = time.time() self.accu_gen_forw_iter_time += self.loss_time - self.forw_time self.accu_gen_loss_iter_time += self.back_time - self.loss_time self.accu_gen_back_iter_time += self.step_time - self.back_time self.accu_gen_step_iter_time += self.avg_time - self.step_time self.accu_gen_avg_iter_time += end_time - self.avg_time def _time_before_leave_dis(self): r""" Record forward, backward, loss time for the discriminator update. """ if getattr(self.cfg, 'speed_benchmark', False): torch.cuda.synchronize() end_time = time.time() self.accu_dis_forw_iter_time += self.loss_time - self.forw_time self.accu_dis_loss_iter_time += self.back_time - self.loss_time self.accu_dis_back_iter_time += self.step_time - self.back_time self.accu_dis_step_iter_time += end_time - self.step_time
class Trainer(BaseTrainer): r"""Initialize vid2vid trainer. Args: cfg (obj): Global configuration. net_G (obj): Generator network. net_D (obj): Discriminator network. opt_G (obj): Optimizer for the generator network. opt_D (obj): Optimizer for the discriminator network. sch_G (obj): Scheduler for the generator optimizer. sch_D (obj): Scheduler for the discriminator optimizer. train_data_loader (obj): Train data loader. val_data_loader (obj): Validation data loader. """ def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader): super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader) # Below is for testing setting, the FID computation during training # is just for getting a quick idea of the performance. It does not # equal to the final performance evaluation. # Below, we will determine how many videos that we want to do # evaluation, and the length of each video. # It is better to keep the number of videos to be multiple of 8 so # that all the GPUs in a node will contribute equally to the # evaluation. None of them is idol. self.sample_size = (getattr(cfg.trainer, 'num_videos_to_test', 64), getattr(cfg.trainer, 'num_frames_per_video', 10)) self.sequence_length = 1 if not self.is_inference: self.train_dataset = self.train_data_loader.dataset self.sequence_length_max = \ min(getattr(cfg.data.train, 'max_sequence_length', 100), self.train_dataset.sequence_length_max) self.Tensor = torch.cuda.FloatTensor self.has_fg = getattr(cfg.data, 'has_foreground', False) self.net_G_output = self.data_prev = None self.net_G_module = self.net_G.module if self.cfg.trainer.model_average: self.net_G_module = self.net_G_module.module def _assign_criteria(self, name, criterion, weight): r"""Assign training loss terms. Args: name (str): Loss name criterion (obj): Loss object. weight (float): Loss weight. It should be non-negative. """ self.criteria[name] = criterion self.weights[name] = weight def _init_loss(self, cfg): r"""Initialize training loss terms. In vid2vid, in addition to the GAN loss, feature matching loss, and perceptual loss used in pix2pixHD, we also add temporal GAN (and feature matching) loss, and flow warping loss. Optionally, we can also add an additional face discriminator for the face region. Args: cfg (obj): Global configuration. """ self.criteria = dict() self.weights = dict() trainer_cfg = cfg.trainer loss_weight = cfg.trainer.loss_weight # GAN loss and feature matching loss. self._assign_criteria('GAN', GANLoss(trainer_cfg.gan_mode), loss_weight.gan) self._assign_criteria('FeatureMatching', FeatureMatchingLoss(), loss_weight.feature_matching) # Perceptual loss. perceptual_loss = cfg.trainer.perceptual_loss self._assign_criteria( 'Perceptual', PerceptualLoss(cfg=cfg, network=perceptual_loss.mode, layers=perceptual_loss.layers, weights=perceptual_loss.weights, num_scales=getattr(perceptual_loss, 'num_scales', 1)), loss_weight.perceptual) # L1 Loss. if getattr(loss_weight, 'L1', 0) > 0: self._assign_criteria('L1', torch.nn.L1Loss(), loss_weight.L1) # Whether to add an additional discriminator for specific regions. self.add_dis_cfg = getattr(self.cfg.dis, 'additional_discriminators', None) if self.add_dis_cfg is not None: for name in self.add_dis_cfg: add_dis_cfg = self.add_dis_cfg[name] self.weights['GAN_' + name] = add_dis_cfg.loss_weight self.weights['FeatureMatching_' + name] = \ loss_weight.feature_matching # Temporal GAN loss. self.num_temporal_scales = get_nested_attr(self.cfg.dis, 'temporal.num_scales', 0) for s in range(self.num_temporal_scales): self.weights['GAN_T%d' % s] = loss_weight.temporal_gan self.weights['FeatureMatching_T%d' % s] = \ loss_weight.feature_matching # Flow loss. It consists of three parts: L1 loss compared to GT, # warping loss when used to warp images, and loss on the occlusion mask. self.use_flow = hasattr(cfg.gen, 'flow') if self.use_flow: self.criteria['Flow'] = FlowLoss(cfg) self.weights['Flow'] = self.weights['Flow_L1'] = \ self.weights['Flow_Warp'] = \ self.weights['Flow_Mask'] = loss_weight.flow # Other custom losses. self._define_custom_losses() def _define_custom_losses(self): r"""All other custom losses are defined here.""" pass def _start_of_epoch(self, current_epoch): r"""Things to do before an epoch. When current_epoch is smaller than $(single_frame_epoch), we only train a single frame and the generator is just an image generator. After that, we start doing temporal training and train multiple frames. We will double the number of training frames every $(num_epochs_temporal_step) epochs. Args: current_epoch (int): Current number of epoch. """ cfg = self.cfg # Only generates one frame at the beginning of training if current_epoch < cfg.single_frame_epoch: self.train_dataset.sequence_length = 1 # Then add the temporal network to generator, and train multiple frames. elif current_epoch == cfg.single_frame_epoch: self.init_temporal_network() # Double the length of training sequence every few epochs. temp_epoch = current_epoch - cfg.single_frame_epoch if temp_epoch > 0: sequence_length = \ cfg.data.train.initial_sequence_length * \ (2 ** (temp_epoch // cfg.num_epochs_temporal_step)) sequence_length = min(sequence_length, self.sequence_length_max) if sequence_length > self.sequence_length: self.sequence_length = sequence_length self.train_dataset.set_sequence_length(sequence_length) print('------- Updating sequence length to %d -------' % sequence_length) def init_temporal_network(self): r"""Initialize temporal training when beginning to train multiple frames. Set the sequence length to $(initial_sequence_length). """ self.tensorboard_init = False # Update training sequence length. self.sequence_length = self.cfg.data.train.initial_sequence_length if not self.is_inference: self.train_dataset.set_sequence_length(self.sequence_length) print('------ Now start training %d frames -------' % self.sequence_length) def _start_of_iteration(self, data, current_iteration): r"""Things to do before an iteration. Args: data (dict): Data used for the current iteration. current_iteration (int): Current number of iteration. """ data = self.pre_process(data) return to_cuda(data) def pre_process(self, data): r"""Do any data pre-processing here. Args: data (dict): Data used for the current iteration. """ data_cfg = self.cfg.data if hasattr(data_cfg, 'for_pose_dataset') and \ ('pose_maps-densepose' in data_cfg.input_labels): pose_cfg = data_cfg.for_pose_dataset data['label'] = pre_process_densepose(pose_cfg, data['label'], self.is_inference) return data def post_process(self, data, net_G_output): r"""Do any postprocessing of the data / output here. Args: data (dict): Training data at the current iteration. net_G_output (dict): Output of the generator. """ return data, net_G_output def gen_update(self, data): r"""Update the vid2vid generator. We update in the fashion of dis_update (frame 1), gen_update (frame 1), dis_update (frame 2), gen_update (frame 2), ... in each iteration. Args: data (dict): Training data at the current iteration. """ # Whether to reuse generator output for both gen_update and dis_update. # It saves time but consumes a bit more memory. reuse_gen_output = getattr(self.cfg.trainer, 'reuse_gen_output', True) past_frames = [None, None] net_G_output = None data_prev = None for t in range(self.sequence_length): # print(self.sequence_length) data_t = self.get_data_t(data, net_G_output, data_prev, t) data_prev = data_t # Discriminator update. if reuse_gen_output: net_G_output = self.net_G(data_t) else: with torch.no_grad(): net_G_output = self.net_G(data_t) data_t, net_G_output = self.post_process(data_t, net_G_output) # Get losses and update D if image generated by network in training. if 'fake_images_source' not in net_G_output: net_G_output['fake_images_source'] = 'in_training' if net_G_output['fake_images_source'] != 'pretrained': net_D_output, _ = self.net_D(data_t, detach(net_G_output), past_frames) self.get_dis_losses(net_D_output) # Generator update. if not reuse_gen_output: net_G_output = self.net_G(data_t) data_t, net_G_output = self.post_process(data_t, net_G_output) # Get losses and update G if image generated by network in training. if 'fake_images_source' not in net_G_output: net_G_output['fake_images_source'] = 'in_training' if net_G_output['fake_images_source'] != 'pretrained': net_D_output, past_frames = \ self.net_D(data_t, net_G_output, past_frames) self.get_gen_losses(data_t, net_G_output, net_D_output) # update average if self.cfg.trainer.model_average: self.net_G.module.update_average() def dis_update(self, data): r"""The update is already done in gen_update. Args: data (dict): Training data at the current iteration. """ pass def reset(self): r"""Reset the trainer (for inference) at the beginning of a sequence. """ # print('Resetting trainer.') self.net_G_output = self.data_prev = None self.t = 0 self.test_in_model_average_mode = getattr( self, 'test_in_model_average_mode', self.cfg.trainer.model_average) if self.test_in_model_average_mode: net_G_module = self.net_G.module.averaged_model else: net_G_module = self.net_G.module if hasattr(net_G_module, 'reset'): net_G_module.reset() def create_sequence_output_dir(self, output_dir, key): r"""Create output subdir for this sequence. Args: output_dir (str): Root output dir. key (str): LMDB key which contains sequence name and file name. Returns: output_dir (str): Output subdir for this sequence. seq_name (str): Name of this sequence. """ seq_dir = '/'.join(key.split('/')[:-1]) output_dir = os.path.join(output_dir, seq_dir) os.makedirs(output_dir, exist_ok=True) seq_name = seq_dir.replace('/', '-') return output_dir, seq_name def test(self, test_data_loader, root_output_dir, inference_args): r"""Run inference on all sequences. Args: test_data_loader (object): Test data loader. root_output_dir (str): Location to dump outputs. inference_args (optional): Optional args. """ # Go over all sequences. loader = test_data_loader num_inference_sequences = loader.dataset.num_inference_sequences() for sequence_idx in range(num_inference_sequences): loader.dataset.set_inference_sequence_idx(sequence_idx) print('Seq id: %d, Seq length: %d' % (sequence_idx + 1, len(loader))) # Reset model at start of new inference sequence. self.reset() self.sequence_length = len(loader) # Go over all frames of this sequence. video = [] for idx, data in enumerate(tqdm(loader)): key = data['key']['images'][0][0] filename = key.split('/')[-1] # Create output dir for this sequence. if idx == 0: output_dir, seq_name = \ self.create_sequence_output_dir(root_output_dir, key) video_path = os.path.join(output_dir, '..', seq_name) # Get output and save images. data['img_name'] = filename data = self.start_of_iteration(data, current_iteration=-1) output = self.test_single(data, output_dir, inference_args) video.append(output) # Save output as mp4. imageio.mimsave(video_path + '.mp4', video, fps=15) def test_single(self, data, output_dir=None, inference_args=None): r"""The inference function. If output_dir exists, also save the output image. Args: data (dict): Training data at the current iteration. output_dir (str): Save image directory. inference_args (obj): Inference args. """ if getattr(inference_args, 'finetune', False): if not getattr(self, 'has_finetuned', False): self.finetune(data, inference_args) net_G = self.net_G if self.test_in_model_average_mode: net_G = net_G.module.averaged_model net_G.eval() data_t = self.get_data_t(data, self.net_G_output, self.data_prev, 0) if self.is_inference or self.sequence_length > 1: self.data_prev = data_t # Generator forward. with torch.no_grad(): self.net_G_output = net_G(data_t) if output_dir is None: return self.net_G_output save_fake_only = getattr(inference_args, 'save_fake_only', False) if save_fake_only: image_grid = tensor2im(self.net_G_output['fake_images'])[0] else: vis_images = self.get_test_output_images(data) image_grid = np.hstack( [np.vstack(im) for im in vis_images if im is not None]) if 'img_name' in data: save_name = data['img_name'].split('.')[0] + '.jpg' else: save_name = '%04d.jpg' % self.t output_filename = os.path.join(output_dir, save_name) os.makedirs(output_dir, exist_ok=True) imageio.imwrite(output_filename, image_grid) self.t += 1 return image_grid def get_test_output_images(self, data): r"""Get the visualization output of test function. Args: data (dict): Training data at the current iteration. """ vis_images = [ self.visualize_label(data['label'][:, -1]), tensor2im(data['images'][:, -1]), tensor2im(self.net_G_output['fake_images']), ] return vis_images def gen_frames(self, data, use_model_average=False): r"""Generate a sequence of frames given a sequence of data. Args: data (dict): Training data at the current iteration. use_model_average (bool): Whether to use model average for update or not. """ net_G_output = None # Previous generator output. data_prev = None # Previous data. if use_model_average: net_G = self.net_G.module.averaged_model else: net_G = self.net_G # Iterate through the length of sequence. all_info = {'inputs': [], 'outputs': []} for t in range(self.sequence_length): # Get the data at the current time frame. data_t = self.get_data_t(data, net_G_output, data_prev, t) data_prev = data_t # Generator forward. with torch.no_grad(): net_G_output = net_G(data_t) # Do any postprocessing if necessary. data_t, net_G_output = self.post_process(data_t, net_G_output) if t == 0: # Get the output at beginning of sequence for visualization. first_net_G_output = net_G_output all_info['inputs'].append(data_t) all_info['outputs'].append(net_G_output) return first_net_G_output, net_G_output, all_info def get_gen_losses(self, data_t, net_G_output, net_D_output): r"""Compute generator losses. Args: data_t (dict): Training data at the current time t. net_G_output (dict): Output of the generator. net_D_output (dict): Output of the discriminator. """ self.opt_G.zero_grad() # Individual frame GAN loss and feature matching loss. self.gen_losses['GAN'], self.gen_losses['FeatureMatching'] = \ self.compute_GAN_losses(net_D_output['indv'], dis_update=False) # Perceptual loss. self.gen_losses['Perceptual'] = self.criteria['Perceptual']( net_G_output['fake_images'], data_t['image']) # L1 loss. if getattr(self.cfg.trainer.loss_weight, 'L1', 0) > 0: self.gen_losses['L1'] = self.criteria['L1']( net_G_output['fake_images'], data_t['image']) # Raw (hallucinated) output image losses (GAN and perceptual). if 'raw' in net_D_output: raw_GAN_losses = self.compute_GAN_losses(net_D_output['raw'], dis_update=False) fg_mask = get_fg_mask(data_t['label'], self.has_fg) raw_perceptual_loss = self.criteria['Perceptual']( net_G_output['fake_raw_images'] * fg_mask, data_t['image'] * fg_mask) self.gen_losses['GAN'] += raw_GAN_losses[0] self.gen_losses['FeatureMatching'] += raw_GAN_losses[1] self.gen_losses['Perceptual'] += raw_perceptual_loss # Additional discriminator losses. if self.add_dis_cfg is not None: for name in self.add_dis_cfg: self.gen_losses['GAN_' + name], \ self.gen_losses['FeatureMatching_' + name] = \ self.compute_GAN_losses(net_D_output[name], dis_update=False) # Flow and mask loss. if self.use_flow: self.gen_losses['Flow_L1'], self.gen_losses['Flow_Warp'], \ self.gen_losses['Flow_Mask'] = self.criteria['Flow']( data_t, net_G_output, self.current_epoch) # Temporal GAN loss and feature matching loss. if self.cfg.trainer.loss_weight.temporal_gan > 0: if self.sequence_length > 1: for s in range(self.num_temporal_scales): loss_GAN, loss_FM = self.compute_GAN_losses( net_D_output['temporal_%d' % s], dis_update=False) self.gen_losses['GAN_T%d' % s] = loss_GAN self.gen_losses['FeatureMatching_T%d' % s] = loss_FM # Other custom losses. self._get_custom_gen_losses(data_t, net_G_output, net_D_output) # Sum all losses together. total_loss = self.Tensor(1).fill_(0) for key in self.gen_losses: if key != 'total': total_loss += self.gen_losses[key] * self.weights[key] self.gen_losses['total'] = total_loss with amp.scale_loss(total_loss, self.opt_G, loss_id=0) as scaled_loss: scaled_loss.backward() self.opt_G.step() def _get_custom_gen_losses(self, data_t, net_G_output, net_D_output): r"""All other custom generator losses go here. Args: data_t (dict): Training data at the current time t. net_G_output (dict): Output of the generator. net_D_output (dict): Output of the discriminator. """ pass def get_dis_losses(self, net_D_output): r"""Compute discriminator losses. Args: net_D_output (dict): Output of the discriminator. """ self.opt_D.zero_grad() # Individual frame GAN loss. self.dis_losses['GAN'] = self.compute_GAN_losses(net_D_output['indv'], dis_update=True) # Raw (hallucinated) output image GAN loss. if 'raw' in net_D_output: raw_loss = self.compute_GAN_losses(net_D_output['raw'], dis_update=True) self.dis_losses['GAN'] += raw_loss # Additional GAN loss. if self.add_dis_cfg is not None: for name in self.add_dis_cfg: self.dis_losses['GAN_' + name] = self.compute_GAN_losses( net_D_output[name], dis_update=True) # Temporal GAN loss. if self.cfg.trainer.loss_weight.temporal_gan > 0: if self.sequence_length > 1: for s in range(self.num_temporal_scales): self.dis_losses['GAN_T%d' % s] = \ self.compute_GAN_losses(net_D_output['temporal_%d' % s], dis_update=True) # Other custom losses. self._get_custom_dis_losses(net_D_output) # Sum all losses together. total_loss = self.Tensor(1).fill_(0) for key in self.dis_losses: if key != 'total': total_loss += self.dis_losses[key] * self.weights[key] self.dis_losses['total'] = total_loss with amp.scale_loss(total_loss, self.opt_D, loss_id=1) as scaled_loss: scaled_loss.backward() self.opt_D.step() def _get_custom_dis_losses(self, net_D_output): r"""All other custom losses go here. Args: net_D_output (dict): Output of the discriminator. """ pass def compute_GAN_losses(self, net_D_output, dis_update): r"""Compute GAN loss and feature matching loss. Args: net_D_output (dict): Output of the discriminator. dis_update (bool): Whether to update discriminator. """ if net_D_output['pred_fake'] is None: return self.Tensor(1).fill_(0) if dis_update else [ self.Tensor(1).fill_(0), self.Tensor(1).fill_(0) ] if dis_update: # Get the GAN loss for real/fake outputs. GAN_loss = \ self.criteria['GAN'](net_D_output['pred_fake']['output'], False, dis_update=True) + \ self.criteria['GAN'](net_D_output['pred_real']['output'], True, dis_update=True) return GAN_loss else: # Get the GAN loss and feature matching loss for fake output. GAN_loss = self.criteria['GAN']( net_D_output['pred_fake']['output'], True, dis_update=False) FM_loss = self.criteria['FeatureMatching']( net_D_output['pred_fake']['features'], net_D_output['pred_real']['features']) return GAN_loss, FM_loss def get_data_t(self, data, net_G_output, data_prev, t): r"""Get data at current time frame given the sequence of data. Args: data (dict): Training data for current iteration. net_G_output (dict): Output of the generator (for previous frame). data_prev (dict): Data for previous frame. t (int): Current time. """ label = data['label'][:, t] image = data['images'][:, t] if data_prev is not None: # Concat previous labels/fake images to the ones before. num_frames_G = self.cfg.data.num_frames_G prev_labels = concat_frames(data_prev['prev_labels'], data_prev['label'], num_frames_G - 1) prev_images = concat_frames(data_prev['prev_images'], net_G_output['fake_images'].detach(), num_frames_G - 1) else: prev_labels = prev_images = None data_t = dict() data_t['label'] = label data_t['image'] = image data_t['prev_labels'] = prev_labels data_t['prev_images'] = prev_images data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None return data_t def _end_of_iteration(self, data, current_epoch, current_iteration): r"""Print the errors to console.""" if not torch.distributed.is_initialized(): if current_iteration % self.cfg.logging_iter == 0: message = '(epoch: %d, iters: %d) ' % (current_epoch, current_iteration) for k, v in self.gen_losses.items(): if k != 'total': message += '%s: %.3f, ' % (k, v) message += '\n' for k, v in self.dis_losses.items(): if k != 'total': message += '%s: %.3f, ' % (k, v) print(message) def _init_tensorboard(self): r"""Initialize the tensorboard. For the SPADE model, we will record regular and FID, which is the average FID. """ self.regular_fid_meter = Meter('FID/regular') if self.cfg.trainer.model_average: self.average_fid_meter = Meter('FID/average') self.image_meter = Meter('images') self.meters = {} names = [ 'optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch' ] for name in names: self.meters[name] = Meter(name) def write_metrics(self): r"""If moving average model presents, we have two meters one for regular FID and one for average FID. If no moving average model, we just report average FID. """ if self.cfg.trainer.model_average: regular_fid, average_fid = self._compute_fid() if regular_fid is None or average_fid is None: return self.regular_fid_meter.write(regular_fid) self.average_fid_meter.write(average_fid) meters = [self.regular_fid_meter, self.average_fid_meter] else: regular_fid = self._compute_fid() if regular_fid is None: return self.regular_fid_meter.write(regular_fid) meters = [self.regular_fid_meter] for meter in meters: meter.flush(self.current_iteration) def _compute_fid(self): r"""Compute FID values.""" self.net_G.eval() self.net_G_output = None # Due to complicated video evaluation procedure we are using, we will # pass the trainer to the evaluation code instead of the # generator network. # net_G_for_evaluation = self.net_G trainer = self self.test_in_model_average_mode = False regular_fid_path = self._get_save_path('regular_fid', 'npy') few_shot = True if 'few_shot' in self.cfg.data.type else False regular_fid_value = compute_fid(regular_fid_path, self.val_data_loader, trainer, sample_size=self.sample_size, is_video=True, few_shot_video=few_shot) print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format( self.current_epoch, self.current_iteration, regular_fid_value)) if self.cfg.trainer.model_average: # Due to complicated video evaluation procedure we are using, # we will pass the trainer to the evaluation code instead of the # generator network. # avg_net_G_for_evaluation = self.net_G.module.averaged_model trainer_avg_mode = self self.test_in_model_average_mode = True # The above flag will be reset after computing FID. fid_path = self._get_save_path('average_fid', 'npy') few_shot = True if 'few_shot' in self.cfg.data.type else False fid_value = compute_fid(fid_path, self.val_data_loader, trainer_avg_mode, sample_size=self.sample_size, is_video=True, few_shot_video=few_shot) print('Epoch {:05}, Iteration {:09}, Average FID {}'.format( self.current_epoch, self.current_iteration, fid_value)) self.net_G.float() return regular_fid_value, fid_value else: self.net_G.float() return regular_fid_value def visualize_label(self, label): r"""Visualize the input label when saving to image. Args: label (tensor): Input label tensor. """ cfgdata = self.cfg.data if hasattr(cfgdata, 'for_pose_dataset'): label = tensor2pose(self.cfg, label) elif hasattr(cfgdata, 'input_labels') and \ 'seg_maps' in cfgdata.input_labels: for input_type in cfgdata.input_types: if 'seg_maps' in input_type: num_labels = input_type['seg_maps'].num_channels label = tensor2label(label, num_labels) elif getattr(cfgdata, 'label_channels', 1) > 3: label = tensor2im(label.sum(1, keepdim=True)) else: label = tensor2im(label) return label def save_image(self, path, data): r"""Save the output images to path. Note when the generate_raw_output is FALSE. Then, first_net_G_output['fake_raw_images'] is None and will not be displayed. In model average mode, we will plot the flow visualization twice. Args: path (str): Save path. data (dict): Training data for current iteration. """ self.net_G.eval() if self.cfg.trainer.model_average: self.net_G.module.averaged_model.eval() self.net_G_output = None with torch.no_grad(): first_net_G_output, net_G_output, all_info = self.gen_frames(data) if self.cfg.trainer.model_average: first_net_G_output_avg, net_G_output_avg, _ = self.gen_frames( data, use_model_average=True) # Visualize labels. label_lengths = self.train_data_loader.dataset.get_label_lengths() labels = split_labels(data['label'], label_lengths) vis_labels_start, vis_labels_end = [], [] for key, value in labels.items(): if key == 'seg_maps': vis_labels_start.append(self.visualize_label(value[:, -1])) vis_labels_end.append(self.visualize_label(value[:, 0])) else: vis_labels_start.append(tensor2im(value[:, -1])) vis_labels_end.append(tensor2im(value[:, 0])) if is_master(): vis_images = [ *vis_labels_start, tensor2im(data['images'][:, -1]), tensor2im(net_G_output['fake_images']), tensor2im(net_G_output['fake_raw_images']) ] if self.cfg.trainer.model_average: vis_images += [ tensor2im(net_G_output_avg['fake_images']), tensor2im(net_G_output_avg['fake_raw_images']) ] if self.sequence_length > 1: vis_images_first = [ *vis_labels_end, tensor2im(data['images'][:, 0]), tensor2im(first_net_G_output['fake_images']), tensor2im(first_net_G_output['fake_raw_images']) ] if self.cfg.trainer.model_average: vis_images_first += [ tensor2im(first_net_G_output_avg['fake_images']), tensor2im(first_net_G_output_avg['fake_raw_images']) ] if self.use_flow: flow_gt, conf_gt = self.criteria['Flow'].flowNet( data['images'][:, -1], data['images'][:, -2]) warped_image_gt = resample(data['images'][:, -1], flow_gt) vis_images_first += [ tensor2flow(flow_gt), tensor2im(conf_gt, normalize=False), tensor2im(warped_image_gt), ] vis_images += [ tensor2flow(net_G_output['fake_flow_maps']), tensor2im(net_G_output['fake_occlusion_masks'], normalize=False), tensor2im(net_G_output['warped_images']), ] if self.cfg.trainer.model_average: vis_images_first += [ tensor2flow(flow_gt), tensor2im(conf_gt, normalize=False), tensor2im(warped_image_gt), ] vis_images += [ tensor2flow(net_G_output_avg['fake_flow_maps']), tensor2im(net_G_output_avg['fake_occlusion_masks'], normalize=False), tensor2im(net_G_output_avg['warped_images']) ] vis_images = [[ np.vstack((im_first, im)) for im_first, im in zip(imgs_first, imgs) ] for imgs_first, imgs in zip(vis_images_first, vis_images) if imgs is not None] image_grid = np.hstack( [np.vstack(im) for im in vis_images if im is not None]) print('Save output images to {}'.format(path)) os.makedirs(os.path.dirname(path), exist_ok=True) imageio.imwrite(path, image_grid) # Gather all outputs for dumping into video. if self.sequence_length > 1: output_images = [] for item in all_info['outputs']: output_images.append(tensor2im(item['fake_images'])[0]) imageio.mimwrite(os.path.splitext(path)[0] + '.mp4', output_images, fps=2, macro_block_size=None) self.net_G.float()