Exemple #1
0
def run():
    # Restore policy.
    policies = load_policies_from_checkpoint(args.checkpoint_file,
                                             [args.policy_id])
    policy = policies[args.policy_id]

    # Adapt policy trained for standard CartPole to the new env.
    ctx: ConnectorContext = ConnectorContext.from_policy(policy)
    policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx))
    policy.action_connectors.append(V1ToV2ActionConnector(ctx))

    # Run CartPole.
    env = MyCartPole()
    obs = env.reset()
    done = False
    step = 0
    while not done:
        step += 1

        # Use local_policy_inference() to easily run poicy with observations.
        policy_outputs = local_policy_inference(policy, "env_1", "agent_1",
                                                obs)
        assert len(policy_outputs) == 1
        actions, _, _ = policy_outputs[0]
        print(f"step {step}", obs, actions)

        obs, _, done, _ = env.step(actions)
Exemple #2
0
    def test_obs_preprocessor_connector(self):
        obs_space = gym.spaces.Dict(
            {
                "a": gym.spaces.Box(low=0, high=1, shape=(1,)),
                "b": gym.spaces.Tuple(
                    [gym.spaces.Discrete(2), gym.spaces.MultiDiscrete(nvec=[2, 3])]
                ),
            }
        )
        ctx = ConnectorContext(config={}, observation_space=obs_space)

        c = ObsPreprocessorConnector(ctx)
        name, params = c.to_config()

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, ObsPreprocessorConnector))

        obs = obs_space.sample()
        # Fake deterministic data.
        obs["a"][0] = 0.5
        obs["b"] = (1, np.array([0, 2]))

        d = AgentConnectorDataType(
            0,
            1,
            {
                SampleBatch.OBS: obs,
            },
        )
        preprocessed = c([d])

        # obs is completely flattened.
        self.assertTrue(
            (preprocessed[0].data[SampleBatch.OBS] == [0.5, 0, 1, 1, 0, 0, 0, 1]).all()
        )
Exemple #3
0
    def test_flatten_data_connector(self):
        ctx = ConnectorContext()

        c = FlattenDataAgentConnector(ctx)

        name, params = c.to_config()
        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, FlattenDataAgentConnector))

        d = AgentConnectorDataType(
            0,
            1,
            {
                SampleBatch.NEXT_OBS: {
                    "sensor1": [[1, 1], [2, 2]],
                    "sensor2": 8.8,
                },
                SampleBatch.REWARDS: 5.8,
                SampleBatch.ACTIONS: [[1, 1], [2]],
                SampleBatch.INFOS: {
                    "random": "info"
                },
            },
        )

        flattened = c(d)
        self.assertEqual(len(flattened), 1)

        batch = flattened[0].data
        self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2,
                                                         8.8]).all())
        self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8)
        # Not flattened.
        self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
        self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
Exemple #4
0
    def test_view_requirement_connector(self):
        view_requirements = {
            "obs":
            ViewRequirement(used_for_training=True,
                            used_for_compute_actions=True),
            "prev_actions":
            ViewRequirement(
                data_col="actions",
                shift=-1,
                used_for_training=True,
                used_for_compute_actions=True,
            ),
        }
        ctx = ConnectorContext(view_requirements=view_requirements)

        c = ViewRequirementAgentConnector(ctx)
        f = FlattenDataAgentConnector(ctx)

        d = AgentConnectorDataType(
            0,
            1,
            {
                SampleBatch.NEXT_OBS: {
                    "sensor1": [[1, 1], [2, 2]],
                    "sensor2": 8.8,
                },
                SampleBatch.ACTIONS: np.array(0),
            },
        )
        # ViewRequirementAgentConnector then FlattenAgentConnector.
        processed = f(c([d]))

        self.assertTrue("obs" in processed[0].data.for_action)
        self.assertTrue("prev_actions" in processed[0].data.for_action)
Exemple #5
0
 def test_connector_pipeline(self):
     ctx = ConnectorContext()
     connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
     pipeline = AgentConnectorPipeline(ctx, connectors)
     name, params = pipeline.to_config()
     restored = get_connector(ctx, name, params)
     self.assertTrue(isinstance(restored, AgentConnectorPipeline))
     self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
Exemple #6
0
 def test_connector_pipeline(self):
     ctx = ConnectorContext()
     connectors = [ConvertToNumpyConnector(ctx)]
     pipeline = ActionConnectorPipeline(ctx, connectors)
     name, params = pipeline.to_config()
     restored = get_connector(ctx, name, params)
     self.assertTrue(isinstance(restored, ActionConnectorPipeline))
     self.assertTrue(
         isinstance(restored.connectors[0], ConvertToNumpyConnector))
Exemple #7
0
def restore_connectors_for_policy(
    policy: "Policy", connector_config: Tuple[str, Tuple[Any]]
) -> Connector:
    """Util to create connector for a Policy based on serialized config.

    Args:
        policy: Policy instance.
        connector_config: Serialized connector config.
    """
    ctx: ConnectorContext = ConnectorContext.from_policy(policy)
    name, params = connector_config
    return get_connector(ctx, name, params)
Exemple #8
0
    def test_normalize_action_connector(self):
        ctx = ConnectorContext(
            action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1]))
        c = NormalizeActionsConnector(ctx)

        name, params = c.to_config()
        self.assertEqual(name, "NormalizeActionsConnector")

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, NormalizeActionsConnector))

        ac_data = ActionConnectorDataType(0, 1, (0.5, [], {}))

        normalized = c(ac_data)
        self.assertEqual(normalized.output[0], 4.5)
Exemple #9
0
def create_connectors_for_policy(policy: "Policy", config: TrainerConfigDict):
    """Util to create agent and action connectors for a Policy.

    Args:
        policy: Policy instance.
        config: Trainer config dict.
    """
    ctx: ConnectorContext = ConnectorContext.from_policy(policy)

    policy.agent_connectors = get_agent_connectors_from_config(ctx, config)
    policy.action_connectors = get_action_connectors_from_config(ctx, config)

    logger.info("Using connectors:")
    logger.info(policy.agent_connectors.__str__(indentation=4))
    logger.info(policy.action_connectors.__str__(indentation=4))
Exemple #10
0
    def test_clip_action_connector(self):
        ctx = ConnectorContext(
            action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1]))
        c = ClipActionsConnector(ctx)

        name, params = c.to_config()
        self.assertEqual(name, "ClipActionsConnector")

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, ClipActionsConnector))

        ac_data = ActionConnectorDataType(0, 1, (8.8, [], {}))

        clipped = c(ac_data)
        self.assertEqual(clipped.output[0], 6.0)
Exemple #11
0
def create_connectors_for_policy(policy: "Policy", config: TrainerConfigDict):
    """Util to create agent and action connectors for a Policy.

    Args:
        policy: Policy instance.
        config: Trainer config dict.
    """
    ctx: ConnectorContext = ConnectorContext.from_policy(policy)

    policy.agent_connectors = get_agent_connectors_from_config(ctx, config)
    policy.action_connectors = get_action_connectors_from_trainer_config(
        ctx, config)

    print("Connectors enabled:")
    print(policy.agent_connectors.__str__(indentation=4))
    print(policy.action_connectors.__str__(indentation=4))
Exemple #12
0
    def test_immutable_action_connector(self):
        ctx = ConnectorContext(
            action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1]))
        c = ImmutableActionsConnector(ctx)

        name, params = c.to_config()
        self.assertEqual(name, "ImmutableActionsConnector")

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, ImmutableActionsConnector))

        ac_data = ActionConnectorDataType(0, 1, (np.array([8.8]), [], {}))

        immutable = c(ac_data)

        with self.assertRaises(ValueError):
            immutable.output[0][0] = 5
Exemple #13
0
    def test_convert_to_numpy_connector(self):
        ctx = ConnectorContext()
        c = ConvertToNumpyConnector(ctx)

        name, params = c.to_config()

        self.assertEqual(name, "ConvertToNumpyConnector")

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, ConvertToNumpyConnector))

        action = torch.Tensor([8, 9])
        states = torch.Tensor([[1, 1, 1], [2, 2, 2]])
        ac_data = ActionConnectorDataType(0, 1, (action, states, {}))

        converted = c(ac_data)
        self.assertTrue(isinstance(converted.output[0], np.ndarray))
        self.assertTrue(isinstance(converted.output[1], np.ndarray))
Exemple #14
0
    def test_flatten_data_connector(self):
        ctx = ConnectorContext()

        c = FlattenDataAgentConnector(ctx)

        name, params = c.to_config()
        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, FlattenDataAgentConnector))

        sample_batch = {
            SampleBatch.NEXT_OBS: {
                "sensor1": [[1, 1], [2, 2]],
                "sensor2": 8.8,
            },
            SampleBatch.REWARDS: 5.8,
            SampleBatch.ACTIONS: [[1, 1], [2]],
            SampleBatch.INFOS: {
                "random": "info"
            },
        }

        d = AgentConnectorDataType(
            0,
            1,
            # FlattenDataAgentConnector does NOT touch for_training dict,
            # so simply pass None here.
            AgentConnectorsOutput(None, sample_batch),
        )

        flattened = c([d])
        self.assertEqual(len(flattened), 1)

        batch = flattened[0].data.for_action
        self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2,
                                                         8.8]).all())
        self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8)
        # Not flattened.
        self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
        self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
Exemple #15
0
    def test_clip_reward_connector(self):
        ctx = ConnectorContext()

        c = ClipRewardAgentConnector(ctx, limit=2.0)
        name, params = c.to_config()

        self.assertEqual(name, "ClipRewardAgentConnector")
        self.assertAlmostEqual(params["limit"], 2.0)

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, ClipRewardAgentConnector))

        d = AgentConnectorDataType(
            0,
            1,
            {
                SampleBatch.REWARDS: 5.8,
            },
        )
        clipped = restored(ac_data=d)

        self.assertEqual(len(clipped), 1)
        self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0)
Exemple #16
0
    def test_unbatch_action_connector(self):
        ctx = ConnectorContext()
        c = UnbatchActionsConnector(ctx)

        name, params = c.to_config()

        self.assertEqual(name, "UnbatchActionsConnector")

        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, UnbatchActionsConnector))

        ac_data = ActionConnectorDataType(
            0,
            1,
            (
                {
                    "a": np.array([1, 2, 3]),
                    "b": (np.array([4, 5, 6]), np.array([7, 8, 9])),
                },
                [],
                {},
            ),
        )

        unbatched = c(ac_data)
        actions, _, _ = unbatched.output

        self.assertEqual(len(actions), 3)
        self.assertEqual(actions[0]["a"], 1)
        self.assertTrue((actions[0]["b"] == np.array((4, 7))).all())

        self.assertEqual(actions[1]["a"], 2)
        self.assertTrue((actions[1]["b"] == np.array((5, 8))).all())

        self.assertEqual(actions[2]["a"], 3)
        self.assertTrue((actions[2]["b"] == np.array((6, 9))).all())
Exemple #17
0
    def test_env_to_per_agent_data_connector(self):
        vrs = {
            "infos":
            ViewRequirement(
                "infos",
                used_for_training=True,
                used_for_compute_actions=False,
            )
        }
        ctx = ConnectorContext(view_requirements=vrs)

        c = EnvToAgentDataConnector(ctx)

        name, params = c.to_config()
        restored = get_connector(ctx, name, params)
        self.assertTrue(isinstance(restored, EnvToAgentDataConnector))

        d = AgentConnectorDataType(
            0,
            None,
            [
                # obs
                {
                    1: [8, 8],
                    2: [9, 9]
                },
                # rewards
                {
                    1: 8.8,
                    2: 9.9,
                },
                # dones
                {
                    1: False,
                    2: False,
                },
                # infos
                {
                    1: {
                        "random": "info"
                    },
                    2: {},
                },
                # training_episode_info
                {
                    1: {
                        SampleBatch.DONES: True
                    },
                },
            ],
        )
        per_agent = c(d)

        self.assertEqual(len(per_agent), 2)

        batch1 = per_agent[0].data
        self.assertEqual(batch1[SampleBatch.NEXT_OBS], [8, 8])
        self.assertTrue(
            batch1[SampleBatch.DONES])  # from training_episode_info
        self.assertTrue(SampleBatch.INFOS in batch1)
        self.assertEqual(batch1[SampleBatch.INFOS]["random"], "info")

        batch2 = per_agent[1].data
        self.assertEqual(batch2[SampleBatch.NEXT_OBS], [9, 9])
        self.assertFalse(batch2[SampleBatch.DONES])