def test_dm_control_tf_policy(self): task = ALL_TASKS[0] with self.graph.as_default(): env = TfEnv(DmControlEnv.from_suite(*task)) policy = GaussianMLPPolicy( env_spec=env.spec, hidden_sizes=(32, 32), ) baseline = LinearFeatureBaseline(env_spec=env.spec) algo = TRPO( env=env, policy=policy, baseline=baseline, batch_size=10, max_path_length=5, n_itr=1, discount=0.99, step_size=0.01, ) algo.train() env.close()
def test_dm_control_tf_policy(self): task = ALL_TASKS[0] with LocalTFRunner(snapshot_config, sess=self.sess) as runner: env = TfEnv(DmControlEnv.from_suite(*task)) policy = GaussianMLPPolicy( env_spec=env.spec, hidden_sizes=(32, 32), ) baseline = LinearFeatureBaseline(env_spec=env.spec) algo = TRPO( env_spec=env.spec, policy=policy, baseline=baseline, max_path_length=5, discount=0.99, max_kl_step=0.01, ) runner.setup(algo, env) runner.train(n_epochs=1, batch_size=10) env.close()
def test_dm_control_tf_policy(self): task = ALL_TASKS[0] env = TfEnv(DmControlEnv.from_suite(*task)) policy = GaussianMLPPolicy( env_spec=env.spec, hidden_sizes=(32, 32), ) baseline = LinearFeatureBaseline(env_spec=env.spec) algo = TRPO( env=env, policy=policy, baseline=baseline, batch_size=10, max_path_length=5, n_itr=1, discount=0.99, step_size=0.01, ) runner = LocalRunner(self.sess) runner.setup(algo, env) runner.train(n_epochs=1, batch_size=10) env.close()
def test_all_pickleable(self, domain_name, task_name): env = DmControlEnv.from_suite(domain_name, task_name) round_trip = pickle.loads(pickle.dumps(env)) assert round_trip # Skip rendering because it causes TravisCI to run out of memory step_env(round_trip, render=False) round_trip.close() env.close()
def test_all_does_not_modify_actions(self, domain_name, task_name): env = DmControlEnv.from_suite(domain_name, task_name) a = env.action_space.sample() a_copy = copy(a) env.step(a) if isinstance(a, collections.Iterable): assert a.all() == a_copy.all() else: assert a == a_copy env.close()
def test_all_can_step(self, domain_name, task_name): env = DmControlEnv.from_suite(domain_name, task_name) ob_space = env.observation_space act_space = env.action_space ob = env.reset() assert ob_space.contains(ob) a = act_space.sample() assert act_space.contains(a) # Skip rendering because it causes TravisCI to run out of memory step_env(env, render=False) env.close()
def test_does_not_modify_actions(self): domain_name, task_name = dm_control.suite.ALL_TASKS[0] env = DmControlEnv.from_suite(domain_name, task_name) a = env.action_space.sample() a_copy = copy(a) env.step(a) if isinstance(a, collections.Iterable): self.assertEqual(a.all(), a_copy.all()) else: self.assertEqual(a, a_copy) env.close()
def run_task(*_): with LocalRunner() as runner: env = normalize(DmControlEnv.from_suite('cartpole', 'balance')) policy = GaussianMLPPolicy( env_spec=env.spec, hidden_sizes=(32, 32), ) baseline = LinearFeatureBaseline(env_spec=env.spec) algo = TRPO( env=env, policy=policy, baseline=baseline, max_path_length=100, discount=0.99, max_kl_step=0.01, ) runner.setup(algo, env) runner.train(n_epochs=400, batch_size=4000, plot=True)
"""Example of how to load, step, and visualize an environment. This example requires that garage[dm_control] be installed. """ import argparse from garage.envs.dm_control import DmControlEnv parser = argparse.ArgumentParser() parser.add_argument('--n_steps', type=int, default=1000, help='Number of steps to run') args = parser.parse_args() # Construct the environment env = DmControlEnv.from_suite('walker', 'run') # Reset the environment and launch the viewer env.reset() env.render() # Step randomly until interrupted steps = 0 while True: if steps == args.n_steps: break env.step(env.action_space.sample()) env.render() steps += 1