Ejemplo n.º 1
0
def test_dqn_get_state_compare_different_agents():
    # Assign
    state_size, action_size = 3, 2
    agent_1 = DQNAgent(state_size, action_size, device='cpu', n_steps=1)
    agent_2 = DQNAgent(state_size, action_size, device='cpu', n_steps=2)

    # Act
    state_1 = agent_1.get_state()
    state_2 = agent_2.get_state()

    # Assert
    assert state_1 != state_2
    assert state_1.model == state_2.model
Ejemplo n.º 2
0
def test_dqn_get_state():
    # Assign
    state_size, action_size = 3, 4
    init_config = {'lr': 0.1, 'gamma': 0.6}
    agent = DQNAgent(state_size, action_size, device='cpu', **init_config)

    # Act
    agent_state = agent.get_state()

    # Assert
    assert isinstance(agent_state, AgentState)
    assert agent_state.model == DQNAgent.name
    assert agent_state.state_space == state_size
    assert agent_state.action_space == action_size
    assert agent_state.config == agent._config
    assert agent_state.config['lr'] == 0.1
    assert agent_state.config['gamma'] == 0.6

    network_state = agent_state.network
    assert isinstance(network_state, NetworkState)
    assert {'net', 'target_net'} == set(network_state.net.keys())

    buffer_state = agent_state.buffer
    assert isinstance(buffer_state, BufferState)
    assert buffer_state.type == agent.buffer.type
    assert buffer_state.batch_size == agent.buffer.batch_size
    assert buffer_state.buffer_size == agent.buffer.buffer_size
Ejemplo n.º 3
0
def test_agent_factory_dqn_agent_from_state_network_buffer_none():
    # Assign
    state_size, action_size = 10, 5
    agent = DQNAgent(state_size, action_size, device="cpu")
    state = agent.get_state()
    state.network = None
    state.buffer = None

    # Act
    new_agent = AgentFactory.from_state(state)

    # Assert
    assert id(new_agent) != id(agent)
    assert new_agent.hparams == agent.hparams
Ejemplo n.º 4
0
def test_agent_factory_dqn_agent_from_state():
    # Assign
    state_size, action_size = 10, 5
    agent = DQNAgent(state_size, action_size, device="cpu")
    state = agent.get_state()

    # Act
    new_agent = AgentFactory.from_state(state)

    # Assert
    assert id(new_agent) != id(agent)
    assert new_agent == agent
    assert new_agent.name == DQNAgent.name
    assert new_agent.hparams == agent.hparams
    assert new_agent.buffer == agent.buffer
Ejemplo n.º 5
0
def test_serialize_agent_state_actual():
    from ai_traineree.agents.dqn import DQNAgent

    agent = DQNAgent(10, 4)
    deterministic_interactions(agent, 30)
    state = agent.get_state()

    # Act
    ser = serialize(state)

    # Assert
    des = json.loads(ser)
    assert des['model'] == DQNAgent.name
    assert len(des['buffer']['data']) == 30
    assert set(des['network']['net'].keys()) == set(('target_net', 'net'))
Ejemplo n.º 6
0
def test_dqn_from_state_network_state_none():
    # Assign
    state_shape, action_size = 10, 3
    agent = DQNAgent(state_shape, action_size)
    agent_state = agent.get_state()
    agent_state.network = None

    # Act
    new_agent = DQNAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, DQNAgent)
    assert new_agent.hparams == agent.hparams
    assert new_agent.buffer == agent.buffer
Ejemplo n.º 7
0
def test_dqn_from_state():
    # Assign
    state_shape, action_size = 10, 3
    agent = DQNAgent(state_shape, action_size)
    agent_state = agent.get_state()

    # Act
    new_agent = DQNAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, DQNAgent)
    assert new_agent.hparams == agent.hparams
    assert all([torch.all(x == y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())])
    assert all([torch.all(x == y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())])
    assert new_agent.buffer == agent.buffer
Ejemplo n.º 8
0
def test_dqn_from_state_one_updated():
    # Assign
    state_shape, action_size = 10, 3
    agent = DQNAgent(state_shape, action_size)
    feed_agent(agent, 2*agent.batch_size)  # Feed 1
    agent_state = agent.get_state()
    feed_agent(agent, 100)  # Feed 2 - to make different

    # Act
    new_agent = DQNAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, DQNAgent)
    assert new_agent.hparams == agent.hparams
    assert any([torch.any(x != y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())])
    assert any([torch.any(x != y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())])
    assert new_agent.buffer != agent.buffer