Example #1
0
    def test_dict_space(self):
        ext.set_seed(0)

        # A dummy dict env
        dummy_env = DummyDictEnv()
        dummy_act = dummy_env.action_space
        dummy_act_sample = dummy_act.sample()

        # A dummy dict env wrapped by garage.tf
        tf_env = TfEnv(DummyDictEnv())
        tf_act = tf_env.action_space
        tf_obs = tf_env.observation_space

        # sample
        assert tf_act.sample() == dummy_act_sample

        # flat_dim
        assert tf_act.flat_dim == tf_act.flatten(dummy_act_sample).shape[-1]

        # flat_dim_with_keys
        assert tf_obs.flat_dim == tf_obs.flat_dim_with_keys(
            iter(["achieved_goal", "desired_goal", "observation"]))

        # un/flatten
        assert tf_act.unflatten(
            tf_act.flatten(dummy_act_sample)) == dummy_act_sample

        # un/flatten_n
        samples = [dummy_act.sample() for _ in range(10)]
        assert tf_act.unflatten_n(tf_act.flatten_n(samples)) == samples

        # un/flatten_with_keys
        assert tf_act.unflatten_with_keys(
            tf_act.flatten_with_keys(dummy_act_sample, iter(["action"])),
            iter(["action"]))
 def test_not_box(self):
     with pytest.raises(ValueError):
         dict_env = TfEnv(DummyDictEnv())
         ContinuousCNNQFunction(env_spec=dict_env.spec,
                                filter_dims=(3, ),
                                num_filters=(5, ),
                                strides=(1, ))
Example #3
0
 def test_invalid_action_spaces(self):
     """Test that policy raises error if passed a box obs space."""
     env = GymEnv(DummyDictEnv(act_space_type='box'))
     with pytest.raises(ValueError):
         CategoricalCNNPolicy(env=env,
                              kernel_sizes=(3, ),
                              hidden_channels=(3, ))
Example #4
0
 def test_does_not_support_dict_obs_space(self):
     """Test that policy raises error if passed a dict obs space."""
     env = GymEnv(DummyDictEnv(act_space_type='discrete'))
     with pytest.raises(ValueError):
         qf = SimpleQFunction(env.spec,
                              name='does_not_support_dict_obs_space')
         DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
 def setup_method(self):
     self.env = DummyDictEnv()
     self.obs = self.env.reset()
     self._replay_k = 4
     self.replay_buffer = HERReplayBuffer(env_spec=self.env.spec,
                                          capacity_in_transitions=10,
                                          replay_k=self._replay_k,
                                          reward_fn=self.env.compute_reward)
Example #6
0
 def test_does_not_support_dict_obs_space(self):
     """Test that policy raises error if passed a dict obs space."""
     env = GymEnv(DummyDictEnv(act_space_type='discrete'))
     with pytest.raises(ValueError,
                        match=('CNN policies do not support '
                               'with akro.Dict observation spaces.')):
         CategoricalCNNPolicy(env=env,
                              kernel_sizes=(3, ),
                              hidden_channels=(3, ))
Example #7
0
 def setup_method(self):
     self.env = DummyDictEnv()
     self.obs = self.env.reset()
     self.replay_buffer = HerReplayBuffer(
         env_spec=self.env.spec,
         size_in_transitions=3,
         time_horizon=1,
         replay_k=0.4,
         reward_fun=self.env.compute_reward)
 def test_does_not_support_dict_obs_space(self, filters, strides, padding,
                                          hidden_sizes):
     """Test that policy raises error if passed a dict obs space."""
     env = GymEnv(DummyDictEnv(act_space_type='discrete'))
     with pytest.raises(ValueError):
         CategoricalCNNPolicy(env_spec=env.spec,
                              filters=filters,
                              strides=strides,
                              padding=padding,
                              hidden_sizes=hidden_sizes)
    def test_get_action_dict_space(self):
        env = GymEnv(DummyDictEnv(obs_space_type='box', act_space_type='box'))
        policy = GaussianMLPPolicy(env_spec=env.spec)
        obs = env.reset()[0]

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
Example #10
0
 def setup_method(self):
     self.env = DummyDictEnv()
     obs = self.env.reset()
     self.replay_buffer = HerReplayBuffer(
         env_spec=self.env.spec,
         size_in_transitions=3,
         time_horizon=1,
         replay_k=0.4,
         reward_fun=self.env.compute_reward)
     # process observations
     self.d_g = obs['desired_goal']
     self.a_g = obs['achieved_goal']
     self.obs = obs['observation']
Example #11
0
    def test_get_action_dict_space(self):
        env = GymEnv(DummyDictEnv(obs_space_type='box', act_space_type='box'))
        policy = GaussianGRUPolicy(env_spec=env.spec,
                                   hidden_dim=4,
                                   state_include_action=False)
        policy.reset(do_resets=None)
        obs = env.reset()[0]

        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
Example #12
0
    def test_get_action_dict_space(self):
        """Test if observations from dict obs spaces are properly flattened."""
        env = GymEnv(DummyDictEnv(obs_space_type='box', act_space_type='box'))
        policy = TanhGaussianMLPPolicy(env_spec=env.spec,
                                       hidden_nonlinearity=None,
                                       hidden_sizes=(1, ),
                                       hidden_w_init=nn.init.ones_,
                                       output_w_init=nn.init.ones_)
        obs = env.reset()[0]

        action, _ = policy.get_action(obs)
        assert env.action_space.shape == action.shape

        actions, _ = policy.get_actions(np.array([obs, obs]))
        for action in actions:
            assert env.action_space.shape == action.shape
Example #13
0
    def test_algo_with_goal_without_es(self):
        # This tests if sampler works properly when algorithm
        # includes goal but is without exploration policy
        env = DummyDictEnv()
        policy = DummyPolicy(env)
        replay_buffer = SimpleReplayBuffer(env_spec=env,
                                           size_in_transitions=int(1e6),
                                           time_horizon=100)
        algo = DummyOffPolicyAlgo(env_spec=env,
                                  qf=None,
                                  replay_buffer=replay_buffer,
                                  policy=policy,
                                  exploration_strategy=None)

        sampler = OffPolicyVectorizedSampler(algo, env, 1, no_reset=True)
        sampler.start_worker()
        sampler.obtain_samples(0, 30)
    def test_get_action(self, obs_dim, action_dim, obs_type):
        assert obs_type in ['discrete', 'dict']
        if obs_type == 'discrete':
            env = GymEnv(
                DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        else:
            env = GymEnv(
                DummyDictEnv(obs_space_type='box', act_space_type='discrete'))
        policy = CategoricalMLPPolicy(env_spec=env.spec)
        obs = env.reset()[0]
        if obs_type == 'discrete':
            obs = obs.flatten()
        action, _ = policy.get_action(obs)
        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
Example #15
0
    def test_q_vals_goal_conditioned(self):
        env = GarageEnv(DummyDictEnv())
        with mock.patch(('garage.tf.q_functions.'
                         'continuous_mlp_q_function.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunction(env_spec=env.spec)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = np.concatenate(
            (obs['observation'], obs['desired_goal'], obs['achieved_goal']),
            axis=-1)
        act = np.full((1, ), 0.5).flatten()

        expected_output = np.full((1, ), 0.5)

        outputs = qf.get_qval([obs], [act])
        assert np.array_equal(outputs[0], expected_output)

        outputs = qf.get_qval([obs, obs, obs], [act, act, act])
        for output in outputs:
            assert np.array_equal(output, expected_output)
    def test_q_vals_input_include_goal(self):
        env = TfEnv(DummyDictEnv())
        with mock.patch(('garage.tf.q_functions.'
                         'continuous_mlp_q_function_with_model.MLPMergeModel'),
                        new=SimpleMLPMergeModel):
            qf = ContinuousMLPQFunctionWithModel(env_spec=env.spec,
                                                 input_include_goal=True)
        env.reset()
        obs, _, _, _ = env.step(1)
        obs = np.concatenate((obs['observation'], obs['desired_goal']),
                             axis=-1)
        act = np.full((1, ), 0.5).flatten()

        expected_output = np.full((1, ), 0.5)
        obs_ph, act_ph = qf.inputs

        outputs = qf.get_qval([obs], [act])
        assert np.array_equal(outputs[0], expected_output)

        outputs = qf.get_qval([obs, obs, obs], [act, act, act])
        for output in outputs:
            assert np.array_equal(output, expected_output)
Example #17
0
    def test_get_action(self, obs_dim, action_dim, obs_type):
        """Test get_action method"""
        assert obs_type in ['box', 'dict']
        if obs_type == 'box':
            env = GymEnv(DummyBoxEnv(obs_dim=obs_dim, action_dim=action_dim))
        else:
            env = GymEnv(
                DummyDictEnv(obs_space_type='box', act_space_type='box'))

        policy = ContinuousMLPPolicy(env_spec=env.spec)

        env.reset()
        obs = env.step(1).observation
        if obs_type == 'box':
            obs = obs.flatten()

        action, _ = policy.get_action(obs)

        assert env.action_space.contains(action)

        actions, _ = policy.get_actions([obs, obs, obs])
        for action in actions:
            assert env.action_space.contains(action)
Example #18
0
 def test_invalid_action_spaces(self):
     """Test that policy raises error if passed a dict obs space."""
     env = GymEnv(DummyDictEnv(act_space_type='box'))
     with pytest.raises(ValueError):
         qf = SimpleQFunction(env.spec)
         DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
Example #19
0
 def test_not_box(self):
     with pytest.raises(ValueError):
         dict_env = GarageEnv(DummyDictEnv())
         ContinuousCNNQFunction(env_spec=dict_env.spec,
                                filters=((5, (3, 3)), ),
                                strides=(1, ))