def test_expand(): root = BanditNode() state = FakeGameState() root.expand(state) children = root.child_nodes() assert len(children) == 2 assert children[0].action == 0 assert children[1].action == 1 assert children[0].parent == root assert children[1].parent == root
def test_backup_with_value(): root = BanditNode() state = FakeGameState() root.expand(state) children = root.child_nodes() state.play(children[0].action) children[0].expand(state) children[0].child_nodes()[0].backup(1) assert children[0].value() == -1 assert children[1].value() == 0 assert children[0].child_nodes()[0].value() == 1 assert children[0].child_nodes()[1].value() == 0 children[0].child_nodes()[1].backup(-1) assert children[0].value() == 0 assert children[0].child_nodes()[0].value() == 1 assert children[0].child_nodes()[1].value() == -1
def test_is_root(): root = BanditNode() state = FakeGameState() root.expand(state) assert root.is_root() for child in root.child_nodes(): assert not child.is_root()
def test_ucb_initial_explore(): root = BanditNode() state = FakeGameState() root.expand(state) children = root.child_nodes() for child in children: assert BanditNode.ucb_value(child, 1) == float('inf')
def test_num_nodes(): root = BanditNode() state = FakeGameState() root.expand(state) assert root.num_nodes() == 3 for child in root.child_nodes(): assert child.num_nodes() == 1
def test_child_nodes(): root = BanditNode() state = FakeGameState() root.expand(state) assert len(root.child_nodes()) == 2 for child in root.child_nodes(): assert len(child.child_nodes()) == 0
def test_backup_with_ucb(): root = BanditNode() state = FakeGameState() root.expand(state) children = root.child_nodes() state.play(children[0].action) children[0].expand(state) children[0].child_nodes()[0].backup(-1) children[0].child_nodes()[1].backup(1) assert BanditNode.ucb_value(children[0], 1) == 0.8325546111576977 assert BanditNode.ucb_value(children[0].child_nodes()[0], 1) == ( 0.17741002251547466) assert BanditNode.ucb_value(children[0].child_nodes()[1], 1) == ( 2.177410022515475)
def test_info_strings_to_json(): root = BanditNode() state = FakeGameState() root.expand(state) children = root.child_nodes() children[0].backup(1) children[1].backup(-1) info = root.info_strings_to_dict() assert info["info"] == "avg_reward: 0.0 num_visits: 2" assert info["children"][0][ "info"] == "player: 0 action: 0 | avg_reward: 1.0 num_visits: 1" assert info["children"][1][ "info"] == "player: 0 action: 1 | avg_reward: -1.0 num_visits: 1"
def test_is_leaf(): root = BanditNode() assert root.is_leaf() state = FakeGameState() root.expand(state) assert not root.is_leaf()
def test_child_nodes(): root = BanditNode() assert root.child_nodes() == []
def test_str(): root = BanditNode() state = FakeGameState() assert str(root) == '{\n "info": "avg_reward: 0 num_visits: 0"\n}'