Esempio n. 1
0
    def test_view_requirement_connector(self):
        view_requirements = {
            "obs":
            ViewRequirement(used_for_training=True,
                            used_for_compute_actions=True),
            "prev_actions":
            ViewRequirement(
                data_col="actions",
                shift=-1,
                used_for_training=True,
                used_for_compute_actions=True,
            ),
        }
        ctx = ConnectorContext(view_requirements=view_requirements)

        c = ViewRequirementAgentConnector(ctx)
        f = FlattenDataAgentConnector(ctx)

        d = AgentConnectorDataType(
            0,
            1,
            {
                SampleBatch.NEXT_OBS: {
                    "sensor1": [[1, 1], [2, 2]],
                    "sensor2": 8.8,
                },
                SampleBatch.ACTIONS: np.array(0),
            },
        )
        # ViewRequirementAgentConnector then FlattenAgentConnector.
        processed = f(c([d]))

        self.assertTrue("obs" in processed[0].data.for_action)
        self.assertTrue("prev_actions" in processed[0].data.for_action)
Esempio n. 2
0
    def test_flatten_data_connector(self):
        ctx = ConnectorContext()

        c = FlattenDataAgentConnector(ctx)

        name, params = c.to_config()
        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, FlattenDataAgentConnector))

        d = AgentConnectorDataType(
            0,
            1,
            {
                SampleBatch.NEXT_OBS: {
                    "sensor1": [[1, 1], [2, 2]],
                    "sensor2": 8.8,
                },
                SampleBatch.REWARDS: 5.8,
                SampleBatch.ACTIONS: [[1, 1], [2]],
                SampleBatch.INFOS: {
                    "random": "info"
                },
            },
        )

        flattened = c(d)
        self.assertEqual(len(flattened), 1)

        batch = flattened[0].data
        self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2,
                                                         8.8]).all())
        self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8)
        # Not flattened.
        self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
        self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
Esempio n. 3
0
    def test_obs_preprocessor_connector(self):
        obs_space = gym.spaces.Dict(
            {
                "a": gym.spaces.Box(low=0, high=1, shape=(1,)),
                "b": gym.spaces.Tuple(
                    [gym.spaces.Discrete(2), gym.spaces.MultiDiscrete(nvec=[2, 3])]
                ),
            }
        )
        ctx = ConnectorContext(config={}, observation_space=obs_space)

        c = ObsPreprocessorConnector(ctx)
        name, params = c.to_config()

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, ObsPreprocessorConnector))

        obs = obs_space.sample()
        # Fake deterministic data.
        obs["a"][0] = 0.5
        obs["b"] = (1, np.array([0, 2]))

        d = AgentConnectorDataType(
            0,
            1,
            {
                SampleBatch.OBS: obs,
            },
        )
        preprocessed = c([d])

        # obs is completely flattened.
        self.assertTrue(
            (preprocessed[0].data[SampleBatch.OBS] == [0.5, 0, 1, 1, 0, 0, 0, 1]).all()
        )
Esempio n. 4
0
 def __call__(
         self, ac_data: AgentConnectorDataType
 ) -> List[AgentConnectorDataType]:
     d = ac_data.data
     return [
         AgentConnectorDataType(ac_data.env_id, ac_data.agent_id, fn(d))
     ]
Esempio n. 5
0
def local_policy_inference(
    policy: "Policy",
    env_id: str,
    agent_id: str,
    obs: TensorStructType,
) -> TensorStructType:
    """Run a connector enabled policy using environment observation.

    policy_inference manages policy and agent/action connectors,
    so the user does not have to care about RNN state buffering or
    extra fetch dictionaries.
    Note that connectors are intentionally run separately from
    compute_actions_from_input_dict(), so we can have the option
    of running per-user connectors on the client side in a
    server-client deployment.

    Args:
        policy: Policy.
        env_id: Environment ID.
        agent_id: Agent ID.
        obs: Env obseration.

    Returns:
        List of outputs from policy forward pass.
    """
    assert (policy.agent_connectors
            ), "policy_inference only works with connector enabled policies."

    # TODO(jungong) : support multiple env, multiple agent inference.
    input_dict = {SampleBatch.NEXT_OBS: obs}
    acd_list: List[AgentConnectorDataType] = [
        AgentConnectorDataType(env_id, agent_id, input_dict)
    ]
    ac_outputs: List[AgentConnectorsOutput] = policy.agent_connectors(acd_list)
    outputs = []
    for ac in ac_outputs:
        policy_output = policy.compute_actions_from_input_dict(
            ac.data.for_action)

        if policy.action_connectors:
            acd = ActionConnectorDataType(env_id, agent_id, policy_output)
            acd = policy.action_connectors(acd)
            actions = acd.output
        else:
            actions = policy_output[0]

        outputs.append(actions)

        # Notify agent connectors with this new policy output.
        # Necessary for state buffering agent connectors, for example.
        policy.agent_connectors.on_policy_output(
            ActionConnectorDataType(env_id, agent_id, policy_output))
    return outputs
Esempio n. 6
0
    def __call__(
            self,
            ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
        if ac_data.agent_id:
            # data is already for a single agent.
            return [ac_data]

        assert isinstance(ac_data.data, (tuple, list)) and len(
            ac_data.data) == 5, (
                "EnvToPerAgentDataConnector expects a tuple of " +
                "(obs, rewards, dones, infos, episode_infos).")
        # episode_infos contains additional training related data bits
        # for each agent, such as SampleBatch.T, SampleBatch.AGENT_INDEX,
        # SampleBatch.ACTIONS, SampleBatch.DONES (if hitting horizon),
        # and is usually empty in inference mode.
        obs, rewards, dones, infos, training_episode_infos = ac_data.data
        for var, name in zip(
            (obs, rewards, dones, infos, training_episode_infos),
            ("obs", "rewards", "dones", "infos", "training_episode_infos"),
        ):
            assert isinstance(
                var, dict), (f"EnvToPerAgentDataConnector expects {name} " +
                             "to be a MultiAgentDict.")

        env_id = ac_data.env_id
        per_agent_data = []
        for agent_id, obs in obs.items():
            input_dict = {
                SampleBatch.ENV_ID:
                env_id,
                SampleBatch.REWARDS:
                rewards[agent_id],
                # SampleBatch.DONES may be overridden by data from
                # training_episode_infos next.
                SampleBatch.DONES:
                dones[agent_id],
                SampleBatch.NEXT_OBS:
                obs,
            }
            if SampleBatch.INFOS in self._view_requirements:
                input_dict[SampleBatch.INFOS] = infos[agent_id]
            if agent_id in training_episode_infos:
                input_dict.update(training_episode_infos[agent_id])

            per_agent_data.append(
                AgentConnectorDataType(env_id, agent_id, input_dict))

        return per_agent_data
Esempio n. 7
0
    def test_flatten_data_connector(self):
        ctx = ConnectorContext()

        c = FlattenDataAgentConnector(ctx)

        name, params = c.to_config()
        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, FlattenDataAgentConnector))

        sample_batch = {
            SampleBatch.NEXT_OBS: {
                "sensor1": [[1, 1], [2, 2]],
                "sensor2": 8.8,
            },
            SampleBatch.REWARDS: 5.8,
            SampleBatch.ACTIONS: [[1, 1], [2]],
            SampleBatch.INFOS: {
                "random": "info"
            },
        }

        d = AgentConnectorDataType(
            0,
            1,
            # FlattenDataAgentConnector does NOT touch for_training dict,
            # so simply pass None here.
            AgentConnectorsOutput(None, sample_batch),
        )

        flattened = c([d])
        self.assertEqual(len(flattened), 1)

        batch = flattened[0].data.for_action
        self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2,
                                                         8.8]).all())
        self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8)
        # Not flattened.
        self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
        self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
Esempio n. 8
0
    def test_clip_reward_connector(self):
        ctx = ConnectorContext()

        c = ClipRewardAgentConnector(ctx, limit=2.0)
        name, params = c.to_config()

        self.assertEqual(name, "ClipRewardAgentConnector")
        self.assertAlmostEqual(params["limit"], 2.0)

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, ClipRewardAgentConnector))

        d = AgentConnectorDataType(
            0,
            1,
            {
                SampleBatch.REWARDS: 5.8,
            },
        )
        clipped = restored(ac_data=d)

        self.assertEqual(len(clipped), 1)
        self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0)
Esempio n. 9
0
    def test_env_to_per_agent_data_connector(self):
        vrs = {
            "infos":
            ViewRequirement(
                "infos",
                used_for_training=True,
                used_for_compute_actions=False,
            )
        }
        ctx = ConnectorContext(view_requirements=vrs)

        c = EnvToAgentDataConnector(ctx)

        name, params = c.to_config()
        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, EnvToAgentDataConnector))

        d = AgentConnectorDataType(
            0,
            None,
            [
                # obs
                {
                    1: [8, 8],
                    2: [9, 9]
                },
                # rewards
                {
                    1: 8.8,
                    2: 9.9,
                },
                # dones
                {
                    1: False,
                    2: False,
                },
                # infos
                {
                    1: {
                        "random": "info"
                    },
                    2: {},
                },
                # training_episode_info
                {
                    1: {
                        SampleBatch.DONES: True
                    },
                },
            ],
        )
        per_agent = c(d)

        self.assertEqual(len(per_agent), 2)

        batch1 = per_agent[0].data
        self.assertEqual(batch1[SampleBatch.NEXT_OBS], [8, 8])
        self.assertTrue(
            batch1[SampleBatch.DONES])  # from training_episode_info
        self.assertTrue(SampleBatch.INFOS in batch1)
        self.assertEqual(batch1[SampleBatch.INFOS]["random"], "info")

        batch2 = per_agent[1].data
        self.assertEqual(batch2[SampleBatch.NEXT_OBS], [9, 9])
        self.assertFalse(batch2[SampleBatch.DONES])
Esempio n. 10
0
    def _handle_done_episode(
        self,
        env_id: EnvID,
        env_obs: MultiAgentDict,
        is_done: bool,
        hit_horizon: bool,
        to_eval: Dict[PolicyID, List[_PolicyEvalData]],
        outputs: List[SampleBatchType],
    ) -> None:
        """Handle an all-finished episode.

        Add collected SampleBatch to batch builder. Reset corresponding env, etc.

        Args:
            env_id: Environment ID.
            env_obs: Last per-environment observation.
            is_done: If all agents are done.
            hit_horizon: Whether the episode ended because it hit horizon.
            to_eval: Output container for policy eval data.
            outputs: Output container for collected sample batches.
        """
        check_dones = is_done and not self._no_done_at_end

        episode: EpisodeV2 = self._active_episodes[env_id]
        batch_builder = self._batch_builders[env_id]
        episode.postprocess_episode(
            batch_builder=batch_builder,
            is_done=is_done or (hit_horizon and not self._soft_horizon),
            check_dones=check_dones,
        )

        # If, we are not allowed to pack the next episode into the same
        # SampleBatch (batch_mode=complete_episodes) -> Build the
        # MultiAgentBatch from a single episode and add it to "outputs".
        # Otherwise, just postprocess and continue collecting across
        # episodes.
        if not self._multiple_episodes_in_batch:
            ma_sample_batch = _build_multi_agent_batch(
                episode.episode_id,
                batch_builder,
                self._large_batch_threshold,
                self._multiple_episodes_in_batch,
            )
            if ma_sample_batch:
                outputs.append(ma_sample_batch)

            # SampleBatch built from data collected by batch_builder.
            # Clean up and delete the batch_builder.
            del self._batch_builders[env_id]

            # Call each (in-memory) policy's Exploration.on_episode_end
            # method.
            # Note: This may break the exploration (e.g. ParameterNoise) of
            # policies in the `policy_map` that have not been recently used
            # (and are therefore stashed to disk). However, we certainly do not
            # want to loop through all (even stashed) policies here as that
            # would counter the purpose of the LRU policy caching.
            for p in self._worker.policy_map.cache.values():
                if getattr(p, "exploration", None) is not None:
                    p.exploration.on_episode_end(
                        policy=p,
                        environment=self._base_env,
                        episode=episode,
                        tf_sess=p.get_session(),
                    )
            # Call custom on_episode_end callback.
            self._callbacks.on_episode_end(
                worker=self._worker,
                base_env=self._base_env,
                policies=self._worker.policy_map,
                episode=episode,
                env_index=env_id,
            )

        # Clean up and deleted the post-processed episode now that we have collected
        # its data.
        self.end_episode(env_id, episode)

        # Horizon hit and we have a soft horizon (no hard env reset).
        if hit_horizon and self._soft_horizon:
            resetted_obs: Dict[EnvID, Dict[AgentID, EnvObsType]] = {
                env_id: env_obs
            }
            # Do not reset connector state if this is a soft reset.
            # Basically carry RNN and other buffered state to the
            # next episode from the same env.
        else:
            resetted_obs: Dict[EnvID, Dict[
                AgentID, EnvObsType]] = self._base_env.try_reset(env_id)
            # Reset connector state if this is a hard reset.
            for p in self._worker.policy_map.cache.values():
                p.agent_connectors.reset(env_id)
        # Reset not supported, drop this env from the ready list.
        if resetted_obs is None:
            if self._horizon != float("inf"):
                raise ValueError(
                    "Setting episode horizon requires reset() support "
                    "from the environment.")
        # Creates a new episode if this is not async return.
        # If reset is async, we will get its result in some future poll.
        elif resetted_obs != ASYNC_RESET_RETURN:
            new_episode: EpisodeV2 = self._active_episodes[env_id]
            per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list)
            # types: AgentID, EnvObsType
            for agent_id, raw_obs in resetted_obs[env_id].items():
                policy_id: PolicyID = new_episode.policy_for(agent_id)
                per_policy_resetted_obs[policy_id].append((agent_id, raw_obs))

            processed = []
            for policy_id, agents_obs in per_policy_resetted_obs.items():
                policy = self._worker.policy_map[policy_id]
                acd_list: List[AgentConnectorDataType] = [
                    AgentConnectorDataType(
                        env_id,
                        agent_id,
                        {
                            SampleBatch.T: new_episode.length - 1,
                            SampleBatch.NEXT_OBS: obs,
                        },
                    ) for agent_id, obs in agents_obs
                ]
                # Call agent connectors on these initial obs.
                processed.extend(policy.agent_connectors(acd_list))

            for d in processed:
                # Add initial obs to buffer.
                new_episode.add_init_obs(
                    d.agent_id,
                    d.data.for_training[SampleBatch.T],
                    d.data.for_training[SampleBatch.NEXT_OBS],
                )
                item = _PolicyEvalData(d.env_id, d.agent_id, d.data.for_action)
                to_eval[policy_id].append(item)
Esempio n. 11
0
    def _process_observations(
        self,
        unfiltered_obs: MultiEnvDict,
        rewards: MultiEnvDict,
        dones: MultiEnvDict,
        infos: MultiEnvDict,
    ) -> Tuple[Dict[PolicyID, List[_PolicyEvalData]], List[Union[
            RolloutMetrics, SampleBatchType]], ]:
        """Process raw obs from env.

        Group data for active agents by policy. Reset environments that are done.

        Args:
            unfiltered_obs: obs
            rewards: rewards
            dones: dones
            infos: infos

        Returns:
            A tuple of:
                _PolicyEvalData for active agents for policy evaluation.
                SampleBatches and RolloutMetrics for completed agents for output.
        """
        # Output objects.
        to_eval: Dict[PolicyID, List[_PolicyEvalData]] = defaultdict(list)
        outputs: List[Union[RolloutMetrics, SampleBatchType]] = []

        # For each (vectorized) sub-environment.
        # types: EnvID, Dict[AgentID, EnvObsType]
        for env_id, env_obs in unfiltered_obs.items():
            # Check for env_id having returned an error instead of a multi-agent
            # obs dict. This is how our BaseEnv can tell the caller to `poll()` that
            # one of its sub-environments is faulty and should be restarted (and the
            # ongoing episode should not be used for training).
            if isinstance(env_obs, Exception):
                assert dones[env_id]["__all__"] is True, (
                    f"ERROR: When a sub-environment (env-id {env_id}) returns an error "
                    "as observation, the dones[__all__] flag must also be set to True!"
                )
                # all_agents_obs is an Exception here.
                # Drop this episode and skip to next.
                self.end_episode(env_id, env_obs)
                continue

            episode: EpisodeV2 = self._active_episodes[env_id]

            # Episode length after this step.
            # If this is a branch new episode, this step is adding init_obs.
            # So env_steps will stay at 0. Otherwise, env_steps will advance by 1.
            next_episode_length = episode.length + 1 if episode.has_init_obs else 0
            # Check episode termination conditions.
            if dones[env_id]["__all__"] or next_episode_length >= self._horizon:
                hit_horizon = (next_episode_length >= self._horizon
                               and not dones[env_id]["__all__"])
                all_agents_done = True
                # Add rollout metrics.
                outputs.extend(self._get_rollout_metrics(episode))
            else:
                hit_horizon = False
                all_agents_done = False

            # Special handling of common info dict.
            episode.set_last_info("__common__",
                                  infos[env_id].get("__common__", {}))

            # Agent sample batches grouped by policy. Each set of sample batches will
            # go through agent connectors together.
            sample_batches_by_policy = defaultdict(list)
            # Whether an agent is done, regardless of no_done_at_end or soft_horizon.
            agent_dones = {}
            for agent_id, obs in env_obs.items():
                assert agent_id != "__all__"

                policy_id: PolicyID = episode.policy_for(agent_id)

                agent_done = bool(all_agents_done
                                  or dones[env_id].get(agent_id))
                agent_dones[agent_id] = agent_done

                # A completely new agent is already done -> Skip entirely.
                if not episode.has_init_obs and agent_done:
                    continue

                values_dict = {
                    SampleBatch.T:
                    episode.length - 1,
                    SampleBatch.ENV_ID:
                    env_id,
                    SampleBatch.AGENT_INDEX:
                    episode.agent_index(agent_id),
                    # Last action (SampleBatch.ACTIONS) column will be populated by
                    # StateBufferConnector.
                    # Reward received after taking action at timestep t.
                    SampleBatch.REWARDS:
                    rewards[env_id].get(agent_id, 0.0),
                    # After taking action=a, did we reach terminal?
                    SampleBatch.DONES:
                    (False if
                     (self._no_done_at_end or
                      (hit_horizon and self._soft_horizon)) else agent_done),
                    SampleBatch.INFOS:
                    infos[env_id].get(agent_id, {}),
                    SampleBatch.NEXT_OBS:
                    obs,
                }

                # Queue this obs sample for connector preprocessing.
                sample_batches_by_policy[policy_id].append(
                    (agent_id, values_dict))

            # The entire episode is done.
            if all_agents_done:
                # Let's check to see if there are any agents that haven't got the
                # last "done" obs yet. If there are, we have to create fake-last
                # observations for them. (the environment is not required to do so if
                # dones[__all__]=True).
                for agent_id in episode.get_agents():
                    # If the latest obs we got for this agent is done, or if its
                    # episode state is already done, nothing to do.
                    if agent_dones.get(agent_id,
                                       False) or episode.is_done(agent_id):
                        continue

                    policy_id: PolicyID = episode.policy_for(agent_id)
                    policy = self._worker.policy_map[policy_id]

                    # Create a fake (all-0s) observation.
                    obs_space = policy.observation_space
                    obs_space = getattr(obs_space, "original_space", obs_space)
                    values_dict = {
                        SampleBatch.T:
                        episode.length - 1,
                        SampleBatch.ENV_ID:
                        env_id,
                        SampleBatch.AGENT_INDEX:
                        episode.agent_index(agent_id),
                        SampleBatch.REWARDS:
                        0.0,
                        SampleBatch.DONES:
                        True,
                        SampleBatch.INFOS: {},
                        SampleBatch.NEXT_OBS:
                        tree.map_structure(np.zeros_like, obs_space.sample()),
                    }

                    # Queue these fake obs for connector preprocessing too.
                    sample_batches_by_policy[policy_id].append(
                        (agent_id, values_dict))

            # Run agent connectors.
            processed = []
            for policy_id, batches in sample_batches_by_policy.items():
                policy: Policy = self._worker.policy_map[policy_id]
                # Collected full MultiAgentDicts for this environment.
                # Run agent connectors.
                assert (policy.agent_connectors
                        ), "EnvRunnerV2 requires agent connectors to work."

                acd_list: List[AgentConnectorDataType] = [
                    AgentConnectorDataType(env_id, agent_id, data)
                    for agent_id, data in batches
                ]
                processed.extend(policy.agent_connectors(acd_list))

            is_initial_obs = not episode.has_init_obs
            for d in processed:
                # Record transition info if applicable.
                if is_initial_obs:
                    episode.add_init_obs(
                        d.agent_id,
                        d.data.for_training[SampleBatch.T],
                        d.data.for_training[SampleBatch.NEXT_OBS],
                    )
                else:
                    episode.add_action_reward_done_next_obs(
                        d.agent_id, d.data.for_training)

                if not agent_dones[d.agent_id]:
                    item = _PolicyEvalData(d.env_id, d.agent_id,
                                           d.data.for_action)
                    to_eval[policy_id].append(item)

            # Exception: The very first env.poll() call causes the env to get reset
            # (no step taken yet, just a single starting observation logged).
            # We need to skip this callback in this case.
            if not is_initial_obs:
                # Finished advancing episode by 1 step, mark it so.
                episode.step()

                # Invoke the `on_episode_step` callback after the step is logged
                # to the episode.
                self._callbacks.on_episode_step(
                    worker=self._worker,
                    base_env=self._base_env,
                    policies=self._worker.policy_map,
                    episode=episode,
                    env_index=env_id,
                )

            # Episode is done for all agents (dones[__all__] == True)
            # or we hit the horizon.
            if all_agents_done:
                is_done = dones[env_id]["__all__"]
                # _handle_done_episode will build a MultiAgentBatch for all
                # the agents that are done during this step of rollout in
                # the case of _multiple_episodes_in_batch=False.
                self._handle_done_episode(env_id, env_obs, is_done,
                                          hit_horizon, to_eval, outputs)

            # Try to build something.
            if self._multiple_episodes_in_batch:
                sample_batch = self._try_build_truncated_episode_multi_agent_batch(
                    self._batch_builders[env_id], episode)
                if sample_batch:
                    outputs.append(sample_batch)

                    # SampleBatch built from data collected by batch_builder.
                    # Clean up and delete the batch_builder.
                    del self._batch_builders[env_id]

        return to_eval, outputs
Esempio n. 12
0
    def __call__(
            self,
            ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
        d = ac_data.data
        assert (
            type(d) == dict
        ), "Single agent data must be of type Dict[str, TensorStructType]"

        env_id = ac_data.env_id
        agent_id = ac_data.agent_id
        assert env_id and agent_id, "StateBufferConnector requires env_id and agent_id"

        vr = self._view_requirements
        assert vr, "ViewRequirements required by ViewRequirementConnector"

        training_dict = {}
        # We construct a proper per-timeslice dict in training mode,
        # for Sampler to construct a complete episode for back propagation.
        if self.is_training:
            # Filter columns that are not needed for traing.
            for col, req in vr.items():
                # Not used for training.
                if not req.used_for_training:
                    continue

                # Create the batch of data from the different buffers.
                data_col = req.data_col or col
                if data_col not in d:
                    continue

                training_dict[data_col] = d[data_col]

        # Agent batch is our buffer of necessary history for computing
        # a SampleBatch for policy forward pass.
        # This is used by both training and inference.
        agent_batch = self._agent_data[env_id][agent_id]
        for col, req in vr.items():
            # Not used for action computation.
            if not req.used_for_compute_actions:
                continue

            # Create the batch of data from the different buffers.
            data_col = req.data_col or col
            if data_col not in d:
                continue

            # Add batch dim to this data_col.
            d_col = np.expand_dims(d[data_col], axis=0)

            if col in agent_batch:
                # Stack along batch dim.
                agent_batch[data_col] = np.vstack(
                    (agent_batch[data_col], d_col))
            else:
                agent_batch[data_col] = d_col
            # Only keep the useful part of the history.
            h = req.shift_from if req.shift_from else -1
            assert h <= 0, "Can use future data to compute action"
            agent_batch[data_col] = agent_batch[data_col][h:]

        sample_batch = self._get_sample_batch_for_action(vr, agent_batch)

        return_data = AgentConnectorDataType(
            env_id, agent_id, AgentConnectorsOutput(training_dict,
                                                    sample_batch))
        return return_data
Esempio n. 13
0
    def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
        assert isinstance(ac_data.data, AgentConnectorsOutput), (
            "ViewRequirementAgentConnector operates on raw input dict and its"
            "flattened SampleBatch."
        )

        d = ac_data.data.for_training
        f = ac_data.data.for_action
        assert (
            type(d) == dict
        ), "Single agent data must be of type Dict[str, TensorStructType]"

        env_id = ac_data.env_id
        agent_id = ac_data.agent_id
        assert env_id is not None and agent_id is not None, (
            f"ViewRequirementAgentConnector requires env_id({env_id}) "
            "and agent_id({agent_id})"
        )

        vr = self._view_requirements
        assert vr, "ViewRequirements required by ViewRequirementConnector"

        training_dict = None
        # We construct a proper per-timeslice dict in training mode,
        # for env runner to construct a complete episode.
        if self.is_training:
            # Note(jungong) : we need to keep the entire input dict here.
            # A column may be used by postprocessing (GAE) even if its
            # iew_requirement.used_for_training is False.
            training_dict = d

        # Agent batch is our buffer of necessary history for computing
        # a SampleBatch for policy forward pass.
        # This is used by both training and inference.
        agent_batch = self._agent_data[env_id][agent_id]
        for col, req in vr.items():
            # Not used for action computation.
            if not req.used_for_compute_actions:
                continue

            # Create the batch of data from the different buffers.
            if col == SampleBatch.OBS:
                # NEXT_OBS from the training sample is the current OBS
                # to run Policy with.
                data_col = SampleBatch.NEXT_OBS
            else:
                data_col = req.data_col or col
            if data_col not in d:
                continue

            if col in agent_batch:
                # Stack along batch dim.
                agent_batch[col] = np.vstack((agent_batch[col], f[data_col]))
            else:
                agent_batch[col] = f[data_col]
            # Only keep the useful part of the history.
            h = req.shift_from if req.shift_from else -1
            assert h <= 0, "Can use future data to compute action"
            agent_batch[col] = agent_batch[col][h:]

        sample_batch = self._get_sample_batch_for_action(vr, agent_batch)

        return_data = AgentConnectorDataType(
            env_id, agent_id, AgentConnectorsOutput(training_dict, sample_batch)
        )
        return return_data
Esempio n. 14
0
 def transform(
         self,
         ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
     return AgentConnectorDataType(ac_data.env_id, ac_data.agent_id,
                                   fn(ac_data.data))