def test_ppo_from_state(): # Assign state_shape, action_size = 10, 3 agent = PPOAgent(state_shape, action_size) agent_state = agent.get_state() # Act new_agent = PPOAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, PPOAgent) assert new_agent.hparams == agent.hparams assert all([ torch.all(x == y) for ( x, y) in zip(agent.policy.parameters(), new_agent.policy.parameters()) ]) assert all([ torch.all(x == y) for (x, y) in zip(agent.actor.parameters(), new_agent.actor.parameters()) ]) assert all([ torch.all(x == y) for ( x, y) in zip(agent.critic.parameters(), new_agent.critic.parameters()) ]) assert new_agent.buffer == agent.buffer
def test_ppo_from_state_one_updated(): # Assign state_shape, action_size = 10, 3 agent = PPOAgent(state_shape, action_size) deterministic_interactions(agent, num_iters=100) agent_state = agent.get_state() deterministic_interactions(agent, num_iters=400) # Act new_agent = PPOAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, PPOAgent) # assert any([torch.any(x != y) for (x, y) in zip(agent.policy.parameters(), new_agent.policy.parameters())]) assert any([ torch.any(x != y) for (x, y) in zip(agent.actor.parameters(), new_agent.actor.parameters()) ]) assert any([ torch.any(x != y) for ( x, y) in zip(agent.critic.parameters(), new_agent.critic.parameters()) ]) assert new_agent.buffer != agent.buffer
def from_state(state: AgentState) -> AgentBase: if state.model == DQNAgent.name: return DQNAgent.from_state(state) elif state.model == PPOAgent.name: return PPOAgent.from_state(state) else: raise ValueError( f"Agent state contains unsupported model type: '{state.model}'" )
def from_state(state: AgentState) -> AgentBase: norm_model = state.model.upper() if norm_model == DQNAgent.name.upper(): return DQNAgent.from_state(state) elif norm_model == PPOAgent.name.upper(): return PPOAgent.from_state(state) elif norm_model == DDPGAgent.name.upper(): return DDPGAgent.from_state(state) elif norm_model == RainbowAgent.name.upper(): return RainbowAgent.from_state(state) else: raise ValueError( f"Agent state contains unsupported model type: {state.model}")
def test_ppo_from_state_network_state_none(): # Assign state_shape, action_size = 10, 3 agent = PPOAgent(state_shape, action_size) agent_state = agent.get_state() agent_state.network = None # Act new_agent = PPOAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, PPOAgent) assert new_agent.hparams == agent.hparams assert new_agent.buffer == agent.buffer