def test_timesteps(self): """Test whether PG can be built with both frameworks.""" config = pg.DEFAULT_CONFIG.copy() config["num_workers"] = 0 # Run locally. config["model"]["fcnet_hiddens"] = [1] config["model"]["fcnet_activation"] = None obs = np.array(1) obs_batch = np.array([1]) for _ in framework_iterator(config): trainer = pg.PG(config=config, env=RandomEnv) policy = trainer.get_policy() for i in range(1, 21): trainer.compute_single_action(obs) check(policy.global_timestep, i) for i in range(1, 21): policy.compute_actions(obs_batch) check(policy.global_timestep, i + 20) # Artificially set ts to 100Bio, then keep computing actions and # train. crazy_timesteps = int(1e11) policy.on_global_var_update({"timestep": crazy_timesteps}) # Run for 10 more ts. for i in range(1, 11): policy.compute_actions(obs_batch) check(policy.global_timestep, i + crazy_timesteps) trainer.train()
def test_bad_envs(self): """Tests different "bad env" errors.""" config = pg.DEFAULT_CONFIG.copy() config["num_workers"] = 0 # Non existing/non-registered gym env string. env = "Alien-Attack-v42" for _ in framework_iterator(config): self.assertRaisesRegex( EnvError, f"The env string you provided \\('{env}'\\) is", lambda: pg.PG(config=config, env=env), ) # Malformed gym env string (must have v\d at end). env = "Alien-Attack-part-42" for _ in framework_iterator(config): self.assertRaisesRegex( EnvError, f"The env string you provided \\('{env}'\\) is", lambda: pg.PG(config=config, env=env), ) # Non-existing class in a full-class-path. env = "ray.rllib.examples.env.random_env.RandomEnvThatDoesntExist" for _ in framework_iterator(config): self.assertRaisesRegex( EnvError, f"The env string you provided \\('{env}'\\) is", lambda: pg.PG(config=config, env=env), ) # Non-existing module inside a full-class-path. env = "ray.rllib.examples.env.module_that_doesnt_exist.SomeEnv" for _ in framework_iterator(config): self.assertRaisesRegex( EnvError, f"The env string you provided \\('{env}'\\) is", lambda: pg.PG(config=config, env=env), )
def test_validate_config_idempotent(self): """ Asserts that validate_config run multiple times on COMMON_CONFIG will be idempotent """ # Given: standard_config = copy.deepcopy(COMMON_CONFIG) algo = pg.PG(env="CartPole-v0", config=standard_config) # When (we validate config 2 times). # Try deprecated `Algorithm._validate_config()` method (static). algo._validate_config(standard_config, algo) config_v1 = copy.deepcopy(standard_config) # Try new method: `Algorithm.validate_config()` (non-static). algo.validate_config(standard_config) config_v2 = copy.deepcopy(standard_config) # Make sure nothing changed. self.assertEqual(config_v1, config_v2) algo.stop()
def test_add_delete_policy(self): config = pg.DEFAULT_CONFIG.copy() config.update({ "env": MultiAgentCartPole, "env_config": { "config": { "num_agents": 4, }, }, "num_workers": 2, # Test on remote workers as well. "num_cpus_per_worker": 0.1, "model": { "fcnet_hiddens": [5], "fcnet_activation": "linear", }, "train_batch_size": 100, "rollout_fragment_length": 50, "multiagent": { # Start with a single policy. "policies": {"p0"}, "policy_mapping_fn": lambda aid, eps, worker, **kwargs: "p0", # And only two policies that can be stored in memory at a # time. "policy_map_capacity": 2, }, "evaluation_num_workers": 1, "evaluation_config": { "num_cpus_per_worker": 0.1, }, }) for _ in framework_iterator(config): algo = pg.PG(config=config) pol0 = algo.get_policy("p0") r = algo.train() self.assertTrue("p0" in r["info"][LEARNER_INFO]) for i in range(1, 3): def new_mapping_fn(agent_id, episode, worker, **kwargs): return f"p{choice([i, i - 1])}" # Add a new policy. pid = f"p{i}" new_pol = algo.add_policy( pid, algo.get_default_policy_class(config), # Test changing the mapping fn. policy_mapping_fn=new_mapping_fn, # Change the list of policies to train. policies_to_train=[f"p{i}", f"p{i-1}"], ) pol_map = algo.workers.local_worker().policy_map self.assertTrue(new_pol is not pol0) for j in range(i + 1): self.assertTrue(f"p{j}" in pol_map) self.assertTrue(len(pol_map) == i + 1) algo.train() checkpoint = algo.save() # Test restoring from the checkpoint (which has more policies # than what's defined in the config dict). test = pg.PG(config=config) test.restore(checkpoint) # Make sure evaluation worker also gets the restored policy. def _has_policy(w): return w.get_policy("p0") is not None self.assertTrue( all(test.evaluation_workers.foreach_worker(_has_policy))) # Make sure algorithm can continue training the restored policy. pol0 = test.get_policy("p0") test.train() # Test creating an action with the added (and restored) policy. a = test.compute_single_action(np.zeros_like( pol0.observation_space.sample()), policy_id=pid) self.assertTrue(pol0.action_space.contains(a)) test.stop() # Delete all added policies again from Algorithm. for i in range(2, 0, -1): algo.remove_policy( f"p{i}", # Note that the complete signature of a policy_mapping_fn # is: `agent_id, episode, worker, **kwargs`. policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i - 1}", policies_to_train=[f"p{i - 1}"], ) algo.stop()