Пример #1
0
def local_policy_inference(
    policy: "Policy",
    env_id: str,
    agent_id: str,
    obs: TensorStructType,
) -> TensorStructType:
    """Run a connector enabled policy using environment observation.

    policy_inference manages policy and agent/action connectors,
    so the user does not have to care about RNN state buffering or
    extra fetch dictionaries.
    Note that connectors are intentionally run separately from
    compute_actions_from_input_dict(), so we can have the option
    of running per-user connectors on the client side in a
    server-client deployment.

    Args:
        policy: Policy.
        env_id: Environment ID.
        agent_id: Agent ID.
        obs: Env obseration.

    Returns:
        List of outputs from policy forward pass.
    """
    assert (policy.agent_connectors
            ), "policy_inference only works with connector enabled policies."

    # TODO(jungong) : support multiple env, multiple agent inference.
    input_dict = {SampleBatch.NEXT_OBS: obs}
    acd_list: List[AgentConnectorDataType] = [
        AgentConnectorDataType(env_id, agent_id, input_dict)
    ]
    ac_outputs: List[AgentConnectorsOutput] = policy.agent_connectors(acd_list)
    outputs = []
    for ac in ac_outputs:
        policy_output = policy.compute_actions_from_input_dict(
            ac.data.for_action)

        if policy.action_connectors:
            acd = ActionConnectorDataType(env_id, agent_id, policy_output)
            acd = policy.action_connectors(acd)
            actions = acd.output
        else:
            actions = policy_output[0]

        outputs.append(actions)

        # Notify agent connectors with this new policy output.
        # Necessary for state buffering agent connectors, for example.
        policy.agent_connectors.on_policy_output(
            ActionConnectorDataType(env_id, agent_id, policy_output))
    return outputs
Пример #2
0
    def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
        assert isinstance(
            ac_data.output, tuple
        ), "Action connector requires PolicyOutputType data."

        actions, states, fetches = ac_data.output
        return ActionConnectorDataType(
            ac_data.env_id,
            ac_data.agent_id,
            (unsquash_action(actions, self._action_space_struct), states, fetches),
        )
Пример #3
0
        def transform(
            self, ac_data: ActionConnectorDataType
        ) -> ActionConnectorDataType:
            assert isinstance(
                ac_data.output, tuple
            ), "Action connector requires PolicyOutputType data."

            actions, states, fetches = ac_data.output
            return ActionConnectorDataType(
                ac_data.env_id,
                ac_data.agent_id,
                fn(actions, states, fetches),
            )
Пример #4
0
    def transform(self,
                  ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
        assert isinstance(
            ac_data.output,
            tuple), "Action connector requires PolicyOutputType data."

        actions, states, fetches = ac_data.output
        tree.traverse(make_action_immutable, actions, top_down=False)

        return ActionConnectorDataType(
            ac_data.env_id,
            ac_data.agent_id,
            (actions, states, fetches),
        )
Пример #5
0
    def test_normalize_action_connector(self):
        ctx = ConnectorContext(
            action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1]))
        c = NormalizeActionsConnector(ctx)

        name, params = c.to_config()
        self.assertEqual(name, "NormalizeActionsConnector")

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, NormalizeActionsConnector))

        ac_data = ActionConnectorDataType(0, 1, (0.5, [], {}))

        normalized = c(ac_data)
        self.assertEqual(normalized.output[0], 4.5)
Пример #6
0
    def test_clip_action_connector(self):
        ctx = ConnectorContext(
            action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1]))
        c = ClipActionsConnector(ctx)

        name, params = c.to_config()
        self.assertEqual(name, "ClipActionsConnector")

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, ClipActionsConnector))

        ac_data = ActionConnectorDataType(0, 1, (8.8, [], {}))

        clipped = c(ac_data)
        self.assertEqual(clipped.output[0], 6.0)
Пример #7
0
    def test_immutable_action_connector(self):
        ctx = ConnectorContext(
            action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1]))
        c = ImmutableActionsConnector(ctx)

        name, params = c.to_config()
        self.assertEqual(name, "ImmutableActionsConnector")

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, ImmutableActionsConnector))

        ac_data = ActionConnectorDataType(0, 1, (np.array([8.8]), [], {}))

        immutable = c(ac_data)

        with self.assertRaises(ValueError):
            immutable.output[0][0] = 5
Пример #8
0
    def test_convert_to_numpy_connector(self):
        ctx = ConnectorContext()
        c = ConvertToNumpyConnector(ctx)

        name, params = c.to_config()

        self.assertEqual(name, "ConvertToNumpyConnector")

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, ConvertToNumpyConnector))

        action = torch.Tensor([8, 9])
        states = torch.Tensor([[1, 1, 1], [2, 2, 2]])
        ac_data = ActionConnectorDataType(0, 1, (action, states, {}))

        converted = c(ac_data)
        self.assertTrue(isinstance(converted.output[0], np.ndarray))
        self.assertTrue(isinstance(converted.output[1], np.ndarray))
Пример #9
0
    def test_unbatch_action_connector(self):
        ctx = ConnectorContext()
        c = UnbatchActionsConnector(ctx)

        name, params = c.to_config()

        self.assertEqual(name, "UnbatchActionsConnector")

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, UnbatchActionsConnector))

        ac_data = ActionConnectorDataType(
            0,
            1,
            (
                {
                    "a": np.array([1, 2, 3]),
                    "b": (np.array([4, 5, 6]), np.array([7, 8, 9])),
                },
                [],
                {},
            ),
        )

        unbatched = c(ac_data)
        actions, _, _ = unbatched.output

        self.assertEqual(len(actions), 3)
        self.assertEqual(actions[0]["a"], 1)
        self.assertTrue((actions[0]["b"] == np.array((4, 7))).all())

        self.assertEqual(actions[1]["a"], 2)
        self.assertTrue((actions[1]["b"] == np.array((5, 8))).all())

        self.assertEqual(actions[2]["a"], 3)
        self.assertTrue((actions[2]["b"] == np.array((6, 9))).all())
Пример #10
0
    def _process_policy_eval_results(
        self,
        to_eval: Dict[PolicyID, List[_PolicyEvalData]],
        eval_results: Dict[PolicyID, PolicyOutputType],
        off_policy_actions: MultiEnvDict,
    ):
        """Process the output of policy neural network evaluation.

        Records policy evaluation results into agent connectors and
        returns replies to send back to agents in the env.

        Args:
            to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects.
            eval_results: Mapping of policy IDs to list of
                actions, rnn-out states, extra-action-fetches dicts.
            off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
                off-policy-action, returned by a `BaseEnv.poll()` call.

        Returns:
            Nested dict of env id -> agent id -> actions to be sent to
            Env (np.ndarrays).
        """
        actions_to_send: Dict[EnvID, Dict[AgentID,
                                          EnvActionType]] = defaultdict(dict)
        for eval_data in to_eval.values():
            for d in eval_data:
                actions_to_send[d.env_id] = {}  # at minimum send empty dict

        # types: PolicyID, List[_PolicyEvalData]
        for policy_id, eval_data in to_eval.items():
            actions: TensorStructType = eval_results[policy_id][0]
            actions = convert_to_numpy(actions)

            rnn_out: StateBatches = eval_results[policy_id][1]
            extra_action_out: dict = eval_results[policy_id][2]

            # In case actions is a list (representing the 0th dim of a batch of
            # primitive actions), try converting it first.
            if isinstance(actions, list):
                actions = np.array(actions)
            # Split action-component batches into single action rows.
            actions: List[EnvActionType] = unbatch(actions)

            policy: Policy = _get_or_raise(self._worker.policy_map, policy_id)
            assert (policy.agent_connectors and policy.action_connectors
                    ), "EnvRunnerV2 requires action connectors to work."

            # types: int, EnvActionType
            for i, action in enumerate(actions):
                env_id: int = eval_data[i].env_id
                agent_id: AgentID = eval_data[i].agent_id

                rnn_states: List[StateBatches] = [c[i] for c in rnn_out]
                fetches: Dict = {k: v[i] for k, v in extra_action_out.items()}

                # Post-process policy output by running them through action connectors.
                ac_data = ActionConnectorDataType(
                    env_id, agent_id, (action, rnn_states, fetches))
                action_to_send, rnn_states, fetches = policy.action_connectors(
                    ac_data).output

                action_to_buffer = (
                    action_to_send if env_id not in off_policy_actions
                    or agent_id not in off_policy_actions[env_id] else
                    off_policy_actions[env_id][agent_id])

                # Notify agent connectors with this new policy output.
                # Necessary for state buffering agent connectors, for example.
                ac_data: AgentConnectorDataType = ActionConnectorDataType(
                    env_id, agent_id, (action_to_buffer, rnn_states, fetches))
                policy.agent_connectors.on_policy_output(ac_data)

                assert agent_id not in actions_to_send[env_id]
                actions_to_send[env_id][agent_id] = action_to_send

        return actions_to_send