コード例 #1
0
def test_py_vision_with_c_transforms():
    """
    Test combining Python vision operations with C++ transforms operations
    """

    ds.config.set_seed(0)

    def test_config(op_list):
        data_dir = "../data/dataset/testImageNetData/train/"
        data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
        data1 = data1.map(operations=op_list, input_columns=["image"])
        transformed_images = []

        for item in data1.create_dict_iterator(num_epochs=1,
                                               output_numpy=True):
            transformed_images.append(item["image"])
        return transformed_images

    # Test with Mask Op
    output_arr = test_config([
        py_vision.Decode(),
        py_vision.CenterCrop((2)), np.array,
        c_transforms.Mask(c_transforms.Relational.GE, 100)
    ])

    exp_arr = [
        np.array([[[True, False, False], [True, False, False]],
                  [[True, False, False], [True, False, False]]]),
        np.array([[[True, False, False], [True, False, False]],
                  [[True, False, False], [True, False, False]]])
    ]

    for exp_a, output in zip(exp_arr, output_arr):
        np.testing.assert_array_equal(exp_a, output)

    # Test with Fill Op
    output_arr = test_config([
        py_vision.Decode(),
        py_vision.CenterCrop((4)), np.array,
        c_transforms.Fill(10)
    ])

    exp_arr = [np.ones((4, 4, 3)) * 10] * 2
    for exp_a, output in zip(exp_arr, output_arr):
        np.testing.assert_array_equal(exp_a, output)

    # Test with Concatenate Op, which will raise an error since ConcatenateOp only supports rank 1 tensors.
    with pytest.raises(RuntimeError) as error_info:
        test_config([
            py_vision.Decode(),
            py_vision.CenterCrop((2)), np.array,
            c_transforms.Concatenate(0)
        ])
    assert "Only 1D tensors supported" in str(error_info.value)
コード例 #2
0
def test_center_crop_comp(height=375, width=375, plot=False):
    """
    Test CenterCrop between python and c image augmentation
    """
    logger.info("Test CenterCrop")

    # First dataset
    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
    decode_op = vision.Decode()
    center_crop_op = vision.CenterCrop([height, width])
    data1 = data1.map(operations=decode_op, input_columns=["image"])
    data1 = data1.map(operations=center_crop_op, input_columns=["image"])

    # Second dataset
    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
    transforms = [
        py_vision.Decode(),
        py_vision.CenterCrop([height, width]),
        py_vision.ToTensor()
    ]
    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
    data2 = data2.map(operations=transform, input_columns=["image"])

    image_c_cropped = []
    image_py_cropped = []
    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
        c_image = item1["image"]
        py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
        # Note: The images aren't exactly the same due to rounding error
        assert diff_mse(py_image, c_image) < 0.001
        image_c_cropped.append(c_image.copy())
        image_py_cropped.append(py_image.copy())
    if plot:
        visualize_list(image_c_cropped, image_py_cropped, visualize_mode=2)
コード例 #3
0
def test_serdes_pyvision(remove_json_files=True):
    """
    Test serdes on py_transform pipeline.
    """
    data_dir = [
        "../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
    ]
    schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
    data1 = ds.TFRecordDataset(data_dir,
                               schema_file,
                               columns_list=["image", "label"],
                               shuffle=False)
    transforms = [
        py_vision.Decode(),
        py_vision.CenterCrop([32, 32]),
        py_vision.ToTensor()
    ]
    data1 = data1.map(operations=py.Compose(transforms),
                      input_columns=["image"])
    # Current python function derialization will be failed for pickle, so we disable this testcase
    # as an exception testcase.
    try:
        util_check_serialize_deserialize_file(data1,
                                              "pyvision_dataset_pipeline",
                                              remove_json_files)
        assert False
    except NotImplementedError as e:
        assert "python function is not yet supported" in str(e)
コード例 #4
0
def test_linear_transformation_exception_02():
    """
    Test LinearTransformation op: mean_vector is not provided
    Expected to raise ValueError
    """
    logger.info("test_linear_transformation_exception_02")

    # Initialize parameters
    height = 50
    weight = 50
    dim = 3 * height * weight
    transformation_matrix = np.ones([dim, dim])

    # Generate dataset
    data1 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    try:
        transforms = [
            py_vision.Decode(),
            py_vision.CenterCrop([height, weight]),
            py_vision.ToTensor(),
            py_vision.LinearTransformation(transformation_matrix, None)
        ]
        transform = mindspore.dataset.transforms.py_transforms.Compose(
            transforms)
        data1 = data1.map(operations=transform, input_columns=["image"])
    except TypeError as e:
        logger.info("Got an exception in DE: {}".format(str(e)))
        assert "Argument mean_vector with value None is not of type [<class 'numpy.ndarray'>]" in str(
            e)
コード例 #5
0
def test_random_apply_md5():
    """
    Test RandomApply op with md5 check
    """
    logger.info("test_random_apply_md5")
    original_seed = config_get_set_seed(10)
    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
    # define map operations
    transforms_list = [py_vision.CenterCrop(64), py_vision.RandomRotation(30)]
    transforms = [
        py_vision.Decode(),
        # Note: using default value "prob=0.5"
        py_transforms.RandomApply(transforms_list),
        py_vision.ToTensor()
    ]
    transform = py_transforms.Compose(transforms)

    #  Generate dataset
    data = ds.TFRecordDataset(DATA_DIR,
                              SCHEMA_DIR,
                              columns_list=["image"],
                              shuffle=False)
    data = data.map(operations=transform, input_columns=["image"])

    # check results with md5 comparison
    filename = "random_apply_01_result.npz"
    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)

    # Restore configuration
    ds.config.set_seed(original_seed)
    ds.config.set_num_parallel_workers((original_num_parallel_workers))
コード例 #6
0
def test_linear_transformation_md5():
    """
    Test LinearTransformation op: valid params (transformation_matrix, mean_vector)
    Expected to pass
    """
    logger.info("test_linear_transformation_md5")

    # Initialize parameters
    height = 50
    weight = 50
    dim = 3 * height * weight
    transformation_matrix = np.ones([dim, dim])
    mean_vector = np.zeros(dim)

    # Generate dataset
    data1 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    transforms = [
        py_vision.Decode(),
        py_vision.CenterCrop([height, weight]),
        py_vision.ToTensor(),
        py_vision.LinearTransformation(transformation_matrix, mean_vector)
    ]
    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
    data1 = data1.map(operations=transform, input_columns=["image"])

    # Compare with expected md5 from images
    filename = "linear_transformation_01_result.npz"
    save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
コード例 #7
0
def test_linear_transformation_exception_04():
    """
    Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix
    Expected to raise ValueError
    """
    logger.info("test_linear_transformation_exception_04")

    # Initialize parameters
    height = 50
    weight = 50
    dim = 3 * height * weight
    transformation_matrix = np.ones([dim, dim])
    mean_vector = np.zeros(dim - 1)

    # Generate dataset
    data1 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    try:
        transforms = [
            py_vision.Decode(),
            py_vision.CenterCrop([height, weight]),
            py_vision.ToTensor(),
            py_vision.LinearTransformation(transformation_matrix, mean_vector)
        ]
        transform = mindspore.dataset.transforms.py_transforms.Compose(
            transforms)
        data1 = data1.map(operations=transform, input_columns=["image"])
    except ValueError as e:
        logger.info("Got an exception in DE: {}".format(str(e)))
        assert "should match" in str(e)
コード例 #8
0
def test_to_pil_02():
    """
    Test ToPIL Op with md5 comparison: input is not PIL image
    Expected to pass
    """
    logger.info("test_to_pil_02")

    # Generate dataset
    data1 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    decode_op = c_vision.Decode()
    transforms = [
        # If input type is not PIL.
        py_vision.ToPIL(),
        py_vision.CenterCrop(375),
        py_vision.ToTensor()
    ]
    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
    data1 = data1.map(operations=decode_op, input_columns=["image"])
    data1 = data1.map(operations=transform, input_columns=["image"])

    # Compare with expected md5 from images
    filename = "to_pil_02_result.npz"
    save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
コード例 #9
0
ファイル: dataset.py プロジェクト: zhangjinrong/mindspore
def create_dataset_py(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
    """
    create a train or eval dataset

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32
        target(str): the device target. Default: Ascend

    Returns:
        dataset
    """
    if target == "Ascend":
        device_num = int(os.getenv("RANK_SIZE"))
        rank_id = int(os.getenv("RANK_ID"))
    else:
        init()
        rank_id = get_rank()
        device_num = get_group_size()

    if do_train:
        if device_num == 1:
            ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
        else:
            ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
                                       num_shards=device_num, shard_id=rank_id)
    else:
        ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False)

    image_size = 224

    # define map operations
    decode_op = P.Decode()
    resize_crop_op = P.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333))
    horizontal_flip_op = P.RandomHorizontalFlip(prob=0.5)

    resize_op = P.Resize(256)
    center_crop = P.CenterCrop(image_size)
    to_tensor = P.ToTensor()
    normalize_op = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    # define map operations
    if do_train:
        trans = [decode_op, resize_crop_op, horizontal_flip_op, to_tensor, normalize_op]
    else:
        trans = [decode_op, resize_op, center_crop, to_tensor, normalize_op]

    compose = P2.Compose(trans)
    ds = ds.map(operations=compose, input_columns="image", num_parallel_workers=8, python_multiprocessing=True)

    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    ds = ds.repeat(repeat_num)

    return ds
コード例 #10
0
def test_linear_transformation_op(plot=False):
    """
    Test LinearTransformation op: verify if images transform correctly
    """
    logger.info("test_linear_transformation_01")

    # Initialize parameters
    height = 50
    weight = 50
    dim = 3 * height * weight
    transformation_matrix = np.eye(dim)
    mean_vector = np.zeros(dim)

    # Define operations
    transforms = [
        py_vision.Decode(),
        py_vision.CenterCrop([height, weight]),
        py_vision.ToTensor()
    ]
    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)

    # First dataset
    data1 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    data1 = data1.map(operations=transform, input_columns=["image"])
    # Note: if transformation matrix is diagonal matrix with all 1 in diagonal,
    #       the output matrix in expected to be the same as the input matrix.
    data1 = data1.map(operations=py_vision.LinearTransformation(
        transformation_matrix, mean_vector),
                      input_columns=["image"])

    # Second dataset
    data2 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    data2 = data2.map(operations=transform, input_columns=["image"])

    image_transformed = []
    image = []
    for item1, item2 in zip(
            data1.create_dict_iterator(num_epochs=1, output_numpy=True),
            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
        image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
        image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
        image_transformed.append(image1)
        image.append(image2)

        mse = diff_mse(image1, image2)
        assert mse == 0
    if plot:
        visualize_list(image, image_transformed)
コード例 #11
0
def test_random_choice_comp(plot=False):
    """
    Test RandomChoice and compare with single CenterCrop results
    """
    logger.info("test_random_choice_comp")
    # define map operations
    transforms_list = [py_vision.CenterCrop(64)]
    transforms1 = [
        py_vision.Decode(),
        py_transforms.RandomChoice(transforms_list),
        py_vision.ToTensor()
    ]
    transform1 = py_transforms.Compose(transforms1)

    transforms2 = [
        py_vision.Decode(),
        py_vision.CenterCrop(64),
        py_vision.ToTensor()
    ]
    transform2 = py_transforms.Compose(transforms2)

    #  First dataset
    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
    data1 = data1.map(operations=transform1, input_columns=["image"])
    #  Second dataset
    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
    data2 = data2.map(operations=transform2, input_columns=["image"])

    image_choice = []
    image_original = []
    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
        image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
        image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
        image_choice.append(image1)
        image_original.append(image2)

        mse = diff_mse(image1, image2)
        assert mse == 0
    if plot:
        visualize_list(image_original, image_choice)
コード例 #12
0
def create_dataset_val(batch_size=128,
                       val_data_url='',
                       workers=8,
                       distributed=False,
                       input_size=224):
    """Create ImageNet validation dataset"""
    if not os.path.exists(val_data_url):
        raise ValueError('Path not exists')
    rank_id = get_rank() if distributed else 0
    rank_size = get_group_size() if distributed else 1
    dataset = ds.ImageFolderDataset(val_data_url,
                                    num_parallel_workers=workers,
                                    num_shards=rank_size,
                                    shard_id=rank_id)
    scale_size = None

    if isinstance(input_size, tuple):
        assert len(input_size) == 2
        if input_size[-1] == input_size[-2]:
            scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT))
        else:
            scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size])
    else:
        scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT))

    type_cast_op = c_transforms.TypeCast(mstype.int32)
    decode_op = py_vision.Decode()
    resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
    center_crop = py_vision.CenterCrop(size=input_size)
    to_tensor = py_vision.ToTensor()
    normalize_op = py_vision.Normalize(IMAGENET_DEFAULT_MEAN,
                                       IMAGENET_DEFAULT_STD)

    image_ops = py_transforms.Compose(
        [decode_op, resize_op, center_crop, to_tensor, normalize_op])

    dataset = dataset.map(input_columns=["label"],
                          operations=type_cast_op,
                          num_parallel_workers=workers)
    dataset = dataset.map(input_columns=["image"],
                          operations=image_ops,
                          num_parallel_workers=workers)
    dataset = dataset.batch(batch_size,
                            per_batch_map=split_imgs_and_labels,
                            input_columns=["image", "label"],
                            num_parallel_workers=2,
                            drop_remainder=True)
    dataset = dataset.repeat(1)
    return dataset
コード例 #13
0
def test_random_apply_op(plot=False):
    """
    Test RandomApply in python transformations
    """
    logger.info("test_random_apply_op")
    # define map operations
    transforms_list = [py_vision.CenterCrop(64), py_vision.RandomRotation(30)]
    transforms1 = [
        py_vision.Decode(),
        py_transforms.RandomApply(transforms_list, prob=0.6),
        py_vision.ToTensor()
    ]
    transform1 = py_transforms.Compose(transforms1)

    transforms2 = [py_vision.Decode(), py_vision.ToTensor()]
    transform2 = py_transforms.Compose(transforms2)

    #  First dataset
    data1 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    data1 = data1.map(operations=transform1, input_columns=["image"])
    #  Second dataset
    data2 = ds.TFRecordDataset(DATA_DIR,
                               SCHEMA_DIR,
                               columns_list=["image"],
                               shuffle=False)
    data2 = data2.map(operations=transform2, input_columns=["image"])

    image_apply = []
    image_original = []
    for item1, item2 in zip(
            data1.create_dict_iterator(num_epochs=1, output_numpy=True),
            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
        image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
        image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
        image_apply.append(image1)
        image_original.append(image2)
    if plot:
        visualize_list(image_original, image_apply)
コード例 #14
0
def test_serdes_pyvision(remove_json_files=True):
    """
    Test serdes on py_transform pipeline.
    """
    data_dir = [
        "../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
    ]
    schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
    data1 = ds.TFRecordDataset(data_dir,
                               schema_file,
                               columns_list=["image", "label"],
                               shuffle=False)
    transforms = [
        py_vision.Decode(),
        py_vision.CenterCrop([32, 32]),
        py_vision.ToTensor()
    ]
    data1 = data1.map(operations=py.Compose(transforms),
                      input_columns=["image"])
    util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline",
                                          remove_json_files)
コード例 #15
0
ファイル: pet_dataset.py プロジェクト: yrpang/mindspore
def create_dataset(dataset_path,
                   do_train,
                   config,
                   platform,
                   repeat_num=1,
                   batch_size=100):
    """
    create a train or eval dataset

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32

    Returns:
        dataset
    """
    if platform == "Ascend":
        rank_size = int(os.getenv("RANK_SIZE"))
        rank_id = int(os.getenv("RANK_ID"))
        if rank_size == 1:
            data_set = ds.MindDataset(dataset_path,
                                      num_parallel_workers=8,
                                      shuffle=True)
        else:
            data_set = ds.MindDataset(dataset_path,
                                      num_parallel_workers=8,
                                      shuffle=True,
                                      num_shards=rank_size,
                                      shard_id=rank_id)
    elif platform == "GPU":
        if do_train:
            from mindspore.communication.management import get_rank, get_group_size
            data_set = ds.MindDataset(dataset_path,
                                      num_parallel_workers=8,
                                      shuffle=True,
                                      num_shards=get_group_size(),
                                      shard_id=get_rank())
        else:
            data_set = ds.MindDataset(dataset_path,
                                      num_parallel_workers=8,
                                      shuffle=False)
    else:
        raise ValueError("Unsupported platform.")

    resize_height = config.image_height
    buffer_size = 1000

    # define map operations
    resize_crop_op = C.RandomCropDecodeResize(resize_height,
                                              scale=(0.08, 1.0),
                                              ratio=(0.75, 1.333))
    horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)

    color_op = C.RandomColorAdjust(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4)
    rescale_op = C.Rescale(1 / 255.0, 0)
    normalize_op = C.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
    change_swap_op = C.HWC2CHW()

    # define python operations
    decode_p = P.Decode()
    resize_p = P.Resize(256, interpolation=Inter.BILINEAR)
    center_crop_p = P.CenterCrop(224)
    totensor = P.ToTensor()
    normalize_p = P.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    composeop = P2.Compose(
        [decode_p, resize_p, center_crop_p, totensor, normalize_p])
    if do_train:
        trans = [
            resize_crop_op, horizontal_flip_op, color_op, rescale_op,
            normalize_op, change_swap_op
        ]
    else:
        trans = composeop
    type_cast_op = C2.TypeCast(mstype.int32)

    data_set = data_set.map(input_columns="image",
                            operations=trans,
                            num_parallel_workers=8)
    data_set = data_set.map(input_columns="label_list",
                            operations=type_cast_op,
                            num_parallel_workers=8)

    # apply shuffle operations
    data_set = data_set.shuffle(buffer_size=buffer_size)

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    data_set = data_set.repeat(repeat_num)

    return data_set
コード例 #16
0
def create_dataset_py(dataset_path,
                      do_train,
                      config,
                      device_target,
                      repeat_num=1,
                      batch_size=32):
    """
    create a train or eval dataset

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1.
        batch_size(int): the batch size of dataset. Default: 32.

    Returns:
        dataset
    """
    if device_target == "Ascend":
        rank_size = int(os.getenv("RANK_SIZE"))
        rank_id = int(os.getenv("RANK_ID"))
        if do_train:
            if rank_size == 1:
                data_set = ds.ImageFolderDataset(dataset_path,
                                                 num_parallel_workers=8,
                                                 shuffle=True)
            else:
                data_set = ds.ImageFolderDataset(dataset_path,
                                                 num_parallel_workers=8,
                                                 shuffle=True,
                                                 num_shards=rank_size,
                                                 shard_id=rank_id)
        else:
            data_set = ds.ImageFolderDataset(dataset_path,
                                             num_parallel_workers=8,
                                             shuffle=False)
    else:
        raise ValueError("Unsupported device target.")

    resize_height = 224

    if do_train:
        buffer_size = 20480
        # apply shuffle operations
        data_set = data_set.shuffle(buffer_size=buffer_size)

    # define map operations
    decode_op = P.Decode()
    resize_crop_op = P.RandomResizedCrop(resize_height,
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.333))
    horizontal_flip_op = P.RandomHorizontalFlip(prob=0.5)

    resize_op = P.Resize(256)
    center_crop = P.CenterCrop(resize_height)
    to_tensor = P.ToTensor()
    normalize_op = P.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])

    if do_train:
        trans = [
            decode_op, resize_crop_op, horizontal_flip_op, to_tensor,
            normalize_op
        ]
    else:
        trans = [decode_op, resize_op, center_crop, to_tensor, normalize_op]

    compose = P2.Compose(trans)

    data_set = data_set.map(operations=compose,
                            input_columns="image",
                            num_parallel_workers=8,
                            python_multiprocessing=True)

    # apply batch operations
    data_set = data_set.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    data_set = data_set.repeat(repeat_num)

    return data_set