示例#1
0
def test_csv_dataset_size():
    dataset = ds.CSVDataset(CSV_FILE,
                            column_defaults=["0", 0, 0.0, "0"],
                            column_names=['1', '2', '3', '4'],
                            shuffle=False)
    assert dataset.get_dataset_size() == 3

    dataset_shard_2_0 = ds.CSVDataset(CSV_FILE,
                                      column_defaults=["0", 0, 0.0, "0"],
                                      column_names=['1', '2', '3', '4'],
                                      shuffle=False,
                                      num_shards=2,
                                      shard_id=0)
    assert dataset_shard_2_0.get_dataset_size() == 2
示例#2
0
def test_csv_dataset_size():
    TEST_FILE = '../data/dataset/testCSV/size.csv'
    data = ds.CSVDataset(TEST_FILE,
                         column_defaults=[0.0, 0.0, 0, 0.0],
                         column_names=['col1', 'col2', 'col3', 'col4'],
                         shuffle=False)
    assert data.get_dataset_size() == 5
示例#3
0
def test_get_column_name_zip():
    data1 = ds.Cifar10Dataset(CIFAR10_DIR)
    assert data1.get_col_names() == ["image", "label"]
    data2 = ds.CSVDataset(CSV_DIR)
    assert data2.get_col_names() == ["1", "2", "3", "4"]
    data = ds.zip((data1, data2))
    assert data.get_col_names() == ["image", "label", "1", "2", "3", "4"]
def test_serdes_csv_dataset(remove_json_files=True):
    """
    Test serdes on Csvdataset pipeline.
    """
    DATA_DIR = "../data/dataset/testCSV/1.csv"
    data1 = ds.CSVDataset(DATA_DIR,
                          column_defaults=["1", "2", "3", "4"],
                          column_names=['col1', 'col2', 'col3', 'col4'],
                          shuffle=False)
    columns = ["col1", "col4", "col2"]
    data1 = data1.project(columns=columns)
    data2 = util_check_serialize_deserialize_file(data1,
                                                  "csv_dataset_pipeline",
                                                  remove_json_files)

    num_samples = 0
    # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
    for item1, item2 in zip(
            data1.create_dict_iterator(num_epochs=1, output_numpy=True),
            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
        np.testing.assert_array_equal(item1['col1'], item2['col1'])
        np.testing.assert_array_equal(item1['col2'], item2['col2'])
        np.testing.assert_array_equal(item1['col4'], item2['col4'])
        num_samples += 1

    assert num_samples == 3
示例#5
0
def test_csv_dataset_one_file():
    data = ds.CSVDataset(DATA_FILE,
                         column_defaults=["1", "2", "3", "4"],
                         column_names=['col1', 'col2', 'col3', 'col4'],
                         shuffle=False)
    buffer = []
    for d in data.create_dict_iterator():
        buffer.append(d)
    assert len(buffer) == 3
示例#6
0
def test_csv_dataset_all_file():
    APPEND_FILE = '../data/dataset/testCSV/2.csv'
    data = ds.CSVDataset([DATA_FILE, APPEND_FILE],
                         column_defaults=["1", "2", "3", "4"],
                         column_names=['col1', 'col2', 'col3', 'col4'],
                         shuffle=False)
    buffer = []
    for d in data.create_dict_iterator():
        buffer.append(d)
    assert len(buffer) == 10
示例#7
0
def test_csv_dataset_type_error():
    TEST_FILE = '../data/dataset/testCSV/exception.csv'
    data = ds.CSVDataset(TEST_FILE,
                         column_defaults=["", 0, "", ""],
                         column_names=['col1', 'col2', 'col3', 'col4'],
                         shuffle=False)
    with pytest.raises(Exception) as err:
        for _ in data.create_dict_iterator():
            pass
    assert "type does not match" in str(err.value)
示例#8
0
def test_csv_dataset_exception():
    TEST_FILE = '../data/dataset/testCSV/exception.csv'
    data = ds.CSVDataset(TEST_FILE,
                         column_defaults=["", "", "", ""],
                         column_names=['col1', 'col2', 'col3', 'col4'],
                         shuffle=False)
    with pytest.raises(Exception) as err:
        for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
            pass
    assert "failed to parse file" in str(err.value)
示例#9
0
def test_csv_dataset_num_samples():
    data = ds.CSVDataset(DATA_FILE,
                         column_defaults=["1", "2", "3", "4"],
                         column_names=['col1', 'col2', 'col3', 'col4'],
                         shuffle=False,
                         num_samples=2)
    count = 0
    for _ in data.create_dict_iterator():
        count += 1
    assert count == 2
示例#10
0
def test_csv_dataset_header():
    TEST_FILE = '../data/dataset/testCSV/header.csv'
    data = ds.CSVDataset(TEST_FILE,
                         column_defaults=["", "", "", ""],
                         shuffle=False)
    buffer = []
    for d in data.create_dict_iterator():
        buffer.extend([
            d['col1'].item().decode("utf8"), d['col2'].item().decode("utf8"),
            d['col3'].item().decode("utf8"), d['col4'].item().decode("utf8")
        ])
    assert buffer == ['a', 'b', 'c', 'd']
示例#11
0
def test_csv_dataset_distribution():
    TEST_FILE = '../data/dataset/testCSV/1.csv'
    data = ds.CSVDataset(TEST_FILE,
                         column_defaults=["1", "2", "3", "4"],
                         column_names=['col1', 'col2', 'col3', 'col4'],
                         shuffle=False,
                         num_shards=2,
                         shard_id=0)
    count = 0
    for _ in data.create_dict_iterator():
        count += 1
    assert count == 2
示例#12
0
def test_csv_dataset_embedded():
    TEST_FILE = '../data/dataset/testCSV/embedded.csv'
    data = ds.CSVDataset(TEST_FILE,
                         column_defaults=["", "", "", ""],
                         column_names=['col1', 'col2', 'col3', 'col4'],
                         shuffle=False)
    buffer = []
    for d in data.create_dict_iterator():
        buffer.extend([
            d['col1'].item().decode("utf8"), d['col2'].item().decode("utf8"),
            d['col3'].item().decode("utf8"), d['col4'].item().decode("utf8")
        ])
    assert buffer == ['a,b', 'c"d', 'e\nf', ' g ']
示例#13
0
def test_csv_dataset_quoted():
    TEST_FILE = '../data/dataset/testCSV/quoted.csv'
    data = ds.CSVDataset(TEST_FILE,
                         column_defaults=["", "", "", ""],
                         column_names=['col1', 'col2', 'col3', 'col4'],
                         shuffle=False)
    buffer = []
    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
        buffer.extend([
            d['col1'].item().decode("utf8"), d['col2'].item().decode("utf8"),
            d['col3'].item().decode("utf8"), d['col4'].item().decode("utf8")
        ])
    assert buffer == ['a', 'b', 'c', 'd']
示例#14
0
def test_csv_dataset_duplicate_columns():
    data = ds.CSVDataset(DATA_FILE,
                         column_defaults=["1", "2", "3", "4"],
                         column_names=[
                             'col1', 'col2', 'col3', 'col4', 'col1', 'col2',
                             'col3', 'col4'
                         ],
                         shuffle=False)
    with pytest.raises(RuntimeError) as info:
        _ = data.create_dict_iterator(num_epochs=1, output_numpy=True)
    assert "Invalid parameter, duplicate column names are not allowed: col1" in str(
        info.value)
    assert "column_names" in str(info.value)
示例#15
0
def test_csv_dataset_number():
    TEST_FILE = '../data/dataset/testCSV/number.csv'
    data = ds.CSVDataset(TEST_FILE,
                         column_defaults=[0.0, 0.0, 0, 0.0],
                         column_names=['col1', 'col2', 'col3', 'col4'],
                         shuffle=False)
    buffer = []
    for d in data.create_dict_iterator():
        buffer.extend([
            d['col1'].item(), d['col2'].item(), d['col3'].item(),
            d['col4'].item()
        ])
    assert np.allclose(buffer, [3.0, 0.3, 4, 55.5])
示例#16
0
def test_csv_dataset_chinese():
    TEST_FILE = '../data/dataset/testCSV/chinese.csv'
    data = ds.CSVDataset(TEST_FILE,
                         column_defaults=["", "", "", "", ""],
                         column_names=['col1', 'col2', 'col3', 'col4', 'col5'],
                         shuffle=False)
    buffer = []
    for d in data.create_dict_iterator():
        buffer.extend([
            d['col1'].item().decode("utf8"), d['col2'].item().decode("utf8"),
            d['col3'].item().decode("utf8"), d['col4'].item().decode("utf8"),
            d['col5'].item().decode("utf8")
        ])
    assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好']
示例#17
0
def test_csv_dataset_basic():
    """
    Test CSV with repeat, skip and so on
    """
    TRAIN_FILE = '../data/dataset/testCSV/1.csv'

    buffer = []
    data = ds.CSVDataset(TRAIN_FILE,
                         column_defaults=["0", 0, 0.0, "0"],
                         column_names=['1', '2', '3', '4'],
                         shuffle=False)
    data = data.repeat(2)
    data = data.skip(2)
    for d in data.create_dict_iterator():
        buffer.append(d)
    assert len(buffer) == 4
示例#18
0
def test_csv_dataset_field_delim_none():
    """
    Test CSV with field_delim=None
    """
    TRAIN_FILE = '../data/dataset/testCSV/1.csv'

    buffer = []
    data = ds.CSVDataset(TRAIN_FILE,
                         field_delim=None,
                         column_defaults=["0", 0, 0.0, "0"],
                         column_names=['1', '2', '3', '4'],
                         shuffle=False)
    data = data.repeat(2)
    data = data.skip(2)
    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
        buffer.append(d)
    assert len(buffer) == 4
示例#19
0
def test_get_column_name_csv():
    data = ds.CSVDataset(CSV_DIR)
    assert data.get_col_names() == ["1", "2", "3", "4"]
    data = ds.CSVDataset(CSV_DIR,
                         column_names=["col1", "col2", "col3", "col4"])
    assert data.get_col_names() == ["col1", "col2", "col3", "col4"]
示例#20
0
def test_csv_dataset_exception():
    TEST_FILE = '../data/dataset/testCSV/exception.csv'
    data = ds.CSVDataset(TEST_FILE,
                         column_defaults=["", "", "", ""],
                         column_names=['col1', 'col2', 'col3', 'col4'],
                         shuffle=False)
    with pytest.raises(Exception) as err:
        for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
            pass
    assert "failed to parse file" in str(err.value)

    TEST_FILE1 = '../data/dataset/testCSV/quoted.csv'

    def exception_func(item):
        raise Exception("Error occur!")

    try:
        data = ds.CSVDataset(TEST_FILE1,
                             column_defaults=["", "", "", ""],
                             column_names=['col1', 'col2', 'col3', 'col4'],
                             shuffle=False)
        data = data.map(operations=exception_func,
                        input_columns=["col1"],
                        num_parallel_workers=1)
        for _ in data.__iter__():
            pass
        assert False
    except RuntimeError as e:
        assert "map operation: [PyFunc] failed. The corresponding data files" in str(
            e)

    try:
        data = ds.CSVDataset(TEST_FILE1,
                             column_defaults=["", "", "", ""],
                             column_names=['col1', 'col2', 'col3', 'col4'],
                             shuffle=False)
        data = data.map(operations=exception_func,
                        input_columns=["col2"],
                        num_parallel_workers=1)
        for _ in data.__iter__():
            pass
        assert False
    except RuntimeError as e:
        assert "map operation: [PyFunc] failed. The corresponding data files" in str(
            e)

    try:
        data = ds.CSVDataset(TEST_FILE1,
                             column_defaults=["", "", "", ""],
                             column_names=['col1', 'col2', 'col3', 'col4'],
                             shuffle=False)
        data = data.map(operations=exception_func,
                        input_columns=["col3"],
                        num_parallel_workers=1)
        for _ in data.__iter__():
            pass
        assert False
    except RuntimeError as e:
        assert "map operation: [PyFunc] failed. The corresponding data files" in str(
            e)

    try:
        data = ds.CSVDataset(TEST_FILE1,
                             column_defaults=["", "", "", ""],
                             column_names=['col1', 'col2', 'col3', 'col4'],
                             shuffle=False)
        data = data.map(operations=exception_func,
                        input_columns=["col4"],
                        num_parallel_workers=1)
        for _ in data.__iter__():
            pass
        assert False
    except RuntimeError as e:
        assert "map operation: [PyFunc] failed. The corresponding data files" in str(
            e)