def test_backup_with_ucb():
    root = BanditNode()
    state = FakeGameState()
    root.expand(state)
    children = root.child_nodes()

    state.play(children[0].action)
    children[0].expand(state)

    children[0].child_nodes()[0].backup(-1)
    children[0].child_nodes()[1].backup(1)
    assert BanditNode.ucb_value(children[0], 1) == 0.8325546111576977
    assert BanditNode.ucb_value(children[0].child_nodes()[0], 1) == (
        0.17741002251547466)
    assert BanditNode.ucb_value(children[0].child_nodes()[1], 1) == (
        2.177410022515475)
def test_backup_with_ucb_explore():
    root = UctNode(1)
    state = FakeGameState()
    root.expand(state)
    children = root.child_nodes()

    state.play(children[0].action)
    children[0].expand(state)

    children[0].child_nodes()[0].backup(1)
    assert children[0].value() == -1
    assert children[1].value() == float("inf")
    assert children[0].child_nodes()[0].value() == 1
    assert children[0].child_nodes()[1].value() == float("inf")
    children[0].child_nodes()[1].backup(-1)
    assert children[0].value() > 0
    assert children[0].child_nodes()[0].value() > 1
    assert children[0].child_nodes()[1].value() > -1
def test_backup_with_value():
    root = BanditNode()
    state = FakeGameState()
    root.expand(state)
    children = root.child_nodes()

    state.play(children[0].action)
    children[0].expand(state)

    children[0].child_nodes()[0].backup(1)
    assert children[0].value() == -1
    assert children[1].value() == 0
    assert children[0].child_nodes()[0].value() == 1
    assert children[0].child_nodes()[1].value() == 0
    children[0].child_nodes()[1].backup(-1)
    assert children[0].value() == 0
    assert children[0].child_nodes()[0].value() == 1
    assert children[0].child_nodes()[1].value() == -1
Ejemplo n.º 4
0
def test_backup_with_value():
    rave_moves = {0: [0]}
    root = RaveNode(1, 300)
    state = FakeGameState()
    root.expand(state)

    children = root.child_nodes()
    state.play(children[0].action)
    children[0].expand(state)

    children[0].child_nodes()[0].backup(-1, rave_moves)
    assert children[0].child_nodes()[0].rave_num_visits == 1
    assert children[0].child_nodes()[1].rave_num_visits == 0
    assert children[0].value() == 0.0033333333333332993
    assert children[1].value() == INF
    assert children[0].child_nodes()[0].value() == -1
    assert children[0].child_nodes()[1].value() == INF
    children[0].child_nodes()[1].backup(1, rave_moves)
    assert children[0].child_nodes()[0].rave_num_visits == 2
    assert children[0].child_nodes()[1].rave_num_visits == 0
    assert children[0].value() == 0.8325546111576977
    assert children[0].child_nodes()[0].value() == 1.1740766891821415
    assert children[0].child_nodes()[1].value() == 1.1807433558488079
Ejemplo n.º 5
0
def test_backup():
    rave_moves = {0: [0]}
    root = RaveNode(1, 300)
    state = FakeGameState()
    RaveNode.enable_rave(state)
    root.expand(state)
    children = root.child_nodes()

    state.play(children[0].action)
    children[0].expand(state)

    children[0].child_nodes()[0].backup(-1, rave_moves)
    assert children[0].child_nodes()[0].avg_reward() == -1
    assert children[0].child_nodes()[1].avg_reward() == 0

    assert children[0].avg_reward() == 1
    assert len(children[1].child_nodes()) == 0
    assert children[1].avg_reward() == 0
    children[0].child_nodes()[1].backup(1, rave_moves)
    assert children[0].child_nodes()[0].avg_reward() == -1
    assert children[0].child_nodes()[1].avg_reward() == 1
    assert children[0].avg_reward() == 0
    assert len(children[1].child_nodes()) == 0
    assert children[1].avg_reward() == 0
Ejemplo n.º 6
0
def test_roll_out():
    random.seed(0)

    state = FakeGameState()
    RaveNode.enable_rave(state)
    patient = RaveAgent(random)
    outcome = patient.roll_out(state, 0)
    assert outcome['score'] == 2
    assert len(outcome['rave_moves'][0]) == 2
    outcome = patient.roll_out(state, 0)
    assert outcome['score'] == -3
    outcome = patient.roll_out(state, 0)
    assert outcome['score'] == -3
    outcome = patient.roll_out(state, 1)
    assert outcome['score'] == -2
    outcome = patient.roll_out(state, 1)
    assert outcome['score'] == 3
    outcome = patient.roll_out(state, 1)
    assert outcome['score'] == 3
    outcome = patient.roll_out(state, 1)
    assert outcome['score'] == -2
Ejemplo n.º 7
0
def test_select_action(player):
    player.select_action(FakeGameState(), time_allowed_s=0.001)