def test_generator_parameter(self): g = create_test_graph() g = StellarGraph(g) # rw = UniformRandomWalk(g) sampler = UnsupervisedSampler(G=g) # generator should be provided with a valid batch size. i.e. an integer >=1 sample_gen = sampler.generator(batch_size=None) with pytest.raises(ValueError): next(sample_gen) sample_gen = sampler.generator(batch_size="x") with pytest.raises(TypeError): next(sample_gen) sample_gen = sampler.generator(batch_size=0) with pytest.raises(ValueError): next(sample_gen) sample_gen = sampler.generator(batch_size=3) with pytest.raises(ValueError): next(sample_gen)
def test_generator_multiple_batches(self): n_feat = 4 batch_size = 4 number_of_batches = 3 G = example_Graph_2(n_feat) sampler = UnsupervisedSampler(G=G) sample_gen = sampler.generator(batch_size) batches = [] for batch in range(number_of_batches): batches.append(next(sample_gen)) assert len(batches) == number_of_batches
def test_generator_samples(self): n_feat = 4 batch_size = 4 G = example_Graph_2(n_feat) sampler = UnsupervisedSampler(G=G) sample_gen = sampler.generator(batch_size) samples = next(sample_gen) # return two lists: [(target,context)] pairs and [1/0] binary labels assert len(samples) == 2 # each (target, context) pair has a matching label assert len(samples[0]) == len(samples[1]) # batch-size number of samples are returned if batch_size is even assert len(samples[0]) == batch_size