Ejemplo n.º 1
0
def compute_real_dist_stats(num_samples,
                            sess,
                            batch_size,
                            split,
                            dataset=None,
                            stats_file=None,
                            seed=0,
                            verbose=True,
                            log_dir='./log'):
    """
    Reads the image data and compute the FID mean and cov statistics
    for real images.

    Args:
        num_samples (int): Number of real images to compute statistics.
        sess (Session): TensorFlow session to use.
        dataset (str/Dataset): Dataset to load.
        batch_size (int): The batch size to feedforward for inference.
        stats_file (str): The statistics file to load from if there is already one.
        verbose (bool): If True, prints progress of computation.
        log_dir (str): Directory where feature statistics can be stored.

    Returns:
        ndarray: Mean features stored as np array.
        ndarray: Covariance of features stored as np array.
    """
    # Create custom stats file name
    if stats_file is None:
        stats_dir = os.path.join(log_dir, 'metrics', 'fid', 'statistics')
        if not os.path.exists(stats_dir):
            os.makedirs(stats_dir)

        stats_file = os.path.join(
            stats_dir,
            "fid_stats_{}_{}k_run_{}.npz".format(dataset, num_samples // 1000,
                                                 seed))

    if stats_file and os.path.exists(stats_file):
        print("INFO: Loading existing statistics for real images...")
        f = np.load(stats_file)
        m_real, s_real = f['mu'][:], f['sigma'][:]
        f.close()

    else:
        # Obtain the numpy format data
        print("INFO: Obtaining images...")
        images = get_dataset_images(dataset,
                                    split=split,
                                    num_samples=num_samples)

        # Compute the mean and cov
        print("INFO: Computing statistics for real images...")
        m_real, s_real = fid_utils.calculate_activation_statistics(
            images=images, sess=sess, batch_size=batch_size, verbose=verbose)

        if not os.path.exists(stats_file):
            print("INFO: Saving statistics for real images...")
            np.savez(stats_file, mu=m_real, sigma=s_real)

    return m_real, s_real
Ejemplo n.º 2
0
    def test_calculate_activation_statistics(self):
        inception_path = './metrics/inception_model'
        inception_utils.create_inception_graph(inception_path)

        mu, sigma = fid_utils.calculate_activation_statistics(
            images=self.images, sess=self.sess)

        assert mu.shape == (2048, )
        assert sigma.shape == (2048, 2048)
Ejemplo n.º 3
0
def compute_gen_dist_stats(netG,
                           num_samples,
                           sess,
                           device,
                           seed,
                           batch_size,
                           print_every=20,
                           verbose=True):
    """
    Directly produces the images and convert them into numpy format without
    saving the images on disk.

    Args:
        netG (Module): Torch Module object representing the generator model.
        num_samples (int): The number of fake images for computing statistics.
        sess (Session): TensorFlow session to use.
        device (str): Device identifier to use for computation.
        seed (int): The random seed to use.
        batch_size (int): The number of samples per batch for inference.
        print_every (int): Interval for printing log.
        verbose (bool): If True, prints progress.

    Returns:
        ndarray: Mean features stored as np array.
        ndarray: Covariance of features stored as np array.
    """
    with torch.no_grad():
        # Set model to evaluation mode
        netG.eval()

        # Inference variables
        batch_size = min(num_samples, batch_size)

        # Collect all samples()
        images = []
        start_time = time.time()
        for idx in range(num_samples // batch_size):
            # Collect fake image
            fake_images = netG.generate_images(num_images=batch_size,
                                               device=device).detach().cpu()
            images.append(fake_images)

            # Print some statistics
            if (idx + 1) % print_every == 0:
                end_time = time.time()
                print(
                    "INFO: Generated image {}/{} [Random Seed {}] ({:.4f} sec/idx)"
                    .format(
                        (idx + 1) * batch_size, num_samples, seed,
                        (end_time - start_time) / (print_every * batch_size)))
                start_time = end_time

        # Produce images in the required (N, H, W, 3) format for FID computation
        images = torch.cat(images, 0)  # Gives (N, 3, H, W)
        images = _normalize_images(images)  # Gives (N, H, W, 3)

    # Compute the FID
    print("INFO: Computing statistics for fake images...")
    m_fake, s_fake = fid_utils.calculate_activation_statistics(
        images=images, sess=sess, batch_size=batch_size, verbose=verbose)

    return m_fake, s_fake
Ejemplo n.º 4
0
def compute_real_dist_stats_with_attr(attr,
                                      sess,
                                      batch_size,
                                      dataset=None,
                                      stats_file=None,
                                      seed=0,
                                      verbose=True,
                                      log_dir='./log',
                                      name=None):
    """
    Reads the image data and compute the FID mean and cov statistics
    for real images.

    Args:
        index (ndarray): The index array of real images to use.
        sess (Session): TensorFlow session to use.
        dataset (str/Dataset): Dataset to load.
        batch_size (int): The batch size to feedforward for inference.
        stats_file (str): The statistics file to load from if there is already one.
        verbose (bool): If True, prints progress of computation.
        log_dir (str): Directory where feature statistics can be stored.

    Returns:
        ndarray: Mean features stored as np array.
        ndarray: Covariance of features stored as np array.
    """
    # Create custom stats file name
    if stats_file is None:
        stats_dir = os.path.join(log_dir, 'metrics', 'fid', 'statistics')
        if not os.path.exists(stats_dir):
            os.makedirs(stats_dir)

        stats_file = os.path.join(
            stats_dir,
            "fid_stats_{}_{}_{}_run_{}.npz".format(name, dataset, attr, seed))

    if stats_file and os.path.exists(stats_file):
        print("INFO: Loading existing statistics for real images...")
        f = np.load(stats_file)
        attr_m_real, attr_s_real = f['attr_mu'][:], f['attr_sigma'][:]
        not_attr_m_real, not_attr_s_real = f['not_attr_mu'][:], f[
            'not_attr_sigma'][:]
        f.close()

    else:
        # Obtain the numpy format data
        print("INFO: Obtaining images...")
        attr_images, not_attr_images = get_dataset_images_with_attr(dataset,
                                                                    attr=attr)

        # Compute the mean and cov
        print("INFO: Computing statistics for real images with attribute...")
        attr_m_real, attr_s_real = fid_utils.calculate_activation_statistics(
            images=attr_images,
            sess=sess,
            batch_size=batch_size,
            verbose=verbose)
        not_attr_m_real, not_attr_s_real = fid_utils.calculate_activation_statistics(
            images=not_attr_images,
            sess=sess,
            batch_size=batch_size,
            verbose=verbose)

        if not os.path.exists(stats_file):
            print("INFO: Saving statistics for real images...")
            np.savez(stats_file,
                     attr_mu=attr_m_real,
                     attr_sigma=attr_s_real,
                     not_attr_mu=not_attr_m_real,
                     not_attr_sigma=not_attr_s_real)

    return attr_m_real, attr_s_real, not_attr_m_real, not_attr_s_real