def test_cutmix_batch_success4(plot=False):
    """
    Test CutMixBatch on a dataset where OneHot returns a 2D vector
    """
    logger.info("test_cutmix_batch_success4")

    ds_original = ds.CelebADataset(DATA_DIR3, shuffle=False)
    decode_op = vision.Decode()
    ds_original = ds_original.map(input_columns=["image"], operations=[decode_op])
    ds_original = ds_original.batch(2, drop_remainder=True)

    images_original = None
    for idx, (image, _) in enumerate(ds_original):
        if idx == 0:
            images_original = image
        else:
            images_original = np.append(images_original, image, axis=0)

    # CutMix Images
    data1 = ds.CelebADataset(dataset_dir=DATA_DIR3, shuffle=False)

    decode_op = vision.Decode()
    data1 = data1.map(input_columns=["image"], operations=[decode_op])

    one_hot_op = data_trans.OneHot(num_classes=100)
    data1 = data1.map(input_columns=["attr"], operations=one_hot_op)

    cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.5, 0.9)
    data1 = data1.batch(2, drop_remainder=True)
    data1 = data1.map(input_columns=["image", "attr"], operations=cutmix_batch_op)

    images_cutmix = None
    for idx, (image, _) in enumerate(data1):
        if idx == 0:
            images_cutmix = image
        else:
            images_cutmix = np.append(images_cutmix, image, axis=0)
    if plot:
        visualize_list(images_original, images_cutmix)

    num_samples = images_original.shape[0]
    mse = np.zeros(num_samples)
    for i in range(num_samples):
        mse[i] = diff_mse(images_cutmix[i], images_original[i])
    logger.info("MSE= {}".format(str(np.mean(mse))))
def test_celeba_dataset_distribute():
    data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0)
    count = 0
    for item in data.create_dict_iterator():
        logger.info("----------image--------")
        logger.info(item["image"])
        logger.info("----------attr--------")
        logger.info(item["attr"])
        count = count + 1
    assert (count == 1)
Example #3
0
def test_celeba_dataset_exception_file_path():
    def exception_func(item):
        raise Exception("Error occur!")

    try:
        data = ds.CelebADataset(DATA_DIR, shuffle=False)
        data = data.map(operations=exception_func,
                        input_columns=["image"],
                        num_parallel_workers=1)
        for _ in data.create_dict_iterator():
            pass
        assert False
    except RuntimeError as e:
        assert "map operation: [PyFunc] failed. The corresponding data files" in str(
            e)

    try:
        data = ds.CelebADataset(DATA_DIR, shuffle=False)
        data = data.map(operations=vision.Decode(),
                        input_columns=["image"],
                        num_parallel_workers=1)
        data = data.map(operations=exception_func,
                        input_columns=["image"],
                        num_parallel_workers=1)
        for _ in data.create_dict_iterator():
            pass
        assert False
    except RuntimeError as e:
        assert "map operation: [PyFunc] failed. The corresponding data files" in str(
            e)

    try:
        data = ds.CelebADataset(DATA_DIR, shuffle=False)
        data = data.map(operations=exception_func,
                        input_columns=["attr"],
                        num_parallel_workers=1)
        for _ in data.create_dict_iterator():
            pass
        assert False
    except RuntimeError as e:
        assert "map operation: [PyFunc] failed. The corresponding data files" in str(
            e)
def test_celeba_sampler_exception():
    """
    Test CelebA with bad sampler input
    """
    logger.info("Test CelebA with bad sampler input")
    try:
        data = ds.CelebADataset(DATA_DIR, sampler="")
        for _ in data.create_dict_iterator():
            pass
        assert False
    except TypeError as e:
        assert "Argument" in str(e)
Example #5
0
def test_celeba_sampler_exception():
    """
    Test CelebA with bad sampler input
    """
    logger.info("Test CelebA with bad sampler input")
    try:
        data = ds.CelebADataset(DATA_DIR, sampler="")
        for _ in data.create_dict_iterator():
            pass
        assert False
    except TypeError as e:
        assert "Unsupported sampler object of type (<class 'str'>)" in str(e)
Example #6
0
def test_celeba_get_dataset_size():
    data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
    size = data.get_dataset_size()
    assert size == 4

    data = ds.CelebADataset(DATA_DIR,
                            shuffle=False,
                            decode=True,
                            usage="train")
    size = data.get_dataset_size()
    assert size == 2

    data = ds.CelebADataset(DATA_DIR,
                            shuffle=False,
                            decode=True,
                            usage="valid")
    size = data.get_dataset_size()
    assert size == 1

    data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test")
    size = data.get_dataset_size()
    assert size == 1
def test_celeba_dataset_distribute():
    """
    Test CelebA dataset with distributed options
    """
    logger.info("Test CelebA with sharding")
    data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0)
    count = 0
    for item in data.create_dict_iterator(num_epochs=1):
        logger.info("----------image--------")
        logger.info(item["image"])
        logger.info("----------attr--------")
        logger.info(item["attr"])
        count = count + 1
    assert count == 2
def test_celeba_padded():
    data = ds.CelebADataset("../data/dataset/testCelebAData/")

    padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}]
    padded_ds = ds.PaddedDataset(padded_samples)
    data = data + padded_ds
    dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
    data.use_sampler(dis_sampler)
    data = data.repeat(2)

    count = 0
    for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
        count = count + 1
    assert count == 4
Example #9
0
def test_celeba_dataset_ext():
    ext = [".JPEG"]
    data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
    expect_labels = [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
                     0, 1, 0, 1, 0, 0, 1],
    count = 0
    for item in data.create_dict_iterator():
        logger.info("----------image--------")
        logger.info(item["image"])
        logger.info("----------attr--------")
        logger.info(item["attr"])
        for index in range(len(expect_labels[count])):
            assert (item["attr"][index] == expect_labels[count][index])
        count = count + 1
    assert (count == 1)
def test_celeba_dataset_op():
    data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
    crop_size = (80, 80)
    resize_size = (24, 24)
    # define map operations
    data = data.repeat(2)
    center_crop = vision.CenterCrop(crop_size)
    resize_op = vision.Resize(resize_size, Inter.LINEAR)  # Bilinear mode
    data = data.map(input_columns=["image"], operations=center_crop)
    data = data.map(input_columns=["image"], operations=resize_op)

    count = 0
    for item in data.create_dict_iterator():
        logger.info("----------image--------")
        logger.info(item["image"])
        count = count + 1
    assert (count == 4)
def test_celeba_dataset_label():
    data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False)
    expect_labels = [[
        0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0,
        1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1
    ],
                     [
                         0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
                         0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
                         0, 0, 0, 1
                     ]]
    count = 0
    for item in data.create_dict_iterator():
        logger.info("----------image--------")
        logger.info(item["image"])
        logger.info("----------attr--------")
        logger.info(item["attr"])
        for index in range(len(expect_labels[count])):
            assert (item["attr"][index] == expect_labels[count][index])
        count = count + 1
    assert (count == 2)
Example #12
0
def test_get_column_name_celeba():
    data = ds.CelebADataset(CELEBA_DIR)
    assert data.get_col_names() == ["image", "attr"]
Example #13
0
def test_celeba_get_dataset_size():
    data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
    size = data.get_dataset_size()
    assert size == 2