예제 #1
0
 def from_state(state: AgentState) -> AgentBase:
     if state.model == DQNAgent.name:
         return DQNAgent.from_state(state)
     elif state.model == PPOAgent.name:
         return PPOAgent.from_state(state)
     else:
         raise ValueError(
             f"Agent state contains unsupported model type: '{state.model}'"
         )
예제 #2
0
 def from_state(state: AgentState) -> AgentBase:
     norm_model = state.model.upper()
     if norm_model == DQNAgent.name.upper():
         return DQNAgent.from_state(state)
     elif norm_model == PPOAgent.name.upper():
         return PPOAgent.from_state(state)
     elif norm_model == DDPGAgent.name.upper():
         return DDPGAgent.from_state(state)
     elif norm_model == RainbowAgent.name.upper():
         return RainbowAgent.from_state(state)
     else:
         raise ValueError(
             f"Agent state contains unsupported model type: {state.model}")
예제 #3
0
def test_dqn_from_state_network_state_none():
    # Assign
    state_shape, action_size = 10, 3
    agent = DQNAgent(state_shape, action_size)
    agent_state = agent.get_state()
    agent_state.network = None

    # Act
    new_agent = DQNAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, DQNAgent)
    assert new_agent.hparams == agent.hparams
    assert new_agent.buffer == agent.buffer
예제 #4
0
def test_dqn_from_state():
    # Assign
    state_shape, action_size = 10, 3
    agent = DQNAgent(state_shape, action_size)
    agent_state = agent.get_state()

    # Act
    new_agent = DQNAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, DQNAgent)
    assert new_agent.hparams == agent.hparams
    assert all([torch.all(x == y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())])
    assert all([torch.all(x == y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())])
    assert new_agent.buffer == agent.buffer
예제 #5
0
def test_dqn_from_state_one_updated():
    # Assign
    state_shape, action_size = 10, 3
    agent = DQNAgent(state_shape, action_size)
    feed_agent(agent, 2*agent.batch_size)  # Feed 1
    agent_state = agent.get_state()
    feed_agent(agent, 100)  # Feed 2 - to make different

    # Act
    new_agent = DQNAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, DQNAgent)
    assert new_agent.hparams == agent.hparams
    assert any([torch.any(x != y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())])
    assert any([torch.any(x != y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())])
    assert new_agent.buffer != agent.buffer