Exemple #1
0
def test_multinetworkbody_num_agents(with_actions):
    torch.manual_seed(0)
    act_size = 2
    obs_size = 4
    network_settings = NetworkSettings()
    obs_shapes = [(obs_size,)]
    action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
    networkbody = MultiAgentNetworkBody(
        create_observation_specs_with_shapes(obs_shapes), network_settings, action_spec
    )
    sample_obs = [[0.1 * torch.ones((1, obs_size))]]
    # simulate baseline in POCA
    sample_act = [
        AgentAction(
            0.1 * torch.ones((1, 2)), [0.1 * torch.ones(1) for _ in range(act_size)]
        )
    ]
    for n_agent, max_so_far in [(1, 1), (5, 5), (4, 5), (10, 10), (5, 10), (1, 10)]:
        if with_actions:
            encoded, _ = networkbody(
                obs_only=sample_obs * (n_agent - 1), obs=sample_obs, actions=sample_act
            )
        else:
            encoded, _ = networkbody(obs_only=sample_obs * n_agent, obs=[], actions=[])
        # look at the last value of the hidden units (the number of agents)
        target = (n_agent * 1.0 / max_so_far) * 2 - 1
        assert abs(encoded[0, -1].item() - target) < 1e-6
        assert encoded[0, -1].item() <= 1
        assert encoded[0, -1].item() >= -1
Exemple #2
0
    def _sample_action(self, dists: DistInstances) -> AgentAction:
        """
        Samples actions from a DistInstances tuple
        :params dists: The DistInstances tuple
        :return: An AgentAction corresponding to the actions sampled from the DistInstances
        """

        continuous_action: Optional[torch.Tensor] = None
        discrete_action: Optional[List[torch.Tensor]] = None
        # This checks None because mypy complains otherwise
        if dists.continuous is not None:
            if self._deterministic:
                continuous_action = dists.continuous.deterministic_sample()
            else:
                continuous_action = dists.continuous.sample()
        if dists.discrete is not None:
            discrete_action = []
            if self._deterministic:
                for discrete_dist in dists.discrete:
                    discrete_action.append(
                        discrete_dist.deterministic_sample())
            else:
                for discrete_dist in dists.discrete:
                    discrete_action.append(discrete_dist.sample())
        return AgentAction(continuous_action, discrete_action)
Exemple #3
0
def test_to_flat():
    # Both continuous and discrete
    aa = AgentAction(torch.tensor([[1.0, 1.0, 1.0]]),
                     [torch.tensor([2]), torch.tensor([1])])
    flattened_actions = aa.to_flat([3, 3])
    assert torch.eq(flattened_actions,
                    torch.tensor([[1, 1, 1, 0, 0, 1, 0, 1, 0]])).all()

    # Just continuous
    aa = AgentAction(torch.tensor([[1.0, 1.0, 1.0]]), None)
    flattened_actions = aa.to_flat([])
    assert torch.eq(flattened_actions, torch.tensor([1, 1, 1])).all()

    # Just discrete
    aa = AgentAction(torch.tensor([]), [torch.tensor([2]), torch.tensor([1])])
    flattened_actions = aa.to_flat([3, 3])
    assert torch.eq(flattened_actions, torch.tensor([0, 0, 1, 0, 1, 0])).all()
Exemple #4
0
def test_slice():
    # Both continuous and discrete
    aa = AgentAction(
        torch.tensor([[1.0], [1.0], [1.0]]),
        [torch.tensor([2, 1, 0]), torch.tensor([1, 2, 0])],
    )
    saa = aa.slice(0, 2)
    assert saa.continuous_tensor.shape == (2, 1)
    assert saa.discrete_tensor.shape == (2, 2)
Exemple #5
0
def test_multinetworkbody_lstm(with_actions):
    torch.manual_seed(0)
    obs_size = 4
    act_size = 2
    seq_len = 16
    n_agents = 3
    network_settings = NetworkSettings(memory=NetworkSettings.MemorySettings(
        sequence_length=seq_len, memory_size=12))

    obs_shapes = [(obs_size, )]
    action_spec = ActionSpec(act_size,
                             tuple(act_size for _ in range(act_size)))
    networkbody = MultiAgentNetworkBody(
        create_observation_specs_with_shapes(obs_shapes), network_settings,
        action_spec)
    optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4)
    sample_obs = [[0.1 * torch.ones((seq_len, obs_size))]
                  for _ in range(n_agents)]
    # simulate baseline in POCA
    sample_act = [
        AgentAction(
            0.1 * torch.ones((seq_len, 2)),
            [0.1 * torch.ones(seq_len) for _ in range(act_size)],
        ) for _ in range(n_agents - 1)
    ]

    for _ in range(300):
        if with_actions:
            encoded, _ = networkbody(
                obs_only=sample_obs[:1],
                obs=sample_obs[1:],
                actions=sample_act,
                memories=torch.ones(1, 1, 12),
                sequence_length=seq_len,
            )
        else:
            encoded, _ = networkbody(
                obs_only=sample_obs,
                obs=[],
                actions=[],
                memories=torch.ones(1, 1, 12),
                sequence_length=seq_len,
            )
        # Try to force output to 1
        loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # In the last step, values should be close to 1
    for _enc in encoded.flatten().tolist():
        assert _enc == pytest.approx(1.0, abs=0.1)
Exemple #6
0
def test_multinetworkbody_visual(with_actions):
    torch.manual_seed(0)
    act_size = 2
    n_agents = 3
    obs_size = 4
    vis_obs_size = (84, 84, 3)
    network_settings = NetworkSettings()
    obs_shapes = [(obs_size, ), vis_obs_size]
    action_spec = ActionSpec(act_size,
                             tuple(act_size for _ in range(act_size)))
    networkbody = MultiAgentNetworkBody(
        create_observation_specs_with_shapes(obs_shapes), network_settings,
        action_spec)
    optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
    sample_obs = [[0.1 * torch.ones(
        (1, obs_size))] + [0.1 * torch.ones((1, 84, 84, 3))]
                  for _ in range(n_agents)]
    # simulate baseline in POCA
    sample_act = [
        AgentAction(0.1 * torch.ones((1, 2)),
                    [0.1 * torch.ones(1) for _ in range(act_size)])
        for _ in range(n_agents - 1)
    ]
    for _ in range(300):
        if with_actions:
            encoded, _ = networkbody(obs_only=sample_obs[:1],
                                     obs=sample_obs[1:],
                                     actions=sample_act)
        else:
            encoded, _ = networkbody(obs_only=sample_obs, obs=[], actions=[])

        assert encoded.shape == (1, network_settings.hidden_units)
        # Try to force output to 1
        loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # In the last step, values should be close to 1
    for _enc in encoded.flatten().tolist():
        assert _enc == pytest.approx(1.0, abs=0.1)
def test_get_probs_and_entropy():
    inp_size = 4
    act_size = 2
    action_model, masks = create_action_model(inp_size, act_size)

    _continuous_dist = GaussianDistInstance(torch.zeros((1, 2)),
                                            torch.ones((1, 2)))
    act_size = 2
    test_prob = torch.tensor([[1.0 - 0.1 * (act_size - 1)] + [0.1] *
                              (act_size - 1)])
    _discrete_dist_list = [
        CategoricalDistInstance(test_prob),
        CategoricalDistInstance(test_prob),
    ]
    dist_tuple = DistInstances(_continuous_dist, _discrete_dist_list)

    agent_action = AgentAction(torch.zeros(
        (1, 2)), [torch.tensor([0]), torch.tensor([1])])

    log_probs, entropies = action_model._get_probs_and_entropy(
        agent_action, dist_tuple)

    assert log_probs.continuous_tensor.shape == (1, 2)
    assert len(log_probs.discrete_list) == 2
    for _disc in log_probs.discrete_list:
        assert _disc.shape == (1, )
    assert len(log_probs.all_discrete_list) == 2
    for _disc in log_probs.all_discrete_list:
        assert _disc.shape == (1, 2)

    for clp in log_probs.continuous_tensor[0]:
        # Log prob of standard normal at 0
        assert clp == pytest.approx(-0.919, abs=0.01)

    assert log_probs.discrete_list[0] > log_probs.discrete_list[1]

    for ent, val in zip(entropies[0], [1.4189, 0.6191, 0.6191]):
        assert ent == pytest.approx(val, abs=0.01)