예제 #1
0
    def test_timesteps(self):
        """Test whether PG can be built with both frameworks."""
        config = pg.DEFAULT_CONFIG.copy()
        config["num_workers"] = 0  # Run locally.
        config["model"]["fcnet_hiddens"] = [1]
        config["model"]["fcnet_activation"] = None

        obs = np.array(1)
        obs_batch = np.array([1])

        for _ in framework_iterator(config):
            trainer = pg.PG(config=config, env=RandomEnv)
            policy = trainer.get_policy()

            for i in range(1, 21):
                trainer.compute_single_action(obs)
                check(policy.global_timestep, i)
            for i in range(1, 21):
                policy.compute_actions(obs_batch)
                check(policy.global_timestep, i + 20)

            # Artificially set ts to 100Bio, then keep computing actions and
            # train.
            crazy_timesteps = int(1e11)
            policy.on_global_var_update({"timestep": crazy_timesteps})
            # Run for 10 more ts.
            for i in range(1, 11):
                policy.compute_actions(obs_batch)
                check(policy.global_timestep, i + crazy_timesteps)
            trainer.train()
예제 #2
0
    def test_bad_envs(self):
        """Tests different "bad env" errors."""
        config = pg.DEFAULT_CONFIG.copy()
        config["num_workers"] = 0

        # Non existing/non-registered gym env string.
        env = "Alien-Attack-v42"
        for _ in framework_iterator(config):
            self.assertRaisesRegex(
                EnvError,
                f"The env string you provided \\('{env}'\\) is",
                lambda: pg.PG(config=config, env=env),
            )

        # Malformed gym env string (must have v\d at end).
        env = "Alien-Attack-part-42"
        for _ in framework_iterator(config):
            self.assertRaisesRegex(
                EnvError,
                f"The env string you provided \\('{env}'\\) is",
                lambda: pg.PG(config=config, env=env),
            )

        # Non-existing class in a full-class-path.
        env = "ray.rllib.examples.env.random_env.RandomEnvThatDoesntExist"
        for _ in framework_iterator(config):
            self.assertRaisesRegex(
                EnvError,
                f"The env string you provided \\('{env}'\\) is",
                lambda: pg.PG(config=config, env=env),
            )

        # Non-existing module inside a full-class-path.
        env = "ray.rllib.examples.env.module_that_doesnt_exist.SomeEnv"
        for _ in framework_iterator(config):
            self.assertRaisesRegex(
                EnvError,
                f"The env string you provided \\('{env}'\\) is",
                lambda: pg.PG(config=config, env=env),
            )
예제 #3
0
    def test_validate_config_idempotent(self):
        """
        Asserts that validate_config run multiple
        times on COMMON_CONFIG will be idempotent
        """
        # Given:
        standard_config = copy.deepcopy(COMMON_CONFIG)
        algo = pg.PG(env="CartPole-v0", config=standard_config)

        # When (we validate config 2 times).
        # Try deprecated `Algorithm._validate_config()` method (static).
        algo._validate_config(standard_config, algo)
        config_v1 = copy.deepcopy(standard_config)
        # Try new method: `Algorithm.validate_config()` (non-static).
        algo.validate_config(standard_config)
        config_v2 = copy.deepcopy(standard_config)

        # Make sure nothing changed.
        self.assertEqual(config_v1, config_v2)

        algo.stop()
예제 #4
0
    def test_add_delete_policy(self):
        config = pg.DEFAULT_CONFIG.copy()
        config.update({
            "env": MultiAgentCartPole,
            "env_config": {
                "config": {
                    "num_agents": 4,
                },
            },
            "num_workers": 2,  # Test on remote workers as well.
            "num_cpus_per_worker": 0.1,
            "model": {
                "fcnet_hiddens": [5],
                "fcnet_activation": "linear",
            },
            "train_batch_size": 100,
            "rollout_fragment_length": 50,
            "multiagent": {
                # Start with a single policy.
                "policies": {"p0"},
                "policy_mapping_fn": lambda aid, eps, worker, **kwargs: "p0",
                # And only two policies that can be stored in memory at a
                # time.
                "policy_map_capacity": 2,
            },
            "evaluation_num_workers": 1,
            "evaluation_config": {
                "num_cpus_per_worker": 0.1,
            },
        })

        for _ in framework_iterator(config):
            algo = pg.PG(config=config)
            pol0 = algo.get_policy("p0")
            r = algo.train()
            self.assertTrue("p0" in r["info"][LEARNER_INFO])
            for i in range(1, 3):

                def new_mapping_fn(agent_id, episode, worker, **kwargs):
                    return f"p{choice([i, i - 1])}"

                # Add a new policy.
                pid = f"p{i}"
                new_pol = algo.add_policy(
                    pid,
                    algo.get_default_policy_class(config),
                    # Test changing the mapping fn.
                    policy_mapping_fn=new_mapping_fn,
                    # Change the list of policies to train.
                    policies_to_train=[f"p{i}", f"p{i-1}"],
                )
                pol_map = algo.workers.local_worker().policy_map
                self.assertTrue(new_pol is not pol0)
                for j in range(i + 1):
                    self.assertTrue(f"p{j}" in pol_map)
                self.assertTrue(len(pol_map) == i + 1)
                algo.train()
                checkpoint = algo.save()

                # Test restoring from the checkpoint (which has more policies
                # than what's defined in the config dict).
                test = pg.PG(config=config)
                test.restore(checkpoint)

                # Make sure evaluation worker also gets the restored policy.
                def _has_policy(w):
                    return w.get_policy("p0") is not None

                self.assertTrue(
                    all(test.evaluation_workers.foreach_worker(_has_policy)))

                # Make sure algorithm can continue training the restored policy.
                pol0 = test.get_policy("p0")
                test.train()
                # Test creating an action with the added (and restored) policy.
                a = test.compute_single_action(np.zeros_like(
                    pol0.observation_space.sample()),
                                               policy_id=pid)
                self.assertTrue(pol0.action_space.contains(a))
                test.stop()

            # Delete all added policies again from Algorithm.
            for i in range(2, 0, -1):
                algo.remove_policy(
                    f"p{i}",
                    # Note that the complete signature of a policy_mapping_fn
                    # is: `agent_id, episode, worker, **kwargs`.
                    policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i - 1}",
                    policies_to_train=[f"p{i - 1}"],
                )

            algo.stop()