def test_text_module_not_found_error(): with pytest.raises(ModuleNotFoundError, match="[text]"): TextClassificationData.from_json("sentence", "lab", backbone=TEST_BACKBONE, train_file="", batch_size=1)
def test_from_json_with_field(tmpdir): json_path = json_data_with_field(tmpdir, multilabel=False) dm = TextClassificationData.from_json( "sentence", "lab", train_file=json_path, val_file=json_path, test_file=json_path, predict_file=json_path, batch_size=1, field="data", ) batch = next(iter(dm.train_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) assert isinstance(batch[DataKeys.INPUT][0], str)
def test_from_json_with_field_multilabel(tmpdir): json_path = json_data_with_field(tmpdir, multilabel=True) dm = TextClassificationData.from_json( "sentence", ["lab1", "lab2"], train_file=json_path, val_file=json_path, test_file=json_path, predict_file=json_path, batch_size=1, field="data", ) assert dm.multi_label batch = next(iter(dm.train_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) assert isinstance(batch[DataKeys.INPUT][0], str)
def test_from_json_with_field(tmpdir): json_path = json_data_with_field(tmpdir) dm = TextClassificationData.from_json( "sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" ) batch = next(iter(dm.train_dataloader())) assert batch["labels"].item() in [0, 1] assert "input_ids" in batch