コード例 #1
0
 def test_invalid_action_spaces(self):
     """Test that policy raises error if passed a box obs space."""
     env = GymEnv(DummyDictEnv(act_space_type='box'))
     with pytest.raises(ValueError):
         CategoricalCNNPolicy(env=env,
                              kernel_sizes=(3, ),
                              hidden_channels=(3, ))
コード例 #2
0
ファイル: test_dict_space.py プロジェクト: Mee321/HAPG_exp
    def test_dict_space(self):
        ext.set_seed(0)

        # A dummy dict env
        dummy_env = DummyDictEnv()
        dummy_act = dummy_env.action_space
        dummy_act_sample = dummy_act.sample()

        # A dummy dict env wrapped by garage.tf
        tf_env = TfEnv(dummy_env)
        tf_act = tf_env.action_space
        tf_obs = tf_env.observation_space

        # flat_dim
        assert tf_act.flat_dim == tf_act.flatten(dummy_act_sample).shape[-1]

        # flat_dim_with_keys
        assert tf_obs.flat_dim == tf_obs.flat_dim_with_keys(
            iter(["achieved_goal", "desired_goal", "observation"]))

        # un/flatten
        assert tf_act.unflatten(
            tf_act.flatten(dummy_act_sample)) == dummy_act_sample

        # un/flatten_n
        samples = [dummy_act.sample() for _ in range(10)]
        assert tf_act.unflatten_n(tf_act.flatten_n(samples)) == samples

        # un/flatten_with_keys
        assert tf_act.unflatten_with_keys(
            tf_act.flatten_with_keys(dummy_act_sample, iter(["action"])),
            iter(["action"]))
コード例 #3
0
 def setup_method(self):
     self.env = GymEnv(DummyDictEnv())
     self.obs = self.env.reset()[0]
     self._replay_k = 4
     self.replay_buffer = HERReplayBuffer(env_spec=self.env.spec,
                                          capacity_in_transitions=10,
                                          replay_k=self._replay_k,
                                          reward_fn=self.env.compute_reward)
コード例 #4
0
 def test_does_not_support_dict_obs_space(self):
     """Test that policy raises error if passed a dict obs space."""
     env = GymEnv(DummyDictEnv(act_space_type='discrete'))
     with pytest.raises(ValueError,
                        match=('CNN policies do not support '
                               'with akro.Dict observation spaces.')):
         CategoricalCNNPolicy(env=env,
                              kernel_sizes=(3, ),
                              hidden_channels=(3, ))
コード例 #5
0
 def test_does_not_support_dict_obs_space(self, filters, strides, padding,
                                          hidden_sizes):
     """Test that policy raises error if passed a dict obs space."""
     env = GymEnv(DummyDictEnv(act_space_type='discrete'))
     with pytest.raises(ValueError):
         CategoricalCNNPolicy(env_spec=env.spec,
                              filters=filters,
                              strides=strides,
                              padding=padding,
                              hidden_sizes=hidden_sizes)
コード例 #6
0
    def test_get_action_dict_space(self):
        env = GymEnv(DummyDictEnv(obs_space_type='box', act_space_type='box'))
        policy = GaussianMLPPolicy(env_spec=env.spec)
        obs = env.reset()[0]

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
コード例 #7
0
    def test_get_action_dict_space(self):
        env = GymEnv(DummyDictEnv(obs_space_type='box', act_space_type='box'))
        policy = GaussianGRUPolicy(env_spec=env.spec,
                                   hidden_dim=4,
                                   state_include_action=False)
        policy.reset(do_resets=None)
        obs = env.reset()[0]

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
コード例 #8
0
    def test_get_action_dict_space(self):
        """Test if observations from dict obs spaces are properly flattened."""
        env = GymEnv(DummyDictEnv(obs_space_type='box', act_space_type='box'))
        policy = TanhGaussianMLPPolicy(env_spec=env.spec,
                                       hidden_nonlinearity=None,
                                       hidden_sizes=(1, ),
                                       hidden_w_init=nn.init.ones_,
                                       output_w_init=nn.init.ones_)
        obs = env.reset()[0]

        action, _ = policy.get_action(obs)
        assert env.action_space.shape == action.shape

        actions, _ = policy.get_actions(np.array([obs, obs]))
        for action in actions:
            assert env.action_space.shape == action.shape
コード例 #9
0
    def test_algo_with_goal_without_es(self):
        # This tests if sampler works properly when algorithm
        # includes goal but is without exploration policy
        env = DummyDictEnv()
        policy = DummyPolicy(env)
        replay_buffer = SimpleReplayBuffer(env_spec=env,
                                           size_in_transitions=int(1e6),
                                           time_horizon=100)
        algo = DummyOffPolicyAlgo(env_spec=env,
                                  qf=None,
                                  replay_buffer=replay_buffer,
                                  policy=policy,
                                  exploration_strategy=None)

        sampler = OffPolicyVectorizedSampler(algo, env, 1, no_reset=True)
        sampler.start_worker()
        sampler.obtain_samples(0, 30)
コード例 #10
0
    def test_get_action(self, obs_dim, action_dim, obs_type):
        assert obs_type in ['discrete', 'dict']
        if obs_type == 'discrete':
            env = GymEnv(
                DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        else:
            env = GymEnv(
                DummyDictEnv(obs_space_type='box', act_space_type='discrete'))
        policy = CategoricalMLPPolicy(env_spec=env.spec)
        obs = env.reset()[0]
        if obs_type == 'discrete':
            obs = obs.flatten()
        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
コード例 #11
0
    def test_q_vals_goal_conditioned(self):
        env = GarageEnv(DummyDictEnv())
        with mock.patch(('garage.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = np.concatenate(
            (obs['observation'], obs['desired_goal'], obs['achieved_goal']),
            axis=-1)
        act = np.full((1, ), 0.5).flatten()

        expected_output = np.full((1, ), 0.5)

        outputs = qf.get_qval([obs], [act])
        assert np.array_equal(outputs[0], expected_output)

        outputs = qf.get_qval([obs, obs, obs], [act, act, act])
        for output in outputs:
            assert np.array_equal(output, expected_output)
コード例 #12
0
    def test_q_vals_input_include_goal(self):
        env = TfEnv(DummyDictEnv())
        with mock.patch(('garage.tf.q_functions.'
                         'continuous_mlp_q_function_with_model.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunctionWithModel(env_spec=env.spec,
                                                 input_include_goal=True)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = np.concatenate((obs['observation'], obs['desired_goal']),
                             axis=-1)
        act = np.full((1, ), 0.5).flatten()

        expected_output = np.full((1, ), 0.5)
        obs_ph, act_ph = qf.inputs

        outputs = qf.get_qval([obs], [act])
        assert np.array_equal(outputs[0], expected_output)

        outputs = qf.get_qval([obs, obs, obs], [act, act, act])
        for output in outputs:
            assert np.array_equal(output, expected_output)
コード例 #13
0
    def test_get_action(self, obs_dim, action_dim, obs_type):
        """Test get_action method"""
        assert obs_type in ['box', 'dict']
        if obs_type == 'box':
            env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        else:
            env = GymEnv(
                DummyDictEnv(obs_space_type='box', act_space_type='box'))

        policy = ContinuousMLPPolicy(env_spec=env.spec)

        env.reset()
        obs = env.step(1).observation
        if obs_type == 'box':
            obs = obs.flatten()

        action, _ = policy.get_action(obs)

        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
コード例 #14
0
class TestHerReplayBuffer:
    def setup_method(self):
        self.env = DummyDictEnv()
        self.obs = self.env.reset()
        self.replay_buffer = HerReplayBuffer(
            env_spec=self.env.spec,
            size_in_transitions=3,
            time_horizon=1,
            replay_k=0.4,
            reward_fun=self.env.compute_reward)

    def _add_single_transition(self):
        self.replay_buffer.add_transition(
            observation=self.obs,
            action=self.env.action_space.sample(),
            terminal=False,
            next_observation=self.obs)

    def _add_transitions(self):
        self.replay_buffer.add_transitions(
            observation=[self.obs],
            action=[self.env.action_space.sample()],
            terminal=[False],
            next_observation=[self.obs])

    def test_add_transition_dtype(self):
        self._add_single_transition()
        sample = self.replay_buffer.sample(1)

        assert sample['observation'].dtype == self.env.observation_space[
            'observation'].dtype
        assert sample['achieved_goal'].dtype == self.env.observation_space[
            'achieved_goal'].dtype
        assert sample['goal'].dtype == self.env.observation_space[
            'desired_goal'].dtype
        assert sample['action'].dtype == self.env.action_space.dtype

    def test_add_transitions_dtype(self):
        self._add_transitions()
        sample = self.replay_buffer.sample(1)

        assert sample['observation'].dtype == self.env.observation_space[
            'observation'].dtype
        assert sample['achieved_goal'].dtype == self.env.observation_space[
            'achieved_goal'].dtype
        assert sample['goal'].dtype == self.env.observation_space[
            'desired_goal'].dtype
        assert sample['action'].dtype == self.env.action_space.dtype

    def test_eviction_policy(self):
        self.replay_buffer.add_transitions(
            observation=[self.obs, self.obs],
            next_observation=[self.obs, self.obs],
            terminal=[False, False],
            action=[1, 2])
        assert not self.replay_buffer.full
        self.replay_buffer.add_transitions(
            observation=[self.obs, self.obs],
            next_observation=[self.obs, self.obs],
            terminal=[False, False],
            action=[3, 4])
        assert self.replay_buffer.full
        self.replay_buffer.add_transitions(
            observation=[self.obs, self.obs],
            next_observation=[self.obs, self.obs],
            terminal=[False, False],
            action=[5, 6])
        self.replay_buffer.add_transitions(
            observation=[self.obs, self.obs],
            next_observation=[self.obs, self.obs],
            terminal=[False, False],
            action=[7, 8])

        assert np.array_equal(self.replay_buffer._buffer['action'],
                              [[7], [8], [6]])
        assert self.replay_buffer.n_transitions_stored == 3

    def test_pickleable(self):
        self._add_transitions()
        replay_buffer_pickled = pickle.loads(pickle.dumps(self.replay_buffer))
        assert replay_buffer_pickled._buffer.keys(
        ) == self.replay_buffer._buffer.keys()
        for k in replay_buffer_pickled._buffer:
            assert replay_buffer_pickled._buffer[
                k].shape == self.replay_buffer._buffer[k].shape
        sample = self.replay_buffer.sample(1)
        sample2 = replay_buffer_pickled.sample(1)
        for k in self.replay_buffer._buffer:
            assert sample[k].shape == sample2[k].shape
コード例 #15
0
 def test_invalid_action_spaces(self):
     """Test that policy raises error if passed a dict obs space."""
     env = GymEnv(DummyDictEnv(act_space_type='box'))
     with pytest.raises(ValueError):
         qf = SimpleQFunction(env.spec)
         DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
コード例 #16
0
class TestHerReplayBuffer:
    def setup_method(self):
        self.env = DummyDictEnv()
        self.obs = self.env.reset()
        self._replay_k = 4
        self.replay_buffer = HERReplayBuffer(env_spec=self.env.spec,
                                             capacity_in_transitions=10,
                                             replay_k=self._replay_k,
                                             reward_fn=self.env.compute_reward)

    def test_replay_k(self):
        self.replay_buffer = HERReplayBuffer(env_spec=self.env.spec,
                                             capacity_in_transitions=10,
                                             replay_k=0,
                                             reward_fn=self.env.compute_reward)

        with pytest.raises(ValueError):
            self.replay_buffer = HERReplayBuffer(
                env_spec=self.env.spec,
                capacity_in_transitions=10,
                replay_k=0.2,
                reward_fn=self.env.compute_reward)

    def _add_one_path(self):
        path = dict(
            observations=np.asarray([self.obs, self.obs]),
            actions=np.asarray([
                self.env.action_space.sample(),
                self.env.action_space.sample()
            ]),
            rewards=np.asarray([[1], [1]]),
            terminals=np.asarray([[False], [False]]),
            next_observations=np.asarray([self.obs, self.obs]),
        )
        self.replay_buffer.add_path(path)

    def test_add_path(self):
        self._add_one_path()

        # HER buffer should add replay_k + 1 transitions to the buffer
        # for each transition in the given path. This doesn't apply to
        # the last transition, where only that transition gets added.

        path_len = 2
        total_expected_transitions = sum(
            [self._replay_k + 1 for _ in range(path_len - 1)]) + 1
        assert (self.replay_buffer.n_transitions_stored ==
                total_expected_transitions)
        assert (len(
            self.replay_buffer._path_segments) == total_expected_transitions -
                1)
        # check that buffer has the correct keys
        assert {
            'observations', 'next_observations', 'actions', 'rewards',
            'terminals'
        } <= set(self.replay_buffer._buffer)

        # check that dict obses are flattened
        obs = self.replay_buffer._buffer['observations'][0]
        next_obs = self.replay_buffer._buffer['next_observations'][0]
        assert obs.shape == self.env.spec.observation_space.flat_dim
        assert next_obs.shape == self.env.spec.observation_space.flat_dim

    def test_pickleable(self):
        self._add_one_path()
        replay_buffer_pickled = pickle.loads(pickle.dumps(self.replay_buffer))
        assert (replay_buffer_pickled._buffer.keys() ==
                self.replay_buffer._buffer.keys())
        for k in replay_buffer_pickled._buffer:
            assert replay_buffer_pickled._buffer[
                k].shape == self.replay_buffer._buffer[k].shape
        sample = self.replay_buffer.sample_transitions(1)
        sample2 = replay_buffer_pickled.sample_transitions(1)
        for k in sample.keys():
            assert sample[k].shape == sample2[k].shape
        assert len(sample) == len(sample2)
コード例 #17
0
 def test_not_box(self):
     with pytest.raises(ValueError):
         dict_env = GarageEnv(DummyDictEnv())
         ContinuousCNNQFunction(env_spec=dict_env.spec,
                                filters=((5, (3, 3)), ),
                                strides=(1, ))