示例#1
0
    def test_default_sampler_cls(self):
        policy = DummyPolicy(env_spec=self.env.spec)
        algo = DummyAlgo(policy=policy, baseline=self.baseline)
        sampler = algo.sampler_cls(algo, self.env, dict())
        assert isinstance(sampler, OnPolicyVectorizedSampler)

        policy = DummyPolicyWithoutVectorized(env_spec=self.env.spec)
        algo = DummyAlgo(policy=policy, baseline=self.baseline)
        sampler = algo.sampler_cls(algo, self.env, dict())
        assert isinstance(sampler, BatchSampler)
示例#2
0
 def test_rl2_worker(self):
     env = TfEnv(DummyBoxEnv(obs_dim=(1, )))
     policy = DummyPolicy(env_spec=env.spec)
     worker = RL2Worker(seed=1,
                        max_path_length=100,
                        worker_number=1,
                        n_paths_per_trial=5)
     worker.update_agent(policy)
     worker.update_env(env)
     rollouts = worker.rollout()
     assert rollouts.rewards.shape[0] == 500
示例#3
0
def run_task(*_):
    env = normalize(CartpoleEnv())

    policy = DummyPolicy(env_spec=env)

    baseline = LinearFeatureBaseline(env_spec=env)
    algo = InstrumentedNOP(env=env,
                           policy=policy,
                           baseline=baseline,
                           batch_size=4000,
                           max_path_length=100,
                           n_itr=4,
                           discount=0.99,
                           step_size=0.01,
                           plot=True)
    algo.train()
示例#4
0
    def test_algo_with_goal_without_es(self):
        # This tests if sampler works properly when algorithm
        # includes goal but is without exploration policy
        env = DummyDictEnv()
        policy = DummyPolicy(env)
        replay_buffer = SimpleReplayBuffer(env_spec=env,
                                           size_in_transitions=int(1e6),
                                           time_horizon=100)
        algo = DummyOffPolicyAlgo(env_spec=env,
                                  qf=None,
                                  replay_buffer=replay_buffer,
                                  policy=policy,
                                  exploration_strategy=None)

        sampler = OffPolicyVectorizedSampler(algo, env, 1, no_reset=True)
        sampler.start_worker()
        sampler.obtain_samples(0, 30)
示例#5
0
 def setup_method(self):
     self.env = GarageEnv(DummyBoxEnv(obs_dim=(4, 4), action_dim=(2, 2)))
     self.policy = DummyPolicy(self.env.spec)