예제 #1
0
def test_rainbow_get_state_compare_different_agents():
    # Assign
    state_size, action_size = 3, 2
    agent_1 = RainbowAgent(state_size, action_size, device='cpu', n_steps=1)
    agent_2 = RainbowAgent(state_size, action_size, device='cpu', n_steps=2)

    # Act
    state_1 = agent_1.get_state()
    state_2 = agent_2.get_state()

    # Assert
    assert state_1 != state_2
    assert state_1.model == state_2.model
예제 #2
0
def test_rainbow_get_state():
    # Assign
    state_size, action_size = 3, 4
    init_config = {'lr': 0.1, 'gamma': 0.6}
    agent = RainbowAgent(state_size, action_size, device='cpu', **init_config)

    # Act
    agent_state = agent.get_state()

    # Assert
    assert isinstance(agent_state, AgentState)
    assert agent_state.model == RainbowAgent.name
    assert agent_state.state_space == state_size
    assert agent_state.action_space == action_size
    assert agent_state.config == agent._config
    assert agent_state.config['lr'] == 0.1
    assert agent_state.config['gamma'] == 0.6

    network_state = agent_state.network
    assert isinstance(network_state, NetworkState)
    assert {'net', 'target_net'} == set(network_state.net.keys())

    buffer_state = agent_state.buffer
    assert isinstance(buffer_state, BufferState)
    assert buffer_state.type == agent.buffer.type
    assert buffer_state.batch_size == agent.buffer.batch_size
    assert buffer_state.buffer_size == agent.buffer.buffer_size
예제 #3
0
def test_rainbow_from_state():
    # Assign
    state_shape, action_size = 10, 3
    agent = RainbowAgent(state_shape, action_size, device='cpu')
    agent_state = agent.get_state()

    # Act
    new_agent = RainbowAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, RainbowAgent)
    assert new_agent.hparams == agent.hparams
    assert all([torch.all(x == y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())])
    assert all([torch.all(x == y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())])
    assert new_agent.buffer == agent.buffer
예제 #4
0
def test_rainbow_from_state_one_updated():
    # Assign
    state_shape, action_size = 10, 3
    agent = RainbowAgent(state_shape, action_size, device='cpu')
    feed_agent(agent, 2*agent.batch_size)  # Feed 1
    agent_state = agent.get_state()
    feed_agent(agent, 100)  # Feed 2 - to make different

    # Act
    new_agent = RainbowAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, RainbowAgent)
    assert new_agent.hparams == agent.hparams
    assert any([torch.any(x != y) for (x, y) in zip(agent.net.parameters(), new_agent.net.parameters())])
    assert any([torch.any(x != y) for (x, y) in zip(agent.target_net.parameters(), new_agent.target_net.parameters())])
    assert new_agent.buffer != agent.buffer