def test_appo_compilation(self): """Test whether APPO can be built with both frameworks.""" config = appo.APPOConfig().rollouts(num_rollout_workers=1) num_iterations = 2 for _ in framework_iterator(config, with_eager_tracing=True): print("w/o v-trace") config.vtrace = False algo = config.build(env="CartPole-v0") for i in range(num_iterations): results = algo.train() check_train_results(results) print(results) check_compute_single_action(algo) algo.stop() print("w/ v-trace") config.vtrace = True algo = config.build(env="CartPole-v0") for i in range(num_iterations): results = algo.train() check_train_results(results) print(results) check_compute_single_action(algo) algo.stop()
def test_ddpg_compilation(self): """Test whether a DDPGTrainer can be built with both frameworks.""" config = ddpg.DDPGConfig() config.seed = 42 config.num_workers = 0 config.num_envs_per_worker = 2 config.replay_buffer_config["learning_starts"] = 0 explore = config.exploration_config.update({"random_timesteps": 100}) config.exploration(exploration_config=explore) num_iterations = 1 # Test against all frameworks. for _ in framework_iterator(config, with_eager_tracing=True): """""" trainer = config.build(env="Pendulum-v1") for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer) # Ensure apply_gradient_fn is being called and updating global_step pol = trainer.get_policy() if config.framework_str == "tf": a = pol.get_session().run(pol.global_step) else: a = pol.global_step check(a, 500) trainer.stop()
def test_dqn_compilation(self): """Test whether a DQNTrainer can be built on all frameworks.""" num_iterations = 1 config = dqn.dqn.DQNConfig().rollouts(num_rollout_workers=2) for _ in framework_iterator(config, with_eager_tracing=True): # Double-dueling DQN. print("Double-dueling") plain_config = deepcopy(config) trainer = dqn.DQNTrainer(config=plain_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop() # Rainbow. print("Rainbow") rainbow_config = deepcopy(config).training(num_atoms=10, noisy=True, double_q=True, dueling=True, n_step=5) trainer = dqn.DQNTrainer(config=rainbow_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop()
def test_ars_compilation(self): """Test whether an ARSTrainer can be built on all frameworks.""" config = ars.ARSConfig() # Keep it simple. config.training( model={ "fcnet_hiddens": [10], "fcnet_activation": None, }, noise_size=2500000, ) # Test eval workers ("normal" WorkerSet, unlike ARS' list of # RolloutWorkers used for collecting train batches). config.evaluation(evaluation_interval=1, evaluation_num_workers=1) num_iterations = 2 for _ in framework_iterator(config): trainer = config.build(env="CartPole-v0") for i in range(num_iterations): results = trainer.train() print(results) check_compute_single_action(trainer) trainer.stop()
def test_preprocessing_disabled(self): config = ppo.DEFAULT_CONFIG.copy() config["env"] = "ray.rllib.examples.env.random_env.RandomEnv" config["env_config"] = { "config": { "observation_space": Dict({ "a": Discrete(5), "b": Dict({ "ba": Discrete(4), "bb": Box(-1.0, 1.0, (2, 3), dtype=np.float32) }), "c": Tuple((MultiDiscrete([2, 3]), Discrete(1))), "d": Box(-1.0, 1.0, (1, ), dtype=np.int32), }), }, } # Set this to True to enforce no preprocessors being used. # Complex observations now arrive directly in the model as # structures of batches, e.g. {"a": tensor, "b": [tensor, tensor]} # for obs-space=Dict(a=..., b=Tuple(..., ...)). config["_disable_preprocessor_api"] = True num_iterations = 1 # Only supported for tf so far. for _ in framework_iterator(config): trainer = ppo.PPOTrainer(config=config) for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop()
def test_r2d2_compilation(self): """Test whether R2D2 can be built on all frameworks.""" config = ( r2d2.R2D2Config().rollouts(num_rollout_workers=0).training( model={ # Wrap with an LSTM and use a very simple base-model. "use_lstm": True, "max_seq_len": 20, "fcnet_hiddens": [32], "lstm_cell_size": 64, }, dueling=False, lr=5e-4, zero_init_states=True, replay_buffer_config={ "replay_burn_in": 20 }, ).exploration(exploration_config={"epsilon_timesteps": 100000})) num_iterations = 1 # Test building an R2D2 agent in all frameworks. for _ in framework_iterator(config, with_eager_tracing=True): algo = config.build(env="CartPole-v0") for i in range(num_iterations): results = algo.train() check_train_results(results) check_batch_sizes(results) print(results) check_compute_single_action(algo, include_state=True)
def test_impala_compilation(self): """Test whether an ImpalaTrainer can be built with both frameworks.""" config = impala.DEFAULT_CONFIG.copy() config["num_gpus"] = 0 config["model"]["lstm_use_prev_action"] = True config["model"]["lstm_use_prev_reward"] = True num_iterations = 1 env = "CartPole-v0" for _ in framework_iterator(config, with_eager_tracing=True): local_cfg = config.copy() for lstm in [False, True]: local_cfg["num_aggregation_workers"] = 0 if not lstm else 1 local_cfg["model"]["use_lstm"] = lstm print("lstm={} aggregation-workers={}".format( lstm, local_cfg["num_aggregation_workers"])) # Test with and w/o aggregation workers (this has nothing # to do with LSTMs, though). trainer = impala.ImpalaTrainer(config=local_cfg, env=env) for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action( trainer, include_state=lstm, include_prev_action_reward=lstm, ) trainer.stop()
def test_dqn_compilation(self): """Test whether a DQNTrainer can be built on all frameworks.""" config = dqn.DEFAULT_CONFIG.copy() config["num_workers"] = 2 num_iterations = 1 for _ in framework_iterator(config): # Double-dueling DQN. print("Double-dueling") plain_config = config.copy() trainer = dqn.DQNTrainer(config=plain_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() print(results) check_compute_single_action(trainer) trainer.stop() # Rainbow. print("Rainbow") rainbow_config = config.copy() rainbow_config["num_atoms"] = 10 rainbow_config["noisy"] = True rainbow_config["double_q"] = True rainbow_config["dueling"] = True rainbow_config["n_step"] = 5 trainer = dqn.DQNTrainer(config=rainbow_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() print(results) check_compute_single_action(trainer) trainer.stop()
def test_dqn_compilation(self): """Test whether a DQNTrainer can be built on all frameworks.""" config = dqn.DEFAULT_CONFIG.copy() config["num_workers"] = 2 num_iterations = 1 for fw in framework_iterator(config): # Double-dueling DQN. plain_config = config.copy() trainer = dqn.DQNTrainer(config=plain_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() print(results) check_compute_single_action(trainer) # Rainbow. # TODO(sven): Add torch once DQN-torch supports distributional-Q. if fw == "torch": continue rainbow_config = config.copy() rainbow_config["num_atoms"] = 10 rainbow_config["noisy"] = True rainbow_config["double_q"] = True rainbow_config["dueling"] = True rainbow_config["n_step"] = 5 trainer = dqn.DQNTrainer(config=rainbow_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() print(results) check_compute_single_action(trainer)
def test_marwil_compilation_and_learning_from_offline_file(self): """Test whether a MARWILTrainer can be built with all frameworks. Learns from a historic-data file. To generate this data, first run: $ ./train.py --run=PPO --env=CartPole-v0 \ --stop='{"timesteps_total": 50000}' \ --config='{"output": "/tmp/out", "batch_mode": "complete_episodes"}' """ rllib_dir = Path(__file__).parent.parent.parent.parent print("rllib dir={}".format(rllib_dir)) data_file = os.path.join(rllib_dir, "tests/data/cartpole/large.json") print("data_file={} exists={}".format(data_file, os.path.isfile(data_file))) config = ( marwil.MARWILConfig() .rollouts(num_rollout_workers=2) .environment(env="CartPole-v0") .evaluation( evaluation_interval=3, evaluation_num_workers=1, evaluation_duration=5, evaluation_parallel_to_training=True, evaluation_config={"input": "sampler"}, ) .offline_data(input_=[data_file]) ) num_iterations = 350 min_reward = 70.0 # Test for all frameworks. for _ in framework_iterator(config, frameworks=("tf", "torch")): trainer = marwil.MARWILTrainer(config=config) learnt = False for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) eval_results = results.get("evaluation") if eval_results: print( "iter={} R={} ".format(i, eval_results["episode_reward_mean"]) ) # Learn until some reward is reached on an actual live env. if eval_results["episode_reward_mean"] > min_reward: print("learnt!") learnt = True break if not learnt: raise ValueError( "MARWILTrainer did not reach {} reward from expert " "offline data!".format(min_reward) ) check_compute_single_action(trainer, include_prev_action_reward=True) trainer.stop()
def test_impala_compilation(self): """Test whether an ImpalaTrainer can be built with both frameworks.""" config = (impala.ImpalaConfig().resources(num_gpus=0).training( model={ "lstm_use_prev_action": True, "lstm_use_prev_reward": True, })) env = "CartPole-v0" num_iterations = 2 for _ in framework_iterator(config, with_eager_tracing=True): for lstm in [False, True]: config.num_aggregation_workers = 0 if not lstm else 1 config.model["use_lstm"] = lstm print("lstm={} aggregation-workers={}".format( lstm, config.num_aggregation_workers)) # Test with and w/o aggregation workers (this has nothing # to do with LSTMs, though). trainer = config.build(env=env) for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action( trainer, include_state=lstm, include_prev_action_reward=lstm, ) trainer.stop()
def test_maml_compilation(self): """Test whether a MAMLTrainer can be built with all frameworks.""" config = maml.DEFAULT_CONFIG.copy() config["num_workers"] = 1 config["horizon"] = 200 num_iterations = 1 # Test for tf framework (torch not implemented yet). for fw in framework_iterator(config, frameworks=("tf", "torch")): for env in [ "pendulum_mass.PendulumMassEnv", "cartpole_mass.CartPoleMassEnv", ]: if fw == "tf" and env.startswith("cartpole"): continue print("env={}".format(env)) env_ = "ray.rllib.examples.env.{}".format(env) trainer = maml.MAMLTrainer(config=config, env=env_) for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer, include_prev_action_reward=True) trainer.stop()
def test_a2c_exec_impl(ray_start_regular): config = {"min_iter_time_s": 0} for _ in framework_iterator(config): trainer = a3c.A2CTrainer(env="CartPole-v0", config=config) assert isinstance(trainer.train(), dict) check_compute_single_action(trainer) trainer.stop()
def test_cql_compilation(self): """Test whether a CQLTrainer can be built with all frameworks.""" # Learns from a historic-data file. # To generate this data, first run: # $ ./train.py --run=SAC --env=Pendulum-v0 \ # --stop='{"timesteps_total": 50000}' \ # --config='{"output": "/tmp/out"}' rllib_dir = Path(__file__).parent.parent.parent.parent print("rllib dir={}".format(rllib_dir)) data_file = os.path.join(rllib_dir, "tests/data/pendulum/small.json") print("data_file={} exists={}".format(data_file, os.path.isfile(data_file))) config = cql.CQL_DEFAULT_CONFIG.copy() config["env"] = "Pendulum-v0" config["input"] = [data_file] config["num_workers"] = 0 # Run locally. config["twin_q"] = True config["clip_actions"] = False config["normalize_actions"] = True config["learning_starts"] = 0 config["rollout_fragment_length"] = 1 config["train_batch_size"] = 10 num_iterations = 2 # Test for tf framework (torch not implemented yet). for _ in framework_iterator(config, frameworks=("torch")): trainer = cql.CQLTrainer(config=config) for i in range(num_iterations): trainer.train() check_compute_single_action(trainer) trainer.stop()
def test_sac_compilation(self): """Tests whether an SACTrainer can be built with all frameworks.""" config = sac.DEFAULT_CONFIG.copy() config["num_workers"] = 0 # Run locally. config["twin_q"] = True config["soft_horizon"] = True config["clip_actions"] = False config["normalize_actions"] = True config["learning_starts"] = 0 config["prioritized_replay"] = True num_iterations = 1 for _ in framework_iterator(config): # Test for different env types (discrete w/ and w/o image, + cont). for env in [ "Pendulum-v0", "MsPacmanNoFrameskip-v4", "CartPole-v0" ]: print("Env={}".format(env)) config["use_state_preprocessor"] = \ env == "MsPacmanNoFrameskip-v4" trainer = sac.SACTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() print(results) check_compute_single_action(trainer) trainer.stop()
def test_r2d2_compilation(self): """Test whether a R2D2Trainer can be built on all frameworks.""" config = dqn.R2D2_DEFAULT_CONFIG.copy() config["num_workers"] = 0 # Run locally. # Wrap with an LSTM and use a very simple base-model. config["model"]["use_lstm"] = True config["model"]["max_seq_len"] = 20 config["model"]["fcnet_hiddens"] = [32] config["model"]["lstm_cell_size"] = 64 config["replay_buffer_config"]["replay_burn_in"] = 20 config["zero_init_states"] = True config["dueling"] = False config["lr"] = 5e-4 config["exploration_config"]["epsilon_timesteps"] = 100000 num_iterations = 1 # Test building an R2D2 agent in all frameworks. for _ in framework_iterator(config, with_eager_tracing=True): trainer = dqn.R2D2Trainer(config=config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() check_train_results(results) check_batch_sizes(results) print(results) check_compute_single_action(trainer, include_state=True)
def test_alpha_star_compilation(self): """Test whether AlphaStar can be built with all frameworks.""" config = (alpha_star.AlphaStarConfig().environment( env="connect_four").training( gamma=1.0, model={ "fcnet_hiddens": [256, 256, 256] }, vf_loss_coeff=0.01, entropy_coeff=0.004, league_builder_config={ "win_rate_threshold_for_new_snapshot": 0.8, "num_random_policies": 2, "num_learning_league_exploiters": 1, "num_learning_main_exploiters": 1, }, grad_clip=10.0, replay_buffer_capacity=10, replay_buffer_replay_ratio=0.0, use_kl_loss=True, ).rollouts(num_rollout_workers=4, num_envs_per_worker=5).resources(num_gpus=4, _fake_gpus=True)) num_iterations = 2 for _ in framework_iterator(config, with_eager_tracing=True): trainer = config.build() for i in range(num_iterations): results = trainer.train() print(results) check_train_results(results) check_compute_single_action(trainer) trainer.stop()
def test_ddpg_compilation(self): """Test whether a DDPGTrainer can be built with both frameworks.""" config = ddpg.DEFAULT_CONFIG.copy() config["num_workers"] = 1 config["num_envs_per_worker"] = 2 config["learning_starts"] = 0 config["exploration_config"]["random_timesteps"] = 100 num_iterations = 1 # Test against all frameworks. for _ in framework_iterator(config): trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0") for i in range(num_iterations): results = trainer.train() print(results) check_compute_single_action(trainer) # Ensure apply_gradient_fn is being called and updating global_step if config["framework"] == "tf": a = trainer.get_policy().global_step.eval( trainer.get_policy().get_session()) else: a = trainer.get_policy().global_step check(a, 500) trainer.stop()
def test_es_compilation(self): """Test whether an ESAlgorithm can be built on all frameworks.""" ray.init(num_cpus=4) config = es.ESConfig() # Keep it simple. config.training( model={ "fcnet_hiddens": [10], "fcnet_activation": None, }, noise_size=2500000, episodes_per_batch=10, train_batch_size=100, ) config.rollouts(num_rollout_workers=1) # Test eval workers ("normal" WorkerSet, unlike ES' list of # RolloutWorkers used for collecting train batches). config.evaluation(evaluation_interval=1, evaluation_num_workers=2) num_iterations = 1 for _ in framework_iterator(config): for env in ["CartPole-v0", "Pendulum-v1"]: algo = config.build(env=env) for i in range(num_iterations): results = algo.train() print(results) check_compute_single_action(algo) algo.stop() ray.shutdown()
def test_es_compilation(self): """Test whether an ESTrainer can be built on all frameworks.""" ray.init(num_cpus=4) config = es.DEFAULT_CONFIG.copy() # Keep it simple. config["model"]["fcnet_hiddens"] = [10] config["model"]["fcnet_activation"] = None config["noise_size"] = 2500000 config["num_workers"] = 1 config["episodes_per_batch"] = 10 config["train_batch_size"] = 100 # Test eval workers ("normal" Trainer eval WorkerSet). config["evaluation_interval"] = 1 config["evaluation_num_workers"] = 2 num_iterations = 1 for _ in framework_iterator(config): for env in ["CartPole-v0", "Pendulum-v0"]: plain_config = config.copy() trainer = es.ESTrainer(config=plain_config, env=env) for i in range(num_iterations): results = trainer.train() print(results) check_compute_single_action(trainer) trainer.stop() ray.shutdown()
def test_ppo_compilation(self): """Test whether a PPOTrainer can be built with all frameworks.""" config = copy.deepcopy(ppo.DEFAULT_CONFIG) config["num_workers"] = 1 config["num_sgd_iter"] = 2 # Settings in case we use an LSTM. config["model"]["lstm_cell_size"] = 10 config["model"]["max_seq_len"] = 20 config["train_batch_size"] = 128 num_iterations = 2 for _ in framework_iterator(config): for env in ["CartPole-v0", "MsPacmanNoFrameskip-v4"]: print("Env={}".format(env)) for lstm in [True, False]: print("LSTM={}".format(lstm)) config["model"]["use_lstm"] = lstm config["model"]["lstm_use_prev_action_reward"] = lstm trainer = ppo.PPOTrainer(config=config, env=env) for i in range(num_iterations): trainer.train() check_compute_single_action( trainer, include_prev_action_reward=True, include_state=lstm) trainer.stop()
def test_appo_compilation(self): """Test whether an APPOTrainer can be built with both frameworks.""" config = ppo.appo.DEFAULT_CONFIG.copy() config["num_workers"] = 1 num_iterations = 2 for _ in framework_iterator(config): print("w/o v-trace") _config = config.copy() _config["vtrace"] = False trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop() print("w/ v-trace") _config = config.copy() _config["vtrace"] = True trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0") for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop()
def test_impala_compilation(self): """Test whether an ImpalaTrainer can be built with both frameworks.""" config = impala.DEFAULT_CONFIG.copy() num_iterations = 1 for _ in framework_iterator(config): local_cfg = config.copy() for env in ["Pendulum-v0", "CartPole-v0"]: print("Env={}".format(env)) print("w/o LSTM") # Test w/o LSTM. local_cfg["model"]["use_lstm"] = False local_cfg["num_aggregation_workers"] = 0 trainer = impala.ImpalaTrainer(config=local_cfg, env=env) for i in range(num_iterations): print(trainer.train()) check_compute_single_action(trainer) trainer.stop() # Test w/ LSTM. print("w/ LSTM") local_cfg["model"]["use_lstm"] = True local_cfg["model"]["lstm_use_prev_action"] = True local_cfg["model"]["lstm_use_prev_reward"] = True local_cfg["num_aggregation_workers"] = 1 trainer = impala.ImpalaTrainer(config=local_cfg, env=env) for i in range(num_iterations): print(trainer.train()) check_compute_single_action(trainer, include_state=True, include_prev_action_reward=True) trainer.stop()
def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): """Test whether an APEX-DQNTrainer can be built on all frameworks.""" config = apex.APEX_DEFAULT_CONFIG.copy() config["num_workers"] = 3 config["num_gpus"] = 0 config["learning_starts"] = 1000 config["prioritized_replay"] = True config["timesteps_per_iteration"] = 100 config["min_iter_time_s"] = 1 config["optimizer"]["num_replay_buffer_shards"] = 1 for _ in framework_iterator(config, with_eager_tracing=True): plain_config = config.copy() trainer = apex.ApexTrainer(config=plain_config, env="CartPole-v0") # Test per-worker epsilon distribution. infos = trainer.workers.foreach_policy( lambda p, _: p.get_exploration_state()) expected = [0.4, 0.016190862, 0.00065536] check([i["cur_epsilon"] for i in infos], [0.0] + expected) check_compute_single_action(trainer) for i in range(2): results = trainer.train() check_train_results(results) print(results) # Test again per-worker epsilon distribution # (should not have changed). infos = trainer.workers.foreach_policy( lambda p, _: p.get_exploration_state()) check([i["cur_epsilon"] for i in infos], [0.0] + expected) trainer.stop()
def test_marwil_compilation_and_learning_from_offline_file(self): """Test whether a MARWILTrainer can be built with all frameworks. And learns from a historic-data file. """ rllib_dir = Path(__file__).parent.parent.parent.parent print("rllib dir={}".format(rllib_dir)) data_file = os.path.join(rllib_dir, "tests/data/cartpole/large.json") print("data_file={} exists={}".format(data_file, os.path.isfile(data_file))) config = marwil.DEFAULT_CONFIG.copy() config["num_workers"] = 0 # Run locally. config["evaluation_num_workers"] = 1 config["evaluation_interval"] = 1 config["evaluation_config"] = {"input": "sampler"} config["input"] = [data_file] num_iterations = 300 # Test for all frameworks. for _ in framework_iterator(config): trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0") for i in range(num_iterations): eval_results = trainer.train()["evaluation"] print("iter={} R={}".format( i, eval_results["episode_reward_mean"])) # Learn until some reward is reached on an actual live env. if eval_results["episode_reward_mean"] > 60.0: print("learnt!") break check_compute_single_action(trainer, include_prev_action_reward=True) trainer.stop()
def test_ppo_compilation_and_schedule_mixins(self): """Test whether a PPOTrainer can be built with all frameworks.""" # Build a PPOConfig object. config = ( ppo.PPOConfig().training( num_sgd_iter=2, # Setup lr schedule for testing. lr_schedule=[[0, 5e-5], [128, 0.0]], # Set entropy_coeff to a faulty value to proof that it'll get # overridden by the schedule below (which is expected). entropy_coeff=100.0, entropy_coeff_schedule=[[0, 0.1], [256, 0.0]], ).rollouts( num_rollout_workers=1, # Test with compression. compress_observations=True, ).training( train_batch_size=128, model=dict( # Settings in case we use an LSTM. lstm_cell_size=10, max_seq_len=20, ), ).callbacks(MyCallbacks)) # For checking lr-schedule correctness. num_iterations = 2 for fw in framework_iterator(config, with_eager_tracing=True): for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]: print("Env={}".format(env)) for lstm in [True, False]: print("LSTM={}".format(lstm)) config.training(model=dict( use_lstm=lstm, lstm_use_prev_action=lstm, lstm_use_prev_reward=lstm, )) trainer = config.build(env=env) policy = trainer.get_policy() entropy_coeff = trainer.get_policy().entropy_coeff lr = policy.cur_lr if fw == "tf": entropy_coeff, lr = policy.get_session().run( [entropy_coeff, lr]) check(entropy_coeff, 0.1) check(lr, config.lr) for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action( trainer, include_prev_action_reward=True, include_state=lstm) trainer.stop()
def test_a2c_exec_impl(self): config = a3c.A2CConfig().reporting(min_time_s_per_reporting=0) for _ in framework_iterator(config): trainer = a3c.A2CTrainer(env="CartPole-v0", config=config) results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop()
def test_a2c_exec_impl_microbatch(ray_start_regular): config = { "min_iter_time_s": 0, "microbatch_size": 10, } for _ in framework_iterator(config, ("tf", "torch")): trainer = a3c.A2CTrainer(env="CartPole-v0", config=config) assert isinstance(trainer.train(), dict) check_compute_single_action(trainer)
def test_a2c_exec_impl(ray_start_regular): config = {"min_time_s_per_reporting": 0} for _ in framework_iterator(config): trainer = a3c.A2CTrainer(env="CartPole-v0", config=config) results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer) trainer.stop()
def test_crr_compilation(self): """Test whether a CRR algorithm can be built with all supported frameworks.""" # TODO: terrible asset management style rllib_dir = Path(__file__).parent.parent.parent.parent print("rllib dir={}".format(rllib_dir)) data_file = os.path.join(rllib_dir, "tests/data/pendulum/large.json") print("data_file={} exists={}".format(data_file, os.path.isfile(data_file))) config = (CRRConfig().environment( env="Pendulum-v1", clip_actions=True).framework("torch").offline_data( input_=[data_file], actions_in_input_normalized=True).training( twin_q=True, train_batch_size=256, replay_buffer_config={ "type": MultiAgentReplayBuffer, "learning_starts": 0, "capacity": 100000, }, weight_type="bin", advantage_type="mean", n_action_sample=4, target_update_grad_intervals=10000, tau=1.0, ).evaluation( evaluation_interval=2, evaluation_num_workers=2, evaluation_duration=10, evaluation_duration_unit="episodes", evaluation_parallel_to_training=True, evaluation_config={ "input": "sampler", "explore": False }, ).rollouts(num_rollout_workers=0)) num_iterations = 4 for _ in ["torch"]: algorithm = config.build() # check if 4 iterations raises any errors for i in range(num_iterations): results = algorithm.train() check_train_results(results) print(results) if (i + 1) % 2 == 0: # evaluation happens every 2 iterations eval_results = results["evaluation"] print(f"iter={algorithm.iteration} " f"R={eval_results['episode_reward_mean']}") check_compute_single_action(algorithm) algorithm.stop()