コード例 #1
0
    def test_obs_is_image(self):
        env = TfEnv(DummyDiscretePixelEnv(), is_image=True)
        with mock.patch(('garage.tf.policies.'
                         'categorical_cnn_policy.CNNModel._build'),
                        autospec=True,
                        side_effect=CNNModel._build) as build:
            policy = CategoricalCNNPolicy(env_spec=env.spec,
                                          conv_filters=(32, ),
                                          conv_filter_sizes=(1, ),
                                          conv_strides=(1, ),
                                          conv_pad='VALID',
                                          hidden_sizes=(3, ))
            normalized_obs = build.call_args_list[0][0][1]

            input_ph = tf.compat.v1.get_default_graph().get_tensor_by_name(
                'Placeholder:0')

            fake_obs = [np.full(env.spec.observation_space.shape, 255)]
            assert (self.sess.run(normalized_obs,
                                  feed_dict={input_ph: fake_obs}) == 1.).all()

            obs_dim = env.spec.observation_space.shape
            state_input = tf.compat.v1.placeholder(tf.float32,
                                                   shape=(None, ) + obs_dim)

            policy.dist_info_sym(state_input, name='another')
            normalized_obs = build.call_args_list[1][0][1]

            input_ph = tf.compat.v1.get_default_graph().get_tensor_by_name(
                'Placeholder_1:0')

            fake_obs = [np.full(env.spec.observation_space.shape, 255)]
            assert (self.sess.run(normalized_obs,
                                  feed_dict={state_input:
                                             fake_obs}) == 1.).all()
コード例 #2
0
    def test_dist_info_sym(self, obs_dim, action_dim, filter_dims,
                           filter_sizes, strides, padding, hidden_sizes):
        env = TfEnv(DummyDiscreteEnv(obs_dim=obs_dim, action_dim=action_dim))
        with mock.patch(('garage.tf.policies.'
                         'categorical_cnn_policy.MLPModel'),
                        new=SimpleMLPModel):
            with mock.patch(('garage.tf.policies.'
                             'categorical_cnn_policy.CNNModel'),
                            new=SimpleCNNModel):
                policy = CategoricalCNNPolicy(env_spec=env.spec,
                                              conv_filters=filter_dims,
                                              conv_filter_sizes=filter_sizes,
                                              conv_strides=strides,
                                              conv_pad=padding,
                                              hidden_sizes=hidden_sizes)

        env.reset()
        obs, _, _, _ = env.step(1)

        expected_prob = np.full(action_dim, 0.5)

        obs_dim = env.spec.observation_space.shape
        state_input = tf.compat.v1.placeholder(tf.float32,
                                               shape=(None, ) + obs_dim)
        dist1 = policy.dist_info_sym(state_input, name='policy2')

        prob = self.sess.run(dist1['prob'], feed_dict={state_input: [obs]})
        assert np.array_equal(prob[0], expected_prob)