Exemplo n.º 1
0
def test_walker_custom(line_graph):
    walker = CustomWalker()
    sampler = UnsupervisedSampler(line_graph, walker=walker)
    batches = sampler.run(2)

    assert len(batches) == line_graph.number_of_nodes()

    # all positive examples should be self loops, since we defined our custom walker this way
    for context_pairs, labels in batches:
        for node, neighbour in context_pairs[labels == 1]:
            assert node == neighbour
Exemplo n.º 2
0
    def test_run_batch_sizes(self, line_graph):
        batch_size = 4
        sampler = UnsupervisedSampler(G=line_graph, length=2, number_of_walks=2)
        batches = sampler.run(batch_size)

        # check batch sizes
        assert len(batches) == np.ceil(len(line_graph.nodes()) * 4 / batch_size)
        for ids, labels in batches[:-1]:
            assert len(ids) == len(labels) == batch_size

        # last batch can be smaller
        ids, labels = batches[-1]
        assert len(ids) == len(labels)
        assert len(ids) <= batch_size
Exemplo n.º 3
0
def test_walker_uniform_random(line_graph):
    length = 3
    number_of_walks = 2
    batch_size = 4

    walker = UniformRandomWalk(line_graph, n=number_of_walks, length=length)
    sampler = UnsupervisedSampler(line_graph, walker=walker)

    batches = sampler.run(batch_size)

    # batches should match the parameters used to create the walker object, instead of the defaults
    # for UnsupervisedSampler
    expected_num_batches = np.ceil(line_graph.number_of_nodes() *
                                   number_of_walks * (length - 1) * 2 /
                                   batch_size)
    assert len(batches) == expected_num_batches
Exemplo n.º 4
0
    def test_run_context_pairs(self, line_graph):
        batch_size = 4
        sampler = UnsupervisedSampler(G=line_graph, length=2, number_of_walks=2)
        batches = sampler.run(batch_size)

        grouped_by_target = defaultdict(list)

        for ids, labels in batches:
            for (target, context), label in zip(ids, labels):
                grouped_by_target[target].append((context, label))

        assert len(grouped_by_target) == len(line_graph.nodes())

        for target, sampled in grouped_by_target.items():
            # exactly 2 positive and 2 negative context pairs for each target node
            assert len(sampled) == 4

            # since each walk has length = 2, there must be an edge between each positive context pair
            for context, label in sampled:
                if label == 1:
                    assert context in set(line_graph.neighbors(target))