Exemple #1
0
def test_ppo_from_state():
    # Assign
    state_shape, action_size = 10, 3
    agent = PPOAgent(state_shape, action_size)
    agent_state = agent.get_state()

    # Act
    new_agent = PPOAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, PPOAgent)
    assert new_agent.hparams == agent.hparams
    assert all([
        torch.all(x == y) for (
            x,
            y) in zip(agent.policy.parameters(), new_agent.policy.parameters())
    ])
    assert all([
        torch.all(x == y)
        for (x,
             y) in zip(agent.actor.parameters(), new_agent.actor.parameters())
    ])
    assert all([
        torch.all(x == y) for (
            x,
            y) in zip(agent.critic.parameters(), new_agent.critic.parameters())
    ])
    assert new_agent.buffer == agent.buffer
Exemple #2
0
def test_ppo_from_state_one_updated():
    # Assign
    state_shape, action_size = 10, 3
    agent = PPOAgent(state_shape, action_size)
    deterministic_interactions(agent, num_iters=100)
    agent_state = agent.get_state()
    deterministic_interactions(agent, num_iters=400)

    # Act
    new_agent = PPOAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, PPOAgent)
    # assert any([torch.any(x != y) for (x, y) in zip(agent.policy.parameters(), new_agent.policy.parameters())])
    assert any([
        torch.any(x != y)
        for (x,
             y) in zip(agent.actor.parameters(), new_agent.actor.parameters())
    ])
    assert any([
        torch.any(x != y) for (
            x,
            y) in zip(agent.critic.parameters(), new_agent.critic.parameters())
    ])
    assert new_agent.buffer != agent.buffer
 def from_state(state: AgentState) -> AgentBase:
     if state.model == DQNAgent.name:
         return DQNAgent.from_state(state)
     elif state.model == PPOAgent.name:
         return PPOAgent.from_state(state)
     else:
         raise ValueError(
             f"Agent state contains unsupported model type: '{state.model}'"
         )
Exemple #4
0
 def from_state(state: AgentState) -> AgentBase:
     norm_model = state.model.upper()
     if norm_model == DQNAgent.name.upper():
         return DQNAgent.from_state(state)
     elif norm_model == PPOAgent.name.upper():
         return PPOAgent.from_state(state)
     elif norm_model == DDPGAgent.name.upper():
         return DDPGAgent.from_state(state)
     elif norm_model == RainbowAgent.name.upper():
         return RainbowAgent.from_state(state)
     else:
         raise ValueError(
             f"Agent state contains unsupported model type: {state.model}")
Exemple #5
0
def test_ppo_from_state_network_state_none():
    # Assign
    state_shape, action_size = 10, 3
    agent = PPOAgent(state_shape, action_size)
    agent_state = agent.get_state()
    agent_state.network = None

    # Act
    new_agent = PPOAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, PPOAgent)
    assert new_agent.hparams == agent.hparams
    assert new_agent.buffer == agent.buffer