def test_csv_dataset_size(): dataset = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'], shuffle=False) assert dataset.get_dataset_size() == 3 dataset_shard_2_0 = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'], shuffle=False, num_shards=2, shard_id=0) assert dataset_shard_2_0.get_dataset_size() == 2
def test_csv_dataset_size(): TEST_FILE = '../data/dataset/testCSV/size.csv' data = ds.CSVDataset(TEST_FILE, column_defaults=[0.0, 0.0, 0, 0.0], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) assert data.get_dataset_size() == 5
def test_get_column_name_zip(): data1 = ds.Cifar10Dataset(CIFAR10_DIR) assert data1.get_col_names() == ["image", "label"] data2 = ds.CSVDataset(CSV_DIR) assert data2.get_col_names() == ["1", "2", "3", "4"] data = ds.zip((data1, data2)) assert data.get_col_names() == ["image", "label", "1", "2", "3", "4"]
def test_serdes_csv_dataset(remove_json_files=True): """ Test serdes on Csvdataset pipeline. """ DATA_DIR = "../data/dataset/testCSV/1.csv" data1 = ds.CSVDataset(DATA_DIR, column_defaults=["1", "2", "3", "4"], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) columns = ["col1", "col4", "col2"] data1 = data1.project(columns=columns) data2 = util_check_serialize_deserialize_file(data1, "csv_dataset_pipeline", remove_json_files) num_samples = 0 # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) for item1, item2 in zip( data1.create_dict_iterator(num_epochs=1, output_numpy=True), data2.create_dict_iterator(num_epochs=1, output_numpy=True)): np.testing.assert_array_equal(item1['col1'], item2['col1']) np.testing.assert_array_equal(item1['col2'], item2['col2']) np.testing.assert_array_equal(item1['col4'], item2['col4']) num_samples += 1 assert num_samples == 3
def test_csv_dataset_one_file(): data = ds.CSVDataset(DATA_FILE, column_defaults=["1", "2", "3", "4"], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) buffer = [] for d in data.create_dict_iterator(): buffer.append(d) assert len(buffer) == 3
def test_csv_dataset_all_file(): APPEND_FILE = '../data/dataset/testCSV/2.csv' data = ds.CSVDataset([DATA_FILE, APPEND_FILE], column_defaults=["1", "2", "3", "4"], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) buffer = [] for d in data.create_dict_iterator(): buffer.append(d) assert len(buffer) == 10
def test_csv_dataset_type_error(): TEST_FILE = '../data/dataset/testCSV/exception.csv' data = ds.CSVDataset(TEST_FILE, column_defaults=["", 0, "", ""], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) with pytest.raises(Exception) as err: for _ in data.create_dict_iterator(): pass assert "type does not match" in str(err.value)
def test_csv_dataset_exception(): TEST_FILE = '../data/dataset/testCSV/exception.csv' data = ds.CSVDataset(TEST_FILE, column_defaults=["", "", "", ""], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) with pytest.raises(Exception) as err: for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): pass assert "failed to parse file" in str(err.value)
def test_csv_dataset_num_samples(): data = ds.CSVDataset(DATA_FILE, column_defaults=["1", "2", "3", "4"], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False, num_samples=2) count = 0 for _ in data.create_dict_iterator(): count += 1 assert count == 2
def test_csv_dataset_header(): TEST_FILE = '../data/dataset/testCSV/header.csv' data = ds.CSVDataset(TEST_FILE, column_defaults=["", "", "", ""], shuffle=False) buffer = [] for d in data.create_dict_iterator(): buffer.extend([ d['col1'].item().decode("utf8"), d['col2'].item().decode("utf8"), d['col3'].item().decode("utf8"), d['col4'].item().decode("utf8") ]) assert buffer == ['a', 'b', 'c', 'd']
def test_csv_dataset_distribution(): TEST_FILE = '../data/dataset/testCSV/1.csv' data = ds.CSVDataset(TEST_FILE, column_defaults=["1", "2", "3", "4"], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False, num_shards=2, shard_id=0) count = 0 for _ in data.create_dict_iterator(): count += 1 assert count == 2
def test_csv_dataset_embedded(): TEST_FILE = '../data/dataset/testCSV/embedded.csv' data = ds.CSVDataset(TEST_FILE, column_defaults=["", "", "", ""], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) buffer = [] for d in data.create_dict_iterator(): buffer.extend([ d['col1'].item().decode("utf8"), d['col2'].item().decode("utf8"), d['col3'].item().decode("utf8"), d['col4'].item().decode("utf8") ]) assert buffer == ['a,b', 'c"d', 'e\nf', ' g ']
def test_csv_dataset_quoted(): TEST_FILE = '../data/dataset/testCSV/quoted.csv' data = ds.CSVDataset(TEST_FILE, column_defaults=["", "", "", ""], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) buffer = [] for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): buffer.extend([ d['col1'].item().decode("utf8"), d['col2'].item().decode("utf8"), d['col3'].item().decode("utf8"), d['col4'].item().decode("utf8") ]) assert buffer == ['a', 'b', 'c', 'd']
def test_csv_dataset_duplicate_columns(): data = ds.CSVDataset(DATA_FILE, column_defaults=["1", "2", "3", "4"], column_names=[ 'col1', 'col2', 'col3', 'col4', 'col1', 'col2', 'col3', 'col4' ], shuffle=False) with pytest.raises(RuntimeError) as info: _ = data.create_dict_iterator(num_epochs=1, output_numpy=True) assert "Invalid parameter, duplicate column names are not allowed: col1" in str( info.value) assert "column_names" in str(info.value)
def test_csv_dataset_number(): TEST_FILE = '../data/dataset/testCSV/number.csv' data = ds.CSVDataset(TEST_FILE, column_defaults=[0.0, 0.0, 0, 0.0], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) buffer = [] for d in data.create_dict_iterator(): buffer.extend([ d['col1'].item(), d['col2'].item(), d['col3'].item(), d['col4'].item() ]) assert np.allclose(buffer, [3.0, 0.3, 4, 55.5])
def test_csv_dataset_chinese(): TEST_FILE = '../data/dataset/testCSV/chinese.csv' data = ds.CSVDataset(TEST_FILE, column_defaults=["", "", "", "", ""], column_names=['col1', 'col2', 'col3', 'col4', 'col5'], shuffle=False) buffer = [] for d in data.create_dict_iterator(): buffer.extend([ d['col1'].item().decode("utf8"), d['col2'].item().decode("utf8"), d['col3'].item().decode("utf8"), d['col4'].item().decode("utf8"), d['col5'].item().decode("utf8") ]) assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好']
def test_csv_dataset_basic(): """ Test CSV with repeat, skip and so on """ TRAIN_FILE = '../data/dataset/testCSV/1.csv' buffer = [] data = ds.CSVDataset(TRAIN_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'], shuffle=False) data = data.repeat(2) data = data.skip(2) for d in data.create_dict_iterator(): buffer.append(d) assert len(buffer) == 4
def test_csv_dataset_field_delim_none(): """ Test CSV with field_delim=None """ TRAIN_FILE = '../data/dataset/testCSV/1.csv' buffer = [] data = ds.CSVDataset(TRAIN_FILE, field_delim=None, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'], shuffle=False) data = data.repeat(2) data = data.skip(2) for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): buffer.append(d) assert len(buffer) == 4
def test_get_column_name_csv(): data = ds.CSVDataset(CSV_DIR) assert data.get_col_names() == ["1", "2", "3", "4"] data = ds.CSVDataset(CSV_DIR, column_names=["col1", "col2", "col3", "col4"]) assert data.get_col_names() == ["col1", "col2", "col3", "col4"]
def test_csv_dataset_exception(): TEST_FILE = '../data/dataset/testCSV/exception.csv' data = ds.CSVDataset(TEST_FILE, column_defaults=["", "", "", ""], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) with pytest.raises(Exception) as err: for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): pass assert "failed to parse file" in str(err.value) TEST_FILE1 = '../data/dataset/testCSV/quoted.csv' def exception_func(item): raise Exception("Error occur!") try: data = ds.CSVDataset(TEST_FILE1, column_defaults=["", "", "", ""], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) data = data.map(operations=exception_func, input_columns=["col1"], num_parallel_workers=1) for _ in data.__iter__(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str( e) try: data = ds.CSVDataset(TEST_FILE1, column_defaults=["", "", "", ""], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) data = data.map(operations=exception_func, input_columns=["col2"], num_parallel_workers=1) for _ in data.__iter__(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str( e) try: data = ds.CSVDataset(TEST_FILE1, column_defaults=["", "", "", ""], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) data = data.map(operations=exception_func, input_columns=["col3"], num_parallel_workers=1) for _ in data.__iter__(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str( e) try: data = ds.CSVDataset(TEST_FILE1, column_defaults=["", "", "", ""], column_names=['col1', 'col2', 'col3', 'col4'], shuffle=False) data = data.map(operations=exception_func, input_columns=["col4"], num_parallel_workers=1) for _ in data.__iter__(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" in str( e)