def test_pg_compilation(self): """Test whether PG can be built with all frameworks.""" config = pg.PGConfig() # Test with filter to see whether they work w/o preprocessing. config.rollouts( num_rollout_workers=1, rollout_fragment_length=500, observation_filter="MeanStdFilter", ) num_iterations = 1 image_space = Box(-1.0, 1.0, shape=(84, 84, 3)) simple_space = Box(-1.0, 1.0, shape=(3, )) tune.register_env( "random_dict_env", lambda _: RandomEnv({ "observation_space": Dict({ "a": simple_space, "b": Discrete(2), "c": image_space, }), "action_space": Box(-1.0, 1.0, shape=(1, )), }), ) tune.register_env( "random_tuple_env", lambda _: RandomEnv({ "observation_space": Tuple([simple_space, Discrete(2), image_space]), "action_space": Box(-1.0, 1.0, shape=(1, )), }), ) for _ in framework_iterator(config, with_eager_tracing=True): # Test for different env types (discrete w/ and w/o image, + cont). for env in [ "random_dict_env", "random_tuple_env", "MsPacmanNoFrameskip-v4", "CartPole-v0", "FrozenLake-v1", ]: print(f"env={env}") trainer = config.build(env=env) for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer, include_prev_action_reward=True)
def test_space_inference_from_remote_workers(self): # Expect to not do space inference if the learner has an env. env = gym.make("CartPole-v0") config = (pg.PGConfig().rollouts( num_rollout_workers=1, validate_workers_after_construction=False).environment( env="CartPole-v0")) # No env on driver -> expect longer build time due to space # lookup from remote worker. t0 = time.time() trainer = config.build() w_lookup = time.time() - t0 print(f"No env on learner: {w_lookup}sec") trainer.stop() # Env on driver -> expect shorted build time due to no space # lookup required from remote worker. config.create_env_on_local_worker = True t0 = time.time() trainer = config.build() wo_lookup = time.time() - t0 print(f"Env on learner: {wo_lookup}sec") self.assertLess(wo_lookup, w_lookup) trainer.stop() # Spaces given -> expect shorter build time due to no space # lookup required from remote worker. config.create_env_on_driver = False config.environment( observation_space=env.observation_space, action_space=env.action_space, ) t0 = time.time() trainer = config.build() wo_lookup = time.time() - t0 print(f"Spaces given manually in config: {wo_lookup}sec") self.assertLess(wo_lookup, w_lookup) trainer.stop()
def test_eval_workers_on_infinite_episodes(self): """Tests whether eval workers warn appropriately after some episode timeout.""" # Create infinitely running episodes, but with horizon setting (RLlib will # auto-terminate the episode). However, in the eval workers, don't set a # horizon -> Expect warning and no proper evaluation results. config = (pg.PGConfig().rollouts( num_rollout_workers=2, horizon=100).reporting( metrics_episode_collection_timeout_s=5.0).environment( env=RandomEnv, env_config={ "p_done": 0.0 }).evaluation( evaluation_num_workers=2, evaluation_interval=1, evaluation_sample_timeout_s=5.0, evaluation_config={ "horizon": None, }, )) algo = config.build() results = algo.train() self.assertTrue(np.isnan(results["evaluation"]["episode_reward_mean"]))
def test_worker_validation_time(self): """Tests the time taken by `validate_workers_after_construction=True`.""" config = pg.PGConfig().environment(env="CartPole-v0") config.validate_workers_after_construction = True # Test, whether validating one worker takes just as long as validating # >> 1 workers. config.num_workers = 1 t0 = time.time() trainer = config.build() total_time_1 = time.time() - t0 print(f"Validating w/ 1 worker: {total_time_1}sec") trainer.stop() config.num_workers = 5 t0 = time.time() trainer = config.build() total_time_5 = time.time() - t0 print(f"Validating w/ 5 workers: {total_time_5}sec") trainer.stop() check(total_time_5 / total_time_1, 1.0, atol=1.0)
def test_pg_loss_functions(self): """Tests the PG loss function math.""" config = (pg.PGConfig().rollouts(num_rollout_workers=0).training( gamma=0.99, model={ "fcnet_hiddens": [10], "fcnet_activation": "linear", }, )) # Fake CartPole episode of n time steps. train_batch = SampleBatch({ SampleBatch.OBS: np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]), SampleBatch.ACTIONS: np.array([0, 1, 1]), SampleBatch.REWARDS: np.array([1.0, 1.0, 1.0]), SampleBatch.DONES: np.array([False, False, True]), SampleBatch.EPS_ID: np.array([1234, 1234, 1234]), SampleBatch.AGENT_INDEX: np.array([0, 0, 0]), }) for fw, sess in framework_iterator(config, session=True): dist_cls = Categorical if fw != "torch" else TorchCategorical trainer = config.build(env="CartPole-v0") policy = trainer.get_policy() vars = policy.model.trainable_variables() if sess: vars = policy.get_session().run(vars) # Post-process (calculate simple (non-GAE) advantages) and attach # to train_batch dict. # A = [0.99^2 * 1.0 + 0.99 * 1.0 + 1.0, 0.99 * 1.0 + 1.0, 1.0] = # [2.9701, 1.99, 1.0] train_batch_ = pg.post_process_advantages(policy, train_batch.copy()) if fw == "torch": train_batch_ = policy._lazy_tensor_dict(train_batch_) # Check Advantage values. check(train_batch_[Postprocessing.ADVANTAGES], [2.9701, 1.99, 1.0]) # Actual loss results. if sess: results = policy.get_session().run( policy._loss, feed_dict=policy._get_loss_inputs_dict(train_batch_, shuffle=False), ) else: results = policy.loss(policy.model, dist_class=dist_cls, train_batch=train_batch_) # Calculate expected results. if fw != "torch": expected_logits = fc( fc(train_batch_[SampleBatch.OBS], vars[0], vars[1], framework=fw), vars[2], vars[3], framework=fw, ) else: expected_logits = fc( fc(train_batch_[SampleBatch.OBS], vars[2], vars[3], framework=fw), vars[0], vars[1], framework=fw, ) expected_logp = dist_cls(expected_logits, policy.model).logp( train_batch_[SampleBatch.ACTIONS]) adv = train_batch_[Postprocessing.ADVANTAGES] if sess: expected_logp = sess.run(expected_logp) elif fw == "torch": expected_logp = expected_logp.detach().cpu().numpy() adv = adv.detach().cpu().numpy() else: expected_logp = expected_logp.numpy() expected_loss = -np.mean(expected_logp * adv) check(results, expected_loss, decimals=4)
def _import_pg(): import ray.rllib.algorithms.pg as pg return pg.PG, pg.PGConfig().to_dict()