Beispiel #1
0
def test_random_color_c_errors():
    """
    Test that Cpp RandomColor errors with bad input
    """
    with pytest.raises(TypeError) as error_info:
        vision.RandomColor((12))
    assert "degrees must be either a tuple or a list." in str(error_info.value)

    with pytest.raises(TypeError) as error_info:
        vision.RandomColor(("col", 3))
    assert "Argument degrees[0] with value col is not of type (<class 'int'>, <class 'float'>)." in str(
        error_info.value)

    with pytest.raises(ValueError) as error_info:
        vision.RandomColor((0.9, 0.1))
    assert "degrees should be in (min,max) format. Got (max,min)." in str(error_info.value)

    with pytest.raises(ValueError) as error_info:
        vision.RandomColor((0.9,))
    assert "degrees must be a sequence with length 2." in str(error_info.value)

    # RandomColor Cpp Op will fail with one channel input
    mnist_ds = ds.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
    mnist_ds = mnist_ds.map(operations=vision.RandomColor(), input_columns="image")

    with pytest.raises(RuntimeError) as error_info:
        for _ in enumerate(mnist_ds):
            pass
    assert "Invalid number of channels in input image" in str(error_info.value)
Beispiel #2
0
def create_dataset(data_path, batch_size=32, repeat_size=1,
                   num_parallel_workers=1):
    """
    create dataset for train or test
    """
    # define dataset
    mnist_ds = ds.MnistDataset(data_path)

    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # define map operations
    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Bilinear mode
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

    # apply map operations on images
    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    # apply DatasetOps
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds
def create_dataset(data_path,
                   batch_size=32,
                   repeat_size=1,
                   num_parallel_workers=1):
    """
    create dataset for train or test
    """
    # define dataset
    mnist_ds = ds.MnistDataset(data_path)

    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0

    # define map operations
    resize_op = CV.Resize((resize_height, resize_width))  # Bilinear mode
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()

    # apply map operations on images
    mnist_ds = mnist_ds.map(input_columns="image",
                            operations=resize_op,
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image",
                            operations=rescale_op,
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image",
                            operations=hwc2chw_op,
                            num_parallel_workers=num_parallel_workers)

    # apply DatasetOps
    mnist_ds = mnist_ds.batch(batch_size)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds
Beispiel #4
0
def test_mnist_sampler_chain():
    """
    Test Mnist sampler chain
    """
    logger.info("test_mnist_sampler_chain")

    sampler = ds.DistributedSampler(num_shards=1,
                                    shard_id=0,
                                    shuffle=False,
                                    num_samples=3,
                                    offset=1)
    child_sampler = ds.RandomSampler(replacement=True, num_samples=4)
    sampler.add_child(child_sampler)
    data1 = ds.MnistDataset(MNIST_DATA_DIR, sampler=sampler)

    # Verify dataset size
    data1_size = data1.get_dataset_size()
    logger.info("dataset size is: {}".format(data1_size))
    assert data1_size == 3
    # Verify number of rows
    assert sum([1 for _ in data1]) == 3

    # Verify dataset contents
    res = []
    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
        logger.info("item: {}".format(item))
        res.append(item)
    logger.info("dataset: {}".format(res))
def create_mnist_dataset(mnist_dir, num_parallel_workers=1):
    """create mnist dataset method"""
    ds = de.MnistDataset(mnist_dir)

    # apply map operations on images
    ds = ds.map(operations=C.TypeCast(mstype.int32), input_columns="label")
    ds = ds.map(operations=VC.Resize((MNIST_CONFIG.image_height, MNIST_CONFIG.image_width),
                                     interpolation=Inter.LINEAR),
                input_columns="image",
                num_parallel_workers=num_parallel_workers)
    ds = ds.map(operations=VC.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081),
                input_columns="image",
                num_parallel_workers=num_parallel_workers)
    ds = ds.map(operations=VC.Rescale(1.0 / 255.0, 0.0),
                input_columns="image",
                num_parallel_workers=num_parallel_workers)
    ds = ds.map(operations=VC.HWC2CHW(),
                input_columns="image",
                num_parallel_workers=num_parallel_workers)

    # apply DatasetOps
    ds = ds.shuffle(buffer_size=MNIST_CONFIG.buffer_size)  # 10000 as in LeNet train script
    ds = ds.batch(MNIST_CONFIG.batch_size, drop_remainder=True)
    ds = ds.repeat(MNIST_CONFIG.repeat_size)

    return ds
def test_mnist_dataset(remove_json_files=True):
    data_dir = "../data/dataset/testMnistData"
    ds.config.set_seed(1)

    data1 = ds.MnistDataset(data_dir, 100)
    one_hot_encode = c.OneHot(10)  # num_classes is input argument
    data1 = data1.map(input_columns="label", operations=one_hot_encode)

    # batch_size is input argument
    data1 = data1.batch(batch_size=10, drop_remainder=True)

    ds.serialize(data1, "mnist_dataset_pipeline.json")
    assert validate_jsonfile("mnist_dataset_pipeline.json") is True

    data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
    ds.serialize(data2, "mnist_dataset_pipeline_1.json")
    assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
    assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')

    data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")

    num = 0
    for data1, data2, data3 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
                                   data3.create_dict_iterator()):
        assert np.array_equal(data1['image'], data2['image'])
        assert np.array_equal(data1['image'], data3['image'])
        assert np.array_equal(data1['label'], data2['label'])
        assert np.array_equal(data1['label'], data3['label'])
        num += 1

    logger.info("mnist total num samples is {}".format(str(num)))
    assert num == 10

    if remove_json_files:
        delete_json_files()
Beispiel #7
0
def test_mnist_sequential_sampler():
    """
    Test MnistDataset with SequentialSampler
    """
    logger.info("Test MnistDataset Op with SequentialSampler")
    num_samples = 50
    sampler = ds.SequentialSampler(num_samples=num_samples)
    data1 = ds.MnistDataset(DATA_DIR, sampler=sampler)
    data2 = ds.MnistDataset(DATA_DIR, shuffle=False, num_samples=num_samples)
    label_list1, label_list2 = [], []
    num_iter = 0
    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
        label_list1.append(item1["label"].asnumpy())
        label_list2.append(item2["label"].asnumpy())
        num_iter += 1
    np.testing.assert_array_equal(label_list1, label_list2)
    assert num_iter == num_samples
Beispiel #8
0
def create_dataset(data_path,
                   batch_size=32,
                   repeat_size=1,
                   num_parallel_workers=1):
    """ create dataset for train or test
    Args:
        data_path: Data path
        batch_size: The number of data records in each group
        repeat_size: The number of replicated data records
        num_parallel_workers: The number of parallel workers
    """
    # define dataset
    mnist_ds = ds.MnistDataset(data_path)

    # define operation parameters
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # define map operations
    resize_op = CV.Resize(
        (resize_height, resize_width),
        interpolation=Inter.LINEAR)  # Resize images to (32, 32)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)  # normalize images
    rescale_op = CV.Rescale(rescale, shift)  # rescale images
    hwc2chw_op = CV.HWC2CHW(
    )  # change shape from (height, width, channel) to (channel, height, width) to fit network.
    type_cast_op = C.TypeCast(
        mstype.int32)  # change data type of label to int32 to fit network

    # apply map operations on images
    mnist_ds = mnist_ds.map(input_columns="label",
                            operations=type_cast_op,
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image",
                            operations=resize_op,
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image",
                            operations=rescale_op,
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image",
                            operations=rescale_nml_op,
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image",
                            operations=hwc2chw_op,
                            num_parallel_workers=num_parallel_workers)

    # apply DatasetOps
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(
        buffer_size=buffer_size)  # 10000 as in LeNet train script
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds
Beispiel #9
0
def create_dataset(data_path,
                   batch_size=32,
                   repeat_size=1,
                   num_parallel_workers=1):
    """
       create dataset for train or test

       Args:
           data_path (str): Data path
           batch_size (int): The number of data records in each group
           repeat_size (int): The number of replicated data records
           num_parallel_workers (int): The number of parallel workers
    """

    # define dataset
    mnist_ds = ds.MnistDataset(data_path)

    # define some parameters needed for data enhancement and rough justification
    resize_height = 32
    resize_width = 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1.0 * 0.1307 / 0.3081

    # according to the parameters, generate the corresponding data enhancement method
    resize_op = CV.Resize((resize_height, resize_width),
                          interpolation=Inter.LINEAR)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

    # using map to apply operations to dataset
    mnist_ds = mnist_ds.map(operations=type_cast_op,
                            input_columns="label",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=resize_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_nml_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=hwc2chw_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)

    # process the generated dataset
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds
 def test_config(usage, mnist_path=None):
     mnist_path = DATA_DIR if mnist_path is None else mnist_path
     try:
         data = ds.MnistDataset(mnist_path, usage=usage, shuffle=False)
         num_rows = 0
         for _ in data.create_dict_iterator():
             num_rows += 1
     except (ValueError, TypeError, RuntimeError) as e:
         return str(e)
     return num_rows
Beispiel #11
0
def test_mnist_usage():
    """
    Validate MnistDataset image readings
    """
    logger.info("Test MnistDataset usage flag")

    def test_config(usage, mnist_path=None):
        mnist_path = DATA_DIR if mnist_path is None else mnist_path
        try:
            data = ds.MnistDataset(mnist_path, usage=usage, shuffle=False)
            num_rows = 0
            for _ in data.create_dict_iterator(num_epochs=1,
                                               output_numpy=True):
                num_rows += 1
        except (ValueError, TypeError, RuntimeError) as e:
            return str(e)
        return num_rows

    assert test_config("test") == 10000
    assert test_config("all") == 10000
    assert " no valid data matching the dataset API MnistDataset" in test_config(
        "train")
    assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config(
        "invalid")
    assert "Argument usage with value ['list'] is not of type (<class 'str'>,)" in test_config(
        ["list"])

    # change this directory to the folder that contains all mnist files
    all_files_path = None
    # the following tests on the entire datasets
    if all_files_path is not None:
        assert test_config("train", all_files_path) == 60000
        assert test_config("test", all_files_path) == 10000
        assert test_config("all", all_files_path) == 70000
        assert ds.MnistDataset(all_files_path,
                               usage="train").get_dataset_size() == 60000
        assert ds.MnistDataset(all_files_path,
                               usage="test").get_dataset_size() == 10000
        assert ds.MnistDataset(all_files_path,
                               usage="all").get_dataset_size() == 70000
Beispiel #12
0
def generate_mnist_dataset(data_path,
                           batch_size=32,
                           repeat_size=1,
                           samples=None,
                           num_parallel_workers=1,
                           sparse=True):
    """
    create dataset for training or testing
    """
    # define dataset
    ds1 = ds.MnistDataset(data_path, num_samples=samples)

    # define operation parameters
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0

    # define map operations
    resize_op = CV.Resize((resize_height, resize_width),
                          interpolation=Inter.LINEAR)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

    # apply map operations on images
    if not sparse:
        one_hot_enco = C.OneHot(10)
        ds1 = ds1.map(input_columns="label",
                      operations=one_hot_enco,
                      num_parallel_workers=num_parallel_workers)
        type_cast_op = C.TypeCast(mstype.float32)
    ds1 = ds1.map(input_columns="label",
                  operations=type_cast_op,
                  num_parallel_workers=num_parallel_workers)
    ds1 = ds1.map(input_columns="image",
                  operations=resize_op,
                  num_parallel_workers=num_parallel_workers)
    ds1 = ds1.map(input_columns="image",
                  operations=rescale_op,
                  num_parallel_workers=num_parallel_workers)
    ds1 = ds1.map(input_columns="image",
                  operations=hwc2chw_op,
                  num_parallel_workers=num_parallel_workers)

    # apply DatasetOps
    buffer_size = 10000
    ds1 = ds1.shuffle(buffer_size=buffer_size)
    ds1 = ds1.batch(batch_size, drop_remainder=True)
    ds1 = ds1.repeat(repeat_size)

    return ds1
def test_random_affine_py_exception_non_pil_images():
    """
    Test RandomAffine: input img is ndarray and not PIL, expected to raise RuntimeError
    """
    logger.info("test_random_affine_exception_negative_degrees")
    dataset = ds.MnistDataset(MNIST_DATA_DIR, num_samples=3, num_parallel_workers=3)
    try:
        transform = mindspore.dataset.transforms.py_transforms.Compose([py_vision.ToTensor(),
                                                                        py_vision.RandomAffine(degrees=(15, 15))])
        dataset = dataset.map(operations=transform, input_columns=["image"], num_parallel_workers=3)
        for _ in dataset.create_dict_iterator(num_epochs=1):
            pass
    except RuntimeError as e:
        logger.info("Got an exception in DE: {}".format(str(e)))
        assert "Pillow image" in str(e)
def test_random_affine_py_exception_non_pil_images():
    """
    Test RandomAffine: input img is ndarray and not PIL, expected to raise RuntimeError
    """
    logger.info("test_random_affine_exception_negative_degrees")
    dataset = ds.MnistDataset(MNIST_DATA_DIR, num_parallel_workers=3)
    try:
        transform = py_vision.ComposeOp([py_vision.ToTensor(),
                                         py_vision.RandomAffine(degrees=(15, 15))])
        dataset = dataset.map(input_columns=["image"], operations=transform(), num_parallel_workers=3,
                              python_multiprocessing=True)
        for _ in dataset.create_dict_iterator():
            break
    except RuntimeError as e:
        logger.info("Got an exception in DE: {}".format(str(e)))
        assert "Pillow image" in str(e)
Beispiel #15
0
def test_mnist_pk_sampler():
    """
    Test MnistDataset with PKSampler
    """
    logger.info("Test MnistDataset Op with PKSampler")
    golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
              5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]
    sampler = ds.PKSampler(3)
    data = ds.MnistDataset(DATA_DIR, sampler=sampler)
    num_iter = 0
    label_list = []
    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
        label_list.append(item["label"])
        num_iter += 1
    np.testing.assert_array_equal(golden, label_list)
    assert num_iter == 30
Beispiel #16
0
def test_mnist_content_check():
    """
    Validate MnistDataset image readings
    """
    logger.info("Test MnistDataset Op with content check")
    data1 = ds.MnistDataset(DATA_DIR, num_samples=100, shuffle=False)
    images, labels = load_mnist(DATA_DIR)
    num_iter = 0
    # in this example, each dictionary has keys "image" and "label"
    image_list, label_list = [], []
    for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
        image_list.append(data["image"])
        label_list.append("label {}".format(data["label"]))
        np.testing.assert_array_equal(data["image"], images[i])
        np.testing.assert_array_equal(data["label"], labels[i])
        num_iter += 1
    assert num_iter == 100
 def sharding_config(num_shards,
                     shard_id,
                     num_samples,
                     shuffle,
                     repeat_cnt=1):
     data1 = ds.MnistDataset(mnist_dir,
                             num_shards=num_shards,
                             shard_id=shard_id,
                             num_samples=num_samples,
                             shuffle=shuffle)
     data1 = data1.repeat(repeat_cnt)
     res = []
     for item in data1.create_dict_iterator():  # each data is a dictionary
         res.append(item["label"].item())
     if print_res:
         logger.info("labels of dataset: {}".format(res))
     return res
Beispiel #18
0
def create_dataset(data_path,
                   batch_size=32,
                   repeat_size=1,
                   num_parallel_workers=1):
    """ create dataset for train or test
    Args:
        data_path (str): Data path
        batch_size (int): The number of data records in each group
        repeat_size (int): The number of replicated data records
        num_parallel_workers (int): The number of parallel workers
    """
    # define dataset
    mnist_ds = ds.MnistDataset(data_path)

    # define operation parameters
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # define map operations
    type_cast_op = C.TypeCast(mstype.int32)
    c_trans = [
        CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR),
        CV.Rescale(rescale_nml, shift_nml),
        CV.Rescale(rescale, shift),
        CV.HWC2CHW()
    ]

    # apply map operations on images
    mnist_ds = mnist_ds.map(operations=type_cast_op,
                            input_columns="label",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=c_trans,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)

    # apply DatasetOps
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(
        buffer_size=buffer_size)  # 10000 as in LeNet train script
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds
Beispiel #19
0
def create_mnist_dataset(mode='train', num_samples=2, batch_size=2):
    """create dataset for train or test"""
    mnist_path = '/home/workspace/mindspore_dataset/mnist'
    num_parallel_workers = 1

    # define dataset
    mnist_ds = ds.MnistDataset(os.path.join(mnist_path, mode),
                               num_samples=num_samples,
                               shuffle=False)

    resize_height, resize_width = 32, 32

    # define map operations
    resize_op = CV.Resize((resize_height, resize_width),
                          interpolation=Inter.LINEAR)  # Bilinear mode
    rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)
    rescale_op = CV.Rescale(1.0 / 255.0, shift=0.0)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

    # apply map operations on images
    mnist_ds = mnist_ds.map(operations=type_cast_op,
                            input_columns="label",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=resize_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_nml_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=hwc2chw_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)

    # apply DatasetOps
    mnist_ds = mnist_ds.batch(batch_size=batch_size, drop_remainder=True)

    return mnist_ds
Beispiel #20
0
def create_dataset(data_path, batch_size=1, num_parallel_workers=1):
    """create data"""
    # 定义数据集
    mnist_ds = ds.MnistDataset(data_path)
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # 定义所需要操作的map映射
    resize_op = CV.Resize((resize_height, resize_width),
                          interpolation=Inter.LINEAR)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

    # 使用map映射函数,将数据操作应用到数据集
    mnist_ds = mnist_ds.map(operations=type_cast_op,
                            input_columns="label",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=resize_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_nml_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=hwc2chw_op,
                            input_columns="image",
                            num_parallel_workers=num_parallel_workers)

    # 进行shuffle、batch操作
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)

    return mnist_ds
def test_mnist_visualize(plot=False):
    """
    Visualize MnistDataset results
    """
    logger.info("Test MnistDataset visualization")

    data1 = ds.MnistDataset(DATA_DIR, num_samples=10, shuffle=False)
    num_iter = 0
    image_list, label_list = [], []
    for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
        image = item["image"]
        label = item["label"]
        image_list.append(image)
        label_list.append("label {}".format(label))
        assert isinstance(image, np.ndarray)
        assert image.shape == (28, 28, 1)
        assert image.dtype == np.uint8
        assert label.dtype == np.uint32
        num_iter += 1
    assert num_iter == 10
    if plot:
        visualize_dataset(image_list, label_list)
Beispiel #22
0
def test_mnist_dataset_size():
    ds_total = ds.MnistDataset(MNIST_DATA_DIR)
    assert ds_total.get_dataset_size() == 10000

    # test get dataset_size with the usage arg
    test_size = ds.MnistDataset(MNIST_DATA_DIR, usage="test").get_dataset_size()
    assert test_size == 10000
    train_size = ds.MnistDataset(MNIST_DATA_DIR, usage="train").get_dataset_size()
    assert train_size == 0
    all_size = ds.MnistDataset(MNIST_DATA_DIR, usage="all").get_dataset_size()
    assert all_size == 10000

    ds_shard_1_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=1, shard_id=0)
    assert ds_shard_1_0.get_dataset_size() == 10000

    ds_shard_2_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=2, shard_id=0)
    assert ds_shard_2_0.get_dataset_size() == 5000

    ds_shard_3_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=3, shard_id=0)
    assert ds_shard_3_0.get_dataset_size() == 3334
Beispiel #23
0
def test_get_column_name_mnist():
    data = ds.MnistDataset(MNIST_DIR)
    assert data.get_col_names() == ["image", "label"]
Beispiel #24
0
 def get_test_set(self):
     dataset = ds.MnistDataset(dataset_dir=self.test_path)
     return DataSource.transform(dataset)
Beispiel #25
0
def test_mnist_exception():
    """
    Test error cases for MnistDataset
    """
    logger.info("Test error cases for MnistDataset")
    error_msg_1 = "sampler and shuffle cannot be specified at the same time"
    with pytest.raises(RuntimeError, match=error_msg_1):
        ds.MnistDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))

    error_msg_2 = "sampler and sharding cannot be specified at the same time"
    with pytest.raises(RuntimeError, match=error_msg_2):
        ds.MnistDataset(DATA_DIR,
                        sampler=ds.PKSampler(3),
                        num_shards=2,
                        shard_id=0)

    error_msg_3 = "num_shards is specified and currently requires shard_id as well"
    with pytest.raises(RuntimeError, match=error_msg_3):
        ds.MnistDataset(DATA_DIR, num_shards=10)

    error_msg_4 = "shard_id is specified but num_shards is not"
    with pytest.raises(RuntimeError, match=error_msg_4):
        ds.MnistDataset(DATA_DIR, shard_id=0)

    error_msg_5 = "Input shard_id is not within the required interval"
    with pytest.raises(ValueError, match=error_msg_5):
        ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=-1)
    with pytest.raises(ValueError, match=error_msg_5):
        ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=5)
    with pytest.raises(ValueError, match=error_msg_5):
        ds.MnistDataset(DATA_DIR, num_shards=2, shard_id=5)

    error_msg_6 = "num_parallel_workers exceeds"
    with pytest.raises(ValueError, match=error_msg_6):
        ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
    with pytest.raises(ValueError, match=error_msg_6):
        ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
    with pytest.raises(ValueError, match=error_msg_6):
        ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)

    error_msg_7 = "Argument shard_id"
    with pytest.raises(TypeError, match=error_msg_7):
        ds.MnistDataset(DATA_DIR, num_shards=2, shard_id="0")

    def exception_func(item):
        raise Exception("Error occur!")

    error_msg_8 = "The corresponding data files"
    with pytest.raises(RuntimeError, match=error_msg_8):
        data = ds.MnistDataset(DATA_DIR)
        data = data.map(operations=exception_func,
                        input_columns=["image"],
                        num_parallel_workers=1)
        for _ in data.__iter__():
            pass
    with pytest.raises(RuntimeError, match=error_msg_8):
        data = ds.MnistDataset(DATA_DIR)
        data = data.map(operations=vision.Decode(),
                        input_columns=["image"],
                        num_parallel_workers=1)
        data = data.map(operations=exception_func,
                        input_columns=["image"],
                        num_parallel_workers=1)
        for _ in data.__iter__():
            pass
    with pytest.raises(RuntimeError, match=error_msg_8):
        data = ds.MnistDataset(DATA_DIR)
        data = data.map(operations=exception_func,
                        input_columns=["label"],
                        num_parallel_workers=1)
        for _ in data.__iter__():
            pass
def test_mnist_exception():
    """
    Test error cases for MnistDataset
    """
    logger.info("Test error cases for MnistDataset")
    error_msg_1 = "sampler and shuffle cannot be specified at the same time"
    with pytest.raises(RuntimeError, match=error_msg_1):
        ds.MnistDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))

    error_msg_2 = "sampler and sharding cannot be specified at the same time"
    with pytest.raises(RuntimeError, match=error_msg_2):
        ds.MnistDataset(DATA_DIR,
                        sampler=ds.PKSampler(3),
                        num_shards=2,
                        shard_id=0)

    error_msg_3 = "num_shards is specified and currently requires shard_id as well"
    with pytest.raises(RuntimeError, match=error_msg_3):
        ds.MnistDataset(DATA_DIR, num_shards=10)

    error_msg_4 = "shard_id is specified but num_shards is not"
    with pytest.raises(RuntimeError, match=error_msg_4):
        ds.MnistDataset(DATA_DIR, shard_id=0)

    error_msg_5 = "Input shard_id is not within the required interval"
    with pytest.raises(ValueError, match=error_msg_5):
        ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=-1)
    with pytest.raises(ValueError, match=error_msg_5):
        ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=5)
    with pytest.raises(ValueError, match=error_msg_5):
        ds.MnistDataset(DATA_DIR, num_shards=2, shard_id=5)

    error_msg_6 = "num_parallel_workers exceeds"
    with pytest.raises(ValueError, match=error_msg_6):
        ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
    with pytest.raises(ValueError, match=error_msg_6):
        ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=65)
    with pytest.raises(ValueError, match=error_msg_6):
        ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)

    error_msg_7 = "Argument shard_id"
    with pytest.raises(TypeError, match=error_msg_7):
        ds.MnistDataset(DATA_DIR, num_shards=2, shard_id="0")
Beispiel #27
0
 def get_train_set(self, sampler):
     dataset = ds.MnistDataset(dataset_dir=self.train_path,
                               sampler=sampler.get())
     return DataSource.transform(dataset)