def test_random_color_c_errors(): """ Test that Cpp RandomColor errors with bad input """ with pytest.raises(TypeError) as error_info: vision.RandomColor((12)) assert "degrees must be either a tuple or a list." in str(error_info.value) with pytest.raises(TypeError) as error_info: vision.RandomColor(("col", 3)) assert "Argument degrees[0] with value col is not of type (<class 'int'>, <class 'float'>)." in str( error_info.value) with pytest.raises(ValueError) as error_info: vision.RandomColor((0.9, 0.1)) assert "degrees should be in (min,max) format. Got (max,min)." in str(error_info.value) with pytest.raises(ValueError) as error_info: vision.RandomColor((0.9,)) assert "degrees must be a sequence with length 2." in str(error_info.value) # RandomColor Cpp Op will fail with one channel input mnist_ds = ds.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False) mnist_ds = mnist_ds.map(operations=vision.RandomColor(), input_columns="image") with pytest.raises(RuntimeError) as error_info: for _ in enumerate(mnist_ds): pass assert "Invalid number of channels in input image" in str(error_info.value)
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test """ # define dataset mnist_ds = ds.MnistDataset(data_path) resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # define map operations resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # apply map operations on images mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # apply DatasetOps buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test """ # define dataset mnist_ds = ds.MnistDataset(data_path) resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 # define map operations resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() # apply map operations on images mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) # apply DatasetOps mnist_ds = mnist_ds.batch(batch_size) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds
def test_mnist_sampler_chain(): """ Test Mnist sampler chain """ logger.info("test_mnist_sampler_chain") sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1) child_sampler = ds.RandomSampler(replacement=True, num_samples=4) sampler.add_child(child_sampler) data1 = ds.MnistDataset(MNIST_DATA_DIR, sampler=sampler) # Verify dataset size data1_size = data1.get_dataset_size() logger.info("dataset size is: {}".format(data1_size)) assert data1_size == 3 # Verify number of rows assert sum([1 for _ in data1]) == 3 # Verify dataset contents res = [] for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): logger.info("item: {}".format(item)) res.append(item) logger.info("dataset: {}".format(res))
def create_mnist_dataset(mnist_dir, num_parallel_workers=1): """create mnist dataset method""" ds = de.MnistDataset(mnist_dir) # apply map operations on images ds = ds.map(operations=C.TypeCast(mstype.int32), input_columns="label") ds = ds.map(operations=VC.Resize((MNIST_CONFIG.image_height, MNIST_CONFIG.image_width), interpolation=Inter.LINEAR), input_columns="image", num_parallel_workers=num_parallel_workers) ds = ds.map(operations=VC.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081), input_columns="image", num_parallel_workers=num_parallel_workers) ds = ds.map(operations=VC.Rescale(1.0 / 255.0, 0.0), input_columns="image", num_parallel_workers=num_parallel_workers) ds = ds.map(operations=VC.HWC2CHW(), input_columns="image", num_parallel_workers=num_parallel_workers) # apply DatasetOps ds = ds.shuffle(buffer_size=MNIST_CONFIG.buffer_size) # 10000 as in LeNet train script ds = ds.batch(MNIST_CONFIG.batch_size, drop_remainder=True) ds = ds.repeat(MNIST_CONFIG.repeat_size) return ds
def test_mnist_dataset(remove_json_files=True): data_dir = "../data/dataset/testMnistData" ds.config.set_seed(1) data1 = ds.MnistDataset(data_dir, 100) one_hot_encode = c.OneHot(10) # num_classes is input argument data1 = data1.map(input_columns="label", operations=one_hot_encode) # batch_size is input argument data1 = data1.batch(batch_size=10, drop_remainder=True) ds.serialize(data1, "mnist_dataset_pipeline.json") assert validate_jsonfile("mnist_dataset_pipeline.json") is True data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json") ds.serialize(data2, "mnist_dataset_pipeline_1.json") assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json') data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json") num = 0 for data1, data2, data3 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(), data3.create_dict_iterator()): assert np.array_equal(data1['image'], data2['image']) assert np.array_equal(data1['image'], data3['image']) assert np.array_equal(data1['label'], data2['label']) assert np.array_equal(data1['label'], data3['label']) num += 1 logger.info("mnist total num samples is {}".format(str(num))) assert num == 10 if remove_json_files: delete_json_files()
def test_mnist_sequential_sampler(): """ Test MnistDataset with SequentialSampler """ logger.info("Test MnistDataset Op with SequentialSampler") num_samples = 50 sampler = ds.SequentialSampler(num_samples=num_samples) data1 = ds.MnistDataset(DATA_DIR, sampler=sampler) data2 = ds.MnistDataset(DATA_DIR, shuffle=False, num_samples=num_samples) label_list1, label_list2 = [], [] num_iter = 0 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)): label_list1.append(item1["label"].asnumpy()) label_list2.append(item2["label"].asnumpy()) num_iter += 1 np.testing.assert_array_equal(label_list1, label_list2) assert num_iter == num_samples
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test Args: data_path: Data path batch_size: The number of data records in each group repeat_size: The number of replicated data records num_parallel_workers: The number of parallel workers """ # define dataset mnist_ds = ds.MnistDataset(data_path) # define operation parameters resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # define map operations resize_op = CV.Resize( (resize_height, resize_width), interpolation=Inter.LINEAR) # Resize images to (32, 32) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images rescale_op = CV.Rescale(rescale, shift) # rescale images hwc2chw_op = CV.HWC2CHW( ) # change shape from (height, width, channel) to (channel, height, width) to fit network. type_cast_op = C.TypeCast( mstype.int32) # change data type of label to int32 to fit network # apply map operations on images mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) # apply DatasetOps buffer_size = 10000 mnist_ds = mnist_ds.shuffle( buffer_size=buffer_size) # 10000 as in LeNet train script mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test Args: data_path (str): Data path batch_size (int): The number of data records in each group repeat_size (int): The number of replicated data records num_parallel_workers (int): The number of parallel workers """ # define dataset mnist_ds = ds.MnistDataset(data_path) # define some parameters needed for data enhancement and rough justification resize_height = 32 resize_width = 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1.0 * 0.1307 / 0.3081 # according to the parameters, generate the corresponding data enhancement method resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # using map to apply operations to dataset mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # process the generated dataset buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds
def test_config(usage, mnist_path=None): mnist_path = DATA_DIR if mnist_path is None else mnist_path try: data = ds.MnistDataset(mnist_path, usage=usage, shuffle=False) num_rows = 0 for _ in data.create_dict_iterator(): num_rows += 1 except (ValueError, TypeError, RuntimeError) as e: return str(e) return num_rows
def test_mnist_usage(): """ Validate MnistDataset image readings """ logger.info("Test MnistDataset usage flag") def test_config(usage, mnist_path=None): mnist_path = DATA_DIR if mnist_path is None else mnist_path try: data = ds.MnistDataset(mnist_path, usage=usage, shuffle=False) num_rows = 0 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): num_rows += 1 except (ValueError, TypeError, RuntimeError) as e: return str(e) return num_rows assert test_config("test") == 10000 assert test_config("all") == 10000 assert " no valid data matching the dataset API MnistDataset" in test_config( "train") assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config( "invalid") assert "Argument usage with value ['list'] is not of type (<class 'str'>,)" in test_config( ["list"]) # change this directory to the folder that contains all mnist files all_files_path = None # the following tests on the entire datasets if all_files_path is not None: assert test_config("train", all_files_path) == 60000 assert test_config("test", all_files_path) == 10000 assert test_config("all", all_files_path) == 70000 assert ds.MnistDataset(all_files_path, usage="train").get_dataset_size() == 60000 assert ds.MnistDataset(all_files_path, usage="test").get_dataset_size() == 10000 assert ds.MnistDataset(all_files_path, usage="all").get_dataset_size() == 70000
def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, samples=None, num_parallel_workers=1, sparse=True): """ create dataset for training or testing """ # define dataset ds1 = ds.MnistDataset(data_path, num_samples=samples) # define operation parameters resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 # define map operations resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # apply map operations on images if not sparse: one_hot_enco = C.OneHot(10) ds1 = ds1.map(input_columns="label", operations=one_hot_enco, num_parallel_workers=num_parallel_workers) type_cast_op = C.TypeCast(mstype.float32) ds1 = ds1.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) ds1 = ds1.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) ds1 = ds1.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) ds1 = ds1.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) # apply DatasetOps buffer_size = 10000 ds1 = ds1.shuffle(buffer_size=buffer_size) ds1 = ds1.batch(batch_size, drop_remainder=True) ds1 = ds1.repeat(repeat_size) return ds1
def test_random_affine_py_exception_non_pil_images(): """ Test RandomAffine: input img is ndarray and not PIL, expected to raise RuntimeError """ logger.info("test_random_affine_exception_negative_degrees") dataset = ds.MnistDataset(MNIST_DATA_DIR, num_samples=3, num_parallel_workers=3) try: transform = mindspore.dataset.transforms.py_transforms.Compose([py_vision.ToTensor(), py_vision.RandomAffine(degrees=(15, 15))]) dataset = dataset.map(operations=transform, input_columns=["image"], num_parallel_workers=3) for _ in dataset.create_dict_iterator(num_epochs=1): pass except RuntimeError as e: logger.info("Got an exception in DE: {}".format(str(e))) assert "Pillow image" in str(e)
def test_random_affine_py_exception_non_pil_images(): """ Test RandomAffine: input img is ndarray and not PIL, expected to raise RuntimeError """ logger.info("test_random_affine_exception_negative_degrees") dataset = ds.MnistDataset(MNIST_DATA_DIR, num_parallel_workers=3) try: transform = py_vision.ComposeOp([py_vision.ToTensor(), py_vision.RandomAffine(degrees=(15, 15))]) dataset = dataset.map(input_columns=["image"], operations=transform(), num_parallel_workers=3, python_multiprocessing=True) for _ in dataset.create_dict_iterator(): break except RuntimeError as e: logger.info("Got an exception in DE: {}".format(str(e))) assert "Pillow image" in str(e)
def test_mnist_pk_sampler(): """ Test MnistDataset with PKSampler """ logger.info("Test MnistDataset Op with PKSampler") golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9] sampler = ds.PKSampler(3) data = ds.MnistDataset(DATA_DIR, sampler=sampler) num_iter = 0 label_list = [] for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): label_list.append(item["label"]) num_iter += 1 np.testing.assert_array_equal(golden, label_list) assert num_iter == 30
def test_mnist_content_check(): """ Validate MnistDataset image readings """ logger.info("Test MnistDataset Op with content check") data1 = ds.MnistDataset(DATA_DIR, num_samples=100, shuffle=False) images, labels = load_mnist(DATA_DIR) num_iter = 0 # in this example, each dictionary has keys "image" and "label" image_list, label_list = [], [] for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)): image_list.append(data["image"]) label_list.append("label {}".format(data["label"])) np.testing.assert_array_equal(data["image"], images[i]) np.testing.assert_array_equal(data["label"], labels[i]) num_iter += 1 assert num_iter == 100
def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1): data1 = ds.MnistDataset(mnist_dir, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples, shuffle=shuffle) data1 = data1.repeat(repeat_cnt) res = [] for item in data1.create_dict_iterator(): # each data is a dictionary res.append(item["label"].item()) if print_res: logger.info("labels of dataset: {}".format(res)) return res
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test Args: data_path (str): Data path batch_size (int): The number of data records in each group repeat_size (int): The number of replicated data records num_parallel_workers (int): The number of parallel workers """ # define dataset mnist_ds = ds.MnistDataset(data_path) # define operation parameters resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # define map operations type_cast_op = C.TypeCast(mstype.int32) c_trans = [ CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR), CV.Rescale(rescale_nml, shift_nml), CV.Rescale(rescale, shift), CV.HWC2CHW() ] # apply map operations on images mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=c_trans, input_columns="image", num_parallel_workers=num_parallel_workers) # apply DatasetOps buffer_size = 10000 mnist_ds = mnist_ds.shuffle( buffer_size=buffer_size) # 10000 as in LeNet train script mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds
def create_mnist_dataset(mode='train', num_samples=2, batch_size=2): """create dataset for train or test""" mnist_path = '/home/workspace/mindspore_dataset/mnist' num_parallel_workers = 1 # define dataset mnist_ds = ds.MnistDataset(os.path.join(mnist_path, mode), num_samples=num_samples, shuffle=False) resize_height, resize_width = 32, 32 # define map operations resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081) rescale_op = CV.Rescale(1.0 / 255.0, shift=0.0) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # apply map operations on images mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # apply DatasetOps mnist_ds = mnist_ds.batch(batch_size=batch_size, drop_remainder=True) return mnist_ds
def create_dataset(data_path, batch_size=1, num_parallel_workers=1): """create data""" # 定义数据集 mnist_ds = ds.MnistDataset(data_path) resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射 resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch操作 buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) return mnist_ds
def test_mnist_visualize(plot=False): """ Visualize MnistDataset results """ logger.info("Test MnistDataset visualization") data1 = ds.MnistDataset(DATA_DIR, num_samples=10, shuffle=False) num_iter = 0 image_list, label_list = [], [] for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): image = item["image"] label = item["label"] image_list.append(image) label_list.append("label {}".format(label)) assert isinstance(image, np.ndarray) assert image.shape == (28, 28, 1) assert image.dtype == np.uint8 assert label.dtype == np.uint32 num_iter += 1 assert num_iter == 10 if plot: visualize_dataset(image_list, label_list)
def test_mnist_dataset_size(): ds_total = ds.MnistDataset(MNIST_DATA_DIR) assert ds_total.get_dataset_size() == 10000 # test get dataset_size with the usage arg test_size = ds.MnistDataset(MNIST_DATA_DIR, usage="test").get_dataset_size() assert test_size == 10000 train_size = ds.MnistDataset(MNIST_DATA_DIR, usage="train").get_dataset_size() assert train_size == 0 all_size = ds.MnistDataset(MNIST_DATA_DIR, usage="all").get_dataset_size() assert all_size == 10000 ds_shard_1_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=1, shard_id=0) assert ds_shard_1_0.get_dataset_size() == 10000 ds_shard_2_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=2, shard_id=0) assert ds_shard_2_0.get_dataset_size() == 5000 ds_shard_3_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=3, shard_id=0) assert ds_shard_3_0.get_dataset_size() == 3334
def test_get_column_name_mnist(): data = ds.MnistDataset(MNIST_DIR) assert data.get_col_names() == ["image", "label"]
def get_test_set(self): dataset = ds.MnistDataset(dataset_dir=self.test_path) return DataSource.transform(dataset)
def test_mnist_exception(): """ Test error cases for MnistDataset """ logger.info("Test error cases for MnistDataset") error_msg_1 = "sampler and shuffle cannot be specified at the same time" with pytest.raises(RuntimeError, match=error_msg_1): ds.MnistDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3)) error_msg_2 = "sampler and sharding cannot be specified at the same time" with pytest.raises(RuntimeError, match=error_msg_2): ds.MnistDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0) error_msg_3 = "num_shards is specified and currently requires shard_id as well" with pytest.raises(RuntimeError, match=error_msg_3): ds.MnistDataset(DATA_DIR, num_shards=10) error_msg_4 = "shard_id is specified but num_shards is not" with pytest.raises(RuntimeError, match=error_msg_4): ds.MnistDataset(DATA_DIR, shard_id=0) error_msg_5 = "Input shard_id is not within the required interval" with pytest.raises(ValueError, match=error_msg_5): ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=-1) with pytest.raises(ValueError, match=error_msg_5): ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=5) with pytest.raises(ValueError, match=error_msg_5): ds.MnistDataset(DATA_DIR, num_shards=2, shard_id=5) error_msg_6 = "num_parallel_workers exceeds" with pytest.raises(ValueError, match=error_msg_6): ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=0) with pytest.raises(ValueError, match=error_msg_6): ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=256) with pytest.raises(ValueError, match=error_msg_6): ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2) error_msg_7 = "Argument shard_id" with pytest.raises(TypeError, match=error_msg_7): ds.MnistDataset(DATA_DIR, num_shards=2, shard_id="0") def exception_func(item): raise Exception("Error occur!") error_msg_8 = "The corresponding data files" with pytest.raises(RuntimeError, match=error_msg_8): data = ds.MnistDataset(DATA_DIR) data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) for _ in data.__iter__(): pass with pytest.raises(RuntimeError, match=error_msg_8): data = ds.MnistDataset(DATA_DIR) data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) for _ in data.__iter__(): pass with pytest.raises(RuntimeError, match=error_msg_8): data = ds.MnistDataset(DATA_DIR) data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1) for _ in data.__iter__(): pass
def test_mnist_exception(): """ Test error cases for MnistDataset """ logger.info("Test error cases for MnistDataset") error_msg_1 = "sampler and shuffle cannot be specified at the same time" with pytest.raises(RuntimeError, match=error_msg_1): ds.MnistDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3)) error_msg_2 = "sampler and sharding cannot be specified at the same time" with pytest.raises(RuntimeError, match=error_msg_2): ds.MnistDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0) error_msg_3 = "num_shards is specified and currently requires shard_id as well" with pytest.raises(RuntimeError, match=error_msg_3): ds.MnistDataset(DATA_DIR, num_shards=10) error_msg_4 = "shard_id is specified but num_shards is not" with pytest.raises(RuntimeError, match=error_msg_4): ds.MnistDataset(DATA_DIR, shard_id=0) error_msg_5 = "Input shard_id is not within the required interval" with pytest.raises(ValueError, match=error_msg_5): ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=-1) with pytest.raises(ValueError, match=error_msg_5): ds.MnistDataset(DATA_DIR, num_shards=5, shard_id=5) with pytest.raises(ValueError, match=error_msg_5): ds.MnistDataset(DATA_DIR, num_shards=2, shard_id=5) error_msg_6 = "num_parallel_workers exceeds" with pytest.raises(ValueError, match=error_msg_6): ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=0) with pytest.raises(ValueError, match=error_msg_6): ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=65) with pytest.raises(ValueError, match=error_msg_6): ds.MnistDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2) error_msg_7 = "Argument shard_id" with pytest.raises(TypeError, match=error_msg_7): ds.MnistDataset(DATA_DIR, num_shards=2, shard_id="0")
def get_train_set(self, sampler): dataset = ds.MnistDataset(dataset_dir=self.train_path, sampler=sampler.get()) return DataSource.transform(dataset)