Example #1
0
def cluster_features(cfg, train_data_loader, net_E,
                     preprocess=None, small_ratio=0.0625, is_cityscapes=True):
    r"""Use clustering to compute the features.

    Args:
        cfg (obj): Global configuration file.
        train_data_loader (obj): Dataloader for iterate through the training
            set.
        net_E (nn.Module): Pytorch network.
        preprocess (function): Pre-processing function.
        small_ratio (float): We only consider instance that at least occupy
            $(small_ratio) amount of image space.
        is_cityscapes (bool): Is this is the cityscape dataset? In the
            Cityscapes dataset, the instance labels for car start with 26001,
            26002, ...

    Returns:
        ( num_labels x num_cluster_centers x feature_dims): cluster centers.
    """
    # Encode features.
    label_nc = get_paired_input_label_channel_number(cfg.data)
    feat_nc = cfg.gen.enc.num_feat_channels
    n_clusters = getattr(cfg.gen.enc, 'num_clusters', 10)
    # Compute features.
    features = {}
    for label in range(label_nc):
        features[label] = np.zeros((0, feat_nc + 1))
    for data in train_data_loader:
        if preprocess is not None:
            data = preprocess(data)
        feat = encode_features(net_E, feat_nc, label_nc,
                               data['images'], data['instance_maps'],
                               is_cityscapes)
        # We only collect the feature vectors for the master GPU.
        if is_master():
            for label in range(label_nc):
                features[label] = np.append(
                    features[label], feat[label], axis=0)
    # Clustering.
    # We only perform clustering for the master GPU.
    if is_master():
        for label in range(label_nc):
            feat = features[label]
            # We only consider segments that are greater than a pre-set
            # threshold.
            feat = feat[feat[:, -1] > small_ratio, :-1]
            if feat.shape[0]:
                n_clusters = min(feat.shape[0], n_clusters)
                kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat)
                n, d = kmeans.cluster_centers_.shape
                this_cluster = getattr(net_E, 'cluster_%d' % label)
                this_cluster[0:n, :] = torch.Tensor(
                    kmeans.cluster_centers_).float()
Example #2
0
    def _compute_fid(self):
        r"""Compute FID. We will compute a FID value per test class. That is
        if you have 30 test classes, we will compute 30 different FID values.
        We will then report the mean of the FID values as the final
        performance number as described in the FUNIT paper.
        """
        self.net_G.eval()
        if self.cfg.trainer.model_average:
            net_G_for_evaluation = self.net_G.module.averaged_model
        else:
            net_G_for_evaluation = self.net_G

        all_fid_values = []
        num_test_classes = self.val_data_loader.dataset.num_style_classes
        for class_idx in range(num_test_classes):
            fid_path = self._get_save_path(os.path.join('fid', str(class_idx)),
                                           'npy')
            self.val_data_loader.dataset.set_sample_class_idx(class_idx)

            fid_value = compute_fid(fid_path, self.val_data_loader,
                                    net_G_for_evaluation, 'images_style',
                                    'images_trans')
            all_fid_values.append(fid_value)

        if is_master():
            mean_fid = np.mean(all_fid_values)
            print('Epoch {:05}, Iteration {:09}, Mean FID {}'.format(
                self.current_epoch, self.current_iteration, mean_fid))
            return mean_fid
        else:
            return None
Example #3
0
def compute_kid_data(kid_path,
                     data_loader_a,
                     data_loader_b,
                     key_a='images',
                     key_b='images',
                     sample_size=None,
                     is_video=False,
                     num_subsets=1,
                     subset_size=None):
    r"""Compute the kid score between two datasets.

    Args:
        kid_path (str): Location for store feature activations.
        data_loader_a (obj): PyTorch dataloader object for dataset a.
        data_loader_b (obj): PyTorch dataloader object for dataset b.
        key_a (str): Dictionary key value for images in the dataset a.
        key_b (str): Dictionary key value for images in the dataset b.
        sample_size (int): How many samples to be used for computing the KID.
        is_video (bool): Whether we are handling video sequences.
        num_subsets (int): Number of subsets to sample from the whole data.
        subset_size (int): Number of samples in each subset.
    Returns:
        kid (float): KID value.
    """
    if sample_size is None:
        sample_size = min(len(data_loader_a.dataset),
                          len(data_loader_b.dataset))
    print('Computing KID using {} images from both distributions.'.format(
        sample_size))
    with torch.no_grad():
        path_a = os.path.join(os.path.dirname(kid_path), 'activations_a.npz')
        path_b = os.path.join(os.path.dirname(kid_path), 'activations_b.npz')
        act_a = load_or_compute_activations(path_a,
                                            data_loader_a,
                                            key_a,
                                            key_a,
                                            sample_size=sample_size,
                                            is_video=is_video)
        act_b = load_or_compute_activations(path_b,
                                            data_loader_b,
                                            key_b,
                                            key_b,
                                            sample_size=sample_size,
                                            is_video=is_video)

        if is_master():
            mmd, mmd_vars = polynomial_mmd_averages(act_a,
                                                    act_b,
                                                    num_subsets,
                                                    subset_size,
                                                    ret_var=True)
            kid = mmd.mean()
            return kid
        else:
            return None
Example #4
0
def compute_kid(kid_path,
                data_loader,
                net_G,
                key_real='images',
                key_fake='fake_images',
                sample_size=None,
                preprocess=None,
                is_video=False,
                save_act=True,
                num_subsets=1,
                subset_size=None):
    r"""Compute the kid score.

    Args:
        kid_path (str): Location for store feature activations.
        data_loader (obj): PyTorch dataloader object.
        net_G (obj): For image generation modes, net_G is the PyTorch trainer
            network. For video generation models, net_G is the trainer
            because video generation requires more complicated processing.
        key_real (str): Dictionary key value for the real data.
        key_fake (str): Dictionary key value for the fake data.
        sample_size (int): How many samples to be used for computing feature
            activations.
        preprocess (func): The preprocess function to be applied to the data.
        is_video (bool): Whether we are handling video sequences.
        save_act (bool): If ``True``, saves real activations to the disk and
            reload them in the future. It might save some computation but will
            cost storage.
        num_subsets (int): Number of subsets to sample from all the samples.
        subset_size (int): Number of samples in each subset.
    Returns:
        kid (float): KID value.
    """
    print('Computing KID.')
    with torch.no_grad():
        # Get the fake activations.
        fake_act = load_or_compute_activations(None, data_loader, key_real,
                                               key_fake, net_G, sample_size,
                                               preprocess, is_video)

        # Get the ground truth activations.
        act_path = os.path.join(os.path.dirname(kid_path),
                                'activations.npy') if save_act else None
        real_act = load_or_compute_activations(act_path, data_loader, key_real,
                                               key_fake, None, sample_size,
                                               preprocess, is_video)

    if is_master():
        mmd, mmd_vars = polynomial_mmd_averages(fake_act,
                                                real_act,
                                                num_subsets,
                                                subset_size,
                                                ret_var=True)
        kid = mmd.mean()
        return kid
Example #5
0
def compute_fid(fid_path, data_loader, net_G,
                key_real='images', key_fake='fake_images',
                sample_size=None, preprocess=None,
                is_video=False, few_shot_video=False):
    r"""Compute the fid score.

    Args:
        fid_path (str): Location for the numpy file to store or to load the
            statistics.
        data_loader (obj): PyTorch dataloader object.
        net_G (obj): For image generation modes, net_G is the PyTorch trainer
            network. For video generation models, net_G is the trainer
            because video generation requires more complicated processing.
        key_real (str): Dictionary key value for the real data.
        key_fake (str): Dictionary key value for the fake data.
        sample_size (int or tuple): How many samples to be used.
        preprocess (func): The preprocess function to be applied to the data.
        is_video (bool): Whether we are handling video sequences.
        few_shot_video (bool): If ``True``, uses few-shot video synthesis.
    Returns:
        (float): FID value.
    """
    print('Computing FID.')
    with torch.no_grad():
        # Get the fake mean and covariance.
        fake_mean, fake_cov = load_or_compute_stats(fid_path,
                                                    data_loader,
                                                    key_real, key_fake, net_G,
                                                    sample_size, preprocess,
                                                    is_video, few_shot_video)
        # Get the ground truth mean and covariance.
        mean_cov_path = os.path.join(os.path.dirname(fid_path),
                                     'real_mean_cov.npz')
        real_mean, real_cov = load_or_compute_stats(mean_cov_path,
                                                    data_loader,
                                                    key_real, key_fake, None,
                                                    sample_size, preprocess,
                                                    is_video, few_shot_video)

    if is_master():
        fid = calculate_frechet_distance(
            real_mean, real_cov, fake_mean, fake_cov)
        return fid
Example #6
0
def compute_prdc(cfg,
                 data_loader,
                 net_G,
                 key_real='images',
                 key_fake='fake_images',
                 k=10):
    r"""Compute precision diversity curve

    Args:

    """
    y_real = get_activations(data_loader, key_real, key_fake, generator=None)
    y_fake = get_activations(data_loader, key_real, key_fake, generator=net_G)
    if is_master():
        print("Computing density and coverage.")
        prdc_data = get_prdc(y_real, y_fake, k)
        return prdc_data['density'], prdc_data['coverage']
    else:
        return None, None
Example #7
0
def compute_fid_data(fid_path, data_loader_a, data_loader_b,
                     key_a='images', key_b='images', sample_size=None,
                     is_video=False, few_shot_video=False):
    r"""Compute the fid score between two datasets.

    Args:
        fid_path (str): Location for the numpy file to store or to load the
            statistics.
        data_loader_a (obj): PyTorch dataloader object for dataset a.
        data_loader_b (obj): PyTorch dataloader object for dataset b.
        key_a (str): Dictionary key value for images in the dataset a.
        key_b (str): Dictionary key value for images in the dataset b.
        sample_size (int): How many samples to be used for computing the FID.
        is_video (bool): Whether we are handling video sequences.
        few_shot_video (bool): If ``True``, uses few-shot video synthesis.
    Returns:
        (float): FID value.
    """
    if sample_size is None:
        sample_size = min(len(data_loader_a.dataset),
                          len(data_loader_b.dataset))
    print('Computing FID using {} images from both distributions.'.
          format(sample_size))
    with torch.no_grad():
        path_a = os.path.join(os.path.dirname(fid_path),
                              'mean_cov_a.npz')
        path_b = os.path.join(os.path.dirname(fid_path),
                              'mean_cov_b.npz')
        mean_a, cov_a = load_or_compute_stats(path_a, data_loader_a,
                                              key_a, key_a,
                                              sample_size=sample_size,
                                              is_video=is_video)
        mean_b, cov_b = load_or_compute_stats(path_b, data_loader_b,
                                              key_b, key_b,
                                              sample_size=sample_size,
                                              is_video=is_video)
    if is_master():
        fid = calculate_frechet_distance(mean_b, cov_b, mean_a, cov_a)
        return fid
Example #8
0
    def save_image(self, path, data):
        r"""Compute visualization images and save them to the disk.

        Args:
            path (str): Location of the file.
            data (dict): Data used for the current iteration.
        """
        self.net_G.eval()
        vis_images = self._get_visualizations(data)
        if is_master() and vis_images is not None:
            vis_images = torch.cat(vis_images, dim=3).float()
            vis_images = (vis_images + 1) / 2
            print('Save output images to {}'.format(path))
            vis_images.clamp_(0, 1)
            os.makedirs(os.path.dirname(path), exist_ok=True)
            image_grid = torchvision.utils.make_grid(vis_images,
                                                     nrow=1,
                                                     padding=0,
                                                     normalize=False)
            if self.cfg.trainer.image_to_tensorboard:
                self.image_meter.write_image(image_grid,
                                             self.current_iteration)
            torchvision.utils.save_image(image_grid, path, nrow=1)
Example #9
0
def get_checkpoint(checkpoint_path, url=''):
    r"""Get the checkpoint path. If it does not exist yet, download it from
    the url.

    Args:
        checkpoint_path (str): Checkpoint path.
        url (str): URL to download checkpoint.
    Returns:
        (str): Full checkpoint path.
    """
    if 'TORCH_HOME' not in os.environ:
        os.environ['TORCH_HOME'] = os.getcwd()
    save_dir = os.path.join(os.environ['TORCH_HOME'], 'checkpoints')
    os.makedirs(save_dir, exist_ok=True)
    full_checkpoint_path = os.path.join(save_dir, checkpoint_path)
    if not os.path.exists(full_checkpoint_path):
        os.makedirs(os.path.dirname(full_checkpoint_path), exist_ok=True)
        if is_master():
            print('Download {}'.format(url))
            download_file_from_google_drive(url, full_checkpoint_path)
    if dist.is_available() and dist.is_initialized():
        dist.barrier()
    return full_checkpoint_path
Example #10
0
def load_or_compute_activations(act_path,
                                data_loader,
                                key_real,
                                key_fake,
                                generator=None,
                                sample_size=None,
                                preprocess=None,
                                is_video=False):
    r"""Load mean and covariance from saved npy file if exists. Otherwise,
    compute the mean and covariance.

    Args:
        act_path (str or None): Location for the numpy file to store or to load
            the statistics.
        data_loader (obj): PyTorch dataloader object.
        key_real (str): Dictionary key value for the real data.
        key_fake (str): Dictionary key value for the fake data.
        generator (obj): PyTorch trainer network.
        sample_size (int): How many samples to be used for computing the KID.
        preprocess (func): The preprocess function to be applied to the data.
        is_video (bool): Whether we are handling video sequences.
    Returns:
        mean (tensor): Mean vector.
        cov (tensor): Covariance matrix.
    """
    if is_video:
        raise NotImplementedError("Video KID is not currently supported.")
    if act_path is not None and os.path.exists(act_path):
        print('Load Inception activations from {}'.format(act_path))
        act = np.load(act_path)
    else:
        act = get_activations(data_loader, key_real, key_fake, generator,
                              sample_size, preprocess)
        if act_path is not None and is_master():
            print('Save Inception activations to {}'.format(act_path))
            np.save(act_path, act)
    return act
Example #11
0
def get_inception_mean_cov(data_loader, key_real, key_fake, generator,
                           sample_size, preprocess,
                           is_video=False, few_shot_video=False):
    r"""Load mean and covariance from saved npy file if exists. Otherwise,
    compute the mean and covariance.

    Args:
        data_loader (obj): PyTorch dataloader object.
        key_real (str): Dictionary key value for the real data.
        key_fake (str): Dictionary key value for the fake data.
        generator (obj): PyTorch trainer network.
        sample_size (int or tuple): How many samples to be used.
        preprocess (func): The preprocess function to be applied to the data.
        is_video (bool): Whether we are handling video sequences.
        is_video (bool): Whether we are handling video sequences.
        few_shot_video (bool): If ``True``, uses few-shot video synthesis.
    Returns:
        (dict):
          - mean (tensor): Mean vector.
          - cov (tensor): Covariance matrix.
    trainer (obj): PyTorch trainer network.
    """
    print('Extract mean and covariance.')
    if is_video:
        y = get_video_activations(data_loader, key_real, key_fake, generator,
                                  sample_size, preprocess, few_shot_video)
    else:
        y = get_activations(data_loader, key_real, key_fake, generator,
                            sample_size, preprocess)
    if is_master():
        m = np.mean(y, axis=0)
        s = np.cov(y, rowvar=False)
    else:
        m = None
        s = None
    return m, s
Example #12
0
def load_or_compute_stats(fid_path, data_loader, key_real, key_fake,
                          generator=None, sample_size=None, preprocess=None,
                          is_video=False, few_shot_video=False):
    r"""Load mean and covariance from saved npy file if exists. Otherwise,
    compute the mean and covariance.

    Args:
        fid_path (str): Location for the numpy file to store or to load the
            statistics.
        data_loader (obj): PyTorch dataloader object.
        key_real (str): Dictionary key value for the real data.
        key_fake (str): Dictionary key value for the fake data.
        generator (obj): PyTorch trainer network.
        sample_size (int or tuple): How many samples to be used.
        preprocess (func): The preprocess function to be applied to the data.
        is_video (bool): Whether we are handling video sequences.
        few_shot_video (bool): If ``True``, uses few-shot video synthesis.
    Returns:
        (dict):
          - mean (tensor): Mean vector.
          - cov (tensor): Covariance matrix.
    """
    if os.path.exists(fid_path):
        print('Load FID mean and cov from {}'.format(fid_path))
        npz_file = np.load(fid_path)
        mean = npz_file['mean']
        cov = npz_file['cov']
    else:
        print('Get FID mean and cov and save to {}'.format(fid_path))
        mean, cov = get_inception_mean_cov(data_loader, key_real, key_fake,
                                           generator, sample_size, preprocess,
                                           is_video, few_shot_video)
        os.makedirs(os.path.dirname(fid_path), exist_ok=True)
        if is_master():
            np.savez(fid_path, mean=mean, cov=cov)
    return mean, cov
Example #13
0
def encode_features(net_E, feat_nc, label_nc, image, inst,
                    is_cityscapes=True):
    r"""Compute feature embeddings for an image image.
    TODO(Ting-Chun): To make this funciton dataset independent.

    Args:
        net_E (nn.Module): The encoder network.
        feat_nc (int): Feature dimensions
        label_nc (int): Number of segmentation labels.
        image (tensor): Input image tensor.
        inst (tensor): Input instance map.
        is_cityscapes (bool): Is this is the cityscape dataset? In the
            Cityscapes dataset, the instance labels for car start with 26001,
            26002, ...
    Returns:
        (list of list of numpy vectors): We will have $(label_nc)
            list. For each list, it will record a list of feature vectors of
            dimension $(feat_nc+1) where the first $(feat_nc) dimensions is
            the representative feature of an instance and the last dimension
            is the proportion.
    """
    # h, w = inst.size()[2:]
    feat_map = net_E(image, inst)
    feature_map_gather = dist_all_gather_tensor(feat_map)
    inst_gathered = dist_all_gather_tensor(inst)
    # Initialize the cluster centers.
    # For each feature vector,
    #   0:feat_nc will be the feature vector.
    #   The feat_nc dimension record the percentage of the instance.
    feature = {}
    for i in range(label_nc):
        feature[i] = np.zeros((0, feat_nc + 1))
    if is_master():
        all_feat_map = torch.cat(feature_map_gather, 0)
        all_inst_map = torch.cat(inst_gathered, 0)
        # Scan through the batches.
        for n in range(all_feat_map.size()[0]):
            feat_map = all_feat_map[n:(n + 1), :, :, :]
            inst = all_inst_map[n:(n + 1), :, :, :]
            fh, fw = feat_map.size()[2:]
            inst_np = inst.cpu().numpy().astype(int)
            for i in np.unique(inst_np):
                if is_cityscapes:
                    label = i if i < 1000 else i // 1000
                else:
                    label = i
                idx = (inst == int(i)).nonzero()
                num = idx.size()[0]
                # We will just pick the middle pixel as its representative
                # feature.
                idx = idx[num // 2, :]
                val = np.zeros((1, feat_nc + 1))
                for k in range(feat_nc):
                    # We expect idx[0]=0 and idx[1]=0 as the number of sample
                    # per processing is 1 (idx[0]=0) and the channel number of
                    # the instance map is 1.
                    val[0, k] = feat_map[
                        idx[0], idx[1] + k, idx[2], idx[3]].item()
                val[0, feat_nc] = float(num) / (fh * fw)
                feature[label] = np.append(feature[label], val, axis=0)
        return feature
    else:
        return feature
Example #14
0
def get_activations(data_loader, key_real, key_fake,
                    generator=None, sample_size=None, preprocess=None):
    r"""Compute activation values and pack them in a list.

    Args:
        data_loader (obj): PyTorch dataloader object.
        key_real (str): Dictionary key value for the real data.
        key_fake (str): Dictionary key value for the fake data.
        generator (obj): PyTorch trainer network.
        sample_size (int): How many samples to use for FID.
        preprocess (func): Pre-processing function to use.
    Returns:
        batch_y (tensor): Inception features of the current batch. Note that
            only the master gpu will get it.
    """
    # Load pretrained inception_v3 network and set it in GPU evaluation mode.
    inception = inception_v3(pretrained=True, transform_input=False,
                             init_weights=False)
    inception = inception.to('cuda').eval()

    # Disable the fully connected layer in the output.
    inception.fc = torch.nn.Sequential()

    world_size = get_world_size()
    batch_y = []
    # Iterate through the dataset to compute the activation.
    for it, data in enumerate(data_loader):
        data = to_cuda(data)
        # preprocess the data is preprocess is not none.
        if preprocess is not None:
            data = preprocess(data)
        # Load real data if trainer is not specified.
        if generator is None:
            images = data[key_real]
        else:
            # Compute the generated image.
            net_G_output = generator(data)
            images = net_G_output[key_fake]
        # Clamp the image for models that do not set the output to between
        # -1, 1. For models that employ tanh, this has no effect.
        images.clamp_(-1, 1)
        images = apply_imagenet_normalization(images)
        images = F.interpolate(images, size=(299, 299),
                               mode='bilinear', align_corners=True)
        y = inception(images)
        batch_y.append(y)
        if sample_size is not None and \
                data_loader.batch_size * world_size * (it + 1) >= sample_size:
            # Reach the number of samples we need.
            break

    batch_y = torch.cat(batch_y)
    batch_y = dist_all_gather_tensor(batch_y)
    if is_master():
        batch_y = torch.cat(batch_y).cpu().data.numpy()
        if sample_size is not None:
            batch_y = batch_y[:sample_size]
        print(batch_y.shape)
        return batch_y
    else:
        return None
Example #15
0
    def save_image(self, path, data):
        r"""Save the output images to path.
        Note when the generate_raw_output is FALSE. Then,
        first_net_G_output['fake_raw_images'] is None and will not be displayed.
        In model average mode, we will plot the flow visualization twice.

        Args:
            path (str): Save path.
            data (dict): Training data for current iteration.
        """
        self.net_G.eval()
        if self.cfg.trainer.model_average:
            self.net_G.module.averaged_model.eval()

        self.net_G_output = None
        with torch.no_grad():
            first_net_G_output, last_net_G_output, _ = self.gen_frames(data)
            if self.cfg.trainer.model_average:
                first_net_G_output_avg, last_net_G_output_avg, _ = \
                    self.gen_frames(data, use_model_average=True)

        def get_images(data, net_G_output, return_first_frame=True,
                       for_model_average=False):
            r"""Get the ourput images to save.

            Args:
                data (dict): Training data for current iteration.
                net_G_output (dict): Generator output.
                return_first_frame (bool): Return output for first frame in the
                sequence.
                for_model_average (bool): For model average output.
            Return:
                vis_images (list of numpy arrays): Visualization images.
            """
            frame_idx = 0 if return_first_frame else -1
            warped_idx = 0 if return_first_frame else 1
            vis_images = []
            if not for_model_average:
                vis_images += [
                    tensor2im(data['few_shot_images'][:, frame_idx]),
                    self.visualize_label(data['label'][:, frame_idx]),
                    tensor2im(data['images'][:, frame_idx])
                ]
            vis_images += [
                tensor2im(net_G_output['fake_images']),
                tensor2im(net_G_output['fake_raw_images'])]
            if not for_model_average:
                vis_images += [
                    tensor2im(net_G_output['warped_images'][warped_idx]),
                    tensor2flow(net_G_output['fake_flow_maps'][warped_idx]),
                    tensor2im(net_G_output['fake_occlusion_masks'][warped_idx],
                              normalize=False)
                ]
            return vis_images

        if is_master():
            vis_images_first = get_images(data, first_net_G_output)
            if self.cfg.trainer.model_average:
                vis_images_first += get_images(data, first_net_G_output_avg,
                                               for_model_average=True)
            if self.sequence_length > 1:
                vis_images_last = get_images(data, last_net_G_output,
                                             return_first_frame=False)
                if self.cfg.trainer.model_average:
                    vis_images_last += get_images(data, last_net_G_output_avg,
                                                  return_first_frame=False,
                                                  for_model_average=True)

                # If generating a video, the first row of each batch will be
                # the first generated frame and the flow/mask for warping the
                # reference image, and the second row will be the last
                # generated frame and the flow/mask for warping the previous
                # frame. If using model average, the frames generated by model
                # average will be at the rightmost columns.
                vis_images = [[np.vstack((im_first, im_last))
                               for im_first, im_last in
                               zip(imgs_first, imgs_last)]
                              for imgs_first, imgs_last in zip(vis_images_first,
                                                               vis_images_last)
                              if imgs_first is not None]
            else:
                vis_images = vis_images_first

            image_grid = np.hstack([np.vstack(im) for im in vis_images
                                    if im is not None])

            print('Save output images to {}'.format(path))
            os.makedirs(os.path.dirname(path), exist_ok=True)
            imageio.imwrite(path, image_grid)
Example #16
0
    def save_image(self, path, data):
        r"""Save the output images to path.
        Note when the generate_raw_output is FALSE. Then,
        first_net_G_output['fake_raw_images'] is None and will not be displayed.
        In model average mode, we will plot the flow visualization twice.

        Args:
            path (str): Save path.
            data (dict): Training data for current iteration.
        """
        self.net_G.eval()
        if self.cfg.trainer.model_average:
            self.net_G.module.averaged_model.eval()
        self.net_G_output = None
        with torch.no_grad():
            first_net_G_output, net_G_output, all_info = self.gen_frames(data)
            if self.cfg.trainer.model_average:
                first_net_G_output_avg, net_G_output_avg = self.gen_frames(
                    data, use_model_average=True)

        # Visualize labels.
        label_lengths = self.train_data_loader.dataset.get_label_lengths()
        labels = split_labels(data['label'], label_lengths)
        vis_labels_start, vis_labels_end = [], []
        for key, value in labels.items():
            if 'seg_maps' in key:
                vis_labels_start.append(self.visualize_label(value[:, -1]))
                vis_labels_end.append(self.visualize_label(value[:, 0]))
            else:
                normalize = self.train_data_loader.dataset.normalize[key]
                vis_labels_start.append(
                    tensor2im(value[:, -1], normalize=normalize))
                vis_labels_end.append(
                    tensor2im(value[:, 0], normalize=normalize))

        if is_master():
            vis_images = [
                *vis_labels_start,
                tensor2im(data['images'][:, -1]),
                tensor2im(net_G_output['fake_images']),
                tensor2im(net_G_output['fake_raw_images'])
            ]
            if self.cfg.trainer.model_average:
                vis_images += [
                    tensor2im(net_G_output_avg['fake_images']),
                    tensor2im(net_G_output_avg['fake_raw_images'])
                ]

            if self.sequence_length > 1:
                if net_G_output['guidance_images_and_masks'] is not None:
                    guidance_image = tensor2im(
                        net_G_output['guidance_images_and_masks'][:, :3])
                    guidance_mask = tensor2im(
                        net_G_output['guidance_images_and_masks'][:, 3:4],
                        normalize=False)
                else:
                    im = tensor2im(data['images'][:, -1])
                    guidance_image = [np.zeros_like(item) for item in im]
                    guidance_mask = [np.zeros_like(item) for item in im]
                vis_images += [guidance_image, guidance_mask]

                vis_images_first = [
                    *vis_labels_end,
                    tensor2im(data['images'][:, 0]),
                    tensor2im(first_net_G_output['fake_images']),
                    tensor2im(first_net_G_output['fake_raw_images']),
                    [np.zeros_like(item) for item in guidance_image],
                    [np.zeros_like(item) for item in guidance_mask]
                ]
                if self.cfg.trainer.model_average:
                    vis_images_first += [
                        tensor2im(first_net_G_output_avg['fake_images']),
                        tensor2im(first_net_G_output_avg['fake_raw_images'])
                    ]

                if self.use_flow:
                    flow_gt, conf_gt = self.criteria['Flow'].flowNet(
                        data['images'][:, -1], data['images'][:, -2])
                    warped_image_gt = resample(data['images'][:, -1], flow_gt)
                    vis_images_first += [
                        tensor2flow(flow_gt),
                        tensor2im(conf_gt, normalize=False),
                        tensor2im(warped_image_gt),
                    ]
                    vis_images += [
                        tensor2flow(net_G_output['fake_flow_maps']),
                        tensor2im(net_G_output['fake_occlusion_masks'],
                                  normalize=False),
                        tensor2im(net_G_output['warped_images']),
                    ]
                    if self.cfg.trainer.model_average:
                        vis_images_first += [
                            tensor2flow(flow_gt),
                            tensor2im(conf_gt, normalize=False),
                            tensor2im(warped_image_gt),
                        ]
                        vis_images += [
                            tensor2flow(net_G_output_avg['fake_flow_maps']),
                            tensor2im(net_G_output_avg['fake_occlusion_masks'],
                                      normalize=False),
                            tensor2im(net_G_output_avg['warped_images'])
                        ]

                vis_images = [[
                    np.vstack((im_first, im))
                    for im_first, im in zip(imgs_first, imgs)
                ] for imgs_first, imgs in zip(vis_images_first, vis_images)
                              if imgs is not None]

            image_grid = np.hstack(
                [np.vstack(im) for im in vis_images if im is not None])

            print('Save output images to {}'.format(path))
            os.makedirs(os.path.dirname(path), exist_ok=True)
            imageio.imwrite(path, image_grid)

            # Gather all inputs and outputs for dumping into video.
            if self.sequence_length > 1:
                input_images, output_images, output_guidance = [], [], []
                for item in all_info['inputs']:
                    input_images.append(tensor2im(item['image'])[0])
                for item in all_info['outputs']:
                    output_images.append(tensor2im(item['fake_images'])[0])
                    if item['guidance_images_and_masks'] is not None:
                        output_guidance.append(
                            tensor2im(
                                item['guidance_images_and_masks'][:, :3])[0])
                    else:
                        output_guidance.append(np.zeros_like(
                            output_images[-1]))

                imageio.mimwrite(os.path.splitext(path)[0] + '.mp4',
                                 output_images,
                                 fps=2,
                                 macro_block_size=None)
                imageio.mimwrite(os.path.splitext(path)[0] + '_guidance.mp4',
                                 output_guidance,
                                 fps=2,
                                 macro_block_size=None)

            # for idx, item in enumerate(output_guidance):
            #     imageio.imwrite(os.path.splitext(
            #         path)[0] + '_guidance_%d.jpg' % (idx), item)
            # for idx, item in enumerate(input_images):
            #     imageio.imwrite(os.path.splitext(
            #         path)[0] + '_input_%d.jpg' % (idx), item)

        self.net_G.float()
Example #17
0
def get_video_activations(data_loader, key_real, key_fake, trainer=None,
                          sample_size=None, preprocess=None, few_shot=False):
    r"""Compute activation values and pack them in a list. We do not do all
    reduce here.

    Args:
        data_loader (obj): PyTorch dataloader object.
        key_real (str): Dictionary key value for the real data.
        key_fake (str): Dictionary key value for the fake data.
        trainer (obj): Trainer. Video generation is more involved, we rely on
            the "reset" and "test" function to conduct the evaluation.
        sample_size (int): For computing video activation, we will use .
        preprocess (func): The preprocess function to be applied to the data.
        few_shot (bool): If ``True``, uses the few-shot setting.
    Returns:
        batch_y (tensor): Inception features of the current batch. Note that
            only the master gpu will get it.
    """
    inception = inception_v3(pretrained=True, transform_input=False)
    inception = inception.to('cuda')
    inception.eval()
    inception.fc = torch.nn.Sequential()
    batch_y = []

    # We divide video sequences to different GPUs for testing.
    num_sequences = data_loader.dataset.num_inference_sequences()
    if sample_size is None:
        num_videos_to_test = 10
        num_frames_per_video = 5
    else:
        num_videos_to_test, num_frames_per_video = sample_size
    if num_videos_to_test == -1:
        num_videos_to_test = num_sequences
    else:
        num_videos_to_test = min(num_videos_to_test, num_sequences)
    print('Number of videos used for evaluation: {}'.format(
        num_videos_to_test))
    print('Number of frames per video used for evaluation: {}'.format(
        num_frames_per_video))

    world_size = get_world_size()
    if num_videos_to_test < world_size:
        seq_to_run = [get_rank() % num_videos_to_test]
    else:
        num_videos_to_test = num_videos_to_test // world_size * world_size
        seq_to_run = range(get_rank(), num_videos_to_test, world_size)

    for sequence_idx in seq_to_run:
        data_loader = set_sequence_idx(few_shot, data_loader, sequence_idx)
        if trainer is not None:
            trainer.reset()
        for it, data in enumerate(data_loader):
            if it >= num_frames_per_video:
                break

            # preprocess the data is preprocess is not none.
            if trainer is not None:
                data = trainer.pre_process(data)
            elif preprocess is not None:
                data = preprocess(data)
            data = to_cuda(data)

            if trainer is None:
                images = data[key_real][:, -1]
            else:
                net_G_output = trainer.test_single(data)
                images = net_G_output[key_fake]
            images.clamp_(-1, 1)
            images = apply_imagenet_normalization(images)
            images = F.interpolate(images, size=(299, 299),
                                   mode='bilinear', align_corners=True)
            y = inception(images)
            batch_y += [y]

    batch_y = torch.cat(batch_y)
    batch_y = dist_all_gather_tensor(batch_y)
    if is_master():
        batch_y = torch.cat(batch_y).cpu().data.numpy()
    return batch_y