예제 #1
0
 def from_state(state: AgentState) -> AgentBase:
     norm_model = state.model.upper()
     if norm_model == DQNAgent.name.upper():
         return DQNAgent.from_state(state)
     elif norm_model == PPOAgent.name.upper():
         return PPOAgent.from_state(state)
     elif norm_model == DDPGAgent.name.upper():
         return DDPGAgent.from_state(state)
     elif norm_model == RainbowAgent.name.upper():
         return RainbowAgent.from_state(state)
     else:
         raise ValueError(
             f"Agent state contains unsupported model type: {state.model}")
예제 #2
0
def test_rainbow_from_state():
    # Assign
    state_shape, action_size = 10, 3
    agent = RainbowAgent(state_shape, action_size, device='cpu')
    agent_state = agent.get_state()

    # Act
    new_agent = RainbowAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, RainbowAgent)
    assert new_agent.hparams == agent.hparams
    assert all([torch.all(x == y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())])
    assert all([torch.all(x == y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())])
    assert new_agent.buffer == agent.buffer
예제 #3
0
def test_rainbow_from_state_one_updated():
    # Assign
    state_shape, action_size = 10, 3
    agent = RainbowAgent(state_shape, action_size, device='cpu')
    feed_agent(agent, 2*agent.batch_size)  # Feed 1
    agent_state = agent.get_state()
    feed_agent(agent, 100)  # Feed 2 - to make different

    # Act
    new_agent = RainbowAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, RainbowAgent)
    assert new_agent.hparams == agent.hparams
    assert any([torch.any(x != y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())])
    assert any([torch.any(x != y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())])
    assert new_agent.buffer != agent.buffer