Ejemplo n.º 1
0
def test_action_masking_continuous():
    n_agents = 10
    shapes = [(3, ), (4, )]
    group_spec = AgentGroupSpec(shapes, ActionType.CONTINUOUS, 10)
    ap_list = generate_list_agent_proto(n_agents, shapes)
    result = batched_step_result_from_proto(ap_list, group_spec)
    masks = result.action_mask
    assert masks is None
Ejemplo n.º 2
0
def test_action_masking_discrete_1():
    n_agents = 10
    shapes = [(3, ), (4, )]
    group_spec = AgentGroupSpec(shapes, ActionType.DISCRETE, (10, ))
    ap_list = generate_list_agent_proto(n_agents, shapes)
    result = batched_step_result_from_proto(ap_list, group_spec)
    masks = result.action_mask
    assert isinstance(masks, list)
    assert len(masks) == 1
    assert masks[0].shape == (n_agents, 10)
    assert masks[0][0, 0]
Ejemplo n.º 3
0
def test_batched_step_result_from_proto():
    n_agents = 10
    shapes = [(3, ), (4, )]
    group_spec = AgentGroupSpec(shapes, ActionType.CONTINUOUS, 3)
    ap_list = generate_list_agent_proto(n_agents, shapes)
    result = batched_step_result_from_proto(ap_list, group_spec)
    assert list(result.reward) == list(range(n_agents))
    assert list(result.agent_id) == list(range(n_agents))
    for index in range(n_agents):
        assert result.done[index] == (index % 2 == 0)
        assert result.max_step[index] == (index % 2 == 1)
    assert list(result.obs[0].shape) == [n_agents] + list(shapes[0])
    assert list(result.obs[1].shape) == [n_agents] + list(shapes[1])
Ejemplo n.º 4
0
 def _update_state(self, output: UnityRLOutputProto) -> None:
     """
     Collects experience information from all external brains in environment at current step.
     """
     for brain_name in self._env_specs.keys():
         if brain_name in output.agentInfos:
             agent_info_list = output.agentInfos[brain_name].value
             self._env_state[brain_name] = batched_step_result_from_proto(
                 agent_info_list, self._env_specs[brain_name]
             )
         else:
             self._env_state[brain_name] = BatchedStepResult.empty(
                 self._env_specs[brain_name]
             )
     self._parse_side_channel_message(self.side_channels, output.side_channel)