def test_auto_contrast_invalid_cutoff_param_py():
    """
    Test AutoContrast python Op with invalid cutoff parameter
    """
    logger.info("Test AutoContrast python Op with invalid cutoff parameter")
    try:
        ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
        ds = ds.map(input_columns=["image"],
                    operations=[
                        F.ComposeOp([
                            F.Decode(),
                            F.Resize((224, 224)),
                            F.AutoContrast(cutoff=-10.0),
                            F.ToTensor()
                        ])
                    ])
    except ValueError as error:
        logger.info("Got an exception in DE: {}".format(str(error)))
        assert "Input cutoff is not within the required interval of (0 to 100)." in str(
            error)
    try:
        ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
        ds = ds.map(input_columns=["image"],
                    operations=[
                        F.ComposeOp([
                            F.Decode(),
                            F.Resize((224, 224)),
                            F.AutoContrast(cutoff=120.0),
                            F.ToTensor()
                        ])
                    ])
    except ValueError as error:
        logger.info("Got an exception in DE: {}".format(str(error)))
        assert "Input cutoff is not within the required interval of (0 to 100)." in str(
            error)
Beispiel #2
0
def test_random_sharpness_py(degrees=(0.7, 0.7), plot=False):
    """
    Test RandomSharpness python op
    """
    logger.info("Test RandomSharpness python op")

    # Original Images
    data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    transforms_original = F.ComposeOp(
        [F.Decode(), F.Resize((224, 224)),
         F.ToTensor()])

    ds_original = data.map(input_columns="image",
                           operations=transforms_original())

    ds_original = ds_original.batch(512)

    for idx, (image, _) in enumerate(ds_original):
        if idx == 0:
            images_original = np.transpose(image, (0, 2, 3, 1))
        else:
            images_original = np.append(images_original,
                                        np.transpose(image, (0, 2, 3, 1)),
                                        axis=0)

    # Random Sharpness Adjusted Images
    data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    py_op = F.RandomSharpness()
    if degrees is not None:
        py_op = F.RandomSharpness(degrees)

    transforms_random_sharpness = F.ComposeOp(
        [F.Decode(), F.Resize((224, 224)), py_op,
         F.ToTensor()])

    ds_random_sharpness = data.map(input_columns="image",
                                   operations=transforms_random_sharpness())

    ds_random_sharpness = ds_random_sharpness.batch(512)

    for idx, (image, _) in enumerate(ds_random_sharpness):
        if idx == 0:
            images_random_sharpness = np.transpose(image, (0, 2, 3, 1))
        else:
            images_random_sharpness = np.append(images_random_sharpness,
                                                np.transpose(
                                                    image, (0, 2, 3, 1)),
                                                axis=0)

    num_samples = images_original.shape[0]
    mse = np.zeros(num_samples)
    for i in range(num_samples):
        mse[i] = diff_mse(images_random_sharpness[i], images_original[i])

    logger.info("MSE= {}".format(str(np.mean(mse))))

    if plot:
        visualize_list(images_original, images_random_sharpness)
def test_auto_contrast_invalid_ignore_param_py():
    """
    Test AutoContrast python Op with invalid ignore parameter
    """
    logger.info("Test AutoContrast python Op with invalid ignore parameter")
    try:
        ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
        ds = ds.map(input_columns=["image"],
                    operations=[
                        F.ComposeOp([
                            F.Decode(),
                            F.Resize((224, 224)),
                            F.AutoContrast(ignore=255.5),
                            F.ToTensor()
                        ])
                    ])
    except TypeError as error:
        logger.info("Got an exception in DE: {}".format(str(error)))
        assert "Argument ignore with value 255.5 is not of type" in str(error)
    try:
        ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
        ds = ds.map(input_columns=["image"],
                    operations=[
                        F.ComposeOp([
                            F.Decode(),
                            F.Resize((224, 224)),
                            F.AutoContrast(ignore=(10, 100)),
                            F.ToTensor()
                        ])
                    ])
    except TypeError as error:
        logger.info("Got an exception in DE: {}".format(str(error)))
        assert "Argument ignore with value (10,100) is not of type" in str(
            error)
def test_random_color(degrees=(0.1, 1.9), plot=False):
    """
    Test RandomColor
    """
    logger.info("Test RandomColor")

    # Original Images
    ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    transforms_original = F.ComposeOp(
        [F.Decode(), F.Resize((224, 224)),
         F.ToTensor()])

    ds_original = ds.map(input_columns="image",
                         operations=transforms_original())

    ds_original = ds_original.batch(512)

    for idx, (image, _) in enumerate(ds_original):
        if idx == 0:
            images_original = np.transpose(image, (0, 2, 3, 1))
        else:
            images_original = np.append(images_original,
                                        np.transpose(image, (0, 2, 3, 1)),
                                        axis=0)

            # Random Color Adjusted Images
    ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    transforms_random_color = F.ComposeOp([
        F.Decode(),
        F.Resize((224, 224)),
        F.RandomColor(degrees=degrees),
        F.ToTensor()
    ])

    ds_random_color = ds.map(input_columns="image",
                             operations=transforms_random_color())

    ds_random_color = ds_random_color.batch(512)

    for idx, (image, _) in enumerate(ds_random_color):
        if idx == 0:
            images_random_color = np.transpose(image, (0, 2, 3, 1))
        else:
            images_random_color = np.append(images_random_color,
                                            np.transpose(image, (0, 2, 3, 1)),
                                            axis=0)

    num_samples = images_original.shape[0]
    mse = np.zeros(num_samples)
    for i in range(num_samples):
        mse[i] = np.mean((images_random_color[i] - images_original[i])**2)
    logger.info("MSE= {}".format(str(np.mean(mse))))

    if plot:
        visualize(images_original, images_random_color)
Beispiel #5
0
def test_equalize_py(plot=False):
    """
    Test Equalize py op
    """
    logger.info("Test Equalize")

    # Original Images
    ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    transforms_original = F.ComposeOp(
        [F.Decode(), F.Resize((224, 224)),
         F.ToTensor()])

    ds_original = ds.map(input_columns="image",
                         operations=transforms_original())

    ds_original = ds_original.batch(512)

    for idx, (image, _) in enumerate(ds_original):
        if idx == 0:
            images_original = np.transpose(image, (0, 2, 3, 1))
        else:
            images_original = np.append(images_original,
                                        np.transpose(image, (0, 2, 3, 1)),
                                        axis=0)

            # Color Equalized Images
    ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    transforms_equalize = F.ComposeOp(
        [F.Decode(),
         F.Resize((224, 224)),
         F.Equalize(),
         F.ToTensor()])

    ds_equalize = ds.map(input_columns="image",
                         operations=transforms_equalize())

    ds_equalize = ds_equalize.batch(512)

    for idx, (image, _) in enumerate(ds_equalize):
        if idx == 0:
            images_equalize = np.transpose(image, (0, 2, 3, 1))
        else:
            images_equalize = np.append(images_equalize,
                                        np.transpose(image, (0, 2, 3, 1)),
                                        axis=0)

    num_samples = images_original.shape[0]
    mse = np.zeros(num_samples)
    for i in range(num_samples):
        mse[i] = diff_mse(images_equalize[i], images_original[i])
    logger.info("MSE= {}".format(str(np.mean(mse))))

    if plot:
        visualize_list(images_original, images_equalize)
Beispiel #6
0
def test_auto_contrast(plot=False):
    """
    Test AutoContrast
    """
    logger.info("Test AutoContrast")

    # Original Images
    ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    transforms_original = F.ComposeOp(
        [F.Decode(), F.Resize((224, 224)),
         F.ToTensor()])

    ds_original = ds.map(input_columns="image",
                         operations=transforms_original())

    ds_original = ds_original.batch(512)

    for idx, (image, _) in enumerate(ds_original):
        if idx == 0:
            images_original = np.transpose(image, (0, 2, 3, 1))
        else:
            images_original = np.append(images_original,
                                        np.transpose(image, (0, 2, 3, 1)),
                                        axis=0)

            # AutoContrast Images
    ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    transforms_auto_contrast = F.ComposeOp(
        [F.Decode(),
         F.Resize((224, 224)),
         F.AutoContrast(),
         F.ToTensor()])

    ds_auto_contrast = ds.map(input_columns="image",
                              operations=transforms_auto_contrast())

    ds_auto_contrast = ds_auto_contrast.batch(512)

    for idx, (image, _) in enumerate(ds_auto_contrast):
        if idx == 0:
            images_auto_contrast = np.transpose(image, (0, 2, 3, 1))
        else:
            images_auto_contrast = np.append(images_auto_contrast,
                                             np.transpose(image, (0, 2, 3, 1)),
                                             axis=0)

    num_samples = images_original.shape[0]
    mse = np.zeros(num_samples)
    for i in range(num_samples):
        mse[i] = np.mean((images_auto_contrast[i] - images_original[i])**2)
    logger.info("MSE= {}".format(str(np.mean(mse))))

    if plot:
        visualize(images_original, images_auto_contrast)
Beispiel #7
0
def test_random_apply_exception_random_crop_badinput():
    """
    Test RandomApply: test invalid input for one of the transform functions,
    expected to raise error
    """
    logger.info("test_random_apply_exception_random_crop_badinput")
    original_seed = config_get_set_seed(200)
    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
    # define map operations
    transforms_list = [
        py_vision.Resize([32, 32]),
        py_vision.RandomCrop(100),  # crop size > image size
        py_vision.RandomRotation(30)
    ]
    transforms = [
        py_vision.Decode(),
        py_vision.RandomApply(transforms_list, prob=0.6),
        py_vision.ToTensor()
    ]
    transform = py_vision.ComposeOp(transforms)
    #  Generate dataset
    data = ds.TFRecordDataset(DATA_DIR,
                              SCHEMA_DIR,
                              columns_list=["image"],
                              shuffle=False)
    data = data.map(input_columns=["image"], operations=transform())
    try:
        _ = data.create_dict_iterator().get_next()
    except RuntimeError as e:
        logger.info("Got an exception in DE: {}".format(str(e)))
        assert "Crop size" in str(e)
    # Restore configuration
    ds.config.set_seed(original_seed)
    ds.config.set_num_parallel_workers(original_num_parallel_workers)
Beispiel #8
0
def skip_test_random_perspective_md5():
    """
    Test RandomPerspective with md5 comparison
    """
    logger.info("test_random_perspective_md5")
    original_seed = config_get_set_seed(5)
    original_num_parallel_workers = config_get_set_num_parallel_workers(1)

    # define map operations
    transforms = [
        py_vision.Decode(),
        py_vision.RandomPerspective(distortion_scale=0.3,
                                    prob=0.7,
                                    interpolation=Inter.BILINEAR),
        py_vision.Resize(
            1450),  # resize to a smaller size to prevent round-off error
        py_vision.ToTensor()
    ]
    transform = py_vision.ComposeOp(transforms)

    #  Generate dataset
    data = ds.TFRecordDataset(DATA_DIR,
                              SCHEMA_DIR,
                              columns_list=["image"],
                              shuffle=False)
    data = data.map(input_columns=["image"], operations=transform())

    # check results with md5 comparison
    filename = "random_perspective_01_result.npz"
    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)

    # Restore configuration
    ds.config.set_seed(original_seed)
    ds.config.set_num_parallel_workers((original_num_parallel_workers))
Beispiel #9
0
def test_concat_14():
    """
    Test concat: create dataset with different dataset folder, and do diffrent operation then concat
    """
    logger.info("test_concat_14")
    DATA_DIR = "../data/dataset/testPK/data"
    DATA_DIR2 = "../data/dataset/testImageNetData/train/"

    data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=3)
    data2 = ds.ImageFolderDatasetV2(DATA_DIR2, num_samples=2)

    transforms1 = F.ComposeOp([F.Decode(), F.Resize((224, 224)), F.ToTensor()])

    data1 = data1.map(input_columns=["image"], operations=transforms1())
    data2 = data2.map(input_columns=["image"], operations=transforms1())
    data3 = data1 + data2

    expected, output = [], []
    for d in data1:
        expected.append(d[0])
    for d in data2:
        expected.append(d[0])
    for d in data3:
        output.append(d[0])

    assert len(expected) == len(output)
    np.array_equal(np.array(output), np.array(expected))

    assert sum([1 for _ in data3]) == 5
    assert data3.get_dataset_size() == 5
Beispiel #10
0
def create_dataset_py(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32):
    """
    create a train or eval dataset

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1.
        batch_size(int): the batch size of dataset. Default: 32.

    Returns:
        dataset
    """
    if device_target == "Ascend":
        rank_size = int(os.getenv("RANK_SIZE"))
        rank_id = int(os.getenv("RANK_ID"))
        if do_train:
            if rank_size == 1:
                ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
            else:
                ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
                                             num_shards=rank_size, shard_id=rank_id)
        else:
            ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False)
    else:
        raise ValueError("Unsupported device target.")

    resize_height = config.image_height

    if do_train:
        buffer_size = 20480
        # apply shuffle operations
        ds = ds.shuffle(buffer_size=buffer_size)

    # define map operations
    decode_op = P.Decode()
    resize_crop_op = P.RandomResizedCrop(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
    horizontal_flip_op = P.RandomHorizontalFlip(prob=0.5)

    resize_op = P.Resize(256)
    center_crop = P.CenterCrop(resize_height)
    to_tensor = P.ToTensor()
    normalize_op = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    if do_train:
        trans = [decode_op, resize_crop_op, horizontal_flip_op, to_tensor, normalize_op]
    else:
        trans = [decode_op, resize_op, center_crop, to_tensor, normalize_op]

    compose = P.ComposeOp(trans)

    ds = ds.map(input_columns="image", operations=compose(), num_parallel_workers=8, python_multiprocessing=True)

    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    ds = ds.repeat(repeat_num)

    return ds
Beispiel #11
0
def test_rgb_hsv_pipeline():
    # First dataset
    transforms1 = [vision.Decode(), vision.Resize([64, 64]), vision.ToTensor()]
    transforms1 = vision.ComposeOp(transforms1)
    ds1 = ds.TFRecordDataset(DATA_DIR,
                             SCHEMA_DIR,
                             columns_list=["image"],
                             shuffle=False)
    ds1 = ds1.map(input_columns=["image"], operations=transforms1())

    # Second dataset
    transforms2 = [
        vision.Decode(),
        vision.Resize([64, 64]),
        vision.ToTensor(),
        vision.RgbToHsv(),
        vision.HsvToRgb()
    ]
    transform2 = vision.ComposeOp(transforms2)
    ds2 = ds.TFRecordDataset(DATA_DIR,
                             SCHEMA_DIR,
                             columns_list=["image"],
                             shuffle=False)
    ds2 = ds2.map(input_columns=["image"], operations=transform2())

    num_iter = 0
    for data1, data2 in zip(ds1.create_dict_iterator(),
                            ds2.create_dict_iterator()):
        num_iter += 1
        ori_img = data1["image"]
        cvt_img = data2["image"]
        assert_allclose(ori_img.flatten(),
                        cvt_img.flatten(),
                        rtol=1e-5,
                        atol=0)
        assert (ori_img.shape == cvt_img.shape)
Beispiel #12
0
def create_dataset(dataset_path,
                   do_train,
                   config,
                   platform,
                   repeat_num=1,
                   batch_size=100):
    """
    create a train or eval dataset

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32

    Returns:
        dataset
    """
    if platform == "Ascend":
        rank_size = int(os.getenv("RANK_SIZE"))
        rank_id = int(os.getenv("RANK_ID"))
        if rank_size == 1:
            ds = de.MindDataset(dataset_path,
                                num_parallel_workers=8,
                                shuffle=True)
        else:
            ds = de.MindDataset(dataset_path,
                                num_parallel_workers=8,
                                shuffle=True,
                                num_shards=rank_size,
                                shard_id=rank_id)
    elif platform == "GPU":
        if do_train:
            from mindspore.communication.management import get_rank, get_group_size
            ds = de.MindDataset(dataset_path,
                                num_parallel_workers=8,
                                shuffle=True,
                                num_shards=get_group_size(),
                                shard_id=get_rank())
        else:
            ds = de.MindDataset(dataset_path,
                                num_parallel_workers=8,
                                shuffle=False)
    else:
        raise ValueError("Unsupport platform.")

    resize_height = config.image_height
    buffer_size = 1000

    # define map operations
    resize_crop_op = C.RandomCropDecodeResize(resize_height,
                                              scale=(0.08, 1.0),
                                              ratio=(0.75, 1.333))
    horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)

    color_op = C.RandomColorAdjust(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4)
    rescale_op = C.Rescale(1 / 255.0, 0)
    normalize_op = C.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
    change_swap_op = C.HWC2CHW()

    # define python operations
    decode_p = P.Decode()
    resize_p = P.Resize(256, interpolation=Inter.BILINEAR)
    center_crop_p = P.CenterCrop(224)
    totensor = P.ToTensor()
    normalize_p = P.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    composeop = P.ComposeOp(
        [decode_p, resize_p, center_crop_p, totensor, normalize_p])
    if do_train:
        trans = [
            resize_crop_op, horizontal_flip_op, color_op, rescale_op,
            normalize_op, change_swap_op
        ]
    else:
        trans = composeop()
    type_cast_op = C2.TypeCast(mstype.int32)

    ds = ds.map(input_columns="image",
                operations=trans,
                num_parallel_workers=8)
    ds = ds.map(input_columns="label_list",
                operations=type_cast_op,
                num_parallel_workers=8)

    # apply shuffle operations
    ds = ds.shuffle(buffer_size=buffer_size)

    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    ds = ds.repeat(repeat_num)

    return ds
def test_uniform_augment(plot=False, num_ops=2):
    """
    Test UniformAugment
    """
    logger.info("Test UniformAugment")

    # Original Images
    ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    transforms_original = F.ComposeOp(
        [F.Decode(), F.Resize((224, 224)),
         F.ToTensor()])

    ds_original = ds.map(input_columns="image",
                         operations=transforms_original())

    ds_original = ds_original.batch(512)

    for idx, (image, label) in enumerate(ds_original):
        if idx == 0:
            images_original = np.transpose(image, (0, 2, 3, 1))
        else:
            images_original = np.append(images_original,
                                        np.transpose(image, (0, 2, 3, 1)),
                                        axis=0)

    # UniformAugment Images
    ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    transform_list = [
        F.RandomRotation(45),
        F.RandomColor(),
        F.RandomSharpness(),
        F.Invert(),
        F.AutoContrast(),
        F.Equalize()
    ]

    transforms_ua = F.ComposeOp([
        F.Decode(),
        F.Resize((224, 224)),
        F.UniformAugment(transforms=transform_list, num_ops=num_ops),
        F.ToTensor()
    ])

    ds_ua = ds.map(input_columns="image", operations=transforms_ua())

    ds_ua = ds_ua.batch(512)

    for idx, (image, label) in enumerate(ds_ua):
        if idx == 0:
            images_ua = np.transpose(image, (0, 2, 3, 1))
        else:
            images_ua = np.append(images_ua,
                                  np.transpose(image, (0, 2, 3, 1)),
                                  axis=0)

    num_samples = images_original.shape[0]
    mse = np.zeros(num_samples)
    for i in range(num_samples):
        mse[i] = np.mean((images_ua[i] - images_original[i])**2)
    logger.info("MSE= {}".format(str(np.mean(mse))))

    if plot:
        visualize(images_original, images_ua)
Beispiel #14
0
def test_auto_contrast_py(plot=False):
    """
    Test AutoContrast
    """
    logger.info("Test AutoContrast Python Op")

    # Original Images
    ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    transforms_original = F.ComposeOp(
        [F.Decode(), F.Resize((224, 224)),
         F.ToTensor()])

    ds_original = ds.map(input_columns="image",
                         operations=transforms_original())

    ds_original = ds_original.batch(512)

    for idx, (image, _) in enumerate(ds_original):
        if idx == 0:
            images_original = np.transpose(image, (0, 2, 3, 1))
        else:
            images_original = np.append(images_original,
                                        np.transpose(image, (0, 2, 3, 1)),
                                        axis=0)

            # AutoContrast Images
    ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

    transforms_auto_contrast = F.ComposeOp(
        [F.Decode(),
         F.Resize((224, 224)),
         F.AutoContrast(),
         F.ToTensor()])

    ds_auto_contrast = ds.map(input_columns="image",
                              operations=transforms_auto_contrast())

    ds_auto_contrast = ds_auto_contrast.batch(512)

    for idx, (image, _) in enumerate(ds_auto_contrast):
        if idx == 0:
            images_auto_contrast = np.transpose(image, (0, 2, 3, 1))
        else:
            images_auto_contrast = np.append(images_auto_contrast,
                                             np.transpose(image, (0, 2, 3, 1)),
                                             axis=0)

    num_samples = images_original.shape[0]
    mse = np.zeros(num_samples)
    for i in range(num_samples):
        mse[i] = diff_mse(images_auto_contrast[i], images_original[i])
    logger.info("MSE= {}".format(str(np.mean(mse))))

    # Compare with expected md5 from images
    filename = "autcontrast_01_result_py.npz"
    save_and_check_md5(ds_auto_contrast,
                       filename,
                       generate_golden=GENERATE_GOLDEN)

    if plot:
        visualize_list(images_original, images_auto_contrast)
Beispiel #15
0
def create_dataset_py(dataset_path,
                      do_train,
                      repeat_num=1,
                      batch_size=32,
                      target="Ascend"):
    """
    create a train or eval dataset

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32
        target(str): the device target. Default: Ascend

    Returns:
        dataset
    """
    if target == "Ascend":
        device_num = int(os.getenv("RANK_SIZE"))
        rank_id = int(os.getenv("RANK_ID"))
    else:
        init("nccl")
        rank_id = get_rank()
        device_num = get_group_size()

    if do_train:
        if device_num == 1:
            ds = de.ImageFolderDatasetV2(dataset_path,
                                         num_parallel_workers=8,
                                         shuffle=True)
        else:
            ds = de.ImageFolderDatasetV2(dataset_path,
                                         num_parallel_workers=8,
                                         shuffle=True,
                                         num_shards=device_num,
                                         shard_id=rank_id)
    else:
        ds = de.ImageFolderDatasetV2(dataset_path,
                                     num_parallel_workers=8,
                                     shuffle=False)

    image_size = 224

    # define map operations
    decode_op = P.Decode()
    resize_crop_op = P.RandomResizedCrop(image_size,
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.333))
    horizontal_flip_op = P.RandomHorizontalFlip(prob=0.5)

    resize_op = P.Resize(256)
    center_crop = P.CenterCrop(image_size)
    to_tensor = P.ToTensor()
    normalize_op = P.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])

    # define map operations
    if do_train:
        trans = [
            decode_op, resize_crop_op, horizontal_flip_op, to_tensor,
            normalize_op
        ]
    else:
        trans = [decode_op, resize_op, center_crop, to_tensor, normalize_op]

    compose = P.ComposeOp(trans)
    ds = ds.map(input_columns="image",
                operations=compose(),
                num_parallel_workers=8,
                python_multiprocessing=True)

    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    ds = ds.repeat(repeat_num)

    return ds