Exemplo n.º 1
0
    def test_load_celeba_dataset(self):
        sizes = [64, 128]

        for size in sizes:
            dataset = data_utils.load_celeba_dataset(size=size,
                                                     root=self.dataset_dir,
                                                     download=False)

            img = dataset[0][0]

            assert img.shape == (3, size, size)
def get_celeba_images_with_attr(attr,
                                root='./dataset',
                                size=128,
                                num_samples=None,
                                **kwargs):
    """
    Loads sampled CelebA images with index.

    Args:
        index (ndarray): The index of images.
        root (str): The root directory where all datasets are stored.
        size (int): Size of image to resize to.

    Returns:
        ndarray: Batch of len(index) images in np array form.
    """
    dataset = data_utils.load_celeba_dataset(
        root=root,
        size=size,
        transform_data=True,
        convert_tensor=False,  # Prevents normalization.
        **kwargs)

    dataset_dir = os.path.join(root, 'celeba')
    attr_index, not_attr_index = get_celeba_index_with_attr(dataset_dir, attr)
    print(f'Number of images with attribute {len(attr_index)}')
    print(f'Number of images without attribute {len(not_attr_index)}')
    if num_samples:
        if len(attr_index) > num_samples:
            attr_index = np.random.choice(attr_index,
                                          size=num_samples,
                                          replace=False)
        if len(not_attr_index) > num_samples:
            not_attr_index = np.random.choice(not_attr_index,
                                              size=num_samples,
                                              replace=False)

    attr_images = get_index_images(dataset, attr_index)
    not_attr_images = get_index_images(dataset, not_attr_index)
    print(f'Number of images with attribute {attr}: {len(attr_images)}')
    print(f'Number of images without attribute {attr}: {len(not_attr_images)}')

    return attr_images, not_attr_images
Exemplo n.º 3
0
def get_celeba_images(num_samples, root='./datasets', size=128, **kwargs):
    """
    Loads randomly sampled CelebA images.

    Args:
        num_samples (int): The number of images to randomly sample.
        root (str): The root directory where all datasets are stored.
        size (int): Size of image to resize to.

    Returns:
        Tensor: Batch of num_samples images in np array form.
    """
    dataset = data_utils.load_celeba_dataset(
        root=root,
        size=size,
        transform_data=True,
        convert_tensor=False,  # Prevents normalization.
        **kwargs)

    images = get_random_images(dataset, num_samples)

    return images
def get_celeba_images_with_index(index, root='./dataset', size=128, **kwargs):
    """
    Loads sampled CelebA images with index.

    Args:
        index (ndarray): The index of images.
        root (str): The root directory where all datasets are stored.
        size (int): Size of image to resize to.

    Returns:
        ndarray: Batch of len(index) images in np array form.
    """
    dataset = data_utils.load_celeba_dataset(
        root=root,
        size=size,
        transform_data=True,
        convert_tensor=False,  # Prevents normalization.
        **kwargs)

    images = get_index_images(dataset, index)

    return images
def get_celeba_with_attr(attr,
                         root='./dataset',
                         split='train',
                         size=128,
                         **kwargs):
    """
    Loads CelebA with attribute target.

    Args:
        attr (str): The name of attribute.
        root (str): The root directory where all datasets are stored.
        split (str): The split of data to use.
        size (int): Size of image to resize to.

    Returns:
        ndarray: Batch of len(index) images in np array form.
    """
    celeba_attr = {
        '5_o_Clock_Shadow': 0,
        'Arched_Eyebrows': 1,
        'Attractive': 2,
        'Bags_Under_Eyes': 3,
        'Bald': 4,
        'Bangs': 5,
        'Big_Lips': 6,
        'Big_Nose': 7,
        'Black_Hair': 8,
        'Blond_Hair': 9,
        'Blurry': 10,
        'Brown_Hair': 11,
        'Bushy_Eyebrows': 12,
        'Chubby': 13,
        'Double_Chin': 14,
        'Eyeglasses': 15,
        'Goatee': 16,
        'Gray_Hair': 17,
        'Heavy_Makeup': 18,
        'High_Cheekbones': 19,
        'Male': 20,
        'Mouth_Slightly_Open': 21,
        'Mustache': 22,
        'Narrow_Eyes': 23,
        'No_Beard': 24,
        'Oval_Face': 25,
        'Pale_Skin': 26,
        'Pointy_Nose': 27,
        'Receding_Hairline': 28,
        'Rosy_Cheeks': 29,
        'Sideburns': 30,
        'Smiling': 31,
        'Straight_Hair': 32,
        'Wavy_Hair': 33,
        'Wearing_Earrings': 34,
        'Wearing_Hat': 35,
        'Wearing_Lipstick': 36,
        'Wearing_Necklace': 37,
        'Wearing_Necktie': 38,
        'Young': 39
    }

    dataset = data_utils.load_celeba_dataset(
        root=root,
        split=split,
        size=size,
        transform_data=True,
        convert_tensor=True,
        target_transform=transforms.Lambda(
            lambda a: 1 if a[celeba_attr[attr]] == 1 else 0),
        **kwargs)

    return dataset