def test_ppo_get_state_compare_different_agents(): # Assign state_size, action_size = 3, 2 agent_1 = PPOAgent(state_size, action_size, device='cpu', n_steps=1) agent_2 = PPOAgent(state_size, action_size, device='cpu', n_steps=2) # Act state_1 = agent_1.get_state() state_2 = agent_2.get_state() # Assert assert state_1 != state_2 assert state_1.model == state_2.model
def test_ppo_from_state(): # Assign state_shape, action_size = 10, 3 agent = PPOAgent(state_shape, action_size) agent_state = agent.get_state() # Act new_agent = PPOAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, PPOAgent) assert new_agent.hparams == agent.hparams assert all([ torch.all(x == y) for ( x, y) in zip(agent.policy.parameters(), new_agent.policy.parameters()) ]) assert all([ torch.all(x == y) for (x, y) in zip(agent.actor.parameters(), new_agent.actor.parameters()) ]) assert all([ torch.all(x == y) for ( x, y) in zip(agent.critic.parameters(), new_agent.critic.parameters()) ]) assert new_agent.buffer == agent.buffer
def test_ppo_get_state(): # Assign state_size, action_size = 3, 4 init_config = {'actor_lr': 0.1, 'gamma': 0.6} agent = PPOAgent(state_size, action_size, device='cpu', **init_config) # Act agent_state = agent.get_state() # Assert assert isinstance(agent_state, AgentState) assert agent_state.model == PPOAgent.name assert agent_state.state_space == state_size assert agent_state.action_space == action_size assert agent_state.config == agent._config assert agent_state.config['actor_lr'] == 0.1 assert agent_state.config['gamma'] == 0.6 network_state = agent_state.network assert isinstance(network_state, NetworkState) assert {'actor', 'critic', 'policy'} == set(network_state.net.keys()) buffer_state = agent_state.buffer assert isinstance(buffer_state, BufferState) assert buffer_state.type == agent.buffer.type assert buffer_state.batch_size == agent.buffer.batch_size assert buffer_state.buffer_size == agent.buffer.buffer_size
def test_ppo_from_state_one_updated(): # Assign state_shape, action_size = 10, 3 agent = PPOAgent(state_shape, action_size) deterministic_interactions(agent, num_iters=100) agent_state = agent.get_state() deterministic_interactions(agent, num_iters=400) # Act new_agent = PPOAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, PPOAgent) # assert any([torch.any(x != y) for (x, y) in zip(agent.policy.parameters(), new_agent.policy.parameters())]) assert any([ torch.any(x != y) for (x, y) in zip(agent.actor.parameters(), new_agent.actor.parameters()) ]) assert any([ torch.any(x != y) for ( x, y) in zip(agent.critic.parameters(), new_agent.critic.parameters()) ]) assert new_agent.buffer != agent.buffer
def test_agent_factory_ppo_agent_from_state_network_buffer_none(): # Assign state_size, action_size = 10, 5 agent = PPOAgent(state_size, action_size, device="cpu") state = agent.get_state() state.network = None state.buffer = None # Act new_agent = AgentFactory.from_state(state) # Assert assert id(new_agent) != id(agent) assert new_agent.hparams == agent.hparams
def test_agent_factory_ppo_agent_from_state(): # Assign state_size, action_size = 10, 5 agent = PPOAgent(state_size, action_size, device="cpu") state = agent.get_state() # Act new_agent = AgentFactory.from_state(state) # Assert assert id(new_agent) != id(agent) assert new_agent == agent assert new_agent.name == PPOAgent.name assert new_agent.hparams == agent.hparams assert new_agent.buffer == agent.buffer
def test_ppo_from_state_network_state_none(): # Assign state_shape, action_size = 10, 3 agent = PPOAgent(state_shape, action_size) agent_state = agent.get_state() agent_state.network = None # Act new_agent = PPOAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, PPOAgent) assert new_agent.hparams == agent.hparams assert new_agent.buffer == agent.buffer