class AdaptiveSamplerTest(unittest.TestCase): def setUp(self): random.seed(123456) tree = nx.DiGraph() tree.add_edges_from([ (0, 1), (1, 2), (1, 3), (2, 4), (2, 5), (0, 6), (6, 7), (0, 8) ]) self.assign_g_attrs(tree) self.tree = tree for t, nodes in enumerate([(0, ), (1, 6, 8), (2, 3, 7), (4, 5)]): for n in nodes: tree.node[n]['datetime'] = t self.s = AdaptiveSampler(self.tree, B=3, timespan_secs=1, node_score_func=lambda p, c: p**2 / c) def assign_g_attrs(self, tree): for s, t in tree.edges_iter(): tree[s][t]['c'] = 1 for n in tree.nodes_iter(): tree.node[n]['r'] = 1 def test_sampler_init(self): assert_equal( # {0: 4, 1: 3, 2: 3, 6: 2}, [0, 2, 1, 6], self.s.roots_sorted_by_upperbound ) assert_equal( 1.0, self.s.explore_proba ) assert_equal(4, self.s.n_nodes_to_cover) def test_update(self): result_tree = nx.DiGraph() result_tree.add_edges_from( [(0, 1), (0, 6), (1, 3)] ) self.assign_g_attrs(result_tree) self.s.update(0, result_tree) assert_equal( 0.5, self.s.explore_proba ) assert_equal( set([0, 1]), self.s.covered_nodes ) assert_equal( {1: 2 ** 2}, self.s.node2score ) # case: score of node 1 increases result_tree.add_edge(1, 2) self.assign_g_attrs(result_tree) self.s.update(0, result_tree) assert_equal( {1: 3 ** 2 / 2}, self.s.node2score ) def test_update_border_case(self): self.s.update(0, self.tree) assert_equal( set([0, 1, 2, 6]), self.s.covered_nodes ) assert_equal(0, self.s.explore_proba) def test_explore_proba(self): assert_equal(1, self.s.explore_proba) result_tree = nx.DiGraph() result_tree.add_edges_from( [(0, 1), (0, 6), (1, 3)] ) self.assign_g_attrs(result_tree) self.s.update(0, result_tree) assert_almost_equal(2 / 4., self.s.explore_proba) def test_take_via_explore(self): r, tree = self.s.take() assert_equal('explore', self.s.random_action()) assert_equal(0, r) assert_equal( sorted([(0, 1), (0, 6), (0, 8)]), sorted(tree.edges()) ) # on and on r, tree = self.s.take() assert_equal(2, r) r, tree = self.s.take() assert_equal(1, r) def test_take_via_exploit(self): # round 1 self.s.update(0, self.tree) assert_equal('exploit', self.s.random_action()) r, tree = self.s.take() assert_equal(1, r) # round 2 self.s.update(r, tree) assert_true(r not in self.s.node2score)