def register_lambda_agent_connector( name: str, fn: Callable[[Any], Any]) -> Type[AgentConnector]: """A util to register any simple transforming function as an AgentConnector The only requirement is that fn should take a single data object and return a single data object. Args: name: Name of the resulting actor connector. fn: The function that transforms env / agent data. Returns: A new AgentConnector class that transforms data using fn. """ class LambdaAgentConnector(AgentConnector): def transform( self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: return AgentConnectorDataType(ac_data.env_id, ac_data.agent_id, fn(ac_data.data)) def to_config(self): return name, None @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): return LambdaAgentConnector(ctx) LambdaAgentConnector.__name__ = name LambdaAgentConnector.__qualname__ = name register_connector(name, LambdaAgentConnector) return LambdaAgentConnector
def register_lambda_action_connector( name: str, fn: Callable[[TensorStructType, StateBatches, Dict], PolicyOutputType] ) -> Type[ActionConnector]: """A util to register any function transforming PolicyOutputType as an ActionConnector. The only requirement is that fn should take actions, states, and fetches as input, and return transformed actions, states, and fetches. Args: name: Name of the resulting actor connector. fn: The function that transforms PolicyOutputType. Returns: A new ActionConnector class that transforms PolicyOutputType using fn. """ class LambdaActionConnector(ActionConnector): def transform( self, ac_data: ActionConnectorDataType ) -> ActionConnectorDataType: assert isinstance( ac_data.output, tuple ), "Action connector requires PolicyOutputType data." actions, states, fetches = ac_data.output return ActionConnectorDataType( ac_data.env_id, ac_data.agent_id, fn(actions, states, fetches), ) def to_config(self): return name, None @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): return LambdaActionConnector(ctx) LambdaActionConnector.__name__ = name LambdaActionConnector.__qualname__ = name register_connector(name, LambdaActionConnector) return LambdaActionConnector
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: for c in self.connectors: ac_data = c(ac_data) return ac_data def to_config(self): return ActionConnectorPipeline.__name__, [ c.to_config() for c in self.connectors ] @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): assert (type(params) == list ), "ActionConnectorPipeline takes a list of connector params." connectors = [ get_connector(ctx, name, subparams) for name, subparams in params ] return ActionConnectorPipeline(ctx, connectors) register_connector(ActionConnectorPipeline.__name__, ActionConnectorPipeline) @DeveloperAPI def get_action_connectors_from_trainer_config( config: TrainerConfigDict, action_space: gym.Space) -> ActionConnectorPipeline: connectors = [] return ActionConnectorPipeline(connectors)
if agent_state.states is not None: states = agent_state.states else: states = self._initial_states for i, v in enumerate(states): d["state_out_{}".format(i)] = v if agent_state.action is not None: d[SampleBatch.ACTIONS] = agent_state.action # Last action else: # Default zero action. d[SampleBatch.ACTIONS] = tree.map_structure( lambda s: np.zeros_like(s.sample(), s.dtype) if hasattr(s, "dtype") else np.zeros_like(s.sample()), self._action_space_struct, ) agent_state.t += 1 return [ac_data] def to_config(self): return StateBufferConnector.__name__, None @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): return StateBufferConnector(ctx) register_connector(StateBufferConnector.__name__, StateBufferConnector)
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: d = ac_data.data assert ( type(d) == dict ), "Single agent data must be of type Dict[str, TensorStructType]" assert SampleBatch.REWARDS in d, "input data does not have reward column." if self.sign: d[SampleBatch.REWARDS] = np.sign(d[SampleBatch.REWARDS]) elif self.limit: d[SampleBatch.REWARDS] = np.clip( d[SampleBatch.REWARDS], a_min=-self.limit, a_max=self.limit, ) return ac_data def to_config(self): return ClipRewardAgentConnector.__name__, { "sign": self.sign, "limit": self.limit, } @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): return ClipRewardAgentConnector(ctx, **params) register_connector(ClipRewardAgentConnector.__name__, ClipRewardAgentConnector)
self._preprocessor = get_preprocessor(ctx.observation_space)( ctx.observation_space, ctx.config.get("model", {})) def __call__( self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]: d = ac_data.data assert ( type(d) == dict ), "Single agent data must be of type Dict[str, TensorStructType]" if SampleBatch.OBS in d: d[SampleBatch.OBS] = self._preprocessor.transform( d[SampleBatch.OBS]) if SampleBatch.NEXT_OBS in d: d[SampleBatch.NEXT_OBS] = self._preprocessor.transform( d[SampleBatch.NEXT_OBS]) return [ac_data] def to_config(self): return ObsPreprocessorConnector.__name__, {} @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): return ObsPreprocessorConnector(ctx, **params) register_connector(ObsPreprocessorConnector.__name__, ObsPreprocessorConnector)
def to_config(self): return AgentConnectorPipeline.__name__, [ c.to_config() for c in self.connectors ] @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): assert (type(params) == list ), "AgentConnectorPipeline takes a list of connector params." connectors = [ get_connector(ctx, name, subparams) for name, subparams in params ] return AgentConnectorPipeline(ctx, connectors) register_connector(AgentConnectorPipeline.__name__, AgentConnectorPipeline) # TODO(jungong) : finish this. @DeveloperAPI def get_agent_connectors_from_config( config: TrainerConfigDict, obs_space: gym.Space) -> AgentConnectorPipeline: connectors = [FlattenDataAgentConnector()] if config["clip_rewards"] is True: connectors.append(ClipRewardAgentConnector(sign=True)) elif type(config["clip_rewards"]) == float: connectors.append( ClipRewardAgentConnector(limit=abs(config["clip_rewards"])))
d_col = np.expand_dims(d[data_col], axis=0) if col in agent_batch: # Stack along batch dim. agent_batch[data_col] = np.vstack( (agent_batch[data_col], d_col)) else: agent_batch[data_col] = d_col # Only keep the useful part of the history. h = req.shift_from if req.shift_from else -1 assert h <= 0, "Can use future data to compute action" agent_batch[data_col] = agent_batch[data_col][h:] sample_batch = self._get_sample_batch_for_action(vr, agent_batch) return_data = AgentConnectorDataType( env_id, agent_id, AgentConnectorsOutput(training_dict, sample_batch)) return return_data def to_config(self): return ViewRequirementAgentConnector.__name__, None @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): return ViewRequirementAgentConnector(ctx) register_connector(ViewRequirementAgentConnector.__name__, ViewRequirementAgentConnector)
from ray.util.annotations import PublicAPI @PublicAPI(stability="alpha") class ImmutableActionsConnector(ActionConnector): def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: assert isinstance( ac_data.output, tuple), "Action connector requires PolicyOutputType data." actions, states, fetches = ac_data.output tree.traverse(make_action_immutable, actions, top_down=False) return ActionConnectorDataType( ac_data.env_id, ac_data.agent_id, (actions, states, fetches), ) def to_config(self): return ImmutableActionsConnector.__name__, None @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): return ImmutableActionsConnector(ctx) register_connector(ImmutableActionsConnector.__name__, ImmutableActionsConnector)
@DeveloperAPI class NormalizeActionsConnector(ActionConnector): def __init__(self, ctx: ConnectorContext): super().__init__(ctx) self._action_space_struct = get_base_struct_from_space(ctx.action_space) def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: assert isinstance( ac_data.output, tuple ), "Action connector requires PolicyOutputType data." actions, states, fetches = ac_data.output return ActionConnectorDataType( ac_data.env_id, ac_data.agent_id, (unsquash_action(actions, self._action_space_struct), states, fetches), ) def to_config(self): return NormalizeActionsConnector.__name__, None @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): return NormalizeActionsConnector(ctx) register_connector(NormalizeActionsConnector.__name__, NormalizeActionsConnector)
@DeveloperAPI class ClipActionsConnector(ActionConnector): def __init__(self, ctx: ConnectorContext): super().__init__(ctx) self._action_space_struct = get_base_struct_from_space( ctx.action_space) def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: assert isinstance( ac_data.output, tuple), "Action connector requires PolicyOutputType data." actions, states, fetches = ac_data.output return ActionConnectorDataType( ac_data.env_id, ac_data.agent_id, (clip_action(actions, self._action_space_struct), states, fetches), ) def to_config(self): return ClipActionsConnector.__name__, None @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): return ClipActionsConnector(ctx) register_connector(ClipActionsConnector.__name__, ClipActionsConnector)
SampleBatch.ENV_ID: env_id, SampleBatch.REWARDS: rewards[agent_id], # SampleBatch.DONES may be overridden by data from # training_episode_infos next. SampleBatch.DONES: dones[agent_id], SampleBatch.NEXT_OBS: obs, } if SampleBatch.INFOS in self._view_requirements: input_dict[SampleBatch.INFOS] = infos[agent_id] if agent_id in training_episode_infos: input_dict.update(training_episode_infos[agent_id]) per_agent_data.append( AgentConnectorDataType(env_id, agent_id, input_dict)) return per_agent_data def to_config(self): return EnvToAgentDataConnector.__name__, None @staticmethod def from_config(ctx: ConnectorContext, params: List[Any]): return EnvToAgentDataConnector(ctx) register_connector(EnvToAgentDataConnector.__name__, EnvToAgentDataConnector)