示例#1
0
def test_cut_out_op_multicut():
    """
    Test Cutout
    """
    logger.info("test_cut_out")

    # First dataset
    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])

    transforms_1 = [f.Decode(), f.ToTensor(), f.RandomErasing(value='random')]
    transform_1 = f.ComposeOp(transforms_1)
    data1 = data1.map(input_columns=["image"], operations=transform_1())

    # Second dataset
    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
    decode_op = c.Decode()
    cut_out_op = c.CutOut(80, num_patches=10)

    transforms_2 = [decode_op, cut_out_op]

    data2 = data2.map(input_columns=["image"], operations=transforms_2)

    num_iter = 0
    for item1, item2 in zip(data1.create_dict_iterator(),
                            data2.create_dict_iterator()):
        num_iter += 1
        image_1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
        # C image doesn't require transpose
        image_2 = item2["image"]

        logger.info("shape of image_1: {}".format(image_1.shape))
        logger.info("shape of image_2: {}".format(image_2.shape))

        logger.info("dtype of image_1: {}".format(image_1.dtype))
        logger.info("dtype of image_2: {}".format(image_2.dtype))
示例#2
0
def test_random_erasing_md5():
    """
    Test RandomErasing with md5 check
    """
    logger.info("Test RandomErasing with md5 check")
    original_seed = config_get_set_seed(5)
    original_num_parallel_workers = config_get_set_num_parallel_workers(1)

    # Generate dataset
    data = ds.TFRecordDataset(DATA_DIR,
                              SCHEMA_DIR,
                              columns_list=["image"],
                              shuffle=False)
    transforms_1 = [
        vision.Decode(),
        vision.ToTensor(),
        vision.RandomErasing(value='random')
    ]
    transform_1 = vision.ComposeOp(transforms_1)
    data = data.map(input_columns=["image"], operations=transform_1())
    # Compare with expected md5 from images
    filename = "random_erasing_01_result.npz"
    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)

    # Restore configuration
    ds.config.set_seed(original_seed)
    ds.config.set_num_parallel_workers((original_num_parallel_workers))
def create_imagenet_dataset(imagenet_dir):
    ds = de.ImageFolderDatasetV2(imagenet_dir)

    transform = F.ComposeOp([
        F.Decode(),
        F.RandomHorizontalFlip(0.5),
        F.ToTensor(),
        F.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
        F.RandomErasing()
    ])
    ds = ds.map(input_columns="image", operations=transform())
    ds = ds.shuffle(buffer_size=5)
    ds = ds.repeat(3)
    return ds
示例#4
0
def test_random_erasing_op(plot=False):
    """
    Test RandomErasing and Cutout
    """
    logger.info("test_random_erasing")

    # First dataset
    data1 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    transforms_1 = [
        vision.Decode(),
        vision.ToTensor(),
        vision.RandomErasing(value='random')
    ]
    transform_1 = vision.ComposeOp(transforms_1)
    data1 = data1.map(input_columns=["image"], operations=transform_1())

    # Second dataset
    data2 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    transforms_2 = [vision.Decode(), vision.ToTensor(), vision.Cutout(80)]
    transform_2 = vision.ComposeOp(transforms_2)
    data2 = data2.map(input_columns=["image"], operations=transform_2())

    num_iter = 0
    for item1, item2 in zip(data1.create_dict_iterator(),
                            data2.create_dict_iterator()):
        num_iter += 1
        image_1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
        image_2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)

        logger.info("shape of image_1: {}".format(image_1.shape))
        logger.info("shape of image_2: {}".format(image_2.shape))

        logger.info("dtype of image_1: {}".format(image_1.dtype))
        logger.info("dtype of image_2: {}".format(image_2.dtype))

        mse = diff_mse(image_1, image_2)
        if plot:
            visualize_image(image_1, image_2, mse)