def test_cutmix_batch_success4(plot=False): """ Test CutMixBatch on a dataset where OneHot returns a 2D vector """ logger.info("test_cutmix_batch_success4") ds_original = ds.CelebADataset(DATA_DIR3, shuffle=False) decode_op = vision.Decode() ds_original = ds_original.map(input_columns=["image"], operations=[decode_op]) ds_original = ds_original.batch(2, drop_remainder=True) images_original = None for idx, (image, _) in enumerate(ds_original): if idx == 0: images_original = image else: images_original = np.append(images_original, image, axis=0) # CutMix Images data1 = ds.CelebADataset(dataset_dir=DATA_DIR3, shuffle=False) decode_op = vision.Decode() data1 = data1.map(input_columns=["image"], operations=[decode_op]) one_hot_op = data_trans.OneHot(num_classes=100) data1 = data1.map(input_columns=["attr"], operations=one_hot_op) cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.5, 0.9) data1 = data1.batch(2, drop_remainder=True) data1 = data1.map(input_columns=["image", "attr"], operations=cutmix_batch_op) images_cutmix = None for idx, (image, _) in enumerate(data1): if idx == 0: images_cutmix = image else: images_cutmix = np.append(images_cutmix, image, axis=0) if plot: visualize_list(images_original, images_cutmix) num_samples = images_original.shape[0] mse = np.zeros(num_samples) for i in range(num_samples): mse[i] = diff_mse(images_cutmix[i], images_original[i]) logger.info("MSE= {}".format(str(np.mean(mse))))
def test_celeba_dataset_distribute(): data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0) count = 0 for item in data.create_dict_iterator(): logger.info("----------image--------") logger.info(item["image"]) logger.info("----------attr--------") logger.info(item["attr"]) count = count + 1 assert (count == 1)
def test_celeba_dataset_exception_file_path(): def exception_func(item): raise Exception("Error occur!") try: data = ds.CelebADataset(DATA_DIR, shuffle=False) data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) for _ in data.create_dict_iterator(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str( e) try: data = ds.CelebADataset(DATA_DIR, shuffle=False) data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) for _ in data.create_dict_iterator(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str( e) try: data = ds.CelebADataset(DATA_DIR, shuffle=False) data = data.map(operations=exception_func, input_columns=["attr"], num_parallel_workers=1) for _ in data.create_dict_iterator(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str( e)
def test_celeba_sampler_exception(): """ Test CelebA with bad sampler input """ logger.info("Test CelebA with bad sampler input") try: data = ds.CelebADataset(DATA_DIR, sampler="") for _ in data.create_dict_iterator(): pass assert False except TypeError as e: assert "Argument" in str(e)
def test_celeba_sampler_exception(): """ Test CelebA with bad sampler input """ logger.info("Test CelebA with bad sampler input") try: data = ds.CelebADataset(DATA_DIR, sampler="") for _ in data.create_dict_iterator(): pass assert False except TypeError as e: assert "Unsupported sampler object of type (<class 'str'>)" in str(e)
def test_celeba_get_dataset_size(): data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True) size = data.get_dataset_size() assert size == 4 data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train") size = data.get_dataset_size() assert size == 2 data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid") size = data.get_dataset_size() assert size == 1 data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test") size = data.get_dataset_size() assert size == 1
def test_celeba_dataset_distribute(): """ Test CelebA dataset with distributed options """ logger.info("Test CelebA with sharding") data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0) count = 0 for item in data.create_dict_iterator(num_epochs=1): logger.info("----------image--------") logger.info(item["image"]) logger.info("----------attr--------") logger.info(item["attr"]) count = count + 1 assert count == 2
def test_celeba_padded(): data = ds.CelebADataset("../data/dataset/testCelebAData/") padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}] padded_ds = ds.PaddedDataset(padded_samples) data = data + padded_ds dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None) data.use_sampler(dis_sampler) data = data.repeat(2) count = 0 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): count = count + 1 assert count == 4
def test_celeba_dataset_ext(): ext = [".JPEG"] data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext) expect_labels = [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1], count = 0 for item in data.create_dict_iterator(): logger.info("----------image--------") logger.info(item["image"]) logger.info("----------attr--------") logger.info(item["attr"]) for index in range(len(expect_labels[count])): assert (item["attr"][index] == expect_labels[count][index]) count = count + 1 assert (count == 1)
def test_celeba_dataset_op(): data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0) crop_size = (80, 80) resize_size = (24, 24) # define map operations data = data.repeat(2) center_crop = vision.CenterCrop(crop_size) resize_op = vision.Resize(resize_size, Inter.LINEAR) # Bilinear mode data = data.map(input_columns=["image"], operations=center_crop) data = data.map(input_columns=["image"], operations=resize_op) count = 0 for item in data.create_dict_iterator(): logger.info("----------image--------") logger.info(item["image"]) count = count + 1 assert (count == 4)
def test_celeba_dataset_label(): data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False) expect_labels = [[ 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1 ], [ 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1 ]] count = 0 for item in data.create_dict_iterator(): logger.info("----------image--------") logger.info(item["image"]) logger.info("----------attr--------") logger.info(item["attr"]) for index in range(len(expect_labels[count])): assert (item["attr"][index] == expect_labels[count][index]) count = count + 1 assert (count == 2)
def test_get_column_name_celeba(): data = ds.CelebADataset(CELEBA_DIR) assert data.get_col_names() == ["image", "attr"]
def test_celeba_get_dataset_size(): data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True) size = data.get_dataset_size() assert size == 2