Ejemplo n.º 1
0
    def _init_tensorboard(self):
        r"""Initialize the tensorboard."""
        # Logging frequency: self.cfg.logging_iter
        self.meters = {}
        names = ['optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch']
        for name in names:
            self.meters[name] = Meter(name)

        # Logging frequency: self.cfg.snapshot_save_iter
        names = ['FID_a', 'best_FID_a', 'FID_b', 'best_FID_b']
        self.metric_meters = {}
        for name in names:
            self.metric_meters[name] = Meter(name)

        # Logging frequency: self.cfg.image_display_iter
        self.image_meter = Meter('images')
Ejemplo n.º 2
0
 def _write_loss_meters(self):
     r"""Write all loss values to tensorboard."""
     for update, losses in self.losses.items():
         # update is 'gen_update' or 'dis_update'.
         assert update == 'gen_update' or update == 'dis_update'
         for loss_name, loss in losses.items():
             full_loss_name = update + '/' + loss_name
             if full_loss_name not in self.meters.keys():
                 # Create a new meter if it doesn't exist.
                 self.meters[full_loss_name] = Meter(full_loss_name)
             self.meters[full_loss_name].write(loss.item())
Ejemplo n.º 3
0
    def _init_tensorboard(self):
        r"""Initialize the tensorboard. Different algorithms might require
        different performance metrics. Hence, custom tensorboard
        initialization might be necessary.
        """
        # Logging frequency: self.cfg.logging_iter
        self.meters = {}
        names = [
            'optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch'
        ]
        for name in names:
            self.meters[name] = Meter(name)

        # Logging frequency: self.cfg.snapshot_save_iter
        names = ['FID', 'best_FID']
        self.metric_meters = {}
        for name in names:
            self.metric_meters[name] = Meter(name)

        # Logging frequency: self.cfg.image_display_iter
        self.image_meter = Meter('images')
Ejemplo n.º 4
0
 def _init_tensorboard(self):
     r"""Initialize the tensorboard. For the SPADE model, we will record
     regular and FID, which is the average FID.
     """
     self.regular_fid_meter = Meter('FID/regular')
     if self.cfg.trainer.model_average:
         self.average_fid_meter = Meter('FID/average')
     self.image_meter = Meter('images')
     self.meters = {}
     names = [
         'optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch'
     ]
     for name in names:
         self.meters[name] = Meter(name)
Ejemplo n.º 5
0
class Trainer(BaseTrainer):
    r"""Initialize SPADE trainer.

    Args:
        cfg (Config): Global configuration.
        net_G (obj): Generator network.
        net_D (obj): Discriminator network.
        opt_G (obj): Optimizer for the generator network.
        opt_D (obj): Optimizer for the discriminator network.
        sch_G (obj): Scheduler for the generator optimizer.
        sch_D (obj): Scheduler for the discriminator optimizer.
        train_data_loader (obj): Train data loader.
        val_data_loader (obj): Validation data loader.
    """
    def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
                 train_data_loader, val_data_loader):
        super(Trainer,
              self).__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
                             train_data_loader, val_data_loader)
        if cfg.data.type == 'imaginaire.datasets.paired_videos':
            self.video_mode = True
        else:
            self.video_mode = False

    def _init_loss(self, cfg):
        r"""Initialize loss terms.

        Args:
            cfg (obj): Global configuration.
        """
        self.criteria['GAN'] = GANLoss(cfg.trainer.gan_mode)
        self.weights['GAN'] = cfg.trainer.loss_weight.gan
        # Setup the perceptual loss. Note that perceptual loss can run in
        # fp16 mode for additional speed. We find that running on fp16 mode
        # leads to improve training speed while maintaining the same accuracy.
        if hasattr(cfg.trainer, 'perceptual_loss'):
            self.criteria['Perceptual'] = \
                PerceptualLoss(
                    cfg=cfg,
                    network=cfg.trainer.perceptual_loss.mode,
                    layers=cfg.trainer.perceptual_loss.layers,
                    weights=cfg.trainer.perceptual_loss.weights)
            self.weights['Perceptual'] = cfg.trainer.loss_weight.perceptual
        # Setup the feature matching loss.
        self.criteria['FeatureMatching'] = FeatureMatchingLoss()
        self.weights['FeatureMatching'] = \
            cfg.trainer.loss_weight.feature_matching
        # Setup the Gaussian KL divergence loss.
        self.criteria['GaussianKL'] = GaussianKLLoss()
        self.weights['GaussianKL'] = cfg.trainer.loss_weight.kl

    def _init_tensorboard(self):
        r"""Initialize the tensorboard. For the SPADE model, we will record
        regular and FID, which is the average FID.
        """
        self.regular_fid_meter = Meter('FID/regular')
        if self.cfg.trainer.model_average:
            self.average_fid_meter = Meter('FID/average')
        self.image_meter = Meter('images')
        self.meters = {}
        names = [
            'optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch'
        ]
        for name in names:
            self.meters[name] = Meter(name)

    def _start_of_iteration(self, data, current_iteration):
        r"""Model specific custom start of iteration process. We will do two
        things. First, put all the data to GPU. Second, we will resize the
        input so that it becomes multiple of the factor for bug-free
        convolutional operations. This factor is given by the yaml file.
        E.g., base = getattr(self.net_G, 'base', 32)

        Args:
            data (dict): The current batch.
            current_iteration (int): The iteration number of the current batch.
        """
        if len(data['label'].size()) == 5:
            label_image_raw = data['images'][:, 0:-1, :, :, :]
            label_image = label_image_raw.reshape([
                label_image_raw.size(0), -1,
                label_image_raw.size(3),
                label_image_raw.size(4)
            ])
            images = data['images'][:, -1, :, :, :]
            label_label = data['label'].reshape([
                data['label'].size(0), -1, data['label'].size(3),
                data['label'].size(4)
            ])
            label = torch.cat([label_label, label_image], 1)
            data['label'] = label
            data['images'] = images
        data = to_device(data, 'cuda')
        data = self._resize_data(data)
        return data

    def gen_forward(self, data):
        r"""Compute the loss for SPADE generator.

        Args:
            data (dict): Training data at the current iteration.
        """
        net_G_output = self.net_G(data)
        net_D_output = self.net_D(data, net_G_output)

        self._time_before_loss()

        output_fake = self._get_outputs(net_D_output, real=False)
        self.gen_losses['GAN'] = \
            self.criteria['GAN'](output_fake, True, dis_update=False)

        self.gen_losses['FeatureMatching'] = self.criteria['FeatureMatching'](
            net_D_output['fake_features'], net_D_output['real_features'])

        if self.net_G_module.use_style_encoder:
            self.gen_losses['GaussianKL'] = \
                self.criteria['GaussianKL'](net_G_output['mu'],
                                            net_G_output['logvar'])
        else:
            self.gen_losses['GaussianKL'] = \
                self.gen_losses['GAN'].new_tensor([0])

        if hasattr(self.cfg.trainer, 'perceptual_loss'):
            self.gen_losses['Perceptual'] = self.criteria['Perceptual'](
                net_G_output['fake_images'], data['images'])

        total_loss = self.gen_losses['GAN'].new_tensor([0])
        for key in self.criteria:
            total_loss += self.gen_losses[key] * self.weights[key]

        self.gen_losses['total'] = total_loss
        return total_loss

    def dis_forward(self, data):
        r"""Compute the loss for SPADE discriminator.

        Args:
            data (dict): Training data at the current iteration.
        """
        with torch.no_grad():
            net_G_output = self.net_G(data)
            net_G_output['fake_images'] = net_G_output['fake_images'].detach()
        net_D_output = self.net_D(data, net_G_output)

        self._time_before_loss()

        output_fake = self._get_outputs(net_D_output, real=False)
        output_real = self._get_outputs(net_D_output, real=True)
        fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True)
        true_loss = self.criteria['GAN'](output_real, True, dis_update=True)
        self.dis_losses['GAN/fake'] = fake_loss
        self.dis_losses['GAN/true'] = true_loss
        self.dis_losses['GAN'] = fake_loss + true_loss
        total_loss = self.dis_losses['GAN'] * self.weights['GAN']
        self.dis_losses['total'] = total_loss
        return total_loss

    def _get_visualizations(self, data):
        r"""Compute visualization image. We will first recalculate the batch
        statistics for the moving average model.

        Args:
            data (dict): The current batch.
        """
        self.recalculate_model_average_batch_norm_statistics(
            self.train_data_loader)
        with torch.no_grad():
            label_lengths = self.train_data_loader.dataset.get_label_lengths()
            labels = split_labels(data['label'], label_lengths)
            # Get visualization of the segmentation mask.
            segmap = tensor2label(labels['seg_maps'],
                                  label_lengths['seg_maps'],
                                  output_normalized_tensor=True)
            segmap = torch.cat([x.unsqueeze(0) for x in segmap], 0)
            net_G_output = self.net_G(data, random_style=True)
            vis_images = [data['images'], segmap, net_G_output['fake_images']]
            if self.cfg.trainer.model_average:
                net_G_model_average_output = \
                    self.net_G.module.averaged_model(data, random_style=True)
                vis_images.append(net_G_model_average_output['fake_images'])
        return vis_images

    def recalculate_model_average_batch_norm_statistics(self, data_loader):
        r"""Update the statistics in the moving average model.

        Args:
            data_loader (pytorch data loader): Data loader for estimating the
                statistics.
        """
        if not self.cfg.trainer.model_average:
            return
        model_average_iteration = \
            self.cfg.trainer.model_average_batch_norm_estimation_iteration
        if model_average_iteration == 0:
            return
        with torch.no_grad():
            # Accumulate bn stats..
            self.net_G.module.averaged_model.train()
            # Reset running stats.
            self.net_G.module.averaged_model.apply(reset_batch_norm)
            for cal_it, cal_data in enumerate(data_loader):
                if cal_it >= model_average_iteration:
                    print('Done with {} iterations of updating batch norm '
                          'statistics'.format(model_average_iteration))
                    break
                # cal_data = to_device(cal_data, 'cuda')
                cal_data = self._start_of_iteration(cal_data, 0)
                # Averaging over all batches
                self.net_G.module.averaged_model.apply(
                    calibrate_batch_norm_momentum)
                self.net_G.module.averaged_model(cal_data)

    def write_metrics(self):
        r"""If moving average model presents, we have two meters one for
        regular FID and one for average FID. If no moving average model,
        we just report average FID.
        """
        if self.cfg.trainer.model_average:
            regular_fid, average_fid = self._compute_fid()
            self.regular_fid_meter.write(regular_fid)
            self.average_fid_meter.write(average_fid)
            meters = [self.regular_fid_meter, self.average_fid_meter]
        else:
            regular_fid = self._compute_fid()
            self.regular_fid_meter.write(regular_fid)
            meters = [self.regular_fid_meter]
        for meter in meters:
            meter.flush(self.current_iteration)

    def _compute_fid(self):
        r"""We will compute FID for the regular model using the eval mode.
        For the moving average model, we will use the eval mode.
        """
        self.net_G.eval()
        net_G_for_evaluation = \
            functools.partial(self.net_G, random_style=True)
        regular_fid_path = self._get_save_path('regular_fid', 'npy')
        preprocess = \
            functools.partial(self._start_of_iteration, current_iteration=0)

        regular_fid_value = compute_fid(regular_fid_path,
                                        self.val_data_loader,
                                        net_G_for_evaluation,
                                        preprocess=preprocess)
        print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format(
            self.current_epoch, self.current_iteration, regular_fid_value))
        if self.cfg.trainer.model_average:
            avg_net_G_for_evaluation = \
                functools.partial(self.net_G.module.averaged_model,
                                  random_style=True)
            fid_path = self._get_save_path('average_fid', 'npy')
            fid_value = compute_fid(fid_path,
                                    self.val_data_loader,
                                    avg_net_G_for_evaluation,
                                    preprocess=preprocess)
            print('Epoch {:05}, Iteration {:09}, FID {}'.format(
                self.current_epoch, self.current_iteration, fid_value))
            self.net_G.float()
            return regular_fid_value, fid_value
        else:
            self.net_G.float()
            return regular_fid_value

    def _resize_data(self, data):
        r"""Resize input label maps and images so that it can be properly
        generated by the generator.

        Args:
            data (dict): Input dictionary contains 'label' and 'image fields.
        """
        base = getattr(self.net_G, 'base', 32)
        sy = math.floor(data['label'].size()[2] * 1.0 // base) * base
        sx = math.floor(data['label'].size()[3] * 1.0 // base) * base
        data['label'] = F.interpolate(data['label'],
                                      size=[sy, sx],
                                      mode='nearest')
        if 'images' in data.keys():
            data['images'] = F.interpolate(data['images'],
                                           size=[sy, sx],
                                           mode='bicubic')
        return data
Ejemplo n.º 6
0
class BaseTrainer(object):
    r"""Base trainer. We expect that all trainers inherit this class.

    Args:
        cfg (obj): Global configuration.
        net_G (obj): Generator network.
        net_D (obj): Discriminator network.
        opt_G (obj): Optimizer for the generator network.
        opt_D (obj): Optimizer for the discriminator network.
        sch_G (obj): Scheduler for the generator optimizer.
        sch_D (obj): Scheduler for the discriminator optimizer.
        train_data_loader (obj): Train data loader.
        val_data_loader (obj): Validation data loader.
    """
    def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
                 train_data_loader, val_data_loader):
        super(BaseTrainer, self).__init__()
        print('Setup trainer.')

        # Initialize models and data loaders.
        self.cfg = cfg
        self.net_G = net_G
        if cfg.trainer.model_average:
            # Two wrappers (DDP + model average).
            self.net_G_module = self.net_G.module.module
        elif not cfg.trainer.distribute:
            self.net_G_module = self.net_G
        else:
            # One wrapper (DDP)
            self.net_G_module = self.net_G.module
        self.val_data_loader = val_data_loader
        self.is_inference = train_data_loader is None
        self.net_D = net_D
        self.opt_G = opt_G
        self.opt_D = opt_D
        self.sch_G = sch_G
        self.sch_D = sch_D
        self.train_data_loader = train_data_loader

        # Initialize loss functions.
        # All loss names have weights. Some have criterion modules.
        # Mapping from loss names to criterion modules.
        self.criteria = nn.ModuleDict()
        # Mapping from loss names to loss weights.
        self.weights = dict()
        self.losses = dict(gen_update=dict(), dis_update=dict())
        self.gen_losses = self.losses['gen_update']
        self.dis_losses = self.losses['dis_update']
        self._init_loss(cfg)
        for loss_name, loss_weight in self.weights.items():
            print("Loss {:<20} Weight {}".format(loss_name, loss_weight))
            if loss_name in self.criteria.keys() and \
                    self.criteria[loss_name] is not None:
                self.criteria[loss_name].to('cuda')

        if self.is_inference:
            # The initialization steps below can be skipped during inference.
            return

        # Initialize logging attributes.
        self.current_iteration = 0
        self.current_epoch = 0
        self.start_iteration_time = None
        self.start_epoch_time = None
        self.elapsed_iteration_time = 0
        self.time_iteration = -1
        self.time_epoch = -1
        self.best_fid = None
        if getattr(self.cfg, 'speed_benchmark', False):
            self.accu_gen_forw_iter_time = 0
            self.accu_gen_loss_iter_time = 0
            self.accu_gen_back_iter_time = 0
            self.accu_gen_step_iter_time = 0
            self.accu_gen_avg_iter_time = 0
            self.accu_dis_forw_iter_time = 0
            self.accu_dis_loss_iter_time = 0
            self.accu_dis_back_iter_time = 0
            self.accu_dis_step_iter_time = 0

        # Initialize tensorboard and hparams.
        self._init_tensorboard()
        self._init_hparams()

    def _init_tensorboard(self):
        r"""Initialize the tensorboard. Different algorithms might require
        different performance metrics. Hence, custom tensorboard
        initialization might be necessary.
        """
        # Logging frequency: self.cfg.logging_iter
        self.meters = {}
        names = [
            'optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch'
        ]
        for name in names:
            self.meters[name] = Meter(name)

        # Logging frequency: self.cfg.snapshot_save_iter
        names = ['FID', 'best_FID']
        self.metric_meters = {}
        for name in names:
            self.metric_meters[name] = Meter(name)

        # Logging frequency: self.cfg.image_display_iter
        self.image_meter = Meter('images')

    def _init_hparams(self):
        r"""Initialize a dictionary of hyperparameters that we want to monitor
        in the HParams dashboard in tensorBoard.
        """
        self.hparam_dict = {}

    def _write_tensorboard(self):
        r"""Write values to tensorboard. By default, we will log the time used
        per iteration, time used per epoch, generator learning rate, and
        discriminator learning rate. We will log all the losses as well as
        custom meters.
        """
        # Logs that are shared by all models.
        # self._write_to_meters({'time/iteration': self.time_iteration,
        #                        'time/epoch': self.time_epoch,
        #                        'optim/gen_lr': self.sch_G.get_last_lr()[0],
        #                        'optim/dis_lr': self.sch_D.get_last_lr()[0]},
        #                       self.meters)

        self._write_to_meters(
            {
                'time/iteration': self.time_iteration,
                'time/epoch': self.time_epoch,
                'optim/gen_lr': self.sch_G.get_lr()[0],
                'optim/dis_lr': self.sch_D.get_lr()[0]
            }, self.meters)
        print("self.sch_G.get_lr(): {}".format(self.sch_G.get_lr()))
        # Logs for loss values. Different models have different losses.
        self._write_loss_meters()
        # Other custom logs.
        self._write_custom_meters()

        # Write all logs to tensorboard.
        self._flush_meters(self.meters)

    def _write_loss_meters(self):
        r"""Write all loss values to tensorboard."""
        for update, losses in self.losses.items():
            # update is 'gen_update' or 'dis_update'.
            assert update == 'gen_update' or update == 'dis_update'
            for loss_name, loss in losses.items():
                full_loss_name = update + '/' + loss_name
                if full_loss_name not in self.meters.keys():
                    # Create a new meter if it doesn't exist.
                    self.meters[full_loss_name] = Meter(full_loss_name)
                self.meters[full_loss_name].write(loss.item())

    def _write_custom_meters(self):
        r"""Dummy member function to be overloaded by the child class.
        In the child class, you can write down whatever you want to track.
        """
        pass

    @staticmethod
    def _write_to_meters(data, meters):
        r"""Write values to meters."""
        for key, value in data.items():
            meters[key].write(value)

    def _flush_meters(self, meters):
        r"""Flush all meters using the current iteration."""
        for meter in meters.values():
            meter.flush(self.current_iteration)

    def _pre_save_checkpoint(self):
        r"""Implement the things you want to do before saving a checkpoint.
        For example, you can compute the K-mean features (pix2pixHD) before
        saving the model weights to a checkpoint.
        """
        pass

    def save_checkpoint(self, current_epoch, current_iteration):
        r"""Save network weights, optimizer parameters, scheduler parameters
        to a checkpoint.
        """
        self._pre_save_checkpoint()
        _save_checkpoint(self.cfg, self.net_G, self.net_D, self.opt_G,
                         self.opt_D, self.sch_G, self.sch_D, current_epoch,
                         current_iteration)

    def load_checkpoint(self, cfg, checkpoint_path, resume=None):
        r"""Load network weights, optimizer parameters, scheduler parameters
        from a checkpoint.

        Args:
            cfg (obj): Global configuration.
            checkpoint_path (str): Path to the checkpoint.
            resume (bool or None): If not ``None``, will determine whether or
                not to load optimizers in addition to network weights.
        """
        if os.path.exists(checkpoint_path):
            # If checkpoint_path exists, we will load its weights to
            # initialize our network.
            if resume is None:
                resume = False
        elif os.path.exists(os.path.join(cfg.logdir, 'latest_checkpoint.txt')):
            # This is for resuming the training from the previously saved
            # checkpoint.
            fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt')
            with open(fn, 'r') as f:
                line = f.read().splitlines()
            checkpoint_path = os.path.join(cfg.logdir, line[0].split(' ')[-1])
            if resume is None:
                resume = True
        else:
            # checkpoint not found and not specified. We will train
            # everything from scratch.
            current_epoch = 0
            current_iteration = 0
            print('No checkpoint found.')
            return current_epoch, current_iteration
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
        current_epoch = 0
        current_iteration = 0
        if resume:
            self.net_G.load_state_dict(checkpoint['net_G'])
            if not self.is_inference:
                self.net_D.load_state_dict(checkpoint['net_D'])
                if 'opt_G' in checkpoint:
                    self.opt_G.load_state_dict(checkpoint['opt_G'])
                    self.opt_D.load_state_dict(checkpoint['opt_D'])
                    self.sch_G.load_state_dict(checkpoint['sch_G'])
                    self.sch_D.load_state_dict(checkpoint['sch_D'])
                    current_epoch = checkpoint['current_epoch']
                    current_iteration = checkpoint['current_iteration']
                    print('Load from: {}'.format(checkpoint_path))
                else:
                    print('Load network weights only.')
        else:
            self.net_G.load_state_dict(checkpoint['net_G'])
            print('Load generator weights only.')

        print('Done with loading the checkpoint.')
        return current_epoch, current_iteration

    def start_of_epoch(self, current_epoch):
        r"""Things to do before an epoch.

        Args:
            current_epoch (int): Current number of epoch.
        """
        self._start_of_epoch(current_epoch)
        self.current_epoch = current_epoch
        self.start_epoch_time = time.time()

    def start_of_iteration(self, data, current_iteration):
        r"""Things to do before an iteration.

        Args:
            data (dict): Data used for the current iteration.
            current_iteration (int): Current number of iteration.
        """
        data = self._start_of_iteration(data, current_iteration)
        data = to_cuda(data)
        self.current_iteration = current_iteration
        if not self.is_inference:
            self.net_D.train()
        self.net_G.train()
        # torch.cuda.synchronize()
        self.start_iteration_time = time.time()
        return data

    def end_of_iteration(self, data, current_epoch, current_iteration):
        r"""Things to do after an iteration.

        Args:
            data (dict): Data used for the current iteration.
            current_epoch (int): Current number of epoch.
            current_iteration (int): Current number of iteration.
        """
        self.current_iteration = current_iteration
        self.current_epoch = current_epoch
        # Update the learning rate policy for the generator if operating in the
        # iteration mode.
        if self.cfg.gen_opt.lr_policy.iteration_mode:
            self.sch_G.step()
        # Update the learning rate policy for the discriminator if operating
        # in the iteration mode.
        if self.cfg.dis_opt.lr_policy.iteration_mode:
            self.sch_D.step()

        # Accumulate time
        # torch.cuda.synchronize()
        self.elapsed_iteration_time += time.time() - self.start_iteration_time
        # Logging.
        if current_iteration % self.cfg.logging_iter == 0:
            ave_t = self.elapsed_iteration_time / self.cfg.logging_iter
            self.time_iteration = ave_t
            print('Iteration: {}, average iter time: '
                  '{:6f}.'.format(current_iteration, ave_t))
            self.elapsed_iteration_time = 0

            if getattr(self.cfg, 'speed_benchmark', False):
                # Below code block only needed when analyzing computation
                # bottleneck.
                print('\tGenerator FWD time {:6f}'.format(
                    self.accu_gen_forw_iter_time / self.cfg.logging_iter))
                print('\tGenerator LOS time {:6f}'.format(
                    self.accu_gen_loss_iter_time / self.cfg.logging_iter))
                print('\tGenerator BCK time {:6f}'.format(
                    self.accu_gen_back_iter_time / self.cfg.logging_iter))
                print('\tGenerator STP time {:6f}'.format(
                    self.accu_gen_step_iter_time / self.cfg.logging_iter))
                print('\tGenerator AVG time {:6f}'.format(
                    self.accu_gen_avg_iter_time / self.cfg.logging_iter))

                print('\tDiscriminator FWD time {:6f}'.format(
                    self.accu_dis_forw_iter_time / self.cfg.logging_iter))
                print('\tDiscriminator LOS time {:6f}'.format(
                    self.accu_dis_loss_iter_time / self.cfg.logging_iter))
                print('\tDiscriminator BCK time {:6f}'.format(
                    self.accu_dis_back_iter_time / self.cfg.logging_iter))
                print('\tDiscriminator STP time {:6f}'.format(
                    self.accu_dis_step_iter_time / self.cfg.logging_iter))

                print('{:6f}'.format(ave_t))

                self.accu_gen_forw_iter_time = 0
                self.accu_gen_loss_iter_time = 0
                self.accu_gen_back_iter_time = 0
                self.accu_gen_step_iter_time = 0
                self.accu_gen_avg_iter_time = 0
                self.accu_dis_forw_iter_time = 0
                self.accu_dis_loss_iter_time = 0
                self.accu_dis_back_iter_time = 0
                self.accu_dis_step_iter_time = 0

        self._end_of_iteration(data, current_epoch, current_iteration)
        # Save everything to the checkpoint.
        if current_iteration >= self.cfg.snapshot_save_start_iter and \
                current_iteration % self.cfg.snapshot_save_iter == 0:
            self.save_image(self._get_save_path('images', 'jpg'), data)
            self.save_checkpoint(current_epoch, current_iteration)
            self.write_metrics()
        # Compute image to be saved.
        elif current_iteration % self.cfg.image_save_iter == 0:
            self.save_image(self._get_save_path('images', 'jpg'), data)
        elif current_iteration % self.cfg.image_display_iter == 0:
            image_path = os.path.join(self.cfg.logdir, 'images', 'current.jpg')
            self.save_image(image_path, data)
        if current_iteration % self.cfg.logging_iter == 0:
            self._write_tensorboard()
            print("gen loss: {}".format(self.gen_losses))
            print("dis loss: {}".format(self.dis_losses))

    def end_of_epoch(self, data, current_epoch, current_iteration):
        r"""Things to do after an epoch.

        Args:
            data (dict): Data used for the current iteration.

            current_epoch (int): Current number of epoch.
            current_iteration (int): Current number of iteration.
        """
        # Update the learning rate policy for the generator if operating in the
        # epoch mode.
        self.current_iteration = current_iteration
        self.current_epoch = current_epoch
        if not self.cfg.gen_opt.lr_policy.iteration_mode:
            self.sch_G.step()
        # Update the learning rate policy for the discriminator if operating
        # in the epoch mode.
        if not self.cfg.dis_opt.lr_policy.iteration_mode:
            self.sch_D.step()
        elapsed_epoch_time = time.time() - self.start_epoch_time
        # Logging.
        print('Epoch: {}, total time: {:6f}.'.format(current_epoch,
                                                     elapsed_epoch_time))
        self.time_epoch = elapsed_epoch_time
        self._end_of_epoch(data, current_epoch, current_iteration)
        # Save everything to the checkpoint.
        if current_epoch >= self.cfg.snapshot_save_start_epoch and \
                current_epoch % self.cfg.snapshot_save_epoch == 0:
            self.save_image(self._get_save_path('images', 'jpg'), data)
            self.save_checkpoint(current_epoch, current_iteration)
            self.write_metrics()

    def pre_process(self, data):
        r"""Custom data pre-processing function. Utilize this function if you
        need to preprocess your data before sending it to the generator and
        discriminator.

        Args:
            data (dict): Data used for the current iteration.
        """

    def recalculate_model_average_batch_norm_statistics(self, data_loader):
        r"""Update the statistics in the moving average model.

        Args:
            data_loader (torch.utils.data.DataLoader): Data loader for
                estimating the statistics.
        """
        if not self.cfg.trainer.model_average:
            return
        model_average_iteration = \
            self.cfg.trainer.model_average_batch_norm_estimation_iteration
        if model_average_iteration == 0:
            return
        with torch.no_grad():
            # Accumulate bn stats..
            self.net_G.module.averaged_model.train()
            # Reset running stats.
            self.net_G.module.averaged_model.apply(reset_batch_norm)
            for cal_it, cal_data in enumerate(data_loader):
                if cal_it >= model_average_iteration:
                    print('Done with {} iterations of updating batch norm '
                          'statistics'.format(model_average_iteration))
                    break
                cal_data = to_device(cal_data, 'cuda')
                # Averaging over all batches
                self.net_G.module.averaged_model.apply(
                    calibrate_batch_norm_momentum)
                self.net_G.module.averaged_model(cal_data)

    def save_image(self, path, data):
        r"""Compute visualization images and save them to the disk.

        Args:
            path (str): Location of the file.
            data (dict): Data used for the current iteration.
        """
        self.net_G.eval()
        vis_images = self._get_visualizations(data)
        if is_master() and vis_images is not None:
            vis_images = torch.cat(vis_images, dim=3).float()
            vis_images = (vis_images + 1) / 2
            print('Save output images to {}'.format(path))
            vis_images.clamp_(0, 1)
            os.makedirs(os.path.dirname(path), exist_ok=True)
            image_grid = torchvision.utils.make_grid(vis_images,
                                                     nrow=1,
                                                     padding=0,
                                                     normalize=False)
            if self.cfg.trainer.image_to_tensorboard:
                self.image_meter.write_image(image_grid,
                                             self.current_iteration)
            torchvision.utils.save_image(image_grid, path, nrow=1)

    def write_metrics(self):
        r"""Write metrics to the tensorboard."""
        cur_fid = self._compute_fid()
        if cur_fid is not None:
            if self.best_fid is not None:
                self.best_fid = min(self.best_fid, cur_fid)
            else:
                self.best_fid = cur_fid
            metric_dict = {'FID': cur_fid, 'best_FID': self.best_fid}
            self._write_to_meters(metric_dict, self.metric_meters)
            self._flush_meters(self.metric_meters)
            if self.cfg.trainer.hparam_to_tensorboard:
                add_hparams(self.hparam_dict, metric_dict)

    def _get_save_path(self, subdir, ext):
        r"""Get the image save path.

        Args:
            subdir (str): Sub-directory under the main directory for saving
                the outputs.
            ext (str): Filename extension for the image (e.g., jpg, png, ...).
        Return:
            (str): image filename to be used to save the visualization results.
        """
        subdir_path = os.path.join(self.cfg.logdir, subdir)
        if not os.path.exists(subdir_path):
            os.makedirs(subdir_path, exist_ok=True)
        return os.path.join(
            subdir_path,
            'epoch_{:05}_iteration_{:09}.{}'.format(self.current_epoch,
                                                    self.current_iteration,
                                                    ext))

    def _get_outputs(self, net_D_output, real=True):
        r"""Return output values. Note that when the gan mode is relativistic.
        It will do the difference before returning.

        Args:
           net_D_output (dict):
               real_outputs (tensor): Real output values.
               fake_outputs (tensor): Fake output values.
           real (bool): Return real or fake.
        """
        def _get_difference(a, b):
            r"""Get difference between two lists of tensors or two tensors.

            Args:
                a: list of tensors or tensor
                b: list of tensors or tensor
            """
            out = list()
            for x, y in zip(a, b):
                if isinstance(x, list):
                    res = _get_difference(x, y)
                else:
                    res = x - y
                out.append(res)
            return out

        if real:
            if self.cfg.trainer.gan_relativistic:
                return _get_difference(net_D_output['real_outputs'],
                                       net_D_output['fake_outputs'])
            else:
                return net_D_output['real_outputs']
        else:
            if self.cfg.trainer.gan_relativistic:
                return _get_difference(net_D_output['fake_outputs'],
                                       net_D_output['real_outputs'])
            else:
                return net_D_output['fake_outputs']

    def _start_of_epoch(self, current_epoch):
        r"""Operations to do before starting an epoch.

        Args:
            current_epoch (int): Current number of epoch.
        """
        pass

    def _start_of_iteration(self, data, current_iteration):
        r"""Operations to do before starting an iteration.

        Args:
            data (dict): Data used for the current iteration.
            current_iteration (int): Current epoch number.
        Returns:
            (dict): Data used for the current iteration. They might be
                processed by the custom _start_of_iteration function.
        """
        return data

    def _end_of_iteration(self, data, current_epoch, current_iteration):
        r"""Operations to do after an iteration.

        Args:
            data (dict): Data used for the current iteration.
            current_epoch (int): Current number of epoch.
            current_iteration (int): Current epoch number.
        """
        pass

    def _end_of_epoch(self, data, current_epoch, current_iteration):
        r"""Operations to do after an epoch.

        Args:
            data (dict): Data used for the current iteration.
            current_epoch (int): Current number of epoch.
            current_iteration (int): Current epoch number.
        """
        pass

    def _get_visualizations(self, data):
        r"""Compute visualization outputs.

        Args:
            data (dict): Data used for the current iteration.
        """
        return None

    def _compute_fid(self):
        r"""FID computation function to be overloaded."""
        return None

    def _init_loss(self, cfg):
        r"""Every trainer should implement its own init loss function."""
        raise NotImplementedError

    def gen_update(self, data):
        r"""Update the generator.

        Args:
            data (dict): Data used for the current iteration.
        """
        self.opt_G.zero_grad()

        # Set requires_grad flags.
        requires_grad(self.net_G_module, True)
        requires_grad(self.net_D, False)

        # Compute the loss.
        self._time_before_forward()
        total_loss = self.gen_forward(data)
        if total_loss is None:
            return

        # Backpropagate the loss.
        self._time_before_backward()
        with amp.scale_loss(total_loss, self.opt_G, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        # Optionally clip gradient norm.
        if hasattr(self.cfg.gen_opt, 'clip_grad_norm'):
            nn.utils.clip_grad_norm_(amp.master_params(self.opt_G),
                                     self.cfg.gen_opt.clip_grad_norm)

        # Perform an optimizer step.
        self._time_before_step()
        self.opt_G.step()

        # Update model average.
        self._time_before_model_avg()
        if self.cfg.trainer.model_average:
            self.net_G.module.update_average()

        self._detach_losses()
        self._time_before_leave_gen()

    def gen_forward(self, data):
        r"""Every trainer should implement its own generator forward."""
        raise NotImplementedError

    def dis_update(self, data):
        r"""Update the discriminator.

        Args:
            data (dict): Data used for the current iteration.
        """
        self.opt_D.zero_grad()

        # Set requires_grad flags.
        requires_grad(self.net_G_module, False)
        requires_grad(self.net_D, True)

        # Compute the loss.
        self._time_before_forward()
        total_loss = self.dis_forward(data)
        if total_loss is None:
            return

        # Backpropagate the loss.
        self._time_before_backward()
        with amp.scale_loss(total_loss, self.opt_D, loss_id=1) as scaled_loss:
            scaled_loss.backward()

        # Perform an optimizer step.
        self._time_before_step()
        self.opt_D.step()

        self._detach_losses()
        self._time_before_leave_dis()

    def dis_forward(self, data):
        r"""Every trainer should implement its own discriminator forward."""
        raise NotImplementedError

    def test(self, data_loader, output_dir, inference_args):
        r"""Compute results images for a batch of input data and save the
        results in the specified folder.

        Args:
            data_loader (torch.utils.data.DataLoader): PyTorch dataloader.
            output_dir (str): Target location for saving the output image.
        """
        if self.cfg.trainer.model_average:
            net_G = self.net_G.module.averaged_model
        else:
            net_G = self.net_G.module
        net_G.eval()

        print('# of samples %d' % len(data_loader))
        for it, data in enumerate(tqdm(data_loader)):
            data = self.start_of_iteration(data, current_iteration=-1)
            with torch.no_grad():
                output_images, file_names = \
                    net_G.inference(data, **vars(inference_args))
            for output_image, file_name in zip(output_images, file_names):
                fullname = os.path.join(output_dir, file_name + '.jpg')
                output_image = tensor2pilimage(output_image.clamp_(-1, 1),
                                               minus1to1_normalized=True)
                save_pilimage_in_jpeg(fullname, output_image)

    def _get_total_loss(self, gen_forward):
        r"""Return the total loss to be backpropagated.
        Args:
            gen_forward (bool): If ``True``, backpropagates the generator loss,
                otherwise the discriminator loss.
        """
        losses = self.gen_losses if gen_forward else self.dis_losses
        total_loss = torch.tensor(0., device=torch.device('cuda'))
        # Iterates over all possible losses.
        for loss_name in self.weights:
            # If it is for the current model (gen/dis).
            if loss_name in losses:
                # Multiply it with the corresponding weight
                # and add it to the total loss.
                total_loss += losses[loss_name] * self.weights[loss_name]
        losses['total'] = total_loss  # logging purpose
        return total_loss

    def _detach_losses(self):
        r"""Detach all logging variables to prevent potential memory leak."""
        for loss_name in self.gen_losses:
            self.gen_losses[loss_name] = self.gen_losses[loss_name].detach()
        for loss_name in self.dis_losses:
            self.dis_losses[loss_name] = self.dis_losses[loss_name].detach()

    def _time_before_forward(self):
        r"""
        Record time before applying forward.
        """
        if getattr(self.cfg, 'speed_benchmark', False):
            torch.cuda.synchronize()
            self.forw_time = time.time()

    def _time_before_loss(self):
        r"""
        Record time before computing loss.
        """
        if getattr(self.cfg, 'speed_benchmark', False):
            torch.cuda.synchronize()
            self.loss_time = time.time()

    def _time_before_backward(self):
        r"""
        Record time before applying backward.
        """
        if getattr(self.cfg, 'speed_benchmark', False):
            torch.cuda.synchronize()
            self.back_time = time.time()

    def _time_before_step(self):
        r"""
        Record time before updating the weights
        """
        if getattr(self.cfg, 'speed_benchmark', False):
            torch.cuda.synchronize()
            self.step_time = time.time()

    def _time_before_model_avg(self):
        r"""
        Record time before applying model average.
        """
        if getattr(self.cfg, 'speed_benchmark', False):
            torch.cuda.synchronize()
            self.avg_time = time.time()

    def _time_before_leave_gen(self):
        r"""
        Record forward, backward, loss, and model average time for the
        generator update.
        """
        if getattr(self.cfg, 'speed_benchmark', False):
            torch.cuda.synchronize()
            end_time = time.time()
            self.accu_gen_forw_iter_time += self.loss_time - self.forw_time
            self.accu_gen_loss_iter_time += self.back_time - self.loss_time
            self.accu_gen_back_iter_time += self.step_time - self.back_time
            self.accu_gen_step_iter_time += self.avg_time - self.step_time
            self.accu_gen_avg_iter_time += end_time - self.avg_time

    def _time_before_leave_dis(self):
        r"""
        Record forward, backward, loss time for the discriminator update.
        """
        if getattr(self.cfg, 'speed_benchmark', False):
            torch.cuda.synchronize()
            end_time = time.time()
            self.accu_dis_forw_iter_time += self.loss_time - self.forw_time
            self.accu_dis_loss_iter_time += self.back_time - self.loss_time
            self.accu_dis_back_iter_time += self.step_time - self.back_time
            self.accu_dis_step_iter_time += end_time - self.step_time
Ejemplo n.º 7
0
class Trainer(BaseTrainer):
    r"""Initialize vid2vid trainer.

    Args:
        cfg (obj): Global configuration.
        net_G (obj): Generator network.
        net_D (obj): Discriminator network.
        opt_G (obj): Optimizer for the generator network.
        opt_D (obj): Optimizer for the discriminator network.
        sch_G (obj): Scheduler for the generator optimizer.
        sch_D (obj): Scheduler for the discriminator optimizer.
        train_data_loader (obj): Train data loader.
        val_data_loader (obj): Validation data loader.
    """
    def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
                 train_data_loader, val_data_loader):
        super(Trainer,
              self).__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
                             train_data_loader, val_data_loader)
        # Below is for testing setting, the FID computation during training
        # is just for getting a quick idea of the performance. It does not
        # equal to the final performance evaluation.
        # Below, we will determine how many videos that we want to do
        # evaluation, and the length of each video.
        # It is better to keep the number of videos to be multiple of 8 so
        # that all the GPUs in a node will contribute equally to the
        # evaluation. None of them is idol.
        self.sample_size = (getattr(cfg.trainer, 'num_videos_to_test', 64),
                            getattr(cfg.trainer, 'num_frames_per_video', 10))

        self.sequence_length = 1
        if not self.is_inference:
            self.train_dataset = self.train_data_loader.dataset
            self.sequence_length_max = \
                min(getattr(cfg.data.train, 'max_sequence_length', 100),
                    self.train_dataset.sequence_length_max)
        self.Tensor = torch.cuda.FloatTensor
        self.has_fg = getattr(cfg.data, 'has_foreground', False)

        self.net_G_output = self.data_prev = None
        self.net_G_module = self.net_G.module
        if self.cfg.trainer.model_average:
            self.net_G_module = self.net_G_module.module

    def _assign_criteria(self, name, criterion, weight):
        r"""Assign training loss terms.

        Args:
            name (str): Loss name
            criterion (obj): Loss object.
            weight (float): Loss weight. It should be non-negative.
        """
        self.criteria[name] = criterion
        self.weights[name] = weight

    def _init_loss(self, cfg):
        r"""Initialize training loss terms. In vid2vid, in addition to
        the GAN loss, feature matching loss, and perceptual loss used in
        pix2pixHD, we also add temporal GAN (and feature matching) loss,
        and flow warping loss. Optionally, we can also add an additional
        face discriminator for the face region.

        Args:
            cfg (obj): Global configuration.
        """
        self.criteria = dict()
        self.weights = dict()
        trainer_cfg = cfg.trainer
        loss_weight = cfg.trainer.loss_weight

        # GAN loss and feature matching loss.
        self._assign_criteria('GAN', GANLoss(trainer_cfg.gan_mode),
                              loss_weight.gan)
        self._assign_criteria('FeatureMatching', FeatureMatchingLoss(),
                              loss_weight.feature_matching)

        # Perceptual loss.
        perceptual_loss = cfg.trainer.perceptual_loss
        self._assign_criteria(
            'Perceptual',
            PerceptualLoss(cfg=cfg,
                           network=perceptual_loss.mode,
                           layers=perceptual_loss.layers,
                           weights=perceptual_loss.weights,
                           num_scales=getattr(perceptual_loss, 'num_scales',
                                              1)), loss_weight.perceptual)

        # L1 Loss.
        if getattr(loss_weight, 'L1', 0) > 0:
            self._assign_criteria('L1', torch.nn.L1Loss(), loss_weight.L1)

        # Whether to add an additional discriminator for specific regions.
        self.add_dis_cfg = getattr(self.cfg.dis, 'additional_discriminators',
                                   None)
        if self.add_dis_cfg is not None:
            for name in self.add_dis_cfg:
                add_dis_cfg = self.add_dis_cfg[name]
                self.weights['GAN_' + name] = add_dis_cfg.loss_weight
                self.weights['FeatureMatching_' + name] = \
                    loss_weight.feature_matching

        # Temporal GAN loss.
        self.num_temporal_scales = get_nested_attr(self.cfg.dis,
                                                   'temporal.num_scales', 0)
        for s in range(self.num_temporal_scales):
            self.weights['GAN_T%d' % s] = loss_weight.temporal_gan
            self.weights['FeatureMatching_T%d' % s] = \
                loss_weight.feature_matching

        # Flow loss. It consists of three parts: L1 loss compared to GT,
        # warping loss when used to warp images, and loss on the occlusion mask.
        self.use_flow = hasattr(cfg.gen, 'flow')
        if self.use_flow:
            self.criteria['Flow'] = FlowLoss(cfg)
            self.weights['Flow'] = self.weights['Flow_L1'] = \
                self.weights['Flow_Warp'] = \
                self.weights['Flow_Mask'] = loss_weight.flow

        # Other custom losses.
        self._define_custom_losses()

    def _define_custom_losses(self):
        r"""All other custom losses are defined here."""
        pass

    def _start_of_epoch(self, current_epoch):
        r"""Things to do before an epoch. When current_epoch is smaller than
        $(single_frame_epoch), we only train a single frame and the generator is
        just an image generator. After that, we start doing temporal training
        and train multiple frames. We will double the number of training frames
        every $(num_epochs_temporal_step) epochs.

        Args:
            current_epoch (int): Current number of epoch.
        """
        cfg = self.cfg
        # Only generates one frame at the beginning of training
        if current_epoch < cfg.single_frame_epoch:
            self.train_dataset.sequence_length = 1
        # Then add the temporal network to generator, and train multiple frames.
        elif current_epoch == cfg.single_frame_epoch:
            self.init_temporal_network()

        # Double the length of training sequence every few epochs.
        temp_epoch = current_epoch - cfg.single_frame_epoch
        if temp_epoch > 0:
            sequence_length = \
                cfg.data.train.initial_sequence_length * \
                (2 ** (temp_epoch // cfg.num_epochs_temporal_step))
            sequence_length = min(sequence_length, self.sequence_length_max)
            if sequence_length > self.sequence_length:
                self.sequence_length = sequence_length
                self.train_dataset.set_sequence_length(sequence_length)
                print('------- Updating sequence length to %d -------' %
                      sequence_length)

    def init_temporal_network(self):
        r"""Initialize temporal training when beginning to train multiple
        frames. Set the sequence length to $(initial_sequence_length).
        """
        self.tensorboard_init = False
        # Update training sequence length.
        self.sequence_length = self.cfg.data.train.initial_sequence_length
        if not self.is_inference:
            self.train_dataset.set_sequence_length(self.sequence_length)
            print('------ Now start training %d frames -------' %
                  self.sequence_length)

    def _start_of_iteration(self, data, current_iteration):
        r"""Things to do before an iteration.

        Args:
            data (dict): Data used for the current iteration.
            current_iteration (int): Current number of iteration.
        """
        data = self.pre_process(data)
        return to_cuda(data)

    def pre_process(self, data):
        r"""Do any data pre-processing here.

        Args:
            data (dict): Data used for the current iteration.
        """
        data_cfg = self.cfg.data
        if hasattr(data_cfg, 'for_pose_dataset') and \
                ('pose_maps-densepose' in data_cfg.input_labels):
            pose_cfg = data_cfg.for_pose_dataset
            data['label'] = pre_process_densepose(pose_cfg, data['label'],
                                                  self.is_inference)
        return data

    def post_process(self, data, net_G_output):
        r"""Do any postprocessing of the data / output here.

        Args:
            data (dict): Training data at the current iteration.
            net_G_output (dict): Output of the generator.
        """
        return data, net_G_output

    def gen_update(self, data):
        r"""Update the vid2vid generator. We update in the fashion of
        dis_update (frame 1), gen_update (frame 1),
        dis_update (frame 2), gen_update (frame 2), ... in each iteration.

        Args:
            data (dict): Training data at the current iteration.
        """
        # Whether to reuse generator output for both gen_update and dis_update.
        # It saves time but consumes a bit more memory.
        reuse_gen_output = getattr(self.cfg.trainer, 'reuse_gen_output', True)

        past_frames = [None, None]
        net_G_output = None
        data_prev = None
        for t in range(self.sequence_length):
            #            print(self.sequence_length)
            data_t = self.get_data_t(data, net_G_output, data_prev, t)
            data_prev = data_t

            # Discriminator update.
            if reuse_gen_output:
                net_G_output = self.net_G(data_t)
            else:
                with torch.no_grad():
                    net_G_output = self.net_G(data_t)
            data_t, net_G_output = self.post_process(data_t, net_G_output)

            # Get losses and update D if image generated by network in training.
            if 'fake_images_source' not in net_G_output:
                net_G_output['fake_images_source'] = 'in_training'
            if net_G_output['fake_images_source'] != 'pretrained':
                net_D_output, _ = self.net_D(data_t, detach(net_G_output),
                                             past_frames)
                self.get_dis_losses(net_D_output)

            # Generator update.
            if not reuse_gen_output:
                net_G_output = self.net_G(data_t)
                data_t, net_G_output = self.post_process(data_t, net_G_output)

            # Get losses and update G if image generated by network in training.
            if 'fake_images_source' not in net_G_output:
                net_G_output['fake_images_source'] = 'in_training'
            if net_G_output['fake_images_source'] != 'pretrained':
                net_D_output, past_frames = \
                    self.net_D(data_t, net_G_output, past_frames)
                self.get_gen_losses(data_t, net_G_output, net_D_output)

        # update average
        if self.cfg.trainer.model_average:
            self.net_G.module.update_average()

    def dis_update(self, data):
        r"""The update is already done in gen_update.

        Args:
            data (dict): Training data at the current iteration.
        """
        pass

    def reset(self):
        r"""Reset the trainer (for inference) at the beginning of a sequence.
        """
        # print('Resetting trainer.')
        self.net_G_output = self.data_prev = None
        self.t = 0

        self.test_in_model_average_mode = getattr(
            self, 'test_in_model_average_mode', self.cfg.trainer.model_average)
        if self.test_in_model_average_mode:
            net_G_module = self.net_G.module.averaged_model
        else:
            net_G_module = self.net_G.module
        if hasattr(net_G_module, 'reset'):
            net_G_module.reset()

    def create_sequence_output_dir(self, output_dir, key):
        r"""Create output subdir for this sequence.

        Args:
            output_dir (str): Root output dir.
            key (str): LMDB key which contains sequence name and file name.
        Returns:
            output_dir (str): Output subdir for this sequence.
            seq_name (str): Name of this sequence.
        """
        seq_dir = '/'.join(key.split('/')[:-1])
        output_dir = os.path.join(output_dir, seq_dir)
        os.makedirs(output_dir, exist_ok=True)
        seq_name = seq_dir.replace('/', '-')
        return output_dir, seq_name

    def test(self, test_data_loader, root_output_dir, inference_args):
        r"""Run inference on all sequences.

        Args:
            test_data_loader (object): Test data loader.
            root_output_dir (str): Location to dump outputs.
            inference_args (optional): Optional args.
        """

        # Go over all sequences.
        loader = test_data_loader
        num_inference_sequences = loader.dataset.num_inference_sequences()
        for sequence_idx in range(num_inference_sequences):
            loader.dataset.set_inference_sequence_idx(sequence_idx)
            print('Seq id: %d, Seq length: %d' %
                  (sequence_idx + 1, len(loader)))

            # Reset model at start of new inference sequence.
            self.reset()
            self.sequence_length = len(loader)

            # Go over all frames of this sequence.
            video = []
            for idx, data in enumerate(tqdm(loader)):
                key = data['key']['images'][0][0]
                filename = key.split('/')[-1]

                # Create output dir for this sequence.
                if idx == 0:
                    output_dir, seq_name = \
                        self.create_sequence_output_dir(root_output_dir, key)
                    video_path = os.path.join(output_dir, '..', seq_name)

                # Get output and save images.
                data['img_name'] = filename
                data = self.start_of_iteration(data, current_iteration=-1)
                output = self.test_single(data, output_dir, inference_args)
                video.append(output)

            # Save output as mp4.
            imageio.mimsave(video_path + '.mp4', video, fps=15)

    def test_single(self, data, output_dir=None, inference_args=None):
        r"""The inference function. If output_dir exists, also save the
        output image.
        Args:
            data (dict): Training data at the current iteration.
            output_dir (str): Save image directory.
            inference_args (obj): Inference args.
        """
        if getattr(inference_args, 'finetune', False):
            if not getattr(self, 'has_finetuned', False):
                self.finetune(data, inference_args)

        net_G = self.net_G
        if self.test_in_model_average_mode:
            net_G = net_G.module.averaged_model
        net_G.eval()

        data_t = self.get_data_t(data, self.net_G_output, self.data_prev, 0)
        if self.is_inference or self.sequence_length > 1:
            self.data_prev = data_t

        # Generator forward.
        with torch.no_grad():
            self.net_G_output = net_G(data_t)

        if output_dir is None:
            return self.net_G_output

        save_fake_only = getattr(inference_args, 'save_fake_only', False)
        if save_fake_only:
            image_grid = tensor2im(self.net_G_output['fake_images'])[0]
        else:
            vis_images = self.get_test_output_images(data)
            image_grid = np.hstack(
                [np.vstack(im) for im in vis_images if im is not None])
        if 'img_name' in data:
            save_name = data['img_name'].split('.')[0] + '.jpg'
        else:
            save_name = '%04d.jpg' % self.t
        output_filename = os.path.join(output_dir, save_name)
        os.makedirs(output_dir, exist_ok=True)
        imageio.imwrite(output_filename, image_grid)
        self.t += 1

        return image_grid

    def get_test_output_images(self, data):
        r"""Get the visualization output of test function.

        Args:
            data (dict): Training data at the current iteration.
        """
        vis_images = [
            self.visualize_label(data['label'][:, -1]),
            tensor2im(data['images'][:, -1]),
            tensor2im(self.net_G_output['fake_images']),
        ]
        return vis_images

    def gen_frames(self, data, use_model_average=False):
        r"""Generate a sequence of frames given a sequence of data.

        Args:
            data (dict): Training data at the current iteration.
            use_model_average (bool): Whether to use model average
                for update or not.
        """
        net_G_output = None  # Previous generator output.
        data_prev = None  # Previous data.
        if use_model_average:
            net_G = self.net_G.module.averaged_model
        else:
            net_G = self.net_G

        # Iterate through the length of sequence.
        all_info = {'inputs': [], 'outputs': []}
        for t in range(self.sequence_length):
            # Get the data at the current time frame.
            data_t = self.get_data_t(data, net_G_output, data_prev, t)
            data_prev = data_t

            # Generator forward.
            with torch.no_grad():
                net_G_output = net_G(data_t)

            # Do any postprocessing if necessary.
            data_t, net_G_output = self.post_process(data_t, net_G_output)

            if t == 0:
                # Get the output at beginning of sequence for visualization.
                first_net_G_output = net_G_output

            all_info['inputs'].append(data_t)
            all_info['outputs'].append(net_G_output)

        return first_net_G_output, net_G_output, all_info

    def get_gen_losses(self, data_t, net_G_output, net_D_output):
        r"""Compute generator losses.

        Args:
            data_t (dict): Training data at the current time t.
            net_G_output (dict): Output of the generator.
            net_D_output (dict): Output of the discriminator.
        """
        self.opt_G.zero_grad()

        # Individual frame GAN loss and feature matching loss.
        self.gen_losses['GAN'], self.gen_losses['FeatureMatching'] = \
            self.compute_GAN_losses(net_D_output['indv'], dis_update=False)

        # Perceptual loss.
        self.gen_losses['Perceptual'] = self.criteria['Perceptual'](
            net_G_output['fake_images'], data_t['image'])

        # L1 loss.
        if getattr(self.cfg.trainer.loss_weight, 'L1', 0) > 0:
            self.gen_losses['L1'] = self.criteria['L1'](
                net_G_output['fake_images'], data_t['image'])

        # Raw (hallucinated) output image losses (GAN and perceptual).
        if 'raw' in net_D_output:
            raw_GAN_losses = self.compute_GAN_losses(net_D_output['raw'],
                                                     dis_update=False)
            fg_mask = get_fg_mask(data_t['label'], self.has_fg)
            raw_perceptual_loss = self.criteria['Perceptual'](
                net_G_output['fake_raw_images'] * fg_mask,
                data_t['image'] * fg_mask)
            self.gen_losses['GAN'] += raw_GAN_losses[0]
            self.gen_losses['FeatureMatching'] += raw_GAN_losses[1]
            self.gen_losses['Perceptual'] += raw_perceptual_loss

        # Additional discriminator losses.
        if self.add_dis_cfg is not None:
            for name in self.add_dis_cfg:
                self.gen_losses['GAN_' + name], \
                    self.gen_losses['FeatureMatching_' + name] = \
                    self.compute_GAN_losses(net_D_output[name],
                                            dis_update=False)

        # Flow and mask loss.
        if self.use_flow:
            self.gen_losses['Flow_L1'], self.gen_losses['Flow_Warp'], \
                self.gen_losses['Flow_Mask'] = self.criteria['Flow'](
                data_t, net_G_output, self.current_epoch)

        # Temporal GAN loss and feature matching loss.
        if self.cfg.trainer.loss_weight.temporal_gan > 0:
            if self.sequence_length > 1:
                for s in range(self.num_temporal_scales):
                    loss_GAN, loss_FM = self.compute_GAN_losses(
                        net_D_output['temporal_%d' % s], dis_update=False)
                    self.gen_losses['GAN_T%d' % s] = loss_GAN
                    self.gen_losses['FeatureMatching_T%d' % s] = loss_FM

        # Other custom losses.
        self._get_custom_gen_losses(data_t, net_G_output, net_D_output)

        # Sum all losses together.
        total_loss = self.Tensor(1).fill_(0)
        for key in self.gen_losses:
            if key != 'total':
                total_loss += self.gen_losses[key] * self.weights[key]

        self.gen_losses['total'] = total_loss
        with amp.scale_loss(total_loss, self.opt_G, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        self.opt_G.step()

    def _get_custom_gen_losses(self, data_t, net_G_output, net_D_output):
        r"""All other custom generator losses go here.

        Args:
            data_t (dict): Training data at the current time t.
            net_G_output (dict): Output of the generator.
            net_D_output (dict): Output of the discriminator.
        """
        pass

    def get_dis_losses(self, net_D_output):
        r"""Compute discriminator losses.

        Args:
            net_D_output (dict): Output of the discriminator.
        """
        self.opt_D.zero_grad()

        # Individual frame GAN loss.
        self.dis_losses['GAN'] = self.compute_GAN_losses(net_D_output['indv'],
                                                         dis_update=True)

        # Raw (hallucinated) output image GAN loss.
        if 'raw' in net_D_output:
            raw_loss = self.compute_GAN_losses(net_D_output['raw'],
                                               dis_update=True)
            self.dis_losses['GAN'] += raw_loss

        # Additional GAN loss.
        if self.add_dis_cfg is not None:
            for name in self.add_dis_cfg:
                self.dis_losses['GAN_' + name] = self.compute_GAN_losses(
                    net_D_output[name], dis_update=True)

        # Temporal GAN loss.
        if self.cfg.trainer.loss_weight.temporal_gan > 0:
            if self.sequence_length > 1:
                for s in range(self.num_temporal_scales):
                    self.dis_losses['GAN_T%d' % s] = \
                        self.compute_GAN_losses(net_D_output['temporal_%d' % s],
                                                dis_update=True)

        # Other custom losses.
        self._get_custom_dis_losses(net_D_output)

        # Sum all losses together.
        total_loss = self.Tensor(1).fill_(0)
        for key in self.dis_losses:
            if key != 'total':
                total_loss += self.dis_losses[key] * self.weights[key]
        self.dis_losses['total'] = total_loss

        with amp.scale_loss(total_loss, self.opt_D, loss_id=1) as scaled_loss:
            scaled_loss.backward()
        self.opt_D.step()

    def _get_custom_dis_losses(self, net_D_output):
        r"""All other custom losses go here.

        Args:
            net_D_output (dict): Output of the discriminator.
        """
        pass

    def compute_GAN_losses(self, net_D_output, dis_update):
        r"""Compute GAN loss and feature matching loss.

        Args:
            net_D_output (dict): Output of the discriminator.
            dis_update (bool): Whether to update discriminator.
        """
        if net_D_output['pred_fake'] is None:
            return self.Tensor(1).fill_(0) if dis_update else [
                self.Tensor(1).fill_(0),
                self.Tensor(1).fill_(0)
            ]
        if dis_update:
            # Get the GAN loss for real/fake outputs.
            GAN_loss = \
                self.criteria['GAN'](net_D_output['pred_fake']['output'], False,
                                     dis_update=True) + \
                self.criteria['GAN'](net_D_output['pred_real']['output'], True,
                                     dis_update=True)
            return GAN_loss
        else:
            # Get the GAN loss and feature matching loss for fake output.
            GAN_loss = self.criteria['GAN'](
                net_D_output['pred_fake']['output'], True, dis_update=False)

            FM_loss = self.criteria['FeatureMatching'](
                net_D_output['pred_fake']['features'],
                net_D_output['pred_real']['features'])
            return GAN_loss, FM_loss

    def get_data_t(self, data, net_G_output, data_prev, t):
        r"""Get data at current time frame given the sequence of data.

        Args:
            data (dict): Training data for current iteration.
            net_G_output (dict): Output of the generator (for previous frame).
            data_prev (dict): Data for previous frame.
            t (int): Current time.
        """
        label = data['label'][:, t]
        image = data['images'][:, t]

        if data_prev is not None:
            # Concat previous labels/fake images to the ones before.
            num_frames_G = self.cfg.data.num_frames_G
            prev_labels = concat_frames(data_prev['prev_labels'],
                                        data_prev['label'], num_frames_G - 1)
            prev_images = concat_frames(data_prev['prev_images'],
                                        net_G_output['fake_images'].detach(),
                                        num_frames_G - 1)
        else:
            prev_labels = prev_images = None

        data_t = dict()
        data_t['label'] = label
        data_t['image'] = image
        data_t['prev_labels'] = prev_labels
        data_t['prev_images'] = prev_images
        data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None
        return data_t

    def _end_of_iteration(self, data, current_epoch, current_iteration):
        r"""Print the errors to console."""
        if not torch.distributed.is_initialized():
            if current_iteration % self.cfg.logging_iter == 0:
                message = '(epoch: %d, iters: %d) ' % (current_epoch,
                                                       current_iteration)
                for k, v in self.gen_losses.items():
                    if k != 'total':
                        message += '%s: %.3f,  ' % (k, v)
                message += '\n'
                for k, v in self.dis_losses.items():
                    if k != 'total':
                        message += '%s: %.3f,  ' % (k, v)
                print(message)

    def _init_tensorboard(self):
        r"""Initialize the tensorboard. For the SPADE model, we will record
        regular and FID, which is the average FID.
        """
        self.regular_fid_meter = Meter('FID/regular')
        if self.cfg.trainer.model_average:
            self.average_fid_meter = Meter('FID/average')
        self.image_meter = Meter('images')
        self.meters = {}
        names = [
            'optim/gen_lr', 'optim/dis_lr', 'time/iteration', 'time/epoch'
        ]
        for name in names:
            self.meters[name] = Meter(name)

    def write_metrics(self):
        r"""If moving average model presents, we have two meters one for
        regular FID and one for average FID. If no moving average model,
        we just report average FID.
        """
        if self.cfg.trainer.model_average:
            regular_fid, average_fid = self._compute_fid()
            if regular_fid is None or average_fid is None:
                return
            self.regular_fid_meter.write(regular_fid)
            self.average_fid_meter.write(average_fid)
            meters = [self.regular_fid_meter, self.average_fid_meter]
        else:
            regular_fid = self._compute_fid()
            if regular_fid is None:
                return
            self.regular_fid_meter.write(regular_fid)
            meters = [self.regular_fid_meter]
        for meter in meters:
            meter.flush(self.current_iteration)

    def _compute_fid(self):
        r"""Compute FID values."""
        self.net_G.eval()
        self.net_G_output = None
        # Due to complicated video evaluation procedure we are using, we will
        # pass the trainer to the evaluation code instead of the
        # generator network.
        # net_G_for_evaluation = self.net_G
        trainer = self
        self.test_in_model_average_mode = False
        regular_fid_path = self._get_save_path('regular_fid', 'npy')
        few_shot = True if 'few_shot' in self.cfg.data.type else False
        regular_fid_value = compute_fid(regular_fid_path,
                                        self.val_data_loader,
                                        trainer,
                                        sample_size=self.sample_size,
                                        is_video=True,
                                        few_shot_video=few_shot)
        print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format(
            self.current_epoch, self.current_iteration, regular_fid_value))
        if self.cfg.trainer.model_average:
            # Due to complicated video evaluation procedure we are using,
            # we will pass the trainer to the evaluation code instead of the
            # generator network.
            # avg_net_G_for_evaluation = self.net_G.module.averaged_model
            trainer_avg_mode = self
            self.test_in_model_average_mode = True
            # The above flag will be reset after computing FID.
            fid_path = self._get_save_path('average_fid', 'npy')
            few_shot = True if 'few_shot' in self.cfg.data.type else False
            fid_value = compute_fid(fid_path,
                                    self.val_data_loader,
                                    trainer_avg_mode,
                                    sample_size=self.sample_size,
                                    is_video=True,
                                    few_shot_video=few_shot)
            print('Epoch {:05}, Iteration {:09}, Average FID {}'.format(
                self.current_epoch, self.current_iteration, fid_value))
            self.net_G.float()
            return regular_fid_value, fid_value
        else:
            self.net_G.float()
            return regular_fid_value

    def visualize_label(self, label):
        r"""Visualize the input label when saving to image.

        Args:
            label (tensor): Input label tensor.
        """
        cfgdata = self.cfg.data
        if hasattr(cfgdata, 'for_pose_dataset'):
            label = tensor2pose(self.cfg, label)
        elif hasattr(cfgdata, 'input_labels') and \
                'seg_maps' in cfgdata.input_labels:
            for input_type in cfgdata.input_types:
                if 'seg_maps' in input_type:
                    num_labels = input_type['seg_maps'].num_channels
            label = tensor2label(label, num_labels)
        elif getattr(cfgdata, 'label_channels', 1) > 3:
            label = tensor2im(label.sum(1, keepdim=True))
        else:
            label = tensor2im(label)
        return label

    def save_image(self, path, data):
        r"""Save the output images to path.
        Note when the generate_raw_output is FALSE. Then,
        first_net_G_output['fake_raw_images'] is None and will not be displayed.
        In model average mode, we will plot the flow visualization twice.
        Args:
            path (str): Save path.
            data (dict): Training data for current iteration.
        """
        self.net_G.eval()
        if self.cfg.trainer.model_average:
            self.net_G.module.averaged_model.eval()
        self.net_G_output = None
        with torch.no_grad():
            first_net_G_output, net_G_output, all_info = self.gen_frames(data)
            if self.cfg.trainer.model_average:
                first_net_G_output_avg, net_G_output_avg, _ = self.gen_frames(
                    data, use_model_average=True)

        # Visualize labels.
        label_lengths = self.train_data_loader.dataset.get_label_lengths()
        labels = split_labels(data['label'], label_lengths)
        vis_labels_start, vis_labels_end = [], []
        for key, value in labels.items():
            if key == 'seg_maps':
                vis_labels_start.append(self.visualize_label(value[:, -1]))
                vis_labels_end.append(self.visualize_label(value[:, 0]))
            else:
                vis_labels_start.append(tensor2im(value[:, -1]))
                vis_labels_end.append(tensor2im(value[:, 0]))

        if is_master():
            vis_images = [
                *vis_labels_start,
                tensor2im(data['images'][:, -1]),
                tensor2im(net_G_output['fake_images']),
                tensor2im(net_G_output['fake_raw_images'])
            ]
            if self.cfg.trainer.model_average:
                vis_images += [
                    tensor2im(net_G_output_avg['fake_images']),
                    tensor2im(net_G_output_avg['fake_raw_images'])
                ]

            if self.sequence_length > 1:
                vis_images_first = [
                    *vis_labels_end,
                    tensor2im(data['images'][:, 0]),
                    tensor2im(first_net_G_output['fake_images']),
                    tensor2im(first_net_G_output['fake_raw_images'])
                ]
                if self.cfg.trainer.model_average:
                    vis_images_first += [
                        tensor2im(first_net_G_output_avg['fake_images']),
                        tensor2im(first_net_G_output_avg['fake_raw_images'])
                    ]

                if self.use_flow:
                    flow_gt, conf_gt = self.criteria['Flow'].flowNet(
                        data['images'][:, -1], data['images'][:, -2])
                    warped_image_gt = resample(data['images'][:, -1], flow_gt)
                    vis_images_first += [
                        tensor2flow(flow_gt),
                        tensor2im(conf_gt, normalize=False),
                        tensor2im(warped_image_gt),
                    ]
                    vis_images += [
                        tensor2flow(net_G_output['fake_flow_maps']),
                        tensor2im(net_G_output['fake_occlusion_masks'],
                                  normalize=False),
                        tensor2im(net_G_output['warped_images']),
                    ]
                    if self.cfg.trainer.model_average:
                        vis_images_first += [
                            tensor2flow(flow_gt),
                            tensor2im(conf_gt, normalize=False),
                            tensor2im(warped_image_gt),
                        ]
                        vis_images += [
                            tensor2flow(net_G_output_avg['fake_flow_maps']),
                            tensor2im(net_G_output_avg['fake_occlusion_masks'],
                                      normalize=False),
                            tensor2im(net_G_output_avg['warped_images'])
                        ]

                vis_images = [[
                    np.vstack((im_first, im))
                    for im_first, im in zip(imgs_first, imgs)
                ] for imgs_first, imgs in zip(vis_images_first, vis_images)
                              if imgs is not None]

            image_grid = np.hstack(
                [np.vstack(im) for im in vis_images if im is not None])

            print('Save output images to {}'.format(path))
            os.makedirs(os.path.dirname(path), exist_ok=True)
            imageio.imwrite(path, image_grid)

            # Gather all outputs for dumping into video.
            if self.sequence_length > 1:
                output_images = []
                for item in all_info['outputs']:
                    output_images.append(tensor2im(item['fake_images'])[0])

                imageio.mimwrite(os.path.splitext(path)[0] + '.mp4',
                                 output_images,
                                 fps=2,
                                 macro_block_size=None)

        self.net_G.float()