def test_ddpg_seed():
    # Assign
    agent_0 = DDPGAgent(4, 2, device='cpu')
    agent_1 = DDPGAgent(4, 2, device='cpu')
    agent_2 = copy.deepcopy(agent_1)

    # Act
    # Make sure agents have the same networks
    assert all([
        sum(sum(l1.weight - l2.weight)) == 0
        for l1, l2 in zip(agent_1.actor.layers, agent_2.actor.layers)
    ])
    assert all([
        sum(sum(l1.weight - l2.weight)) == 0
        for l1, l2 in zip(agent_1.critic.layers, agent_2.critic.layers)
    ])

    agent_0.seed(32167)
    actions_0 = deterministic_interactions(agent_0)
    agent_1.seed(0)
    actions_1 = deterministic_interactions(agent_1)
    agent_2.seed(0)
    actions_2 = deterministic_interactions(agent_2)

    # Assert
    # First we check that there's definitely more than one type of action
    assert actions_1[0] != actions_1[1]
    assert actions_2[0] != actions_2[1]

    # All generated actions need to identical
    assert any(a0 != a1 for (a0, a1) in zip(actions_0, actions_1))
    for idx, (a1, a2) in enumerate(zip(actions_1, actions_2)):
        assert a1 == pytest.approx(
            a2, 1e-4), f"Action mismatch on position {idx}: {a1} != {a2}"
def test_rainbow_seed():
    # Assign
    agent_0 = RainbowAgent(4, 4, device='cpu')
    agent_1 = RainbowAgent(4, 4, device='cpu')
    agent_2 = copy.deepcopy(agent_1)

    # Act
    # Make sure agents have the same networks
    agent_nets = zip(agent_1.net.value_net.layers,
                     agent_2.net.value_net.layers)
    agent_target_nets = zip(agent_1.target_net.value_net.layers,
                            agent_2.target_net.value_net.layers)
    assert all([sum(sum(l1.weight - l2.weight)) == 0 for l1, l2 in agent_nets])
    assert all(
        [sum(sum(l1.weight - l2.weight)) == 0 for l1, l2 in agent_target_nets])

    agent_0.seed(32167)
    actions_0 = deterministic_interactions(agent_0)
    agent_1.seed(0)
    actions_1 = deterministic_interactions(agent_1)
    agent_2.seed(0)
    actions_2 = deterministic_interactions(agent_2)

    # Assert
    assert any(a0 != a1 for (a0, a1) in zip(actions_0, actions_1))
    for idx, (a1, a2) in enumerate(zip(actions_1, actions_2)):
        assert a1 == a2, f"Action mismatch on position {idx}: {a1} != {a2}"
Beispiel #3
0
def test_ppo_from_state_one_updated():
    # Assign
    state_shape, action_size = 10, 3
    agent = PPOAgent(state_shape, action_size)
    deterministic_interactions(agent, num_iters=100)
    agent_state = agent.get_state()
    deterministic_interactions(agent, num_iters=400)

    # Act
    new_agent = PPOAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, PPOAgent)
    # assert any([torch.any(x != y) for (x, y) in zip(agent.policy.parameters(), new_agent.policy.parameters())])
    assert any([
        torch.any(x != y)
        for (x,
             y) in zip(agent.actor.parameters(), new_agent.actor.parameters())
    ])
    assert any([
        torch.any(x != y) for (
            x,
            y) in zip(agent.critic.parameters(), new_agent.critic.parameters())
    ])
    assert new_agent.buffer != agent.buffer
Beispiel #4
0
def test_serialize_network_state_actual():
    from ai_traineree.agents.dqn import DQNAgent

    agent = DQNAgent(10, 4)
    deterministic_interactions(agent, 30)
    network_state = agent.get_network_state()

    # Act
    ser = serialize(network_state)

    # Assert
    des = json.loads(ser)
    assert set(des['net'].keys()) == set(('target_net', 'net'))
Beispiel #5
0
def test_serialize_agent_state_actual():
    from ai_traineree.agents.dqn import DQNAgent

    agent = DQNAgent(10, 4)
    deterministic_interactions(agent, 30)
    state = agent.get_state()

    # Act
    ser = serialize(state)

    # Assert
    des = json.loads(ser)
    assert des['model'] == DQNAgent.name
    assert len(des['buffer']['data']) == 30
    assert set(des['network']['net'].keys()) == set(('target_net', 'net'))