Beispiel #1
0
    def test_ucb_score_no_children_visited(self):
        node = Node(0.5, to_play=1)
        node.visit_count = 1

        state = [0, 0, 0, 0]
        action_probs = [0.25, 0.15, 0.5, 0.1]
        to_play = 1

        node.expand(state, to_play, action_probs)
        node.children[0].visit_count = 0
        node.children[1].visit_count = 0
        node.children[2].visit_count = 1
        node.children[3].visit_count = 0

        score_0 = ucb_score(node, node.children[0])
        score_1 = ucb_score(node, node.children[1])
        score_2 = ucb_score(node, node.children[2])
        score_3 = ucb_score(node, node.children[3])

        # With no visits, UCB score is just the priors
        self.assertEqual(score_0, node.children[0].prior)
        self.assertEqual(score_1, node.children[1].prior)
        # If we visit one child once, its score is halved
        self.assertEqual(score_2, node.children[2].prior / 2)
        self.assertEqual(score_3, node.children[3].prior)
Beispiel #2
0
    def test_initialization(self):
        node = Node(0.5, to_play=1)

        self.assertEqual(node.visit_count, 0)
        self.assertEqual(node.prior, 0.5)
        self.assertEqual(len(node.children), 0)
        self.assertFalse(node.expanded())
        self.assertEqual(node.value(), 0)
Beispiel #3
0
    def test_expansion(self):
        node = Node(0.5, to_play=1)

        state = [0, 0, 0, 0]
        action_probs = [0.25, 0.15, 0.5, 0.1]
        to_play = 1

        node.expand(state, to_play, action_probs)

        self.assertEqual(len(node.children), 4)
        self.assertTrue(node.expanded())
        self.assertEqual(node.to_play, to_play)
        self.assertEqual(node.children[0].prior, 0.25)
        self.assertEqual(node.children[1].prior, 0.15)
        self.assertEqual(node.children[2].prior, 0.50)
        self.assertEqual(node.children[3].prior, 0.10)
Beispiel #4
0
    def test_ucb_score_one_child_visited_twice(self):
        node = Node(0.5, to_play=1)
        node.visit_count = 2

        state = [0, 0, 0, 0]

        action_probs = [0.25, 0.15, 0.5, 0.1]
        to_play = 1

        node.expand(state, to_play, action_probs)
        node.children[0].visit_count = 0
        node.children[1].visit_count = 0
        node.children[2].visit_count = 2
        node.children[3].visit_count = 0

        score_0 = ucb_score(node, node.children[0])
        score_1 = ucb_score(node, node.children[1])
        score_2 = ucb_score(node, node.children[2])
        score_3 = ucb_score(node, node.children[3])

        action, child = node.select_child()

        # Now that we've visited the second action twice, we should
        # end up trying the first action
        self.assertEqual(action, 0)
Beispiel #5
0
    def test_selection(self):
        node = Node(0.5, to_play=1)
        c0 = Node(0.5, to_play=-1)
        c1 = Node(0.5, to_play=-1)
        c2 = Node(0.5, to_play=-1)
        node.visit_count = 1
        c0.visit_count = 0
        c2.visit_count = 0
        c2.visit_count = 1

        node.children = {
            0: c0,
            1: c1,
            2: c2,
        }

        action = node.select_action(temperature=0)
        self.assertEqual(action, 2)