Ejemplo n.º 1
0
 def _generate_step_input(
     self, vector_action: Dict[str, ActionTuple]
 ) -> UnityInputProto:
     rl_in = UnityRLInputProto()
     for b in vector_action:
         n_agents = len(self._env_state[b][0])
         if n_agents == 0:
             continue
         for i in range(n_agents):
             action = AgentActionProto()
             if vector_action[b].continuous is not None:
                 action.vector_actions_deprecated.extend(
                     vector_action[b].continuous[i]
                 )
                 action.continuous_actions.extend(vector_action[b].continuous[i])
             if vector_action[b].discrete is not None:
                 action.vector_actions_deprecated.extend(
                     vector_action[b].discrete[i]
                 )
                 action.discrete_actions.extend(vector_action[b].discrete[i])
             rl_in.agent_actions[b].value.extend([action])
             rl_in.command = STEP
     rl_in.side_channel = bytes(
         self._side_channel_manager.generate_side_channel_messages()
     )
     return self._wrap_unity_input(rl_in)
def proto_from_steps_and_action(
    decision_steps: DecisionSteps,
    terminal_steps: TerminalSteps,
    continuous_actions: np.ndarray,
    discrete_actions: np.ndarray,
) -> List[AgentInfoActionPairProto]:
    agent_info_protos = proto_from_steps(decision_steps, terminal_steps)
    agent_action_protos = []
    num_agents = (len(continuous_actions)
                  if continuous_actions is not None else len(discrete_actions))
    for i in range(num_agents):
        proto = AgentActionProto()
        if continuous_actions is not None:
            proto.continuous_actions.extend(continuous_actions[i])
            proto.vector_actions_deprecated.extend(continuous_actions[i])
        if discrete_actions is not None:
            proto.discrete_actions.extend(discrete_actions[i])
            proto.vector_actions_deprecated.extend(discrete_actions[i])
        agent_action_protos.append(proto)
    agent_info_action_pair_protos = [
        AgentInfoActionPairProto(agent_info=agent_info_proto,
                                 action_info=action_proto)
        for agent_info_proto, action_proto in zip(agent_info_protos,
                                                  agent_action_protos)
    ]
    return agent_info_action_pair_protos
Ejemplo n.º 3
0
 def _generate_step_input(
         self, vector_action: Dict[str, np.ndarray]) -> UnityInputProto:
     rl_in = UnityRLInputProto()
     for b in vector_action:
         n_agents = len(self._env_state[b][0])
         if n_agents == 0:
             continue
         for i in range(n_agents):
             action = AgentActionProto(vector_actions=vector_action[b][i])
             rl_in.agent_actions[b].value.extend([action])
             rl_in.command = STEP
     rl_in.side_channel = bytes(
         self._generate_side_channel_data(self.side_channels))
     return self.wrap_unity_input(rl_in)
Ejemplo n.º 4
0
def proto_from_steps_and_action(
        decision_steps: DecisionSteps, terminal_steps: TerminalSteps,
        actions: np.ndarray) -> List[AgentInfoActionPairProto]:
    agent_info_protos = proto_from_steps(decision_steps, terminal_steps)
    agent_action_protos = [
        AgentActionProto(vector_actions=action) for action in actions
    ]
    agent_info_action_pair_protos = [
        AgentInfoActionPairProto(agent_info=agent_info_proto,
                                 action_info=action_proto)
        for agent_info_proto, action_proto in zip(agent_info_protos,
                                                  agent_action_protos)
    ]
    return agent_info_action_pair_protos
Ejemplo n.º 5
0
def proto_from_batched_step_result_and_action(
    batched_step_result: BatchedStepResult, actions: np.ndarray
) -> List[AgentInfoActionPairProto]:
    agent_info_protos = proto_from_batched_step_result(batched_step_result)
    agent_action_protos = [
        AgentActionProto(vector_actions=action) for action in actions
    ]
    agent_info_action_pair_protos = [
        AgentInfoActionPairProto(agent_info=agent_info_proto, action_info=action_proto)
        for agent_info_proto, action_proto in zip(
            agent_info_protos, agent_action_protos
        )
    ]
    return agent_info_action_pair_protos