def get_agent_connectors_from_config( ctx: ConnectorContext, config: TrainerConfigDict, ) -> AgentConnectorPipeline: connectors = [] if config["clip_rewards"] is True: connectors.append(ClipRewardAgentConnector(ctx, sign=True)) elif type(config["clip_rewards"]) == float: connectors.append( ClipRewardAgentConnector(ctx, limit=abs(config["clip_rewards"])) ) if not config["_disable_preprocessor_api"]: connectors.append(ObsPreprocessorConnector(ctx)) connectors.extend( [ StateBufferConnector(ctx), ViewRequirementAgentConnector(ctx), FlattenDataAgentConnector(ctx), # Creates batch dimension. ] ) return AgentConnectorPipeline(ctx, connectors)
def get_agent_connectors_from_config( config: TrainerConfigDict, obs_space: gym.Space) -> AgentConnectorPipeline: connectors = [FlattenDataAgentConnector()] if config["clip_rewards"] is True: connectors.append(ClipRewardAgentConnector(sign=True)) elif type(config["clip_rewards"]) == float: connectors.append( ClipRewardAgentConnector(limit=abs(config["clip_rewards"]))) return AgentConnectorPipeline(connectors)
def test_connector_pipeline(self): ctx = ConnectorContext() connectors = [ClipRewardAgentConnector(ctx, False, 1.0)] pipeline = AgentConnectorPipeline(ctx, connectors) name, params = pipeline.to_config() restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, AgentConnectorPipeline)) self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
def test_clip_reward_connector(self): ctx = ConnectorContext() c = ClipRewardAgentConnector(ctx, limit=2.0) name, params = c.to_config() self.assertEqual(name, "ClipRewardAgentConnector") self.assertAlmostEqual(params["limit"], 2.0) restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, ClipRewardAgentConnector)) d = AgentConnectorDataType( 0, 1, { SampleBatch.REWARDS: 5.8, }, ) clipped = restored(ac_data=d) self.assertEqual(len(clipped), 1) self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0)