def test_rainbow_get_state_compare_different_agents(): # Assign state_size, action_size = 3, 2 agent_1 = RainbowAgent(state_size, action_size, device='cpu', n_steps=1) agent_2 = RainbowAgent(state_size, action_size, device='cpu', n_steps=2) # Act state_1 = agent_1.get_state() state_2 = agent_2.get_state() # Assert assert state_1 != state_2 assert state_1.model == state_2.model
def test_rainbow_get_state(): # Assign state_size, action_size = 3, 4 init_config = {'lr': 0.1, 'gamma': 0.6} agent = RainbowAgent(state_size, action_size, device='cpu', **init_config) # Act agent_state = agent.get_state() # Assert assert isinstance(agent_state, AgentState) assert agent_state.model == RainbowAgent.name assert agent_state.state_space == state_size assert agent_state.action_space == action_size assert agent_state.config == agent._config assert agent_state.config['lr'] == 0.1 assert agent_state.config['gamma'] == 0.6 network_state = agent_state.network assert isinstance(network_state, NetworkState) assert {'net', 'target_net'} == set(network_state.net.keys()) buffer_state = agent_state.buffer assert isinstance(buffer_state, BufferState) assert buffer_state.type == agent.buffer.type assert buffer_state.batch_size == agent.buffer.batch_size assert buffer_state.buffer_size == agent.buffer.buffer_size
def test_rainbow_from_state(): # Assign state_shape, action_size = 10, 3 agent = RainbowAgent(state_shape, action_size, device='cpu') agent_state = agent.get_state() # Act new_agent = RainbowAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, RainbowAgent) assert new_agent.hparams == agent.hparams assert all([torch.all(x == y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())]) assert all([torch.all(x == y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())]) assert new_agent.buffer == agent.buffer
def test_rainbow_from_state_one_updated(): # Assign state_shape, action_size = 10, 3 agent = RainbowAgent(state_shape, action_size, device='cpu') feed_agent(agent, 2*agent.batch_size) # Feed 1 agent_state = agent.get_state() feed_agent(agent, 100) # Feed 2 - to make different # Act new_agent = RainbowAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, RainbowAgent) assert new_agent.hparams == agent.hparams assert any([torch.any(x != y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())]) assert any([torch.any(x != y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())]) assert new_agent.buffer != agent.buffer