Example #1
0
    def test_mbmpo_compilation(self):
        """Test whether an MBMPOTrainer can be built with all frameworks."""
        config = (mbmpo.MBMPOConfig().rollouts(
            num_rollout_workers=2, horizon=200).training(dynamics_model={
                "ensemble_size": 2
            }).environment(
                env="ray.rllib.examples.env.mbmpo_env.CartPoleWrapper"))
        num_iterations = 1

        # Test for torch framework (tf not implemented yet).
        for _ in framework_iterator(config, frameworks="torch"):
            trainer = config.build()

            for i in range(num_iterations):
                results = trainer.train()
                check_train_results(results)
                print(results)

            check_compute_single_action(trainer,
                                        include_prev_action_reward=False)
            trainer.stop()
Example #2
0
def _import_mbmpo():
    import ray.rllib.algorithms.mbmpo as mbmpo

    return mbmpo.MBMPO, mbmpo.MBMPOConfig().to_dict()