def test_emmental_dataset(caplog): """Unit test of emmental dataset""" caplog.set_level(logging.INFO) x1 = [ torch.Tensor([1]), torch.Tensor([1, 2]), torch.Tensor([1, 2, 3]), torch.Tensor([1, 2, 3, 4]), torch.Tensor([1, 2, 3, 4, 5]), ] y1 = torch.Tensor([0, 0, 0, 0, 0]) dataset = EmmentalDataset(X_dict={"data1": x1}, Y_dict={"label1": y1}, name="new_data") # Check if the dataset is correctly constructed assert torch.equal(dataset[0][0]["data1"], x1[0]) assert torch.equal(dataset[0][1]["label1"], y1[0]) x2 = [ torch.Tensor([1, 2, 3, 4, 5]), torch.Tensor([1, 2, 3, 4]), torch.Tensor([1, 2, 3]), torch.Tensor([1, 2]), torch.Tensor([1]), ] dataset.add_features(X_dict={"data2": x2}) # Check add one more feature to dataset assert torch.equal(dataset[0][0]["data2"], x2[0]) y2 = torch.Tensor([1, 1, 1, 1, 1]) dataset.add_labels(Y_dict={"label2": y2}) # Check add one more label to dataset assert torch.equal(dataset[0][1]["label2"], y2[0]) dataset.remove_label(label_name="label1") # Check remove one more label to dataset assert "label1" not in dataset.Y_dict
def test_emmental_dataset(caplog): """Unit test of emmental dataset.""" caplog.set_level(logging.INFO) dirpath = "temp_test_data" Meta.reset() emmental.init(dirpath) x1 = [ torch.Tensor([1]), torch.Tensor([1, 2]), torch.Tensor([1, 2, 3]), torch.Tensor([1, 2, 3, 4]), torch.Tensor([1, 2, 3, 4, 5]), ] y1 = torch.Tensor([0, 0, 0, 0, 0]) dataset = EmmentalDataset(X_dict={"data1": x1}, Y_dict={"label1": y1}, name="new_data") # Check if the dataset is correctly constructed assert torch.equal(dataset[0][0]["data1"], x1[0]) assert torch.equal(dataset[0][1]["label1"], y1[0]) x2 = [ torch.Tensor([1, 2, 3, 4, 5]), torch.Tensor([1, 2, 3, 4]), torch.Tensor([1, 2, 3]), torch.Tensor([1, 2]), torch.Tensor([1]), ] dataset.add_features(X_dict={"data2": x2}) dataset.remove_feature("data2") assert "data2" not in dataset.X_dict dataset.add_features(X_dict={"data2": x2}) # Check add one more feature to dataset assert torch.equal(dataset[0][0]["data2"], x2[0]) y2 = torch.Tensor([1, 1, 1, 1, 1]) dataset.add_labels(Y_dict={"label2": y2}) with pytest.raises(ValueError): dataset.add_labels(Y_dict={"label2": x2}) # Check add one more label to dataset assert torch.equal(dataset[0][1]["label2"], y2[0]) dataset.remove_label(label_name="label1") # Check remove one more label to dataset assert "label1" not in dataset.Y_dict with pytest.raises(ValueError): dataset = EmmentalDataset(X_dict={"data1": x1}, Y_dict={"label1": y1}, name="new_data", uid="ids") dataset = EmmentalDataset(X_dict={"_uids_": x1}, Y_dict={"label1": y1}, name="new_data") dataset = EmmentalDataset(X_dict={"data1": x1}, name="new_data") # Check if the dataset is correctly constructed assert torch.equal(dataset[0]["data1"], x1[0]) dataset.add_features(X_dict={"data2": x2}) dataset.remove_feature("data2") assert "data2" not in dataset.X_dict dataset.add_features(X_dict={"data2": x2}) # Check add one more feature to dataset assert torch.equal(dataset[0]["data2"], x2[0]) y2 = torch.Tensor([1, 1, 1, 1, 1]) dataset.add_labels(Y_dict={"label2": y2}) # Check add one more label to dataset assert torch.equal(dataset[0][1]["label2"], y2[0]) shutil.rmtree(dirpath)