def test_from_hf_datasets_multilabel(): TEST_HF_DATASET_DATA_MULTILABEL = Dataset.from_pandas(TEST_DATA_FRAME_DATA_MULTILABEL) dm = TextClassificationData.from_hf_datasets( "sentence", ["lab1", "lab2"], train_hf_dataset=TEST_HF_DATASET_DATA_MULTILABEL, val_hf_dataset=TEST_HF_DATASET_DATA_MULTILABEL, test_hf_dataset=TEST_HF_DATASET_DATA_MULTILABEL, predict_hf_dataset=TEST_HF_DATASET_DATA_MULTILABEL, batch_size=1, ) assert dm.multi_label batch = next(iter(dm.train_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) assert isinstance(batch[DataKeys.INPUT][0], str)
def test_from_hf_datasets(): TEST_HF_DATASET_DATA = Dataset.from_pandas(TEST_DATA_FRAME_DATA) dm = TextClassificationData.from_hf_datasets( "sentence", "lab1", train_hf_dataset=TEST_HF_DATASET_DATA, val_hf_dataset=TEST_HF_DATASET_DATA, test_hf_dataset=TEST_HF_DATASET_DATA, predict_hf_dataset=TEST_HF_DATASET_DATA, batch_size=1, ) batch = next(iter(dm.train_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) assert isinstance(batch[DataKeys.INPUT][0], str)