def test_check_matching_networks(self): layer_1 = tf.keras.layers.Dense(3) layer_2 = tf.keras.layers.Dense(3) q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_1) q_net_2 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_2) common.check_matching_networks(q_net_1, q_net_2)
def test_check_matching_networks_different_vars(self): layer_1 = tf.keras.layers.Dense(3) layer_2 = tf.keras.layers.GRU(3) q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_1) q_net_2 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_2) q_net_1.create_variables() q_net_2.create_variables() with self.assertRaisesRegexp(ValueError, 'Variables lengths do not match'): common.check_matching_networks(q_net_1, q_net_2)
def test_check_matching_networks_different_shape(self): layer_1 = tf.keras.layers.Dense(3) layer_2 = tf.keras.layers.Dense(4) q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_1) q_net_2 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_2) with self.assertRaisesRegexp(ValueError, 'Variable dtypes or shapes do not match'): common.check_matching_networks(q_net_1, q_net_2)
def test_check_matching_networks_different_input_spec(self): layer_1 = tf.keras.layers.Dense(3) layer_2 = tf.keras.layers.Dense(3) q_net_1 = networks_test_utils.KerasLayersNet(self._observation_spec, self._action_spec, layer_1) q_net_2 = networks_test_utils.KerasLayersNet( tensor_spec.TensorSpec([3], tf.float32), self._action_spec, layer_2) with self.assertRaisesRegexp( ValueError, 'Input tensor specs of network and target network ' 'do not match'): common.check_matching_networks(q_net_1, q_net_2)