def test_ucb_score_no_children_visited(self): node = Node(0.5, to_play=1) node.visit_count = 1 state = [0, 0, 0, 0] action_probs = [0.25, 0.15, 0.5, 0.1] to_play = 1 node.expand(state, to_play, action_probs) node.children[0].visit_count = 0 node.children[1].visit_count = 0 node.children[2].visit_count = 1 node.children[3].visit_count = 0 score_0 = ucb_score(node, node.children[0]) score_1 = ucb_score(node, node.children[1]) score_2 = ucb_score(node, node.children[2]) score_3 = ucb_score(node, node.children[3]) # With no visits, UCB score is just the priors self.assertEqual(score_0, node.children[0].prior) self.assertEqual(score_1, node.children[1].prior) # If we visit one child once, its score is halved self.assertEqual(score_2, node.children[2].prior / 2) self.assertEqual(score_3, node.children[3].prior)
def test_ucb_score_one_child_visited_twice(self): node = Node(0.5, to_play=1) node.visit_count = 2 state = [0, 0, 0, 0] action_probs = [0.25, 0.15, 0.5, 0.1] to_play = 1 node.expand(state, to_play, action_probs) node.children[0].visit_count = 0 node.children[1].visit_count = 0 node.children[2].visit_count = 2 node.children[3].visit_count = 0 score_0 = ucb_score(node, node.children[0]) score_1 = ucb_score(node, node.children[1]) score_2 = ucb_score(node, node.children[2]) score_3 = ucb_score(node, node.children[3]) action, child = node.select_child() # Now that we've visited the second action twice, we should # end up trying the first action self.assertEqual(action, 0)