def _test_dedupe_column_names(tmpdir, input_column_names: List[str], input_data: List[int], expected_column_names: List[str], expected_data: List[int], dedupe_column_names: bool = True, **kwargs) -> None: header_str = ','.join(input_column_names) data_str = ','.join(str(x) for x in input_data) csv_file = tmpdir.join("test.csv") csv_file.write(header_str + '\n' + data_str) dataset = [mlio.File(str(csv_file))] reader_params = mlio.DataReaderParams(dataset=dataset, batch_size=1) csv_params = mlio.CsvParams(dedupe_column_names=dedupe_column_names, **kwargs) reader = mlio.CsvReader(reader_params, csv_params) example = reader.read_example() names = [desc.name for desc in example.schema.descriptors] assert names == expected_column_names record = [as_numpy(feature) for feature in example] assert np.all(np.array(record).squeeze() == np.array(expected_data))
def test_csv_params(): filename = os.path.join(resources_dir, 'test.csv') dataset = [mlio.File(filename)] rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=1) csv_prm = mlio.CsvParams(header_row_index=None) reader = mlio.CsvReader(rdr_prm, csv_prm) example = reader.read_example() record = [as_numpy(feature) for feature in example] assert np.all(np.array(record).squeeze() == np.array([1, 0, 0, 0])) reader2 = mlio.CsvReader(rdr_prm, csv_prm) assert reader2.peek_example()
def test_recordio_protobuf_reader_params(): filename = os.path.join(resources_dir, 'test.pbr') dataset = [mlio.File(filename)] rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=1) reader = mlio.RecordIOProtobufReader(rdr_prm) example = reader.read_example() record = [as_numpy(feature) for feature in example] assert record[0].squeeze() == np.array(1) assert np.all(record[1].squeeze() == np.array([0, 0, 0])) # Parameters should be reusable reader2 = mlio.RecordIOProtobufReader(rdr_prm) assert reader2.peek_example()
def test_image_reader_jpeg_no_resize(): filename = os.path.join(resources_dir, 'test_image_0.jpg') dataset = [mlio.File(filename)] rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=1) img_prm = mlio.ImageReaderParams(img_frame=mlio.ImageFrame.NONE, image_dimensions=[3, 50, 50], to_rgb=1) reader = mlio.ImageReader(rdr_prm, img_prm) example = reader.read_example() tensor = example['value'] assert tensor.shape == (1, 50, 50, 3) assert tensor.strides == (7500, 150, 3, 1)
def test_image_reader_recordio(): filename = os.path.join(resources_dir, 'test_image_0.rec') dataset = [mlio.File(filename)] rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=1) img_prm = mlio.ImageReaderParams(img_frame=mlio.ImageFrame.RECORDIO, resize=100, image_dimensions=[3, 100, 100], to_rgb=1) reader = mlio.ImageReader(rdr_prm, img_prm) example = reader.read_example() tensor = example['value'] assert tensor.shape == (1, 100, 100, 3) assert tensor.strides == (30000, 300, 3, 1)
def test_csv_nonutf_encoding_with_encoding_param(): filename = os.path.join(resources_dir, 'test_iso8859_5.csv') dataset = [mlio.File(filename)] rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=2) csv_params = mlio.CsvParams(encoding='ISO-8859-5') reader = mlio.CsvReader(rdr_prm, csv_params) example = reader.read_example() nonutf_feature = example['col_3'] try: feature_np = as_numpy(nonutf_feature) except SystemError as err: pytest.fail("Unexpected exception thrown")
def test_data_reader_params_members(): filename = os.path.join(resources_dir, 'test.pbr') dataset = [mlio.File(filename)] rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=1) assert rdr_prm.dataset == dataset assert rdr_prm.batch_size == 1 assert rdr_prm.num_prefetched_batches == 0 assert rdr_prm.num_parallel_reads == 0 assert rdr_prm.last_batch_handling == mlio.LastBatchHandling.NONE assert rdr_prm.bad_batch_handling == mlio.BadBatchHandling.ERROR assert rdr_prm.num_instances_to_skip == 0 assert rdr_prm.num_instances_to_read is None assert rdr_prm.shard_index == 0 assert rdr_prm.num_shards == 0 assert rdr_prm.shuffle_instances is False assert rdr_prm.shuffle_window == 0 assert rdr_prm.shuffle_seed is None assert rdr_prm.reshuffle_each_epoch is True assert rdr_prm.subsample_ratio is None rdr_prm.batch_size = 2 assert rdr_prm.batch_size == 2