def test_dqn_get_state_compare_different_agents(): # Assign state_size, action_size = 3, 2 agent_1 = DQNAgent(state_size, action_size, device='cpu', n_steps=1) agent_2 = DQNAgent(state_size, action_size, device='cpu', n_steps=2) # Act state_1 = agent_1.get_state() state_2 = agent_2.get_state() # Assert assert state_1 != state_2 assert state_1.model == state_2.model
def test_dqn_get_state(): # Assign state_size, action_size = 3, 4 init_config = {'lr': 0.1, 'gamma': 0.6} agent = DQNAgent(state_size, action_size, device='cpu', **init_config) # Act agent_state = agent.get_state() # Assert assert isinstance(agent_state, AgentState) assert agent_state.model == DQNAgent.name assert agent_state.state_space == state_size assert agent_state.action_space == action_size assert agent_state.config == agent._config assert agent_state.config['lr'] == 0.1 assert agent_state.config['gamma'] == 0.6 network_state = agent_state.network assert isinstance(network_state, NetworkState) assert {'net', 'target_net'} == set(network_state.net.keys()) buffer_state = agent_state.buffer assert isinstance(buffer_state, BufferState) assert buffer_state.type == agent.buffer.type assert buffer_state.batch_size == agent.buffer.batch_size assert buffer_state.buffer_size == agent.buffer.buffer_size
def test_agent_factory_dqn_agent_from_state_network_buffer_none(): # Assign state_size, action_size = 10, 5 agent = DQNAgent(state_size, action_size, device="cpu") state = agent.get_state() state.network = None state.buffer = None # Act new_agent = AgentFactory.from_state(state) # Assert assert id(new_agent) != id(agent) assert new_agent.hparams == agent.hparams
def test_agent_factory_dqn_agent_from_state(): # Assign state_size, action_size = 10, 5 agent = DQNAgent(state_size, action_size, device="cpu") state = agent.get_state() # Act new_agent = AgentFactory.from_state(state) # Assert assert id(new_agent) != id(agent) assert new_agent == agent assert new_agent.name == DQNAgent.name assert new_agent.hparams == agent.hparams assert new_agent.buffer == agent.buffer
def test_serialize_agent_state_actual(): from ai_traineree.agents.dqn import DQNAgent agent = DQNAgent(10, 4) deterministic_interactions(agent, 30) state = agent.get_state() # Act ser = serialize(state) # Assert des = json.loads(ser) assert des['model'] == DQNAgent.name assert len(des['buffer']['data']) == 30 assert set(des['network']['net'].keys()) == set(('target_net', 'net'))
def test_dqn_from_state_network_state_none(): # Assign state_shape, action_size = 10, 3 agent = DQNAgent(state_shape, action_size) agent_state = agent.get_state() agent_state.network = None # Act new_agent = DQNAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, DQNAgent) assert new_agent.hparams == agent.hparams assert new_agent.buffer == agent.buffer
def test_dqn_from_state(): # Assign state_shape, action_size = 10, 3 agent = DQNAgent(state_shape, action_size) agent_state = agent.get_state() # Act new_agent = DQNAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, DQNAgent) assert new_agent.hparams == agent.hparams assert all([torch.all(x == y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())]) assert all([torch.all(x == y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())]) assert new_agent.buffer == agent.buffer
def test_dqn_from_state_one_updated(): # Assign state_shape, action_size = 10, 3 agent = DQNAgent(state_shape, action_size) feed_agent(agent, 2*agent.batch_size) # Feed 1 agent_state = agent.get_state() feed_agent(agent, 100) # Feed 2 - to make different # Act new_agent = DQNAgent.from_state(agent_state) # Assert assert id(agent) != id(new_agent) # assert new_agent == agent assert isinstance(new_agent, DQNAgent) assert new_agent.hparams == agent.hparams assert any([torch.any(x != y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())]) assert any([torch.any(x != y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())]) assert new_agent.buffer != agent.buffer