Example #1
0
    def test_handle_preprocessing_layers(self, outer_dims):
        num_actions_dims = 2

        observation_spec = (tensor_spec.TensorSpec([1], tf.float32),
                            tensor_spec.TensorSpec([], tf.float32))
        time_step_spec = ts.time_step_spec(observation_spec)
        time_step = tensor_spec.sample_spec_nest(
            time_step_spec, outer_dims=outer_dims)

        action_spec = tensor_spec.BoundedTensorSpec((2, ), tf.float32, 2, 3)

        actions = tf.random.uniform(list(outer_dims) + [num_actions_dims])

        preprocessing_layers = (tf.keras.layers.Dense(4),
                                sequential_layer.SequentialLayer([
                                    tf.keras.layers.Reshape((1, )),
                                    tf.keras.layers.Dense(4)
                                ]))

        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_preprocessing_layers=preprocessing_layers,
            observation_preprocessing_combiner=tf.keras.layers.Add())

        q_values, _ = critic_net((time_step.observation, actions))
        self.assertAllEqual(q_values.shape.as_list(), list(outer_dims))
Example #2
0
    def test_build(self, outer_dims):
        num_obs_dims = 5
        num_actions_dims = 2
        obs_spec = tensor_spec.TensorSpec([num_obs_dims], tf.float32)
        action_spec = tensor_spec.TensorSpec([num_actions_dims], tf.float32)

        obs = tf.random.uniform(list(outer_dims) + [num_obs_dims])
        actions = tf.random.uniform(list(outer_dims) + [num_actions_dims])
        critic_net = critic_network.CriticNetwork((obs_spec, action_spec))

        q_values, _ = critic_net((obs, actions))
        self.assertAllEqual(q_values.shape.as_list(), list(outer_dims))
        self.assertLen(critic_net.trainable_variables, 2)
Example #3
0
    def test_add_obs_fc_layers(self, outer_dims):
        num_obs_dims = 5
        num_actions_dims = 2

        obs_spec = tensor_spec.TensorSpec([3, 3, num_obs_dims], tf.float32)
        action_spec = tensor_spec.TensorSpec([num_actions_dims], tf.float32)
        critic_net = critic_network.CriticNetwork(
            (obs_spec, action_spec), observation_fc_layer_params=[20, 10])

        obs = tf.random.uniform(list(outer_dims) + [3, 3, num_obs_dims])
        actions = tf.random.uniform(list(outer_dims) + [num_actions_dims])
        q_values, _ = critic_net((obs, actions))

        self.assertAllEqual(q_values.shape.as_list(), list(outer_dims))
        self.assertLen(critic_net.trainable_variables, 6)