Beispiel #1
0
def test_cut_out_op_multicut():
    """
    Test Cutout
    """
    logger.info("test_cut_out")

    # First dataset
    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])

    transforms_1 = [f.Decode(), f.ToTensor(), f.RandomErasing(value='random')]
    transform_1 = f.ComposeOp(transforms_1)
    data1 = data1.map(input_columns=["image"], operations=transform_1())

    # Second dataset
    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
    decode_op = c.Decode()
    cut_out_op = c.CutOut(80, num_patches=10)

    transforms_2 = [decode_op, cut_out_op]

    data2 = data2.map(input_columns=["image"], operations=transforms_2)

    num_iter = 0
    for item1, item2 in zip(data1.create_dict_iterator(),
                            data2.create_dict_iterator()):
        num_iter += 1
        image_1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
        # C image doesn't require transpose
        image_2 = item2["image"]

        logger.info("shape of image_1: {}".format(image_1.shape))
        logger.info("shape of image_2: {}".format(image_2.shape))

        logger.info("dtype of image_1: {}".format(image_1.dtype))
        logger.info("dtype of image_2: {}".format(image_2.dtype))
Beispiel #2
0
def test_cut_out_md5():
    """
    Test Cutout with md5 check
    """
    logger.info("test_cut_out_md5")
    original_seed = config_get_set_seed(2)
    original_num_parallel_workers = config_get_set_num_parallel_workers(1)

    # First dataset
    data1 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    decode_op = c.Decode()
    cut_out_op = c.CutOut(100)
    data1 = data1.map(input_columns=["image"], operations=decode_op)
    data1 = data1.map(input_columns=["image"], operations=cut_out_op)

    data2 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    transforms = [f.Decode(), f.ToTensor(), f.Cutout(100)]
    transform = f.ComposeOp(transforms)
    data2 = data2.map(input_columns=["image"], operations=transform())

    # Compare with expected md5 from images
    filename1 = "cut_out_01_c_result.npz"
    save_and_check_md5(data1, filename1, generate_golden=GENERATE_GOLDEN)
    filename2 = "cut_out_01_py_result.npz"
    save_and_check_md5(data2, filename2, generate_golden=GENERATE_GOLDEN)

    # Restore config
    ds.config.set_seed(original_seed)
    ds.config.set_num_parallel_workers(original_num_parallel_workers)
Beispiel #3
0
def test_cut_out_comp(plot=False):
    """
    Test Cutout with c++ and python op comparison
    """
    logger.info("test_cut_out_comp")

    # First dataset
    data1 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)

    transforms_1 = [f.Decode(), f.ToTensor(), f.Cutout(200)]
    transform_1 = f.ComposeOp(transforms_1)
    data1 = data1.map(input_columns=["image"], operations=transform_1())

    # Second dataset
    data2 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)

    transforms_2 = [c.Decode(), c.CutOut(200)]

    data2 = data2.map(input_columns=["image"], operations=transforms_2)

    num_iter = 0
    image_list_1, image_list_2 = [], []
    for item1, item2 in zip(data1.create_dict_iterator(),
                            data2.create_dict_iterator()):
        num_iter += 1
        image_1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
        # C image doesn't require transpose
        image_2 = item2["image"]
        image_list_1.append(image_1)
        image_list_2.append(image_2)

        logger.info("shape of image_1: {}".format(image_1.shape))
        logger.info("shape of image_2: {}".format(image_2.shape))

        logger.info("dtype of image_1: {}".format(image_1.dtype))
        logger.info("dtype of image_2: {}".format(image_2.dtype))
    if plot:
        visualize_list(image_list_1, image_list_2, visualize_mode=2)