Пример #1
0
def test_persona_deepwalk_dataset():
    persona_ab = persona_graph(graph_ab)
    forward_persona, reverse_persona = lookup_tables(persona_ab)
    forward, reverse = lookup_tables(graph_ab)
    dataset = PersonaDeepWalkDataset(
        graph=persona_ab,
        window_size=1,
        walk_length=10,
        dataset_size=25,
        forward_lookup_persona=forward_persona,
        forward_lookup=forward,
    )
    for item in range(25):
        assert len(dataset[item]) == 3
        for index in range(3):
            assert isinstance(dataset[item][index], int)
    assert len(dataset) == 25
Пример #2
0
def test_train():
    # this just tests that the model training script runs on a real example without throwing errors
    karate = nx.karate_club_graph()
    persona_karate = persona_graph(karate)
    forward_persona, reverse_persona = lookup_tables(persona_karate)
    forward, reverse = lookup_tables(karate)
    embedding = SplitterEmbedding(
        node_count=karate.number_of_nodes(),
        persona_node_count=persona_karate.number_of_nodes(),
        embedding_dimension=100,
    )
    optimizer = SGD(embedding.parameters(), lr=0.01)
    dataset = PersonaDeepWalkDataset(
        graph=persona_karate,
        window_size=3,
        walk_length=20,
        dataset_size=10,
        forward_lookup_persona=forward_persona,
        forward_lookup=forward,
    )
    scheduler = Mock()
    epoch_callback = Mock()
    train(
        dataset=dataset,
        model=embedding,
        scheduler=scheduler,
        epochs=1,
        batch_size=100,
        optimizer=optimizer,
        epoch_callback=epoch_callback,
        cuda=False,
    )
    assert embedding.persona_embedding.weight.shape == (
        persona_karate.number_of_nodes(),
        100,
    )
    assert embedding.embedding.weight.shape == (karate.number_of_nodes(), 100)
    assert scheduler.step.call_count == 1
    assert epoch_callback.call_count == 1
Пример #3
0
def test_embedding_graph():
    forward, reverse = lookup_tables(graph_abcd)
    walks = take(10, iter_random_walks(graph_abcd, 5))
    node_embeddings = initial_deepwalk_embedding(walks, forward, 100)
    initial_embedding = to_embedding_matrix(node_embeddings, 100, reverse)
    embedding = SplitterEmbedding(
        node_count=10,
        persona_node_count=15,
        embedding_dimension=100,
        initial_embedding=initial_embedding,
    )
    persona_batch = torch.ones(5).long()
    output = embedding(persona_batch)
    assert output.shape == (5, 100)
Пример #4
0
def test_predict():
    karate = nx.karate_club_graph()
    persona_karate = persona_graph(karate)
    forward_persona, reverse_persona = lookup_tables(persona_karate)
    embedding = SplitterEmbedding(
        node_count=karate.number_of_nodes(),
        persona_node_count=persona_karate.number_of_nodes(),
        embedding_dimension=100,
    )
    persona_node_list, node_list, index_list, persona_embedding_list = predict(
        reverse_persona, embedding
    )
    for item in [persona_node_list, node_list, index_list, persona_embedding_list]:
        assert len(item) == len(reverse_persona)
    assert set(node_list) == set(karate.nodes)
    assert set(persona_node_list) == set(persona_karate.nodes)
    assert set(index_list).issuperset({0, 1})
Пример #5
0
def test_to_embedding_matrix():
    forward, reverse = lookup_tables(graph_ab)
    walks = take(100, iter_random_walks(graph_ab, 2))
    node_embedding = initial_deepwalk_embedding(walks, forward, 10)
    embedding = to_embedding_matrix(node_embedding, 10, reverse)
    assert embedding.shape == (2, 10)
Пример #6
0
def test_basic_initial_deepwalk_embedding_oov():
    forward, reverse = lookup_tables(graph_ab)
    walks = take(100, cycle([["a"]]))
    embedding = initial_deepwalk_embedding(walks, forward, 10)
    assert len(embedding) == 2
    assert set(embedding.keys()) == {"a", "b"}
Пример #7
0
def test_basic_initial_deepwalk_embedding():
    forward, reverse = lookup_tables(graph_ab)
    walks = take(100, iter_random_walks(graph_ab, 2))
    embedding = initial_deepwalk_embedding(walks, forward, 10)
    assert len(embedding) == 2
    assert set(embedding.keys()) == {"a", "b"}
Пример #8
0
def test_basic_lookup_tables():
    forward, reverse = lookup_tables(graph_ab)
    assert set(forward.keys()) == {"a", "b"}
    assert set(reverse.values()) == {"a", "b"}
    assert set(forward.values()) == {0, 1}
    assert set(reverse.keys()) == {0, 1}
Пример #9
0
print(sum(positive_scores_non_persona))
print(sum(negative_scores_non_persona))

print(
    roc_auc_score(
        [1] * len(positive_samples) + [0] * len(negative_samples),
        positive_scores_non_persona + negative_scores_non_persona,
    )
)

print("Constructing persona graph.")
PG = persona_graph(G_original)

print("Constructing lookups.")
forward_persona, reverse_persona = lookup_tables(PG)
forward, reverse = lookup_tables(G)

groups = groupby(operator.attrgetter("node"), PG.nodes())

positive_scores_persona = [
    excepts(ValueError, max, lambda _: 0.0)(
        iter_get_scores_networkx(groups, node1, node2, PG, jaccard_coefficient)
    )
    for (node1, node2) in positive_samples
]
negative_scores_persona = [
    excepts(ValueError, max, lambda _: 0.0)(
        iter_get_scores_networkx(groups, node1, node2, PG, jaccard_coefficient)
    )
    for (node1, node2) in negative_samples