Esempio n. 1
0
def test_emmental_dataset(caplog):
    """Unit test of emmental dataset"""

    caplog.set_level(logging.INFO)

    x1 = [
        torch.Tensor([1]),
        torch.Tensor([1, 2]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3, 4, 5]),
    ]

    y1 = torch.Tensor([0, 0, 0, 0, 0])

    dataset = EmmentalDataset(X_dict={"data1": x1},
                              Y_dict={"label1": y1},
                              name="new_data")

    # Check if the dataset is correctly constructed
    assert torch.equal(dataset[0][0]["data1"], x1[0])
    assert torch.equal(dataset[0][1]["label1"], y1[0])

    x2 = [
        torch.Tensor([1, 2, 3, 4, 5]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2]),
        torch.Tensor([1]),
    ]

    dataset.add_features(X_dict={"data2": x2})

    # Check add one more feature to dataset
    assert torch.equal(dataset[0][0]["data2"], x2[0])

    y2 = torch.Tensor([1, 1, 1, 1, 1])

    dataset.add_labels(Y_dict={"label2": y2})

    # Check add one more label to dataset
    assert torch.equal(dataset[0][1]["label2"], y2[0])

    dataset.remove_label(label_name="label1")

    # Check remove one more label to dataset
    assert "label1" not in dataset.Y_dict
Esempio n. 2
0
def test_emmental_dataset(caplog):
    """Unit test of emmental dataset."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_data"

    Meta.reset()
    emmental.init(dirpath)

    x1 = [
        torch.Tensor([1]),
        torch.Tensor([1, 2]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3, 4, 5]),
    ]

    y1 = torch.Tensor([0, 0, 0, 0, 0])

    dataset = EmmentalDataset(X_dict={"data1": x1},
                              Y_dict={"label1": y1},
                              name="new_data")

    # Check if the dataset is correctly constructed
    assert torch.equal(dataset[0][0]["data1"], x1[0])
    assert torch.equal(dataset[0][1]["label1"], y1[0])

    x2 = [
        torch.Tensor([1, 2, 3, 4, 5]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2]),
        torch.Tensor([1]),
    ]

    dataset.add_features(X_dict={"data2": x2})

    dataset.remove_feature("data2")
    assert "data2" not in dataset.X_dict

    dataset.add_features(X_dict={"data2": x2})

    # Check add one more feature to dataset
    assert torch.equal(dataset[0][0]["data2"], x2[0])

    y2 = torch.Tensor([1, 1, 1, 1, 1])

    dataset.add_labels(Y_dict={"label2": y2})

    with pytest.raises(ValueError):
        dataset.add_labels(Y_dict={"label2": x2})

    # Check add one more label to dataset
    assert torch.equal(dataset[0][1]["label2"], y2[0])

    dataset.remove_label(label_name="label1")

    # Check remove one more label to dataset
    assert "label1" not in dataset.Y_dict

    with pytest.raises(ValueError):
        dataset = EmmentalDataset(X_dict={"data1": x1},
                                  Y_dict={"label1": y1},
                                  name="new_data",
                                  uid="ids")

    dataset = EmmentalDataset(X_dict={"_uids_": x1},
                              Y_dict={"label1": y1},
                              name="new_data")

    dataset = EmmentalDataset(X_dict={"data1": x1}, name="new_data")

    # Check if the dataset is correctly constructed
    assert torch.equal(dataset[0]["data1"], x1[0])

    dataset.add_features(X_dict={"data2": x2})

    dataset.remove_feature("data2")
    assert "data2" not in dataset.X_dict

    dataset.add_features(X_dict={"data2": x2})

    # Check add one more feature to dataset
    assert torch.equal(dataset[0]["data2"], x2[0])

    y2 = torch.Tensor([1, 1, 1, 1, 1])

    dataset.add_labels(Y_dict={"label2": y2})

    # Check add one more label to dataset
    assert torch.equal(dataset[0][1]["label2"], y2[0])

    shutil.rmtree(dirpath)