def test_maml_compilation(self): """Test whether a MAMLTrainer can be built with all frameworks.""" config = maml.MAMLConfig().rollouts(num_rollout_workers=1, horizon=200) num_iterations = 1 # Test for tf framework (torch not implemented yet). for fw in framework_iterator(config, frameworks=("tf", "torch")): for env in [ "pendulum_mass.PendulumMassEnv", "cartpole_mass.CartPoleMassEnv", ]: if fw == "tf" and env.startswith("cartpole"): continue print("env={}".format(env)) env_ = "ray.rllib.examples.env.{}".format(env) trainer = config.build(env=env_) for i in range(num_iterations): results = trainer.train() check_train_results(results) print(results) check_compute_single_action(trainer, include_prev_action_reward=True) trainer.stop()
def _import_maml(): import ray.rllib.algorithms.maml as maml return maml.MAML, maml.MAMLConfig().to_dict()