def local_policy_inference( policy: "Policy", env_id: str, agent_id: str, obs: TensorStructType, ) -> TensorStructType: """Run a connector enabled policy using environment observation. policy_inference manages policy and agent/action connectors, so the user does not have to care about RNN state buffering or extra fetch dictionaries. Note that connectors are intentionally run separately from compute_actions_from_input_dict(), so we can have the option of running per-user connectors on the client side in a server-client deployment. Args: policy: Policy. env_id: Environment ID. agent_id: Agent ID. obs: Env obseration. Returns: List of outputs from policy forward pass. """ assert (policy.agent_connectors ), "policy_inference only works with connector enabled policies." # TODO(jungong) : support multiple env, multiple agent inference. input_dict = {SampleBatch.NEXT_OBS: obs} acd_list: List[AgentConnectorDataType] = [ AgentConnectorDataType(env_id, agent_id, input_dict) ] ac_outputs: List[AgentConnectorsOutput] = policy.agent_connectors(acd_list) outputs = [] for ac in ac_outputs: policy_output = policy.compute_actions_from_input_dict( ac.data.for_action) if policy.action_connectors: acd = ActionConnectorDataType(env_id, agent_id, policy_output) acd = policy.action_connectors(acd) actions = acd.output else: actions = policy_output[0] outputs.append(actions) # Notify agent connectors with this new policy output. # Necessary for state buffering agent connectors, for example. policy.agent_connectors.on_policy_output( ActionConnectorDataType(env_id, agent_id, policy_output)) return outputs
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: assert isinstance( ac_data.output, tuple ), "Action connector requires PolicyOutputType data." actions, states, fetches = ac_data.output return ActionConnectorDataType( ac_data.env_id, ac_data.agent_id, (unsquash_action(actions, self._action_space_struct), states, fetches), )
def transform( self, ac_data: ActionConnectorDataType ) -> ActionConnectorDataType: assert isinstance( ac_data.output, tuple ), "Action connector requires PolicyOutputType data." actions, states, fetches = ac_data.output return ActionConnectorDataType( ac_data.env_id, ac_data.agent_id, fn(actions, states, fetches), )
def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: assert isinstance( ac_data.output, tuple), "Action connector requires PolicyOutputType data." actions, states, fetches = ac_data.output tree.traverse(make_action_immutable, actions, top_down=False) return ActionConnectorDataType( ac_data.env_id, ac_data.agent_id, (actions, states, fetches), )
def test_normalize_action_connector(self): ctx = ConnectorContext( action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])) c = NormalizeActionsConnector(ctx) name, params = c.to_config() self.assertEqual(name, "NormalizeActionsConnector") restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, NormalizeActionsConnector)) ac_data = ActionConnectorDataType(0, 1, (0.5, [], {})) normalized = c(ac_data) self.assertEqual(normalized.output[0], 4.5)
def test_clip_action_connector(self): ctx = ConnectorContext( action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])) c = ClipActionsConnector(ctx) name, params = c.to_config() self.assertEqual(name, "ClipActionsConnector") restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, ClipActionsConnector)) ac_data = ActionConnectorDataType(0, 1, (8.8, [], {})) clipped = c(ac_data) self.assertEqual(clipped.output[0], 6.0)
def test_immutable_action_connector(self): ctx = ConnectorContext( action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])) c = ImmutableActionsConnector(ctx) name, params = c.to_config() self.assertEqual(name, "ImmutableActionsConnector") restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, ImmutableActionsConnector)) ac_data = ActionConnectorDataType(0, 1, (np.array([8.8]), [], {})) immutable = c(ac_data) with self.assertRaises(ValueError): immutable.output[0][0] = 5
def test_convert_to_numpy_connector(self): ctx = ConnectorContext() c = ConvertToNumpyConnector(ctx) name, params = c.to_config() self.assertEqual(name, "ConvertToNumpyConnector") restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, ConvertToNumpyConnector)) action = torch.Tensor([8, 9]) states = torch.Tensor([[1, 1, 1], [2, 2, 2]]) ac_data = ActionConnectorDataType(0, 1, (action, states, {})) converted = c(ac_data) self.assertTrue(isinstance(converted.output[0], np.ndarray)) self.assertTrue(isinstance(converted.output[1], np.ndarray))
def test_unbatch_action_connector(self): ctx = ConnectorContext() c = UnbatchActionsConnector(ctx) name, params = c.to_config() self.assertEqual(name, "UnbatchActionsConnector") restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, UnbatchActionsConnector)) ac_data = ActionConnectorDataType( 0, 1, ( { "a": np.array([1, 2, 3]), "b": (np.array([4, 5, 6]), np.array([7, 8, 9])), }, [], {}, ), ) unbatched = c(ac_data) actions, _, _ = unbatched.output self.assertEqual(len(actions), 3) self.assertEqual(actions[0]["a"], 1) self.assertTrue((actions[0]["b"] == np.array((4, 7))).all()) self.assertEqual(actions[1]["a"], 2) self.assertTrue((actions[1]["b"] == np.array((5, 8))).all()) self.assertEqual(actions[2]["a"], 3) self.assertTrue((actions[2]["b"] == np.array((6, 9))).all())
def _process_policy_eval_results( self, to_eval: Dict[PolicyID, List[_PolicyEvalData]], eval_results: Dict[PolicyID, PolicyOutputType], off_policy_actions: MultiEnvDict, ): """Process the output of policy neural network evaluation. Records policy evaluation results into agent connectors and returns replies to send back to agents in the env. Args: to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects. eval_results: Mapping of policy IDs to list of actions, rnn-out states, extra-action-fetches dicts. off_policy_actions: Doubly keyed dict of env-ids -> agent ids -> off-policy-action, returned by a `BaseEnv.poll()` call. Returns: Nested dict of env id -> agent id -> actions to be sent to Env (np.ndarrays). """ actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict) for eval_data in to_eval.values(): for d in eval_data: actions_to_send[d.env_id] = {} # at minimum send empty dict # types: PolicyID, List[_PolicyEvalData] for policy_id, eval_data in to_eval.items(): actions: TensorStructType = eval_results[policy_id][0] actions = convert_to_numpy(actions) rnn_out: StateBatches = eval_results[policy_id][1] extra_action_out: dict = eval_results[policy_id][2] # In case actions is a list (representing the 0th dim of a batch of # primitive actions), try converting it first. if isinstance(actions, list): actions = np.array(actions) # Split action-component batches into single action rows. actions: List[EnvActionType] = unbatch(actions) policy: Policy = _get_or_raise(self._worker.policy_map, policy_id) assert (policy.agent_connectors and policy.action_connectors ), "EnvRunnerV2 requires action connectors to work." # types: int, EnvActionType for i, action in enumerate(actions): env_id: int = eval_data[i].env_id agent_id: AgentID = eval_data[i].agent_id rnn_states: List[StateBatches] = [c[i] for c in rnn_out] fetches: Dict = {k: v[i] for k, v in extra_action_out.items()} # Post-process policy output by running them through action connectors. ac_data = ActionConnectorDataType( env_id, agent_id, (action, rnn_states, fetches)) action_to_send, rnn_states, fetches = policy.action_connectors( ac_data).output action_to_buffer = ( action_to_send if env_id not in off_policy_actions or agent_id not in off_policy_actions[env_id] else off_policy_actions[env_id][agent_id]) # Notify agent connectors with this new policy output. # Necessary for state buffering agent connectors, for example. ac_data: AgentConnectorDataType = ActionConnectorDataType( env_id, agent_id, (action_to_buffer, rnn_states, fetches)) policy.agent_connectors.on_policy_output(ac_data) assert agent_id not in actions_to_send[env_id] actions_to_send[env_id][agent_id] = action_to_send return actions_to_send