Пример #1
0
    def finetune(self, data, inference_args):
        r"""Finetune the model for a few iterations on the inference data."""
        # Get the list of params to finetune.
        self.net_G, self.net_D, self.opt_G, self.opt_D = \
            get_optimizer_with_params(self.cfg, self.net_G, self.net_D,
                                      param_names_start_with=[
                                          'weight_generator.fc', 'conv_img',
                                          'up'])
        data_finetune = {k: v for k, v in data.items()}
        ref_labels = data_finetune['few_shot_label']
        ref_images = data_finetune['few_shot_images']

        # Number of iterations to finetune.
        iterations = getattr(inference_args, 'finetune_iter', 100)
        for it in range(1, iterations + 1):
            # Randomly set one of the reference images as target.
            idx = np.random.randint(ref_labels.size(1))
            tgt_label, tgt_image = ref_labels[:, idx], ref_images[:, idx]
            # Randomly shift and flip the target image.
            tgt_label, tgt_image = random_roll([tgt_label, tgt_image])
            data_finetune['label'] = tgt_label.unsqueeze(1)
            data_finetune['images'] = tgt_image.unsqueeze(1)

            self.gen_update(data_finetune)
            self.dis_update(data_finetune)
            if (it % (iterations // 10)) == 0:
                print(it)

        self.has_finetuned = True
Пример #2
0
    def recalculate_model_average_batch_norm_statistics(self, data_loader):
        r"""Update the statistics in the moving average model.

        Args:
            data_loader (pytorch data loader): Data loader for estimating the
                statistics.
        """
        if not self.cfg.trainer.model_average:
            return
        model_average_iteration = \
            self.cfg.trainer.model_average_batch_norm_estimation_iteration
        if model_average_iteration == 0:
            return
        with torch.no_grad():
            # Accumulate bn stats..
            self.net_G.module.averaged_model.train()
            # Reset running stats.
            self.net_G.module.averaged_model.apply(reset_batch_norm)
            for cal_it, cal_data in enumerate(data_loader):
                if cal_it >= model_average_iteration:
                    print('Done with {} iterations of updating batch norm '
                          'statistics'.format(model_average_iteration))
                    break
                # cal_data = to_device(cal_data, 'cuda')
                cal_data = self._start_of_iteration(cal_data, 0)
                # Averaging over all batches
                self.net_G.module.averaged_model.apply(
                    calibrate_batch_norm_momentum)
                self.net_G.module.averaged_model(cal_data)
Пример #3
0
    def _start_of_epoch(self, current_epoch):
        r"""Things to do before an epoch. When current_epoch is smaller than
        $(single_frame_epoch), we only train a single frame and the generator is
        just an image generator. After that, we start doing temporal training
        and train multiple frames. We will double the number of training frames
        every $(num_epochs_temporal_step) epochs.

        Args:
            current_epoch (int): Current number of epoch.
        """
        cfg = self.cfg
        # Only generates one frame at the beginning of training
        if current_epoch < cfg.single_frame_epoch:
            self.train_dataset.sequence_length = 1
        # Then add the temporal network to generator, and train multiple frames.
        elif current_epoch == cfg.single_frame_epoch:
            self.init_temporal_network()

        # Double the length of training sequence every few epochs.
        temp_epoch = current_epoch - cfg.single_frame_epoch
        if temp_epoch > 0:
            sequence_length = \
                cfg.data.train.initial_sequence_length * \
                (2 ** (temp_epoch // cfg.num_epochs_temporal_step))
            sequence_length = min(sequence_length, self.sequence_length_max)
            if sequence_length > self.sequence_length:
                self.sequence_length = sequence_length
                self.train_dataset.set_sequence_length(sequence_length)
                print('------- Updating sequence length to %d -------' %
                      sequence_length)
Пример #4
0
    def _write_tensorboard(self):
        r"""Write values to tensorboard. By default, we will log the time used
        per iteration, time used per epoch, generator learning rate, and
        discriminator learning rate. We will log all the losses as well as
        custom meters.
        """
        # Logs that are shared by all models.
        # self._write_to_meters({'time/iteration': self.time_iteration,
        #                        'time/epoch': self.time_epoch,
        #                        'optim/gen_lr': self.sch_G.get_last_lr()[0],
        #                        'optim/dis_lr': self.sch_D.get_last_lr()[0]},
        #                       self.meters)

        self._write_to_meters(
            {
                'time/iteration': self.time_iteration,
                'time/epoch': self.time_epoch,
                'optim/gen_lr': self.sch_G.get_lr()[0],
                'optim/dis_lr': self.sch_D.get_lr()[0]
            }, self.meters)
        print("self.sch_G.get_lr(): {}".format(self.sch_G.get_lr()))
        # Logs for loss values. Different models have different losses.
        self._write_loss_meters()
        # Other custom logs.
        self._write_custom_meters()

        # Write all logs to tensorboard.
        self._flush_meters(self.meters)
Пример #5
0
    def end_of_epoch(self, data, current_epoch, current_iteration):
        r"""Things to do after an epoch.

        Args:
            data (dict): Data used for the current iteration.

            current_epoch (int): Current number of epoch.
            current_iteration (int): Current number of iteration.
        """
        # Update the learning rate policy for the generator if operating in the
        # epoch mode.
        self.current_iteration = current_iteration
        self.current_epoch = current_epoch
        if not self.cfg.gen_opt.lr_policy.iteration_mode:
            self.sch_G.step()
        # Update the learning rate policy for the discriminator if operating
        # in the epoch mode.
        if not self.cfg.dis_opt.lr_policy.iteration_mode:
            self.sch_D.step()
        elapsed_epoch_time = time.time() - self.start_epoch_time
        # Logging.
        print('Epoch: {}, total time: {:6f}.'.format(current_epoch,
                                                     elapsed_epoch_time))
        self.time_epoch = elapsed_epoch_time
        self._end_of_epoch(data, current_epoch, current_iteration)
        # Save everything to the checkpoint.
        if current_epoch >= self.cfg.snapshot_save_start_epoch and \
                current_epoch % self.cfg.snapshot_save_epoch == 0:
            self.save_image(self._get_save_path('images', 'jpg'), data)
            self.save_checkpoint(current_epoch, current_iteration)
            self.write_metrics()
Пример #6
0
    def _compute_fid(self):
        r"""We will compute FID for the regular model using the eval mode.
        For the moving average model, we will use the eval mode.
        """
        self.net_G.eval()
        net_G_for_evaluation = \
            functools.partial(self.net_G, random_style=True)
        regular_fid_path = self._get_save_path('regular_fid', 'npy')
        preprocess = \
            functools.partial(self._start_of_iteration, current_iteration=0)

        regular_fid_value = compute_fid(regular_fid_path,
                                        self.val_data_loader,
                                        net_G_for_evaluation,
                                        preprocess=preprocess)
        print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format(
            self.current_epoch, self.current_iteration, regular_fid_value))
        if self.cfg.trainer.model_average:
            avg_net_G_for_evaluation = \
                functools.partial(self.net_G.module.averaged_model,
                                  random_style=True)
            fid_path = self._get_save_path('average_fid', 'npy')
            fid_value = compute_fid(fid_path,
                                    self.val_data_loader,
                                    avg_net_G_for_evaluation,
                                    preprocess=preprocess)
            print('Epoch {:05}, Iteration {:09}, FID {}'.format(
                self.current_epoch, self.current_iteration, fid_value))
            self.net_G.float()
            return regular_fid_value, fid_value
        else:
            self.net_G.float()
            return regular_fid_value
Пример #7
0
    def test(self, data_loader, output_dir, inference_args):
        r"""Compute results images for a batch of input data and save the
        results in the specified folder.

        Args:
            data_loader (torch.utils.data.DataLoader): PyTorch dataloader.
            output_dir (str): Target location for saving the output image.
        """
        if self.cfg.trainer.model_average:
            net_G = self.net_G.module.averaged_model
        else:
            net_G = self.net_G.module
        net_G.eval()

        print('# of samples %d' % len(data_loader))
        for it, data in enumerate(tqdm(data_loader)):
            data = self.start_of_iteration(data, current_iteration=-1)
            with torch.no_grad():
                output_images, file_names = \
                    net_G.inference(data, **vars(inference_args))
            for output_image, file_name in zip(output_images, file_names):
                fullname = os.path.join(output_dir, file_name + '.jpg')
                output_image = tensor2pilimage(output_image.clamp_(-1, 1),
                                               minus1to1_normalized=True)
                save_pilimage_in_jpeg(fullname, output_image)
Пример #8
0
 def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
     super(GANLoss, self).__init__()
     self.real_label = target_real_label
     self.fake_label = target_fake_label
     self.real_label_tensor = None
     self.fake_label_tensor = None
     self.gan_mode = gan_mode
     print('GAN mode: %s' % gan_mode)
Пример #9
0
    def load_checkpoint(self, cfg, checkpoint_path, resume=None):
        r"""Load network weights, optimizer parameters, scheduler parameters
        from a checkpoint.

        Args:
            cfg (obj): Global configuration.
            checkpoint_path (str): Path to the checkpoint.
            resume (bool or None): If not ``None``, will determine whether or
                not to load optimizers in addition to network weights.
        """
        if os.path.exists(checkpoint_path):
            # If checkpoint_path exists, we will load its weights to
            # initialize our network.
            if resume is None:
                resume = False
        elif os.path.exists(os.path.join(cfg.logdir, 'latest_checkpoint.txt')):
            # This is for resuming the training from the previously saved
            # checkpoint.
            fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt')
            with open(fn, 'r') as f:
                line = f.read().splitlines()
            checkpoint_path = os.path.join(cfg.logdir, line[0].split(' ')[-1])
            if resume is None:
                resume = True
        else:
            # checkpoint not found and not specified. We will train
            # everything from scratch.
            current_epoch = 0
            current_iteration = 0
            print('No checkpoint found.')
            return current_epoch, current_iteration
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
        current_epoch = 0
        current_iteration = 0
        if resume:
            self.net_G.load_state_dict(checkpoint['net_G'])
            if not self.is_inference:
                self.net_D.load_state_dict(checkpoint['net_D'])
                if 'opt_G' in checkpoint:
                    self.opt_G.load_state_dict(checkpoint['opt_G'])
                    self.opt_D.load_state_dict(checkpoint['opt_D'])
                    self.sch_G.load_state_dict(checkpoint['sch_G'])
                    self.sch_D.load_state_dict(checkpoint['sch_D'])
                    current_epoch = checkpoint['current_epoch']
                    current_iteration = checkpoint['current_iteration']
                    print('Load from: {}'.format(checkpoint_path))
                else:
                    print('Load network weights only.')
        else:
            self.net_G.load_state_dict(checkpoint['net_G'])
            print('Load generator weights only.')

        print('Done with loading the checkpoint.')
        return current_epoch, current_iteration
Пример #10
0
def compute_kid(kid_path,
                data_loader,
                net_G,
                key_real='images',
                key_fake='fake_images',
                sample_size=None,
                preprocess=None,
                is_video=False,
                save_act=True,
                num_subsets=1,
                subset_size=None):
    r"""Compute the kid score.

    Args:
        kid_path (str): Location for store feature activations.
        data_loader (obj): PyTorch dataloader object.
        net_G (obj): For image generation modes, net_G is the PyTorch trainer
            network. For video generation models, net_G is the trainer
            because video generation requires more complicated processing.
        key_real (str): Dictionary key value for the real data.
        key_fake (str): Dictionary key value for the fake data.
        sample_size (int): How many samples to be used for computing feature
            activations.
        preprocess (func): The preprocess function to be applied to the data.
        is_video (bool): Whether we are handling video sequences.
        save_act (bool): If ``True``, saves real activations to the disk and
            reload them in the future. It might save some computation but will
            cost storage.
        num_subsets (int): Number of subsets to sample from all the samples.
        subset_size (int): Number of samples in each subset.
    Returns:
        kid (float): KID value.
    """
    print('Computing KID.')
    with torch.no_grad():
        # Get the fake activations.
        fake_act = load_or_compute_activations(None, data_loader, key_real,
                                               key_fake, net_G, sample_size,
                                               preprocess, is_video)

        # Get the ground truth activations.
        act_path = os.path.join(os.path.dirname(kid_path),
                                'activations.npy') if save_act else None
        real_act = load_or_compute_activations(act_path, data_loader, key_real,
                                               key_fake, None, sample_size,
                                               preprocess, is_video)

    if is_master():
        mmd, mmd_vars = polynomial_mmd_averages(fake_act,
                                                real_act,
                                                num_subsets,
                                                subset_size,
                                                ret_var=True)
        kid = mmd.mean()
        return kid
Пример #11
0
def compute_kid_data(kid_path,
                     data_loader_a,
                     data_loader_b,
                     key_a='images',
                     key_b='images',
                     sample_size=None,
                     is_video=False,
                     num_subsets=1,
                     subset_size=None):
    r"""Compute the kid score between two datasets.

    Args:
        kid_path (str): Location for store feature activations.
        data_loader_a (obj): PyTorch dataloader object for dataset a.
        data_loader_b (obj): PyTorch dataloader object for dataset b.
        key_a (str): Dictionary key value for images in the dataset a.
        key_b (str): Dictionary key value for images in the dataset b.
        sample_size (int): How many samples to be used for computing the KID.
        is_video (bool): Whether we are handling video sequences.
        num_subsets (int): Number of subsets to sample from the whole data.
        subset_size (int): Number of samples in each subset.
    Returns:
        kid (float): KID value.
    """
    if sample_size is None:
        sample_size = min(len(data_loader_a.dataset),
                          len(data_loader_b.dataset))
    print('Computing KID using {} images from both distributions.'.format(
        sample_size))
    with torch.no_grad():
        path_a = os.path.join(os.path.dirname(kid_path), 'activations_a.npz')
        path_b = os.path.join(os.path.dirname(kid_path), 'activations_b.npz')
        act_a = load_or_compute_activations(path_a,
                                            data_loader_a,
                                            key_a,
                                            key_a,
                                            sample_size=sample_size,
                                            is_video=is_video)
        act_b = load_or_compute_activations(path_b,
                                            data_loader_b,
                                            key_b,
                                            key_b,
                                            sample_size=sample_size,
                                            is_video=is_video)

        if is_master():
            mmd, mmd_vars = polynomial_mmd_averages(act_a,
                                                    act_b,
                                                    num_subsets,
                                                    subset_size,
                                                    ret_var=True)
            kid = mmd.mean()
            return kid
        else:
            return None
Пример #12
0
 def init_temporal_network(self):
     r"""Initialize temporal training when beginning to train multiple
     frames. Set the sequence length to $(initial_sequence_length).
     """
     self.tensorboard_init = False
     # Update training sequence length.
     self.sequence_length = self.cfg.data.train.initial_sequence_length
     if not self.is_inference:
         self.train_dataset.set_sequence_length(self.sequence_length)
         print('------ Now start training %d frames -------' %
               self.sequence_length)
Пример #13
0
def make_logging_dir(logdir):
    r"""Create the logging directory

    Args:
        logdir (str): Log directory name
    """
    print('Make folder {}'.format(logdir))
    os.makedirs(logdir, exist_ok=True)
    tensorboard_dir = os.path.join(logdir, 'tensorboard')
    os.makedirs(tensorboard_dir, exist_ok=True)
    set_summary_writer(tensorboard_dir)
Пример #14
0
def init_cudnn(deterministic, benchmark):
    r"""Initialize the cudnn module. The two things to consider is whether to
    use cudnn benchmark and whether to use cudnn deterministic. If cudnn
    benchmark is set, then the cudnn deterministic is automatically false.

    Args:
        deterministic (bool): Whether to use cudnn deterministic.
        benchmark (bool): Whether to use cudnn benchmark.
    """
    cudnn.deterministic = deterministic
    cudnn.benchmark = benchmark
    print('cudnn benchmark: {}'.format(benchmark))
    print('cudnn deterministic: {}'.format(deterministic))
Пример #15
0
    def flush(self, step):
        r"""Write the value in the tensorboard.

        Args:
            step (int): Epoch or iteration number.
        """
        if not all(math.isfinite(x) for x in self.values):
            print("meter {} contained a nan or inf.".format(self.name))
        filtered_values = list(filter(lambda x: math.isfinite(x), self.values))
        if float(len(filtered_values)) != 0:
            value = float(sum(filtered_values)) / float(len(filtered_values))
            write_summary(self.name, value, step)
        self.reset()
Пример #16
0
 def _end_of_iteration(self, data, current_epoch, current_iteration):
     r"""Print the errors to console."""
     if not torch.distributed.is_initialized():
         if current_iteration % self.cfg.logging_iter == 0:
             message = '(epoch: %d, iters: %d) ' % (current_epoch,
                                                    current_iteration)
             for k, v in self.gen_losses.items():
                 if k != 'total':
                     message += '%s: %.3f,  ' % (k, v)
             message += '\n'
             for k, v in self.dis_losses.items():
                 if k != 'total':
                     message += '%s: %.3f,  ' % (k, v)
             print(message)
Пример #17
0
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    r"""Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
    d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    Stable version by Dougal J. Sutherland.

    Args:
        mu1: Numpy array containing the activations of a layer of the
            inception net (like returned by the function 'get_predictions')
            for generated samples.
        mu2: The sample mean over activations, pre-calculated on an
            representative data set.
        sigma1: The covariance matrix over activations for generated samples.
        sigma2: The covariance matrix over activations, pre-calculated on an
            representative data set.
        eps: a value added to the diagonal of cov for numerical stability.
    Returns:
        The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)
    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'
    diff = mu1 - mu2
    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            print('Imaginary component {}'.format(m))
            # raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real
    tr_covmean = np.trace(covmean)
    return (diff.dot(diff) + np.trace(sigma1) + np.trace(
        sigma2) - 2 * tr_covmean)
Пример #18
0
def get_paired_input_image_channel_number(data_cfg):
    r"""Get number of channels for the input image.

    Args:
        data_cfg (obj): Data configuration structure.
    Returns:
        num_channels (int): Number of input image channels.
    """
    num_channels = 0
    for ix, data_type in enumerate(data_cfg.input_types):
        for k in data_type:
            if k in data_cfg.input_image:
                num_channels += data_type[k].num_channels
                print('Concatenate %s for input.' % data_type)
    print('\tNum. of channels in the input image: %d' % num_channels)
    return num_channels
Пример #19
0
    def set_sequence_length(self, sequence_length):
        r"""Set the length of sequence you want as output from dataloader.

        Args:
            sequence_length (int): Length of output sequences.
        """
        assert isinstance(sequence_length, int)
        if sequence_length > self.sequence_length_max:
            print('Requested sequence length (%d) > ' % (sequence_length) +
                  'max sequence length (%d). ' % (self.sequence_length_max) +
                  'Limiting sequence length to max sequence length.')
            sequence_length = self.sequence_length_max
        self.sequence_length = sequence_length
        # Recalculate mapping as some sequences might no longer be useful.
        self.mapping, self.epoch_length = self._create_mapping()
        print('Epoch length:', self.epoch_length)
Пример #20
0
    def load_pretrained_network(self, pretrained_dict):
        r"""Load a pretrained network."""
        # print(pretrained_dict.keys())
        model_dict = self.state_dict()
        print('Pretrained network has fewer layers; The following are '
              'not initialized:')

        not_initialized = set()
        for k, v in model_dict.items():
            kp = 'module.' + k.replace('global_model.', 'global_model.model.')
            if kp in pretrained_dict and v.size() == pretrained_dict[kp].size():
                model_dict[k] = pretrained_dict[kp]
            else:
                not_initialized.add('.'.join(k.split('.')[:2]))
        print(sorted(not_initialized))
        self.load_state_dict(model_dict)
Пример #21
0
def polynomial_mmd_averages(codes_g,
                            codes_r,
                            n_subsets,
                            subset_size,
                            ret_var=True,
                            **kernel_args):
    r"""Computes MMD between two sets of features using polynomial kernels. It
    performs a number of repetitions of subset sampling without replacement.

    Args:
        codes_g (Tensor): Feature activations of generated images.
        codes_r (Tensor): Feature activations of real images.
        n_subsets (int): The number of subsets.
        subset_size (int): The number of samples in each subset.
        ret_var (bool): If ``True``, returns both mean and variance of MMDs,
            otherwise only returns the mean.
    Returns:
        (tuple):
          - mmds (Tensor): Mean of MMDs.
          - mmd_vars (Tensor): Variance of MMDs.
    """
    codes_g = torch.tensor(codes_g, device=torch.device('cuda'))
    codes_r = torch.tensor(codes_r, device=torch.device('cuda'))
    mmds = np.zeros(n_subsets)
    if ret_var:
        mmd_vars = np.zeros(n_subsets)
    choice = np.random.choice

    if subset_size is None:
        subset_size = min(len(codes_r), len(codes_r))
        print("Subset size not provided, "
              "setting it to the data size ({}).".format(subset_size))
    if subset_size > len(codes_g) or subset_size > len(codes_r):
        subset_size = min(len(codes_r), len(codes_r))
        warnings.warn("Subset size is large than the actual data size, "
                      "setting it to the data size ({}).".format(subset_size))

    for i in range(n_subsets):
        g = codes_g[choice(len(codes_g), subset_size, replace=False)]
        r = codes_r[choice(len(codes_r), subset_size, replace=False)]
        o = polynomial_mmd(g, r, **kernel_args, ret_var=ret_var)
        if ret_var:
            mmds[i], mmd_vars[i] = o
        else:
            mmds[i] = o
    return (mmds, mmd_vars) if ret_var else mmds
Пример #22
0
    def perform_augmentation(self, inputs, paired):
        r"""Entry point for augmentation.

        Args:
            inputs (dict): Keys are from self.augmentable_data_types. Values are
                list of numpy.ndarray (list of images).
            paired (bool): Apply same augmentation to all input keys?
        """
        # Make sure that all inputs are of same size, else trouble will
        # ensue. This is because different images might have different
        # aspect ratios.
        # Check within data type.
        for data_type in inputs:
            if data_type in self.keypoint_data_types or \
                    data_type not in self.image_data_types:
                continue
            for idx in range(len(inputs[data_type])):
                if idx == 0:
                    w, h = inputs[data_type][idx].size
                else:
                    this_w, this_h = inputs[data_type][idx].size
                    # assert this_w == w and this_h == h
                    # assert this_w / (1.0 * this_h) == w / (1.0 * h)
                    if this_w / (1.0 * this_h) != w / (1.0 * h):
                        print('(%d, %d) != (%d, %d)' % (
                            this_w, this_h, w, h))
        # Check across data types.
        if paired and self.resize_smallest_side is not None:
            for idx, data_type in enumerate(inputs):
                if data_type in self.keypoint_data_types or \
                        data_type not in self.image_data_types:
                    continue
                if idx == 0:
                    w, h = inputs[data_type][0].size
                else:
                    this_w, this_h = inputs[data_type][0].size
                    # assert this_w == w and this_h == h
                    # assert this_w / (1.0 * this_h) == w / (1.0 * h)
                    if this_w / (1.0 * this_h) != w / (1.0 * h):
                        print('(%d, %d) != (%d, %d)' % (
                            this_w, this_h, w, h))
        # Do appropriate augmentation.
        if paired:
            return self._perform_paired_augmentation(inputs)
        else:
            return self._perform_unpaired_augmentation(inputs)
Пример #23
0
    def get_train_params(net, param_names_start_with, param_names_include):
        r"""Get train parameters.

        Args:
            net (obj): Network object.
            param_names_start_with (list of strings): Params whose names
                start with any of the strings will be trained.
            param_names_include (list of strings): Params whose names include
                any of the strings will be trained.
        """
        params_to_train = []
        params_dict = net.state_dict()
        list_of_param_names_to_train = set()
        # Iterate through all params in the network and check if we need to
        # train it.
        for key, value in params_dict.items():
            do_train = False
            # If the param name starts with the target string (excluding
            # the 'module' part etc), we will train this param.
            key_s = key.replace('module.', '').replace('averaged_model.', '')
            for param_name in param_names_start_with:
                if key_s.startswith(param_name):
                    do_train = True
                    list_of_param_names_to_train.add(param_name)

            # Otherwise, if the param name includes the target string,
            # we will also train it.
            if not do_train:
                for param_name in param_names_include:
                    if param_name in key_s:
                        do_train = True
                        full_param_name = \
                            key_s[:(key_s.find(param_name) + len(param_name))]
                        list_of_param_names_to_train.add(full_param_name)

            # If we decide to train the param, add it to the list to train.
            if do_train:
                module = net
                key_list = key.split('.')
                for k in key_list:
                    module = getattr(module, k)
                params_to_train += [module]

        print('Training layers: ', sorted(list_of_param_names_to_train))
        return params_to_train
Пример #24
0
    def test(self, test_data_loader, root_output_dir, inference_args):
        r"""Run inference on all sequences.

        Args:
            test_data_loader (object): Test data loader.
            root_output_dir (str): Location to dump outputs.
            inference_args (optional): Optional args.
        """

        # Go over all sequences.
        loader = test_data_loader
        num_inference_sequences = loader.dataset.num_inference_sequences()
        for sequence_idx in range(num_inference_sequences):
            loader.dataset.set_inference_sequence_idx(sequence_idx)
            print('Seq id: %d, Seq length: %d' %
                  (sequence_idx + 1, len(loader)))

            # Reset model at start of new inference sequence.
            self.reset()
            self.sequence_length = len(loader)

            # Go over all frames of this sequence.
            video = []
            for idx, data in enumerate(tqdm(loader)):
                key = data['key']['images'][0][0]
                filename = key.split('/')[-1]

                # Create output dir for this sequence.
                if idx == 0:
                    output_dir, seq_name = \
                        self.create_sequence_output_dir(root_output_dir, key)
                    video_path = os.path.join(output_dir, '..', seq_name)

                # Get output, and save all vis to all/.
                data['img_name'] = filename
                data = to_cuda(data)
                output = self.test_single(data, output_dir=output_dir + '/all')

                # Dump just the fake image here.
                fake = tensor2im(output['fake_images'])[0]
                video.append(fake)
                imageio.imsave(output_dir + '/fake/%s.jpg' % (filename), fake)

            # Save as mp4 and gif.
            imageio.mimsave(video_path + '.mp4', video, fps=15)
Пример #25
0
def _get_train_and_val_dataset_objects(cfg):
    r"""Return dataset objects for the training and validation sets.

    Args:
        cfg (obj): Global configuration file.

    Returns:
        (dict):
          - train_dataset (obj): PyTorch training dataset object.
          - val_dataset (obj): PyTorch validation dataset object.
    """
    dataset_module = importlib.import_module(cfg.data.type)
    train_dataset = dataset_module.Dataset(cfg, is_inference=False)
    val_in_val = getattr(cfg.data, 'val_in_val', True)
    val_dataset = dataset_module.Dataset(cfg, is_inference=val_in_val)
    print('Train dataset length:', len(train_dataset))
    print('Val dataset length:', len(val_dataset))
    return train_dataset, val_dataset
Пример #26
0
 def _compute_fid(self):
     r"""Compute FID for both domains.
     """
     self.net_G.eval()
     if self.cfg.trainer.model_average:
         net_G_for_evaluation = self.net_G.module.averaged_model
     else:
         net_G_for_evaluation = self.net_G
     fid_a_path = self._get_save_path('fid_a', 'npy')
     fid_b_path = self._get_save_path('fid_b', 'npy')
     fid_value_a = compute_fid(fid_a_path, self.val_data_loader,
                               net_G_for_evaluation, 'images_a', 'images_ba')
     fid_value_b = compute_fid(fid_b_path, self.val_data_loader,
                               net_G_for_evaluation, 'images_b', 'images_ab')
     print('Epoch {:05}, Iteration {:09}, FID a {}, FID b {}'.format(
         self.current_epoch, self.current_iteration,
         fid_value_a, fid_value_b))
     return fid_value_a, fid_value_b
Пример #27
0
def main():
    args = parse_args()
    set_affinity(args.local_rank)
    set_random_seed(args.seed, by_rank=True)
    cfg = Config(args.config)

    # If args.single_gpu is set to True,
    # we will disable distributed data parallel
    if not args.single_gpu:
        cfg.local_rank = args.local_rank
        init_dist(cfg.local_rank)

    # Override the number of data loading workers if necessary
    if args.num_workers is not None:
        cfg.data.num_workers = args.num_workers

    # Create log directory for storing training results.
    cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir)
    make_logging_dir(cfg.logdir)

    # Initialize cudnn.
    init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)

    # Initialize data loaders and models.
    train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg)
    net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
        get_model_optimizer_and_scheduler(cfg, seed=args.seed)
    trainer = get_trainer(cfg, net_G, net_D,
                          opt_G, opt_D,
                          sch_G, sch_D,
                          train_data_loader, val_data_loader)

    # Start evaluation.
    checkpoints = \
        sorted(glob.glob('{}/*.pt'.format(args.checkpoint_logdir)))
    for checkpoint in checkpoints:
        current_epoch, current_iteration = \
            trainer.load_checkpoint(cfg, checkpoint, resume=True)
        trainer.current_epoch = current_epoch
        trainer.current_iteration = current_iteration
        trainer.write_metrics()
    print('Done with evaluation!!!')
    return
Пример #28
0
 def _compute_fid(self):
     r"""Compute FID values."""
     self.net_G.eval()
     self.net_G_output = None
     # Due to complicated video evaluation procedure we are using, we will
     # pass the trainer to the evaluation code instead of the
     # generator network.
     # net_G_for_evaluation = self.net_G
     trainer = self
     self.test_in_model_average_mode = False
     regular_fid_path = self._get_save_path('regular_fid', 'npy')
     few_shot = True if 'few_shot' in self.cfg.data.type else False
     regular_fid_value = compute_fid(regular_fid_path,
                                     self.val_data_loader,
                                     trainer,
                                     sample_size=self.sample_size,
                                     is_video=True,
                                     few_shot_video=few_shot)
     print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format(
         self.current_epoch, self.current_iteration, regular_fid_value))
     if self.cfg.trainer.model_average:
         # Due to complicated video evaluation procedure we are using,
         # we will pass the trainer to the evaluation code instead of the
         # generator network.
         # avg_net_G_for_evaluation = self.net_G.module.averaged_model
         trainer_avg_mode = self
         self.test_in_model_average_mode = True
         # The above flag will be reset after computing FID.
         fid_path = self._get_save_path('average_fid', 'npy')
         few_shot = True if 'few_shot' in self.cfg.data.type else False
         fid_value = compute_fid(fid_path,
                                 self.val_data_loader,
                                 trainer_avg_mode,
                                 sample_size=self.sample_size,
                                 is_video=True,
                                 few_shot_video=few_shot)
         print('Epoch {:05}, Iteration {:09}, Average FID {}'.format(
             self.current_epoch, self.current_iteration, fid_value))
         self.net_G.float()
         return regular_fid_value, fid_value
     else:
         self.net_G.float()
         return regular_fid_value
Пример #29
0
def compute_fid(fid_path, data_loader, net_G,
                key_real='images', key_fake='fake_images',
                sample_size=None, preprocess=None,
                is_video=False, few_shot_video=False):
    r"""Compute the fid score.

    Args:
        fid_path (str): Location for the numpy file to store or to load the
            statistics.
        data_loader (obj): PyTorch dataloader object.
        net_G (obj): For image generation modes, net_G is the PyTorch trainer
            network. For video generation models, net_G is the trainer
            because video generation requires more complicated processing.
        key_real (str): Dictionary key value for the real data.
        key_fake (str): Dictionary key value for the fake data.
        sample_size (int or tuple): How many samples to be used.
        preprocess (func): The preprocess function to be applied to the data.
        is_video (bool): Whether we are handling video sequences.
        few_shot_video (bool): If ``True``, uses few-shot video synthesis.
    Returns:
        (float): FID value.
    """
    print('Computing FID.')
    with torch.no_grad():
        # Get the fake mean and covariance.
        fake_mean, fake_cov = load_or_compute_stats(fid_path,
                                                    data_loader,
                                                    key_real, key_fake, net_G,
                                                    sample_size, preprocess,
                                                    is_video, few_shot_video)
        # Get the ground truth mean and covariance.
        mean_cov_path = os.path.join(os.path.dirname(fid_path),
                                     'real_mean_cov.npz')
        real_mean, real_cov = load_or_compute_stats(mean_cov_path,
                                                    data_loader,
                                                    key_real, key_fake, None,
                                                    sample_size, preprocess,
                                                    is_video, few_shot_video)

    if is_master():
        fid = calculate_frechet_distance(
            real_mean, real_cov, fake_mean, fake_cov)
        return fid
Пример #30
0
    def test(self, test_data_loader, root_output_dir, inference_args):
        r"""Run inference on all sequences.

        Args:
            test_data_loader (object): Test data loader.
            root_output_dir (str): Location to dump outputs.
            inference_args (optional): Optional args.
        """

        # Go over all sequences.
        loader = test_data_loader
        num_inference_sequences = loader.dataset.num_inference_sequences()
        for sequence_idx in range(num_inference_sequences):
            loader.dataset.set_inference_sequence_idx(sequence_idx)
            print('Seq id: %d, Seq length: %d' %
                  (sequence_idx + 1, len(loader)))

            # Reset model at start of new inference sequence.
            self.reset()
            self.sequence_length = len(loader)

            # Go over all frames of this sequence.
            video = []
            for idx, data in enumerate(tqdm(loader)):
                key = data['key']['images'][0][0]
                filename = key.split('/')[-1]

                # Create output dir for this sequence.
                if idx == 0:
                    output_dir, seq_name = \
                        self.create_sequence_output_dir(root_output_dir, key)
                    video_path = os.path.join(output_dir, '..', seq_name)

                # Get output and save images.
                data['img_name'] = filename
                data = self.start_of_iteration(data, current_iteration=-1)
                output = self.test_single(data, output_dir, inference_args)
                video.append(output)

            # Save output as mp4.
            imageio.mimsave(video_path + '.mp4', video, fps=15)