コード例 #1
0
def register_lambda_agent_connector(
        name: str, fn: Callable[[Any], Any]) -> Type[AgentConnector]:
    """A util to register any simple transforming function as an AgentConnector

    The only requirement is that fn should take a single data object and return
    a single data object.

    Args:
        name: Name of the resulting actor connector.
        fn: The function that transforms env / agent data.

    Returns:
        A new AgentConnector class that transforms data using fn.
    """
    class LambdaAgentConnector(AgentConnector):
        def transform(
                self,
                ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
            return AgentConnectorDataType(ac_data.env_id, ac_data.agent_id,
                                          fn(ac_data.data))

        def to_config(self):
            return name, None

        @staticmethod
        def from_config(ctx: ConnectorContext, params: List[Any]):
            return LambdaAgentConnector(ctx)

    LambdaAgentConnector.__name__ = name
    LambdaAgentConnector.__qualname__ = name

    register_connector(name, LambdaAgentConnector)

    return LambdaAgentConnector
コード例 #2
0
ファイル: lambdas.py プロジェクト: ray-project/ray
def register_lambda_action_connector(
    name: str, fn: Callable[[TensorStructType, StateBatches, Dict], PolicyOutputType]
) -> Type[ActionConnector]:
    """A util to register any function transforming PolicyOutputType as an ActionConnector.

    The only requirement is that fn should take actions, states, and fetches as input,
    and return transformed actions, states, and fetches.

    Args:
        name: Name of the resulting actor connector.
        fn: The function that transforms PolicyOutputType.

    Returns:
        A new ActionConnector class that transforms PolicyOutputType using fn.
    """

    class LambdaActionConnector(ActionConnector):
        def transform(
            self, ac_data: ActionConnectorDataType
        ) -> ActionConnectorDataType:
            assert isinstance(
                ac_data.output, tuple
            ), "Action connector requires PolicyOutputType data."

            actions, states, fetches = ac_data.output
            return ActionConnectorDataType(
                ac_data.env_id,
                ac_data.agent_id,
                fn(actions, states, fetches),
            )

        def to_config(self):
            return name, None

        @staticmethod
        def from_config(ctx: ConnectorContext, params: List[Any]):
            return LambdaActionConnector(ctx)

    LambdaActionConnector.__name__ = name
    LambdaActionConnector.__qualname__ = name

    register_connector(name, LambdaActionConnector)

    return LambdaActionConnector
コード例 #3
0
    def __call__(self,
                 ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
        for c in self.connectors:
            ac_data = c(ac_data)
        return ac_data

    def to_config(self):
        return ActionConnectorPipeline.__name__, [
            c.to_config() for c in self.connectors
        ]

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        assert (type(params) == list
                ), "ActionConnectorPipeline takes a list of connector params."
        connectors = [
            get_connector(ctx, name, subparams) for name, subparams in params
        ]
        return ActionConnectorPipeline(ctx, connectors)


register_connector(ActionConnectorPipeline.__name__, ActionConnectorPipeline)


@DeveloperAPI
def get_action_connectors_from_trainer_config(
        config: TrainerConfigDict,
        action_space: gym.Space) -> ActionConnectorPipeline:
    connectors = []
    return ActionConnectorPipeline(connectors)
コード例 #4
0
        if agent_state.states is not None:
            states = agent_state.states
        else:
            states = self._initial_states
        for i, v in enumerate(states):
            d["state_out_{}".format(i)] = v

        if agent_state.action is not None:
            d[SampleBatch.ACTIONS] = agent_state.action  # Last action
        else:
            # Default zero action.
            d[SampleBatch.ACTIONS] = tree.map_structure(
                lambda s: np.zeros_like(s.sample(), s.dtype)
                if hasattr(s, "dtype") else np.zeros_like(s.sample()),
                self._action_space_struct,
            )

        agent_state.t += 1

        return [ac_data]

    def to_config(self):
        return StateBufferConnector.__name__, None

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        return StateBufferConnector(ctx)


register_connector(StateBufferConnector.__name__, StateBufferConnector)
コード例 #5
0
ファイル: clip_reward.py プロジェクト: ray-project/ray
    def transform(self,
                  ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
        d = ac_data.data
        assert (
            type(d) == dict
        ), "Single agent data must be of type Dict[str, TensorStructType]"

        assert SampleBatch.REWARDS in d, "input data does not have reward column."
        if self.sign:
            d[SampleBatch.REWARDS] = np.sign(d[SampleBatch.REWARDS])
        elif self.limit:
            d[SampleBatch.REWARDS] = np.clip(
                d[SampleBatch.REWARDS],
                a_min=-self.limit,
                a_max=self.limit,
            )
        return ac_data

    def to_config(self):
        return ClipRewardAgentConnector.__name__, {
            "sign": self.sign,
            "limit": self.limit,
        }

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        return ClipRewardAgentConnector(ctx, **params)


register_connector(ClipRewardAgentConnector.__name__, ClipRewardAgentConnector)
コード例 #6
0
ファイル: obs_preproc.py プロジェクト: vishalbelsare/ray
        self._preprocessor = get_preprocessor(ctx.observation_space)(
            ctx.observation_space, ctx.config.get("model", {}))

    def __call__(
            self,
            ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
        d = ac_data.data
        assert (
            type(d) == dict
        ), "Single agent data must be of type Dict[str, TensorStructType]"

        if SampleBatch.OBS in d:
            d[SampleBatch.OBS] = self._preprocessor.transform(
                d[SampleBatch.OBS])
        if SampleBatch.NEXT_OBS in d:
            d[SampleBatch.NEXT_OBS] = self._preprocessor.transform(
                d[SampleBatch.NEXT_OBS])

        return [ac_data]

    def to_config(self):
        return ObsPreprocessorConnector.__name__, {}

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        return ObsPreprocessorConnector(ctx, **params)


register_connector(ObsPreprocessorConnector.__name__, ObsPreprocessorConnector)
コード例 #7
0
    def to_config(self):
        return AgentConnectorPipeline.__name__, [
            c.to_config() for c in self.connectors
        ]

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        assert (type(params) == list
                ), "AgentConnectorPipeline takes a list of connector params."
        connectors = [
            get_connector(ctx, name, subparams) for name, subparams in params
        ]
        return AgentConnectorPipeline(ctx, connectors)


register_connector(AgentConnectorPipeline.__name__, AgentConnectorPipeline)


# TODO(jungong) : finish this.
@DeveloperAPI
def get_agent_connectors_from_config(
        config: TrainerConfigDict,
        obs_space: gym.Space) -> AgentConnectorPipeline:
    connectors = [FlattenDataAgentConnector()]

    if config["clip_rewards"] is True:
        connectors.append(ClipRewardAgentConnector(sign=True))
    elif type(config["clip_rewards"]) == float:
        connectors.append(
            ClipRewardAgentConnector(limit=abs(config["clip_rewards"])))
コード例 #8
0
            d_col = np.expand_dims(d[data_col], axis=0)

            if col in agent_batch:
                # Stack along batch dim.
                agent_batch[data_col] = np.vstack(
                    (agent_batch[data_col], d_col))
            else:
                agent_batch[data_col] = d_col
            # Only keep the useful part of the history.
            h = req.shift_from if req.shift_from else -1
            assert h <= 0, "Can use future data to compute action"
            agent_batch[data_col] = agent_batch[data_col][h:]

        sample_batch = self._get_sample_batch_for_action(vr, agent_batch)

        return_data = AgentConnectorDataType(
            env_id, agent_id, AgentConnectorsOutput(training_dict,
                                                    sample_batch))
        return return_data

    def to_config(self):
        return ViewRequirementAgentConnector.__name__, None

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        return ViewRequirementAgentConnector(ctx)


register_connector(ViewRequirementAgentConnector.__name__,
                   ViewRequirementAgentConnector)
コード例 #9
0
from ray.util.annotations import PublicAPI


@PublicAPI(stability="alpha")
class ImmutableActionsConnector(ActionConnector):
    def transform(self,
                  ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
        assert isinstance(
            ac_data.output,
            tuple), "Action connector requires PolicyOutputType data."

        actions, states, fetches = ac_data.output
        tree.traverse(make_action_immutable, actions, top_down=False)

        return ActionConnectorDataType(
            ac_data.env_id,
            ac_data.agent_id,
            (actions, states, fetches),
        )

    def to_config(self):
        return ImmutableActionsConnector.__name__, None

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        return ImmutableActionsConnector(ctx)


register_connector(ImmutableActionsConnector.__name__,
                   ImmutableActionsConnector)
コード例 #10
0
ファイル: normalize.py プロジェクト: vishalbelsare/ray

@DeveloperAPI
class NormalizeActionsConnector(ActionConnector):
    def __init__(self, ctx: ConnectorContext):
        super().__init__(ctx)

        self._action_space_struct = get_base_struct_from_space(ctx.action_space)

    def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
        assert isinstance(
            ac_data.output, tuple
        ), "Action connector requires PolicyOutputType data."

        actions, states, fetches = ac_data.output
        return ActionConnectorDataType(
            ac_data.env_id,
            ac_data.agent_id,
            (unsquash_action(actions, self._action_space_struct), states, fetches),
        )

    def to_config(self):
        return NormalizeActionsConnector.__name__, None

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        return NormalizeActionsConnector(ctx)


register_connector(NormalizeActionsConnector.__name__, NormalizeActionsConnector)
コード例 #11
0
@DeveloperAPI
class ClipActionsConnector(ActionConnector):
    def __init__(self, ctx: ConnectorContext):
        super().__init__(ctx)

        self._action_space_struct = get_base_struct_from_space(
            ctx.action_space)

    def __call__(self,
                 ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
        assert isinstance(
            ac_data.output,
            tuple), "Action connector requires PolicyOutputType data."

        actions, states, fetches = ac_data.output
        return ActionConnectorDataType(
            ac_data.env_id,
            ac_data.agent_id,
            (clip_action(actions, self._action_space_struct), states, fetches),
        )

    def to_config(self):
        return ClipActionsConnector.__name__, None

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        return ClipActionsConnector(ctx)


register_connector(ClipActionsConnector.__name__, ClipActionsConnector)
コード例 #12
0
                SampleBatch.ENV_ID:
                env_id,
                SampleBatch.REWARDS:
                rewards[agent_id],
                # SampleBatch.DONES may be overridden by data from
                # training_episode_infos next.
                SampleBatch.DONES:
                dones[agent_id],
                SampleBatch.NEXT_OBS:
                obs,
            }
            if SampleBatch.INFOS in self._view_requirements:
                input_dict[SampleBatch.INFOS] = infos[agent_id]
            if agent_id in training_episode_infos:
                input_dict.update(training_episode_infos[agent_id])

            per_agent_data.append(
                AgentConnectorDataType(env_id, agent_id, input_dict))

        return per_agent_data

    def to_config(self):
        return EnvToAgentDataConnector.__name__, None

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        return EnvToAgentDataConnector(ctx)


register_connector(EnvToAgentDataConnector.__name__, EnvToAgentDataConnector)