def test_r2d2_compilation(self): """Test whether R2D2 can be built on all frameworks.""" config = ( r2d2.R2D2Config().rollouts(num_rollout_workers=0).training( model={ # Wrap with an LSTM and use a very simple base-model. "use_lstm": True, "max_seq_len": 20, "fcnet_hiddens": [32], "lstm_cell_size": 64, }, dueling=False, lr=5e-4, zero_init_states=True, replay_buffer_config={ "replay_burn_in": 20 }, ).exploration(exploration_config={"epsilon_timesteps": 100000})) num_iterations = 1 # Test building an R2D2 agent in all frameworks. for _ in framework_iterator(config, with_eager_tracing=True): algo = config.build(env="CartPole-v0") for i in range(num_iterations): results = algo.train() check_train_results(results) check_batch_sizes(results) print(results) check_compute_single_action(algo, include_state=True)
def _import_r2d2(): import ray.rllib.algorithms.r2d2 as r2d2 return r2d2.R2D2, r2d2.R2D2Config().to_dict()