Exemple #1
0
def test_dataset_transform_mixed_multiple_named_cols():
    train = (
            ("Lorem ipsum dolor sit amet", "POSITIVE"),
            ("Sed ut perspiciatis unde", "NEGATIVE"))

    class DummyField(Field):
        def setup(self, *data: np.ndarray) -> None:
            pass

        def process(self, ex1, ex2):
            return torch.tensor(0)

    transform = {
        "text": {
            "field": DummyField(),
            "columns": ['text', 'label']
        },
        "other": {
            "field": DummyField(),
            "columns": [0, 1]
        },
        "other2": {
            "field": DummyField(),
            "columns": [0, 'label']
        }
    }

    t = TabularDataset(train, transform=transform, named_columns=['text', 'label'])
    assert t.train.cols() == 3
Exemple #2
0
def test_invalid_dataset2():
    """Test dataset is invalid as different splits contain different columns"""
    train = (("Lorem ipsum dolor sit amet", 3, 4.5),
             ("Sed ut perspiciatis unde", 4, 5.5))
    val = (("ipsum quia dolor sit", 3.5), )
    with pytest.raises(ValueError):
        t = TabularDataset(train, val)
Exemple #3
0
def test_incomplete_dataset():
    """Test dataset missing either val or test"""
    train = (("Lorem ipsum dolor sit amet", 3, 4.5),
             ("Sed ut perspiciatis unde", 4, 5.5))
    t = TabularDataset(train)

    assert len(t.val) == 0
    assert len(t.test) == 0
Exemple #4
0
def test_dataset_transform_8():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {"tx": {"field": LabelField(), "columns": [0, 1]}}

    with pytest.raises(TypeError):
        t = TabularDataset(train, transform=transform)
        t.train.cols()
Exemple #5
0
def test_dataset_transform_with_invalid_named_cols():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {"tx": {"field": LabelField(), "columns": 'none_existent'}}

    with pytest.raises(ValueError):
        TabularDataset(train,
                       transform=transform,
                       named_columns=['text', 'label'])
Exemple #6
0
def test_dataset_transform_with_named_cols():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {"tx": {"field": LabelField(), "columns": 'label'}}

    t = TabularDataset(train,
                       transform=transform,
                       named_columns=['text', 'label'])
    assert len(t.train[0]) == 1
Exemple #7
0
def test_dataset_transform():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {"text": TextField(), "label": LabelField()}

    t = TabularDataset(train, transform=transform)

    assert hasattr(t, "text")
    assert hasattr(t, "label")

    assert t.label.vocab_size == 2
    assert t.text.vocab_size == 11
Exemple #8
0
def test_dataset_transform_3():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {
        "text": {
            "columns": 0
        },
        "label": {
            "field": LabelField(),
            "columns": 1
        }
    }

    with pytest.raises(ValueError):
        TabularDataset(train, transform=transform)
Exemple #9
0
def test_dataset_transform_5():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {
        "t1": {
            "field": TextField(),
            "columns": 0
        },
        "t2": {
            "field": TextField(),
            "columns": 0
        }
    }

    t = TabularDataset(train, transform=transform)
    assert t.train.cols() == 2
Exemple #10
0
def test_cache_dataset():
    """Test caching the dataset"""
    train = (("Lorem ipsum dolor sit amet", 3,
              4.5), ("Sed ut perspiciatis unde", 5,
                     5.5), ("Lorem ipsum dolor sit amet", 3,
                            4.5), ("Sed ut perspiciatis unde", 5, 5.5),
             ("Lorem ipsum dolor sit amet", 3,
              4.5), ("Sed ut perspiciatis unde", 5,
                     5.5), ("Lorem ipsum dolor sit amet", 3,
                            4.5), ("Sed ut perspiciatis unde", 5, 5.5),
             ("Lorem ipsum dolor sit amet", 3,
              4.5), ("Sed ut perspiciatis unde", 5, 5.5))

    t = TabularDataset(train, cache=True)

    assert len(t.train.cached_data) == 0
    for i, _ in enumerate(t.train):
        assert len(t.train.cached_data) == i + 1
Exemple #11
0
def test_valid_dataset():
    """Test trivial dataset build process"""
    train = (("Lorem ipsum dolor sit amet", 3, 4.5),
             ("Sed ut perspiciatis unde", 5, 5.5))
    val = (("ipsum quia dolor sit", 10, 3.5),)
    test = (("Ut enim ad minima veniam", 100, 35),)

    t = TabularDataset(train, val, test)

    assert len(t) == 4
    assert len(t.train) == 2
    assert len(t.val) == 1
    assert len(t.test) == 1

    def check(d, t):
        for i, tu in enumerate(d):
            v0, v1, v2 = tu
            assert t[i][0] == v0
            assert t[i][1] == v1
            assert t[i][2] == v2

    check(train, t.train)
    check(val, t.val)
    check(test, t.test)
Exemple #12
0
def test_invalid_columns():
    """Test dataset is invalid as it has different columns"""
    train = (("Lorem ipsum dolor sit amet", 3),
             ("Sed ut perspiciatis unde", 5.5))
    with pytest.raises(ValueError):
        TabularDataset(train, named_columns=['some_random_col'])
Exemple #13
0
def test_named_columns():
    """Test dataset is invalid as it has different columns"""
    train = (("Lorem ipsum dolor sit amet", 3),
             ("Sed ut perspiciatis unde", 5.5))
    TabularDataset(train, named_columns=['col1', 'col2'])
Exemple #14
0
def test_invalid_dataset():
    """Test dataset is invalid as it has different columns"""
    train = (("Lorem ipsum dolor sit amet", 3, 4.5),
             ("Sed ut perspiciatis unde", 5.5))
    with pytest.raises(ValueError):
        TabularDataset(train)