Beispiel #1
0
def _test_dedupe_column_names(tmpdir,
                              input_column_names: List[str],
                              input_data: List[int],
                              expected_column_names: List[str],
                              expected_data: List[int],
                              dedupe_column_names: bool = True,
                              **kwargs) -> None:

    header_str = ','.join(input_column_names)
    data_str = ','.join(str(x) for x in input_data)
    csv_file = tmpdir.join("test.csv")
    csv_file.write(header_str + '\n' + data_str)

    dataset = [mlio.File(str(csv_file))]
    reader_params = mlio.DataReaderParams(dataset=dataset, batch_size=1)
    csv_params = mlio.CsvParams(dedupe_column_names=dedupe_column_names,
                                **kwargs)
    reader = mlio.CsvReader(reader_params, csv_params)

    example = reader.read_example()
    names = [desc.name for desc in example.schema.descriptors]
    assert names == expected_column_names

    record = [as_numpy(feature) for feature in example]
    assert np.all(np.array(record).squeeze() == np.array(expected_data))
def test_csv_params():
    filename = os.path.join(resources_dir, 'test.csv')
    dataset = [mlio.File(filename)]
    rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=1)
    csv_prm = mlio.CsvParams(header_row_index=None)
    reader = mlio.CsvReader(rdr_prm, csv_prm)

    example = reader.read_example()
    record = [as_numpy(feature) for feature in example]
    assert np.all(np.array(record).squeeze() == np.array([1, 0, 0, 0]))

    reader2 = mlio.CsvReader(rdr_prm, csv_prm)
    assert reader2.peek_example()
def test_recordio_protobuf_reader_params():
    filename = os.path.join(resources_dir, 'test.pbr')
    dataset = [mlio.File(filename)]
    rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=1)
    reader = mlio.RecordIOProtobufReader(rdr_prm)

    example = reader.read_example()
    record = [as_numpy(feature) for feature in example]
    assert record[0].squeeze() == np.array(1)
    assert np.all(record[1].squeeze() == np.array([0, 0, 0]))

    # Parameters should be reusable
    reader2 = mlio.RecordIOProtobufReader(rdr_prm)
    assert reader2.peek_example()
Beispiel #4
0
def test_image_reader_jpeg_no_resize():
    filename = os.path.join(resources_dir, 'test_image_0.jpg')
    dataset = [mlio.File(filename)]
    rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=1)
    img_prm = mlio.ImageReaderParams(img_frame=mlio.ImageFrame.NONE,
                                     image_dimensions=[3, 50, 50],
                                     to_rgb=1)

    reader = mlio.ImageReader(rdr_prm, img_prm)
    example = reader.read_example()
    tensor = example['value']

    assert tensor.shape == (1, 50, 50, 3)
    assert tensor.strides == (7500, 150, 3, 1)
Beispiel #5
0
def test_image_reader_recordio():
    filename = os.path.join(resources_dir, 'test_image_0.rec')
    dataset = [mlio.File(filename)]
    rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=1)
    img_prm = mlio.ImageReaderParams(img_frame=mlio.ImageFrame.RECORDIO,
                                     resize=100,
                                     image_dimensions=[3, 100, 100],
                                     to_rgb=1)

    reader = mlio.ImageReader(rdr_prm, img_prm)
    example = reader.read_example()
    tensor = example['value']

    assert tensor.shape == (1, 100, 100, 3)
    assert tensor.strides == (30000, 300, 3, 1)
Beispiel #6
0
def test_csv_nonutf_encoding_with_encoding_param():
    filename = os.path.join(resources_dir, 'test_iso8859_5.csv')
    dataset = [mlio.File(filename)]
    rdr_prm = mlio.DataReaderParams(dataset=dataset,
                                    batch_size=2)
    csv_params = mlio.CsvParams(encoding='ISO-8859-5')

    reader = mlio.CsvReader(rdr_prm, csv_params)
    example = reader.read_example()
    nonutf_feature = example['col_3']

    try:
        feature_np = as_numpy(nonutf_feature)
    except SystemError as err:
        pytest.fail("Unexpected exception thrown")
def test_data_reader_params_members():
    filename = os.path.join(resources_dir, 'test.pbr')
    dataset = [mlio.File(filename)]
    rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=1)

    assert rdr_prm.dataset == dataset
    assert rdr_prm.batch_size == 1
    assert rdr_prm.num_prefetched_batches == 0
    assert rdr_prm.num_parallel_reads == 0
    assert rdr_prm.last_batch_handling == mlio.LastBatchHandling.NONE
    assert rdr_prm.bad_batch_handling == mlio.BadBatchHandling.ERROR
    assert rdr_prm.num_instances_to_skip == 0
    assert rdr_prm.num_instances_to_read is None
    assert rdr_prm.shard_index == 0
    assert rdr_prm.num_shards == 0
    assert rdr_prm.shuffle_instances is False
    assert rdr_prm.shuffle_window == 0
    assert rdr_prm.shuffle_seed is None
    assert rdr_prm.reshuffle_each_epoch is True
    assert rdr_prm.subsample_ratio is None

    rdr_prm.batch_size = 2
    assert rdr_prm.batch_size == 2