示例#1
0
def test_adds_correct_args():
    deep_walk_args = ["walk-length", "walk-num", "window-size", "worker", "iteration"]
    deep_walk_calls = [call(f"--{x}", type=int, default=mock.ANY, help=mock.ANY) for x in deep_walk_args]

    parser = ArgumentParser()
    with patch.object(parser, "add_argument", return_value=None) as mocked_method:
        DeepWalk.add_args(parser)
        mocked_method.assert_has_calls(deep_walk_calls)
示例#2
0
def test_will_return_computed_embeddings_for_simple_fully_connected_graph():
    args = get_args()
    model: DeepWalk = DeepWalk.build_model_from_args(args)
    graph = Graph(edge_index=(torch.LongTensor([0]), torch.LongTensor([1])))
    trained = model(graph, creator)
    assert len(trained) == 2
    np.testing.assert_array_equal(trained[0], embed_1)
    np.testing.assert_array_equal(trained[1], embed_2)
示例#3
0
def test_correctly_builds():
    args = get_args()
    model = DeepWalk.build_model_from_args(args)
    assert model.dimension == args.hidden_size
    assert model.walk_length == args.walk_length
    assert model.walk_num == args.walk_num
    assert model.window_size == args.window_size
    assert model.worker == args.worker
    assert model.iteration == args.iteration
示例#4
0
def test_will_return_computed_embeddings_for_simple_fully_connected_graph():
    args = get_args()
    model: DeepWalk = DeepWalk.build_model_from_args(args)
    graph = nx.Graph()
    graph.add_nodes_from([1, 2])
    graph.add_edge(1, 2)
    trained = model.train(graph, creator)
    assert len(trained) == 2
    np.testing.assert_array_equal(trained[0], embed_1)
    np.testing.assert_array_equal(trained[1], embed_2)
示例#5
0
def test_will_pass_correct_number_of_walks():
    args = get_args()
    args.walk_num = 2
    model: DeepWalk = DeepWalk.build_model_from_args(args)
    graph = nx.Graph()
    graph.add_nodes_from([1, 2, 3])
    captured_walks_no = []

    def creator_mocked(walks, size, window, min_count, sg, workers, iter):
        captured_walks_no.append(len(walks))
        return creator(walks, size, window, min_count, sg, workers, iter)

    model.train(graph, creator_mocked)
    assert captured_walks_no[0] == args.walk_num * len(graph)
示例#6
0
def test_will_pass_correct_number_of_walks():
    args = get_args()
    args.walk_num = 2
    model: DeepWalk = DeepWalk.build_model_from_args(args)
    graph = Graph(edge_index=(torch.LongTensor([0, 1]),
                              torch.LongTensor([1, 2])))
    captured_walks_no = []

    def creator_mocked(walks, size, window, min_count, sg, workers, iter):
        captured_walks_no.append(len(walks))
        return creator(walks, size, window, min_count, sg, workers, iter)

    model(graph, creator_mocked)
    assert captured_walks_no[0] == args.walk_num * graph.num_nodes