示例#1
0
def test_missing_datatype_exception(data, fields, tmpdir):
    data_null = [(*d, None) for d in data]
    null_field = Field(
        "null_field", keep_raw=True, allow_missing_data=True, numericalizer=Vocab()
    )
    fields_null = [*fields, null_field]

    exf = ExampleFactory(fields_null)
    examples = map(exf.from_list, data_null)

    with pytest.raises(RuntimeError):
        DiskBackedDataset.from_examples(fields_null, examples, cache_path=tmpdir)
示例#2
0
def test_delete_cache(data, fields):
    cache_dir = tempfile.mkdtemp()

    example_factory = ExampleFactory(fields)
    examples = map(example_factory.from_list, data)
    ad = DiskBackedDataset.from_examples(fields, examples, cache_path=cache_dir)

    assert os.path.exists(cache_dir)
    ad.delete_cache()
    assert not os.path.exists(cache_dir)
示例#3
0
def test_dump_and_load(pyarrow_dataset, tmpdir):
    cache_dir = pyarrow_dataset.dump_cache(cache_path=None)
    loaded_dataset = DiskBackedDataset.load_cache(cache_dir)

    assert len(loaded_dataset) == len(pyarrow_dataset)
    for ex_original, ex_loaded in zip(pyarrow_dataset, loaded_dataset):
        assert ex_original["number"] == ex_loaded["number"]
        assert ex_original["tokens"] == ex_loaded["tokens"]
    assert (
        pyarrow_dataset.field_dict["tokens"].vocab.stoi
        == loaded_dataset.field_dict["tokens"].vocab.stoi
    )

    loaded_dataset.delete_cache()

    dataset_sliced = pyarrow_dataset[8:2:-2]
    cache_dir_sliced = dataset_sliced.dump_cache(cache_path=None)
    loaded_dataset_sliced = DiskBackedDataset.load_cache(cache_dir_sliced)

    assert len(loaded_dataset_sliced) == len(dataset_sliced)
    for ex_original, ex_loaded in zip(dataset_sliced, loaded_dataset_sliced):
        assert ex_original["number"] == ex_loaded["number"]
        assert ex_original["tokens"] == ex_loaded["tokens"]
    assert (
        pyarrow_dataset.field_dict["tokens"].vocab.stoi
        == loaded_dataset_sliced.field_dict["tokens"].vocab.stoi
    )

    loaded_dataset_sliced.delete_cache()

    pyarrow_dataset.dump_cache(cache_path=tmpdir)
    loaded_dataset = DiskBackedDataset.load_cache(tmpdir)

    assert len(loaded_dataset) == len(pyarrow_dataset)
    for ex_original, ex_loaded in zip(pyarrow_dataset, loaded_dataset):
        assert ex_original["number"] == ex_loaded["number"]
        assert ex_original["tokens"] == ex_loaded["tokens"]
    assert (
        pyarrow_dataset.field_dict["tokens"].vocab.stoi
        == loaded_dataset.field_dict["tokens"].vocab.stoi
    )

    loaded_dataset.delete_cache()
示例#4
0
def test_from_dataset(data, fields):
    example_factory = ExampleFactory(fields)
    examples = [example_factory.from_list(raw_example) for raw_example in data]
    dataset = Dataset(examples, fields)
    pyarrow_dataset = DiskBackedDataset.from_dataset(dataset)

    for ds_ex, arrow_ex in zip(dataset, pyarrow_dataset):
        assert ds_ex.number == arrow_ex.number
        assert ds_ex.tokens == arrow_ex.tokens

    pyarrow_dataset.delete_cache()
示例#5
0
def test_from_tabular(data, fields, tmpdir):
    test_file = os.path.join(tmpdir, "test.csv")
    with open(test_file, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerows(data)

    csv_dataset = DiskBackedDataset.from_tabular_file(test_file, "csv", fields)
    for ex, d in zip(csv_dataset, data):
        assert int(ex.number[0]) == d[0]
        assert ex.tokens[0] == d[1]

    csv_dataset.delete_cache()
示例#6
0
def test_from_pandas_index(data):
    import pandas as pd

    df = pd.DataFrame([[x[1]] for x in data], index=[x[0] for x in data])
    fields = [Field("text_field", keep_raw=True, tokenizer="split")]

    ds = DiskBackedDataset.from_pandas(
        df, fields, index_field=Field("number_field", tokenizer=None, keep_raw=True)
    )

    assert set(ds.field_dict) == set(["text_field", "number_field"])
    for original, (raw, _) in zip(data, ds.number_field):
        assert original[0] == raw
示例#7
0
def test_from_pandas_field_dict(data):
    import pandas as pd

    df = pd.DataFrame(data, columns=["number", "text"])
    fields = {
        "number": Field("number", tokenizer=None),
        "text": Field("text", keep_raw=True, tokenizer="split"),
    }

    ds = DiskBackedDataset.from_pandas(df, fields)

    for original, (raw, _) in zip(data, ds.text):
        assert original[1] == raw
示例#8
0
def test_from_examples(data, fields):
    example_factory = ExampleFactory(fields)
    examples = [example_factory.from_list(ex) for ex in data]
    ad = DiskBackedDataset.from_examples(fields, examples)

    for (raw, tokenized), (num, _) in zip(ad.number, data):
        assert raw == num
        assert tokenized is num

    for (raw, tokenized), (_, tok) in zip(ad.tokens, data):
        assert raw == tok
        assert tokenized == tok.split(" ")

    ad.delete_cache()
示例#9
0
def test_datatype_definition(data, fields):
    data_null = [(*d, None) for d in data]
    null_field = Field(
        "null_field", keep_raw=True, allow_missing_data=True, numericalizer=Vocab()
    )
    fields_null = [*fields, null_field]

    exf = ExampleFactory(fields_null)
    examples = map(exf.from_list, data_null)

    datatypes = {"null_field": (pa.string(), pa.list_(pa.string()))}
    dataset = DiskBackedDataset.from_examples(fields_null, examples, data_types=datatypes)

    for ex, d in zip(dataset, data_null):
        assert int(ex["number"][0]) == d[0]
        assert ex["tokens"][0] == d[1]

    dataset.delete_cache()
示例#10
0
def test_slice_view_to_dataset(dataset, tmp_path):
    start, stop, step = 3, 8, 2
    slc = slice(start, stop, step)
    dataset_view = DatasetSlicedView(dataset, s=slc)

    # cast to Dataset
    ds = Dataset.from_dataset(dataset_view)
    assert isinstance(ds, Dataset)
    assert len(ds) == len(dataset_view)
    for ex_view, ex_dataset in zip(dataset_view, ds):
        for f in dataset.fields:
            assert ex_view[f.name] == ex_dataset[f.name]

    # cast to DiskBackedDataset
    ds = DiskBackedDataset.from_dataset(dataset_view, cache_path=tmp_path)
    assert isinstance(ds, DiskBackedDataset)
    assert len(ds) == len(dataset_view)
    for ex_view, ex_dataset in zip(dataset_view, ds):
        for f in dataset.fields:
            assert ex_view[f.name] == ex_dataset[f.name]
示例#11
0
def pyarrow_dataset(data, fields):
    example_factory = ExampleFactory(fields)
    examples = map(example_factory.from_list, data)
    return DiskBackedDataset.from_examples(fields, examples)