Пример #1
0
def test_ppo_get_state_compare_different_agents():
    # Assign
    state_size, action_size = 3, 2
    agent_1 = PPOAgent(state_size, action_size, device='cpu', n_steps=1)
    agent_2 = PPOAgent(state_size, action_size, device='cpu', n_steps=2)

    # Act
    state_1 = agent_1.get_state()
    state_2 = agent_2.get_state()

    # Assert
    assert state_1 != state_2
    assert state_1.model == state_2.model
Пример #2
0
def test_ppo_from_state():
    # Assign
    state_shape, action_size = 10, 3
    agent = PPOAgent(state_shape, action_size)
    agent_state = agent.get_state()

    # Act
    new_agent = PPOAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, PPOAgent)
    assert new_agent.hparams == agent.hparams
    assert all([
        torch.all(x == y) for (
            x,
            y) in zip(agent.policy.parameters(), new_agent.policy.parameters())
    ])
    assert all([
        torch.all(x == y)
        for (x,
             y) in zip(agent.actor.parameters(), new_agent.actor.parameters())
    ])
    assert all([
        torch.all(x == y) for (
            x,
            y) in zip(agent.critic.parameters(), new_agent.critic.parameters())
    ])
    assert new_agent.buffer == agent.buffer
Пример #3
0
def test_ppo_get_state():
    # Assign
    state_size, action_size = 3, 4
    init_config = {'actor_lr': 0.1, 'gamma': 0.6}
    agent = PPOAgent(state_size, action_size, device='cpu', **init_config)

    # Act
    agent_state = agent.get_state()

    # Assert
    assert isinstance(agent_state, AgentState)
    assert agent_state.model == PPOAgent.name
    assert agent_state.state_space == state_size
    assert agent_state.action_space == action_size
    assert agent_state.config == agent._config
    assert agent_state.config['actor_lr'] == 0.1
    assert agent_state.config['gamma'] == 0.6

    network_state = agent_state.network
    assert isinstance(network_state, NetworkState)
    assert {'actor', 'critic', 'policy'} == set(network_state.net.keys())

    buffer_state = agent_state.buffer
    assert isinstance(buffer_state, BufferState)
    assert buffer_state.type == agent.buffer.type
    assert buffer_state.batch_size == agent.buffer.batch_size
    assert buffer_state.buffer_size == agent.buffer.buffer_size
Пример #4
0
def test_ppo_from_state_one_updated():
    # Assign
    state_shape, action_size = 10, 3
    agent = PPOAgent(state_shape, action_size)
    deterministic_interactions(agent, num_iters=100)
    agent_state = agent.get_state()
    deterministic_interactions(agent, num_iters=400)

    # Act
    new_agent = PPOAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, PPOAgent)
    # assert any([torch.any(x != y) for (x, y) in zip(agent.policy.parameters(), new_agent.policy.parameters())])
    assert any([
        torch.any(x != y)
        for (x,
             y) in zip(agent.actor.parameters(), new_agent.actor.parameters())
    ])
    assert any([
        torch.any(x != y) for (
            x,
            y) in zip(agent.critic.parameters(), new_agent.critic.parameters())
    ])
    assert new_agent.buffer != agent.buffer
Пример #5
0
def test_agent_factory_ppo_agent_from_state_network_buffer_none():
    # Assign
    state_size, action_size = 10, 5
    agent = PPOAgent(state_size, action_size, device="cpu")
    state = agent.get_state()
    state.network = None
    state.buffer = None

    # Act
    new_agent = AgentFactory.from_state(state)

    # Assert
    assert id(new_agent) != id(agent)
    assert new_agent.hparams == agent.hparams
def test_agent_factory_ppo_agent_from_state():
    # Assign
    state_size, action_size = 10, 5
    agent = PPOAgent(state_size, action_size, device="cpu")
    state = agent.get_state()

    # Act
    new_agent = AgentFactory.from_state(state)

    # Assert
    assert id(new_agent) != id(agent)
    assert new_agent == agent
    assert new_agent.name == PPOAgent.name
    assert new_agent.hparams == agent.hparams
    assert new_agent.buffer == agent.buffer
Пример #7
0
def test_ppo_from_state_network_state_none():
    # Assign
    state_shape, action_size = 10, 3
    agent = PPOAgent(state_shape, action_size)
    agent_state = agent.get_state()
    agent_state.network = None

    # Act
    new_agent = PPOAgent.from_state(agent_state)

    # Assert
    assert id(agent) != id(new_agent)
    # assert new_agent == agent
    assert isinstance(new_agent, PPOAgent)
    assert new_agent.hparams == agent.hparams
    assert new_agent.buffer == agent.buffer