Пример #1
0
def test_emmental_dataset(caplog):
    """Unit test of emmental dataset"""

    caplog.set_level(logging.INFO)

    x1 = [
        torch.Tensor([1]),
        torch.Tensor([1, 2]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3, 4, 5]),
    ]

    y1 = torch.Tensor([0, 0, 0, 0, 0])

    dataset = EmmentalDataset(X_dict={"data1": x1},
                              Y_dict={"label1": y1},
                              name="new_data")

    # Check if the dataset is correctly constructed
    assert torch.equal(dataset[0][0]["data1"], x1[0])
    assert torch.equal(dataset[0][1]["label1"], y1[0])

    x2 = [
        torch.Tensor([1, 2, 3, 4, 5]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2]),
        torch.Tensor([1]),
    ]

    dataset.add_features(X_dict={"data2": x2})

    # Check add one more feature to dataset
    assert torch.equal(dataset[0][0]["data2"], x2[0])

    y2 = torch.Tensor([1, 1, 1, 1, 1])

    dataset.add_labels(Y_dict={"label2": y2})

    # Check add one more label to dataset
    assert torch.equal(dataset[0][1]["label2"], y2[0])

    dataset.remove_label(label_name="label1")

    # Check remove one more label to dataset
    assert "label1" not in dataset.Y_dict
Пример #2
0
def test_emmental_dataset(caplog):
    """Unit test of emmental dataset."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_data"

    Meta.reset()
    emmental.init(dirpath)

    x1 = [
        torch.Tensor([1]),
        torch.Tensor([1, 2]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3, 4, 5]),
    ]

    y1 = torch.Tensor([0, 0, 0, 0, 0])

    dataset = EmmentalDataset(X_dict={"data1": x1},
                              Y_dict={"label1": y1},
                              name="new_data")

    # Check if the dataset is correctly constructed
    assert torch.equal(dataset[0][0]["data1"], x1[0])
    assert torch.equal(dataset[0][1]["label1"], y1[0])

    x2 = [
        torch.Tensor([1, 2, 3, 4, 5]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2]),
        torch.Tensor([1]),
    ]

    dataset.add_features(X_dict={"data2": x2})

    dataset.remove_feature("data2")
    assert "data2" not in dataset.X_dict

    dataset.add_features(X_dict={"data2": x2})

    # Check add one more feature to dataset
    assert torch.equal(dataset[0][0]["data2"], x2[0])

    y2 = torch.Tensor([1, 1, 1, 1, 1])

    dataset.add_labels(Y_dict={"label2": y2})

    with pytest.raises(ValueError):
        dataset.add_labels(Y_dict={"label2": x2})

    # Check add one more label to dataset
    assert torch.equal(dataset[0][1]["label2"], y2[0])

    dataset.remove_label(label_name="label1")

    # Check remove one more label to dataset
    assert "label1" not in dataset.Y_dict

    with pytest.raises(ValueError):
        dataset = EmmentalDataset(X_dict={"data1": x1},
                                  Y_dict={"label1": y1},
                                  name="new_data",
                                  uid="ids")

    dataset = EmmentalDataset(X_dict={"_uids_": x1},
                              Y_dict={"label1": y1},
                              name="new_data")

    dataset = EmmentalDataset(X_dict={"data1": x1}, name="new_data")

    # Check if the dataset is correctly constructed
    assert torch.equal(dataset[0]["data1"], x1[0])

    dataset.add_features(X_dict={"data2": x2})

    dataset.remove_feature("data2")
    assert "data2" not in dataset.X_dict

    dataset.add_features(X_dict={"data2": x2})

    # Check add one more feature to dataset
    assert torch.equal(dataset[0]["data2"], x2[0])

    y2 = torch.Tensor([1, 1, 1, 1, 1])

    dataset.add_labels(Y_dict={"label2": y2})

    # Check add one more label to dataset
    assert torch.equal(dataset[0][1]["label2"], y2[0])

    shutil.rmtree(dirpath)
Пример #3
0
def load_data_from_db(postgres_db_name,
                      postgres_db_location,
                      label_dict,
                      char_dict=None,
                      clobber_label=True):
    """Load data from database.
    """

    print(f"Loading data from db {postgres_db_name}")
    # Start DB connection
    conn_string = os.path.join(postgres_db_location, postgres_db_name)
    session = Meta.init(conn_string).Session()

    # Printing number of docs/sentences
    print("==============================")
    print(f"DB contents for {postgres_db_name}:")
    print(f"Number of documents: {session.query(Document).count()}")
    print("==============================")

    docs = session.query(Document).all()

    uid_field = []
    text_field = []
    label_field = []
    missed_ids = 0

    term = r"([Ll]ocation:[\w\W]{1,200}</.{0,20}>|\W[cC]ity:[\w\W]{1,200}</.{0,20}>|\d\dyo\W|\d\d.{0,10}\Wyo\W|\d\d.{0,10}\Wold\W|\d\d.{0,10}\Wyoung\W|\Wage\W.{0,10}\d\d)"

    for doc in docs:
        if (doc.name in label_dict) or clobber_label:
            uid_field.append(doc.name)
            text_field.append(get_posting_html_fast(doc.text, term))
            if not clobber_label:
                label_field.append(label_dict[doc.name])
            else:
                label_field.append(-1)
        else:
            missed_ids += 1

    # Printing data stats
    print("==============================")
    print(f"Loaded {len(uid_field)} ids")
    print(f"Loaded {len(text_field)} text")
    print(f"Loaded {len(label_field)} labels")
    print(f"Missed {missed_ids} samples")

    X_dict = {"text": text_field, "uid": uid_field}
    Y_dict = {"label": torch.from_numpy(np.array(label_field))}

    dataset = EmmentalDataset(name="HT",
                              X_dict=X_dict,
                              Y_dict=Y_dict,
                              uid="uid")

    emb_field = []
    for i in range(len(dataset)):
        emb_field.append(
            torch.from_numpy(
                np.array(list(map(char_dict.lookup, dataset[i][0]['text'])))))
    dataset.add_features({"emb": emb_field})
    return dataset