def test_mbmpo_compilation(self): """Test whether an MBMPOTrainer can be built with all frameworks.""" config = (mbmpo.MBMPOConfig().rollouts( num_rollout_workers=2, horizon=200).training(dynamics_model={ "ensemble_size": 2 }).environment( env="ray.rllib.examples.env.mbmpo_env.CartPoleWrapper")) num_iterations = 1 # Test for torch framework (tf not implemented yet). for _ in framework_iterator(config, frameworks="torch"): trainer = config.build() for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer, include_prev_action_reward=False) trainer.stop()
def _import_mbmpo(): import ray.rllib.algorithms.mbmpo as mbmpo return mbmpo.MBMPO, mbmpo.MBMPOConfig().to_dict()