Exemplo n.º 1
0
 def test_all_does_not_modify_actions(self, domain_name, task_name):
     env = DmControlEnv(domain_name, task_name)
     a = env.action_space.sample()
     a_copy = copy(a)
     env.step(a)
     if isinstance(a, collections.Iterable):
         self.assertEquals(a.all(), a_copy.all())
     else:
         self.assertEquals(a, a_copy)
Exemplo n.º 2
0
 def test_can_step_and_render(self, domain_name, task_name):
     env = DmControlEnv(domain_name, task_name)
     ob_space = env.observation_space
     act_space = env.action_space
     ob = env.reset()
     assert ob_space.contains(ob)
     a = act_space.sample()
     assert act_space.contains(a)
     step_env(env, n=10, render=True)
Exemplo n.º 3
0
def run_task(domain_name, task_name):
    print("run: domain %s task %s" % (domain_name, task_name))
    dm_control_env = normalize(
        DmControlEnv(
            domain_name=domain_name,
            task_name=task_name,
            plot=True,
            width=600,
            height=400),
        normalize_obs=False,
        normalize_reward=False)

    time_step = dm_control_env.reset()
    action_spec = dm_control_env.action_space
    for _ in range(5):
        dm_control_env.render()
        action = action_spec.sample()
        next_obs, reward, done, info = dm_control_env.step(action)
        if done:
            break

    dm_control_env.close()
Exemplo n.º 4
0
    def test_dm_control_theano_policy(self):
        task = ALL_TASKS[0]

        env = TheanoEnv(DmControlEnv(domain_name=task[0], task_name=task[1]))

        policy = GaussianMLPPolicy(
            env_spec=env.spec,
            hidden_sizes=(32, 32),
        )

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(
            env=env,
            policy=policy,
            baseline=baseline,
            batch_size=10,
            max_path_length=5,
            n_itr=1,
            discount=0.99,
            step_size=0.01,
        )
        algo.train()
Exemplo n.º 5
0
 def test_pickling(self, domain_name, task_name):
     env = DmControlEnv(domain_name, task_name)
     round_trip = pickle.loads(pickle.dumps(env))
     assert round_trip
     step_env(round_trip)
Exemplo n.º 6
0
 def __init__(self, method_name='runTest', param=ALL_TASKS[0]):
     super().__init__(method_name)
     self.env = DmControlEnv(domain_name=param[0], task_name=param[1])