Пример #1
0
def test_text_module_not_found_error():
    with pytest.raises(ModuleNotFoundError, match="[text]"):
        TextClassificationData.from_json("sentence",
                                         "lab",
                                         backbone=TEST_BACKBONE,
                                         train_file="",
                                         batch_size=1)
Пример #2
0
def test_from_json_with_field(tmpdir):
    json_path = json_data_with_field(tmpdir, multilabel=False)
    dm = TextClassificationData.from_json(
        "sentence",
        "lab",
        train_file=json_path,
        val_file=json_path,
        test_file=json_path,
        predict_file=json_path,
        batch_size=1,
        field="data",
    )

    batch = next(iter(dm.train_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Пример #3
0
def test_from_json_with_field_multilabel(tmpdir):
    json_path = json_data_with_field(tmpdir, multilabel=True)
    dm = TextClassificationData.from_json(
        "sentence",
        ["lab1", "lab2"],
        train_file=json_path,
        val_file=json_path,
        test_file=json_path,
        predict_file=json_path,
        batch_size=1,
        field="data",
    )

    assert dm.multi_label

    batch = next(iter(dm.train_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Пример #4
0
def test_from_json_with_field(tmpdir):
    json_path = json_data_with_field(tmpdir)
    dm = TextClassificationData.from_json(
        "sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
    )
    batch = next(iter(dm.train_dataloader()))
    assert batch["labels"].item() in [0, 1]
    assert "input_ids" in batch