def run(): # Restore policy. policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id]) policy = policies[args.policy_id] # Adapt policy trained for standard CartPole to the new env. ctx: ConnectorContext = ConnectorContext.from_policy(policy) policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx)) policy.action_connectors.append(V1ToV2ActionConnector(ctx)) # Run CartPole. env = MyCartPole() obs = env.reset() done = False step = 0 while not done: step += 1 # Use local_policy_inference() to easily run poicy with observations. policy_outputs = local_policy_inference(policy, "env_1", "agent_1", obs) assert len(policy_outputs) == 1 actions, _, _ = policy_outputs[0] print(f"step {step}", obs, actions) obs, _, done, _ = env.step(actions)
def test_obs_preprocessor_connector(self): obs_space = gym.spaces.Dict( { "a": gym.spaces.Box(low=0, high=1, shape=(1,)), "b": gym.spaces.Tuple( [gym.spaces.Discrete(2), gym.spaces.MultiDiscrete(nvec=[2, 3])] ), } ) ctx = ConnectorContext(config={}, observation_space=obs_space) c = ObsPreprocessorConnector(ctx) name, params = c.to_config() restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, ObsPreprocessorConnector)) obs = obs_space.sample() # Fake deterministic data. obs["a"][0] = 0.5 obs["b"] = (1, np.array([0, 2])) d = AgentConnectorDataType( 0, 1, { SampleBatch.OBS: obs, }, ) preprocessed = c([d]) # obs is completely flattened. self.assertTrue( (preprocessed[0].data[SampleBatch.OBS] == [0.5, 0, 1, 1, 0, 0, 0, 1]).all() )
def test_flatten_data_connector(self): ctx = ConnectorContext() c = FlattenDataAgentConnector(ctx) name, params = c.to_config() restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, FlattenDataAgentConnector)) d = AgentConnectorDataType( 0, 1, { SampleBatch.NEXT_OBS: { "sensor1": [[1, 1], [2, 2]], "sensor2": 8.8, }, SampleBatch.REWARDS: 5.8, SampleBatch.ACTIONS: [[1, 1], [2]], SampleBatch.INFOS: { "random": "info" }, }, ) flattened = c(d) self.assertEqual(len(flattened), 1) batch = flattened[0].data self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2, 8.8]).all()) self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8) # Not flattened. self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2) self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
def test_view_requirement_connector(self): view_requirements = { "obs": ViewRequirement(used_for_training=True, used_for_compute_actions=True), "prev_actions": ViewRequirement( data_col="actions", shift=-1, used_for_training=True, used_for_compute_actions=True, ), } ctx = ConnectorContext(view_requirements=view_requirements) c = ViewRequirementAgentConnector(ctx) f = FlattenDataAgentConnector(ctx) d = AgentConnectorDataType( 0, 1, { SampleBatch.NEXT_OBS: { "sensor1": [[1, 1], [2, 2]], "sensor2": 8.8, }, SampleBatch.ACTIONS: np.array(0), }, ) # ViewRequirementAgentConnector then FlattenAgentConnector. processed = f(c([d])) self.assertTrue("obs" in processed[0].data.for_action) self.assertTrue("prev_actions" in processed[0].data.for_action)
def test_connector_pipeline(self): ctx = ConnectorContext() connectors = [ClipRewardAgentConnector(ctx, False, 1.0)] pipeline = AgentConnectorPipeline(ctx, connectors) name, params = pipeline.to_config() restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, AgentConnectorPipeline)) self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
def test_connector_pipeline(self): ctx = ConnectorContext() connectors = [ConvertToNumpyConnector(ctx)] pipeline = ActionConnectorPipeline(ctx, connectors) name, params = pipeline.to_config() restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, ActionConnectorPipeline)) self.assertTrue( isinstance(restored.connectors[0], ConvertToNumpyConnector))
def restore_connectors_for_policy( policy: "Policy", connector_config: Tuple[str, Tuple[Any]] ) -> Connector: """Util to create connector for a Policy based on serialized config. Args: policy: Policy instance. connector_config: Serialized connector config. """ ctx: ConnectorContext = ConnectorContext.from_policy(policy) name, params = connector_config return get_connector(ctx, name, params)
def test_normalize_action_connector(self): ctx = ConnectorContext( action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])) c = NormalizeActionsConnector(ctx) name, params = c.to_config() self.assertEqual(name, "NormalizeActionsConnector") restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, NormalizeActionsConnector)) ac_data = ActionConnectorDataType(0, 1, (0.5, [], {})) normalized = c(ac_data) self.assertEqual(normalized.output[0], 4.5)
def create_connectors_for_policy(policy: "Policy", config: TrainerConfigDict): """Util to create agent and action connectors for a Policy. Args: policy: Policy instance. config: Trainer config dict. """ ctx: ConnectorContext = ConnectorContext.from_policy(policy) policy.agent_connectors = get_agent_connectors_from_config(ctx, config) policy.action_connectors = get_action_connectors_from_config(ctx, config) logger.info("Using connectors:") logger.info(policy.agent_connectors.__str__(indentation=4)) logger.info(policy.action_connectors.__str__(indentation=4))
def test_clip_action_connector(self): ctx = ConnectorContext( action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])) c = ClipActionsConnector(ctx) name, params = c.to_config() self.assertEqual(name, "ClipActionsConnector") restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, ClipActionsConnector)) ac_data = ActionConnectorDataType(0, 1, (8.8, [], {})) clipped = c(ac_data) self.assertEqual(clipped.output[0], 6.0)
def create_connectors_for_policy(policy: "Policy", config: TrainerConfigDict): """Util to create agent and action connectors for a Policy. Args: policy: Policy instance. config: Trainer config dict. """ ctx: ConnectorContext = ConnectorContext.from_policy(policy) policy.agent_connectors = get_agent_connectors_from_config(ctx, config) policy.action_connectors = get_action_connectors_from_trainer_config( ctx, config) print("Connectors enabled:") print(policy.agent_connectors.__str__(indentation=4)) print(policy.action_connectors.__str__(indentation=4))
def test_immutable_action_connector(self): ctx = ConnectorContext( action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])) c = ImmutableActionsConnector(ctx) name, params = c.to_config() self.assertEqual(name, "ImmutableActionsConnector") restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, ImmutableActionsConnector)) ac_data = ActionConnectorDataType(0, 1, (np.array([8.8]), [], {})) immutable = c(ac_data) with self.assertRaises(ValueError): immutable.output[0][0] = 5
def test_convert_to_numpy_connector(self): ctx = ConnectorContext() c = ConvertToNumpyConnector(ctx) name, params = c.to_config() self.assertEqual(name, "ConvertToNumpyConnector") restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, ConvertToNumpyConnector)) action = torch.Tensor([8, 9]) states = torch.Tensor([[1, 1, 1], [2, 2, 2]]) ac_data = ActionConnectorDataType(0, 1, (action, states, {})) converted = c(ac_data) self.assertTrue(isinstance(converted.output[0], np.ndarray)) self.assertTrue(isinstance(converted.output[1], np.ndarray))
def test_flatten_data_connector(self): ctx = ConnectorContext() c = FlattenDataAgentConnector(ctx) name, params = c.to_config() restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, FlattenDataAgentConnector)) sample_batch = { SampleBatch.NEXT_OBS: { "sensor1": [[1, 1], [2, 2]], "sensor2": 8.8, }, SampleBatch.REWARDS: 5.8, SampleBatch.ACTIONS: [[1, 1], [2]], SampleBatch.INFOS: { "random": "info" }, } d = AgentConnectorDataType( 0, 1, # FlattenDataAgentConnector does NOT touch for_training dict, # so simply pass None here. AgentConnectorsOutput(None, sample_batch), ) flattened = c([d]) self.assertEqual(len(flattened), 1) batch = flattened[0].data.for_action self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2, 8.8]).all()) self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8) # Not flattened. self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2) self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
def test_clip_reward_connector(self): ctx = ConnectorContext() c = ClipRewardAgentConnector(ctx, limit=2.0) name, params = c.to_config() self.assertEqual(name, "ClipRewardAgentConnector") self.assertAlmostEqual(params["limit"], 2.0) restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, ClipRewardAgentConnector)) d = AgentConnectorDataType( 0, 1, { SampleBatch.REWARDS: 5.8, }, ) clipped = restored(ac_data=d) self.assertEqual(len(clipped), 1) self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0)
def test_unbatch_action_connector(self): ctx = ConnectorContext() c = UnbatchActionsConnector(ctx) name, params = c.to_config() self.assertEqual(name, "UnbatchActionsConnector") restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, UnbatchActionsConnector)) ac_data = ActionConnectorDataType( 0, 1, ( { "a": np.array([1, 2, 3]), "b": (np.array([4, 5, 6]), np.array([7, 8, 9])), }, [], {}, ), ) unbatched = c(ac_data) actions, _, _ = unbatched.output self.assertEqual(len(actions), 3) self.assertEqual(actions[0]["a"], 1) self.assertTrue((actions[0]["b"] == np.array((4, 7))).all()) self.assertEqual(actions[1]["a"], 2) self.assertTrue((actions[1]["b"] == np.array((5, 8))).all()) self.assertEqual(actions[2]["a"], 3) self.assertTrue((actions[2]["b"] == np.array((6, 9))).all())
def test_env_to_per_agent_data_connector(self): vrs = { "infos": ViewRequirement( "infos", used_for_training=True, used_for_compute_actions=False, ) } ctx = ConnectorContext(view_requirements=vrs) c = EnvToAgentDataConnector(ctx) name, params = c.to_config() restored = get_connector(ctx, name, params) self.assertTrue(isinstance(restored, EnvToAgentDataConnector)) d = AgentConnectorDataType( 0, None, [ # obs { 1: [8, 8], 2: [9, 9] }, # rewards { 1: 8.8, 2: 9.9, }, # dones { 1: False, 2: False, }, # infos { 1: { "random": "info" }, 2: {}, }, # training_episode_info { 1: { SampleBatch.DONES: True }, }, ], ) per_agent = c(d) self.assertEqual(len(per_agent), 2) batch1 = per_agent[0].data self.assertEqual(batch1[SampleBatch.NEXT_OBS], [8, 8]) self.assertTrue( batch1[SampleBatch.DONES]) # from training_episode_info self.assertTrue(SampleBatch.INFOS in batch1) self.assertEqual(batch1[SampleBatch.INFOS]["random"], "info") batch2 = per_agent[1].data self.assertEqual(batch2[SampleBatch.NEXT_OBS], [9, 9]) self.assertFalse(batch2[SampleBatch.DONES])