def test_handle_preprocessing_layers(self, outer_dims): num_actions_dims = 2 observation_spec = (tensor_spec.TensorSpec([1], tf.float32), tensor_spec.TensorSpec([], tf.float32)) time_step_spec = ts.time_step_spec(observation_spec) time_step = tensor_spec.sample_spec_nest( time_step_spec, outer_dims=outer_dims) action_spec = tensor_spec.BoundedTensorSpec((2, ), tf.float32, 2, 3) actions = tf.random.uniform(list(outer_dims) + [num_actions_dims]) preprocessing_layers = (tf.keras.layers.Dense(4), sequential_layer.SequentialLayer([ tf.keras.layers.Reshape((1, )), tf.keras.layers.Dense(4) ])) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_preprocessing_layers=preprocessing_layers, observation_preprocessing_combiner=tf.keras.layers.Add()) q_values, _ = critic_net((time_step.observation, actions)) self.assertAllEqual(q_values.shape.as_list(), list(outer_dims))
def test_build(self, outer_dims): num_obs_dims = 5 num_actions_dims = 2 obs_spec = tensor_spec.TensorSpec([num_obs_dims], tf.float32) action_spec = tensor_spec.TensorSpec([num_actions_dims], tf.float32) obs = tf.random.uniform(list(outer_dims) + [num_obs_dims]) actions = tf.random.uniform(list(outer_dims) + [num_actions_dims]) critic_net = critic_network.CriticNetwork((obs_spec, action_spec)) q_values, _ = critic_net((obs, actions)) self.assertAllEqual(q_values.shape.as_list(), list(outer_dims)) self.assertLen(critic_net.trainable_variables, 2)
def test_add_obs_fc_layers(self, outer_dims): num_obs_dims = 5 num_actions_dims = 2 obs_spec = tensor_spec.TensorSpec([3, 3, num_obs_dims], tf.float32) action_spec = tensor_spec.TensorSpec([num_actions_dims], tf.float32) critic_net = critic_network.CriticNetwork( (obs_spec, action_spec), observation_fc_layer_params=[20, 10]) obs = tf.random.uniform(list(outer_dims) + [3, 3, num_obs_dims]) actions = tf.random.uniform(list(outer_dims) + [num_actions_dims]) q_values, _ = critic_net((obs, actions)) self.assertAllEqual(q_values.shape.as_list(), list(outer_dims)) self.assertLen(critic_net.trainable_variables, 6)