コード例 #1
0
    def test_maml_compilation(self):
        """Test whether a MAMLTrainer can be built with all frameworks."""
        config = maml.MAMLConfig().rollouts(num_rollout_workers=1, horizon=200)

        num_iterations = 1

        # Test for tf framework (torch not implemented yet).
        for fw in framework_iterator(config, frameworks=("tf", "torch")):
            for env in [
                    "pendulum_mass.PendulumMassEnv",
                    "cartpole_mass.CartPoleMassEnv",
            ]:
                if fw == "tf" and env.startswith("cartpole"):
                    continue
                print("env={}".format(env))
                env_ = "ray.rllib.examples.env.{}".format(env)
                trainer = config.build(env=env_)
                for i in range(num_iterations):
                    results = trainer.train()
                    check_train_results(results)
                    print(results)
                check_compute_single_action(trainer,
                                            include_prev_action_reward=True)
                trainer.stop()
コード例 #2
0
ファイル: registry.py プロジェクト: parasj/ray
def _import_maml():
    import ray.rllib.algorithms.maml as maml

    return maml.MAML, maml.MAMLConfig().to_dict()