def test_maml_trpo_dummy_named_env(): """Test with dummy environment that has env_name.""" env = GarageEnv( normalize(DummyMultiTaskBoxEnv(), expected_action_scale=10.)) policy = GaussianMLPPolicy( env_spec=env.spec, hidden_sizes=(64, 64), hidden_nonlinearity=torch.tanh, output_nonlinearity=None, ) value_function = GaussianMLPValueFunction(env_spec=env.spec, hidden_sizes=(32, 32)) rollouts_per_task = 2 max_episode_length = 100 runner = LocalRunner(snapshot_config) algo = MAMLTRPO(env=env, policy=policy, value_function=value_function, max_episode_length=max_episode_length, meta_batch_size=5, discount=0.99, gae_lambda=1., inner_lr=0.1, num_grad_updates=1) runner.setup(algo, env, sampler_cls=LocalSampler) runner.train(n_epochs=2, batch_size=rollouts_per_task * max_episode_length)
def test_maml_trpo_dummy_named_env(): """Test with dummy environment that has env_name.""" env = normalize(GymEnv(DummyMultiTaskBoxEnv(), max_episode_length=100), expected_action_scale=10.) policy = GaussianMLPPolicy( env_spec=env.spec, hidden_sizes=(64, 64), hidden_nonlinearity=torch.tanh, output_nonlinearity=None, ) value_function = GaussianMLPValueFunction(env_spec=env.spec, hidden_sizes=(32, 32)) task_sampler = SetTaskSampler( DummyMultiTaskBoxEnv, wrapper=lambda env, _: normalize(GymEnv(env, max_episode_length=100), expected_action_scale=10.)) episodes_per_task = 2 max_episode_length = env.spec.max_episode_length sampler = LocalSampler(agents=policy, envs=env, max_episode_length=env.spec.max_episode_length) trainer = Trainer(snapshot_config) algo = MAMLTRPO(env=env, policy=policy, sampler=sampler, task_sampler=task_sampler, value_function=value_function, meta_batch_size=5, discount=0.99, gae_lambda=1., inner_lr=0.1, num_grad_updates=1) trainer.setup(algo, env) trainer.train(n_epochs=2, batch_size=episodes_per_task * max_episode_length)