def test_from_keras_model_succeeds_from_set(self): keras_model = _create_keras_model() input_spec = _create_input_spec() keras_utils.from_keras_model(keras_model=keras_model, global_layers=set(keras_model.layers), local_layers=set(), input_spec=input_spec)
def test_from_keras_model_fails_missing_variables(self): """Ensures failure if global/local layers are missing variables.""" keras_model = _create_keras_model() input_spec = _create_input_spec() with self.assertRaisesRegex(ValueError, 'variables'): keras_utils.from_keras_model(keras_model=keras_model, global_layers=keras_model.layers[:-1], local_layers=[], input_spec=input_spec)
def test_from_keras_model_fails_bad_input_spec(self): keras_model = _create_keras_model() input_spec = collections.namedtuple( 'Batch', ['x'])(x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32)) with self.assertRaisesRegex(ValueError, 'input_spec'): keras_utils.from_keras_model(keras_model=keras_model, global_layers=keras_model.layers, local_layers=[], input_spec=input_spec)
def test_from_keras_model_fails_compiled(self): keras_model = _create_keras_model() keras_model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=tf.keras.optimizers.SGD(learning_rate=0.1)) input_spec = _create_input_spec() with self.assertRaisesRegex(ValueError, 'compiled'): keras_utils.from_keras_model(keras_model=keras_model, global_layers=keras_model.layers, local_layers=[], input_spec=input_spec)
def test_from_keras_model_forward_pass(self): keras_model = _create_keras_model() input_spec = _create_input_spec() recon_model = keras_utils.from_keras_model( keras_model=keras_model, global_layers=keras_model.layers[:-1], local_layers=keras_model.layers[-1:], input_spec=input_spec) batch_input = collections.namedtuple('Batch', ['x', 'y'])( x=tf.ones(shape=[10, 784], dtype=tf.float32), y=tf.zeros(shape=[10, 1], dtype=tf.int32)) batch_output = recon_model.forward_pass(batch_input) self.assertIsInstance(batch_output, model_lib.BatchOutput) self.assertEqual(batch_output.num_examples, 10) self.assertAllEqual(batch_output.labels, tf.zeros(shape=[10, 1], dtype=tf.int32)) # Change num_examples and labels. batch_input = collections.namedtuple('Batch', ['x', 'y'])( x=tf.zeros(shape=[5, 784], dtype=tf.float32), y=tf.ones(shape=[5, 1], dtype=tf.int32)) batch_output = recon_model.forward_pass(batch_input) self.assertIsInstance(batch_output, model_lib.BatchOutput) self.assertEqual(batch_output.num_examples, 5) self.assertAllEqual(batch_output.labels, tf.ones(shape=[5, 1], dtype=tf.int32))
def test_has_only_global_variables_true(self): keras_model = _create_keras_model() input_spec = _create_input_spec() model = keras_utils.from_keras_model(keras_model=keras_model, global_layers=keras_model.layers, local_layers=[], input_spec=input_spec) self.assertTrue(reconstruction_utils.has_only_global_variables(model))
def local_recon_model_fn(): """Keras MNIST model with final dense layer local.""" keras_model = _create_keras_model() input_spec = _create_input_spec() return keras_utils.from_keras_model(keras_model=keras_model, global_layers=keras_model.layers[:-1], local_layers=keras_model.layers[-1:], input_spec=input_spec)
def global_recon_model_fn(): """Keras MNIST model with no local variables.""" keras_model = _create_keras_model() input_spec = _create_input_spec() return keras_utils.from_keras_model(keras_model=keras_model, global_layers=keras_model.layers, local_layers=[], input_spec=input_spec)
def keras_linear_model_fn(): """Should produce the same results as `LinearModel`.""" inputs = tf.keras.layers.Input(shape=[1]) scaled_input = tf.keras.layers.Dense(1, use_bias=False, kernel_initializer='zeros')(inputs) outputs = BiasLayer()(scaled_input) keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) input_spec = _create_input_spec() return keras_utils.from_keras_model(keras_model=keras_model, global_layers=keras_model.layers[:-1], local_layers=keras_model.layers[-1:], input_spec=input_spec)
def test_from_keras_model_forward_pass_fails_bad_input_keys(self): keras_model = _create_keras_model() input_spec = _create_input_spec() recon_model = keras_utils.from_keras_model( keras_model=keras_model, global_layers=keras_model.layers, local_layers=[], input_spec=input_spec) batch_input = collections.namedtuple('Batch', ['a', 'b'])( a=tf.ones(shape=[10, 784], dtype=tf.float32), b=tf.zeros(shape=[10, 1], dtype=tf.int32)) with self.assertRaisesRegex(KeyError, 'keys'): recon_model.forward_pass(batch_input)
def test_get_local_variables(self): keras_model = _create_keras_model() input_spec = _create_input_spec() model = keras_utils.from_keras_model( keras_model=keras_model, global_layers=keras_model.layers[:-1], local_layers=keras_model.layers[-1:], input_spec=input_spec) local_weights = reconstruction_utils.get_local_variables(model) self.assertIsInstance(local_weights, model_utils.ModelWeights) # The last layer of the Keras model, which is a local Dense layer, contains # 2 trainable variables for the weights and bias. self.assertEqual(local_weights.trainable, keras_model.trainable_variables[-2:]) self.assertEmpty(local_weights.non_trainable)
def test_from_keras_model_properties(self): keras_model = _create_keras_model() input_spec = _create_input_spec() recon_model = keras_utils.from_keras_model( keras_model=keras_model, global_layers=keras_model.layers, local_layers=[], input_spec=input_spec) # Global trainable/non_trainable should include all the variables, and # local should be empty. self.assertEqual(recon_model.global_trainable_variables, keras_model.trainable_variables) self.assertEqual(recon_model.global_non_trainable_variables, keras_model.non_trainable_variables) self.assertEmpty(recon_model.local_trainable_variables) self.assertEmpty(recon_model.local_non_trainable_variables) self.assertEqual(input_spec, recon_model.input_spec)
def test_from_keras_model_local_layers_properties(self): keras_model = _create_keras_model() input_spec = _create_input_spec() recon_model = keras_utils.from_keras_model( keras_model=keras_model, global_layers=keras_model. layers[:-1], # Last Dense layer is local. local_layers=keras_model.layers[-1:], input_spec=input_spec) # Expect last two variables, the weights and bias for the final Dense layer, # to be local trainable, and the rest global. self.assertEqual(recon_model.global_trainable_variables, keras_model.trainable_variables[:-2]) self.assertEqual(recon_model.global_non_trainable_variables, keras_model.non_trainable_variables) self.assertEqual(recon_model.local_trainable_variables, keras_model.trainable_variables[-2:]) self.assertEmpty(recon_model.local_non_trainable_variables) self.assertEqual(input_spec, recon_model.input_spec)
def test_from_keras_model_forward_pass_list_input(self): """Forward pass still works with a 2-element list batch input.""" keras_model = _create_keras_model() input_spec = _create_input_spec() recon_model = keras_utils.from_keras_model( keras_model=keras_model, global_layers=keras_model.layers[:-1], local_layers=keras_model.layers[-1:], input_spec=input_spec) batch_input = [ tf.ones(shape=[10, 784], dtype=tf.float32), tf.zeros(shape=[10, 1], dtype=tf.int32) ] batch_output = recon_model.forward_pass(batch_input) self.assertIsInstance(batch_output, model_lib.BatchOutput) self.assertEqual(batch_output.num_examples, 10) self.assertAllEqual(batch_output.labels, tf.zeros(shape=[10, 1], dtype=tf.int32))