def from_state(state: AgentState) -> AgentBase: if state.model == DQNAgent.name: return DQNAgent.from_state(state) elif state.model == PPOAgent.name: return PPOAgent.from_state(state) else: raise ValueError( f"Agent state contains unsupported model type: '{state.model}'" )
def from_state(state: AgentState) -> AgentBase: norm_model = state.model.upper() if norm_model == DQNAgent.name.upper(): return DQNAgent.from_state(state) elif norm_model == PPOAgent.name.upper(): return PPOAgent.from_state(state) elif norm_model == DDPGAgent.name.upper(): return DDPGAgent.from_state(state) elif norm_model == RainbowAgent.name.upper(): return RainbowAgent.from_state(state) else: raise ValueError( f"Agent state contains unsupported model type: {state.model}")
def test_dqn_from_state_network_state_none(): # Assign state_shape, action_size = 10, 3 agent = DQNAgent(state_shape, action_size) agent_state = agent.get_state() agent_state.network = None # Act new_agent = DQNAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, DQNAgent) assert new_agent.hparams == agent.hparams assert new_agent.buffer == agent.buffer
def test_dqn_from_state(): # Assign state_shape, action_size = 10, 3 agent = DQNAgent(state_shape, action_size) agent_state = agent.get_state() # Act new_agent = DQNAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, DQNAgent) assert new_agent.hparams == agent.hparams assert all([torch.all(x == y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())]) assert all([torch.all(x == y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())]) assert new_agent.buffer == agent.buffer
def test_dqn_from_state_one_updated(): # Assign state_shape, action_size = 10, 3 agent = DQNAgent(state_shape, action_size) feed_agent(agent, 2*agent.batch_size) # Feed 1 agent_state = agent.get_state() feed_agent(agent, 100) # Feed 2 - to make different # Act new_agent = DQNAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, DQNAgent) assert new_agent.hparams == agent.hparams assert any([torch.any(x != y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())]) assert any([torch.any(x != y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())]) assert new_agent.buffer != agent.buffer