예제 #1
0
 def test_done(self):
     env = PointEnv()
     for _ in range(1000):
         _, _, done, _ = env.step(env._goal)
         if done:
             break
     else:
         assert False, 'Should report done'
예제 #2
0
    def test_task(self):
        env = PointEnv()
        tasks = env.sample_tasks(5)
        assert len(tasks) == 5

        for task in tasks:
            env.set_task(task)
            assert (env._goal == task['goal']).all()
예제 #3
0
 def test_done(self):
     env = PointEnv()
     env.reset()
     for _ in range(1000):
         done = env.step(env._goal).terminal
         if done:
             break
     else:
         assert False, 'Should report done'
예제 #4
0
    def test_reset(self):
        env = PointEnv()

        assert (env._point == np.array([0, 0])).all()

        a = env.action_space.sample()
        _ = env.step(a)
        env.reset()

        assert (env._point == np.array([0, 0])).all()
예제 #5
0
    def test_visualization(self):
        env = PointEnv()
        assert env.render_modes == ['ascii']
        env.reset()
        assert env.render('ascii') == f'Point: {env._point}, Goal: {env._goal}'

        env.visualize()
        env.step(env.action_space.sample())
예제 #6
0
 def test_does_not_modify_action(self):
     env = PointEnv()
     a = env.action_space.sample()
     a_copy = a.copy()
     env.reset()
     env.step(a)
     assert a.all() == a_copy.all()
     env.close()
예제 #7
0
 def test_does_not_modify_action(self):
     env = PointEnv()
     a = env.action_space.sample()
     a_copy = a.copy()
     env.reset()
     env.step(a)
     self.assertEquals(a.all(), a_copy.all())
예제 #8
0
 def test_pickleable(self):
     env = PointEnv()
     round_trip = pickle.loads(pickle.dumps(env))
     assert round_trip
     step_env(round_trip)
예제 #9
0
 def test_catch_no_reset(self):
     env = PointEnv()
     with pytest.raises(RuntimeError, match='reset()'):
         env.step(env.action_space.sample())
예제 #10
0
파일: trpo_point.py 프로젝트: gntoni/garage
from garage.algos import TRPO
from garage.baselines import LinearFeatureBaseline
from garage.envs import normalize
from garage.envs.point_env import PointEnv
from garage.policies import GaussianMLPPolicy
from garage.theano.envs import TheanoEnv

env = TheanoEnv(normalize(PointEnv()))
policy = GaussianMLPPolicy(env_spec=env.spec, )
baseline = LinearFeatureBaseline(env_spec=env.spec)
algo = TRPO(
    env=env,
    policy=policy,
    baseline=baseline,
)
algo.train()
예제 #11
0
 def test_observation_space(self):
     env = PointEnv()
     obs_space = env.observation_space
     a = env.action_space.sample()
     obs, _, _, _ = env.step(a)
     assert obs_space.contains(obs)