def test_walker_custom(line_graph): walker = CustomWalker() sampler = UnsupervisedSampler(line_graph, walker=walker) batches = sampler.run(2) assert len(batches) == line_graph.number_of_nodes() # all positive examples should be self loops, since we defined our custom walker this way for context_pairs, labels in batches: for node, neighbour in context_pairs[labels == 1]: assert node == neighbour
def test_run_batch_sizes(self, line_graph): batch_size = 4 sampler = UnsupervisedSampler(G=line_graph, length=2, number_of_walks=2) batches = sampler.run(batch_size) # check batch sizes assert len(batches) == np.ceil(len(line_graph.nodes()) * 4 / batch_size) for ids, labels in batches[:-1]: assert len(ids) == len(labels) == batch_size # last batch can be smaller ids, labels = batches[-1] assert len(ids) == len(labels) assert len(ids) <= batch_size
def test_walker_uniform_random(line_graph): length = 3 number_of_walks = 2 batch_size = 4 walker = UniformRandomWalk(line_graph, n=number_of_walks, length=length) sampler = UnsupervisedSampler(line_graph, walker=walker) batches = sampler.run(batch_size) # batches should match the parameters used to create the walker object, instead of the defaults # for UnsupervisedSampler expected_num_batches = np.ceil(line_graph.number_of_nodes() * number_of_walks * (length - 1) * 2 / batch_size) assert len(batches) == expected_num_batches
def test_run_context_pairs(self, line_graph): batch_size = 4 sampler = UnsupervisedSampler(G=line_graph, length=2, number_of_walks=2) batches = sampler.run(batch_size) grouped_by_target = defaultdict(list) for ids, labels in batches: for (target, context), label in zip(ids, labels): grouped_by_target[target].append((context, label)) assert len(grouped_by_target) == len(line_graph.nodes()) for target, sampled in grouped_by_target.items(): # exactly 2 positive and 2 negative context pairs for each target node assert len(sampled) == 4 # since each walk has length = 2, there must be an edge between each positive context pair for context, label in sampled: if label == 1: assert context in set(line_graph.neighbors(target))