Beispiel #1
0
 def test_from_keras_model_succeeds_from_set(self):
     keras_model = _create_keras_model()
     input_spec = _create_input_spec()
     keras_utils.from_keras_model(keras_model=keras_model,
                                  global_layers=set(keras_model.layers),
                                  local_layers=set(),
                                  input_spec=input_spec)
Beispiel #2
0
    def test_from_keras_model_fails_missing_variables(self):
        """Ensures failure if global/local layers are missing variables."""
        keras_model = _create_keras_model()
        input_spec = _create_input_spec()

        with self.assertRaisesRegex(ValueError, 'variables'):
            keras_utils.from_keras_model(keras_model=keras_model,
                                         global_layers=keras_model.layers[:-1],
                                         local_layers=[],
                                         input_spec=input_spec)
Beispiel #3
0
    def test_from_keras_model_fails_bad_input_spec(self):
        keras_model = _create_keras_model()
        input_spec = collections.namedtuple(
            'Batch',
            ['x'])(x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32))

        with self.assertRaisesRegex(ValueError, 'input_spec'):
            keras_utils.from_keras_model(keras_model=keras_model,
                                         global_layers=keras_model.layers,
                                         local_layers=[],
                                         input_spec=input_spec)
Beispiel #4
0
    def test_from_keras_model_fails_compiled(self):
        keras_model = _create_keras_model()
        keras_model.compile(
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            optimizer=tf.keras.optimizers.SGD(learning_rate=0.1))
        input_spec = _create_input_spec()

        with self.assertRaisesRegex(ValueError, 'compiled'):
            keras_utils.from_keras_model(keras_model=keras_model,
                                         global_layers=keras_model.layers,
                                         local_layers=[],
                                         input_spec=input_spec)
Beispiel #5
0
    def test_from_keras_model_forward_pass(self):
        keras_model = _create_keras_model()
        input_spec = _create_input_spec()

        recon_model = keras_utils.from_keras_model(
            keras_model=keras_model,
            global_layers=keras_model.layers[:-1],
            local_layers=keras_model.layers[-1:],
            input_spec=input_spec)

        batch_input = collections.namedtuple('Batch', ['x', 'y'])(
            x=tf.ones(shape=[10, 784], dtype=tf.float32),
            y=tf.zeros(shape=[10, 1], dtype=tf.int32))

        batch_output = recon_model.forward_pass(batch_input)

        self.assertIsInstance(batch_output, model_lib.BatchOutput)
        self.assertEqual(batch_output.num_examples, 10)
        self.assertAllEqual(batch_output.labels,
                            tf.zeros(shape=[10, 1], dtype=tf.int32))

        # Change num_examples and labels.
        batch_input = collections.namedtuple('Batch', ['x', 'y'])(
            x=tf.zeros(shape=[5, 784], dtype=tf.float32),
            y=tf.ones(shape=[5, 1], dtype=tf.int32))

        batch_output = recon_model.forward_pass(batch_input)

        self.assertIsInstance(batch_output, model_lib.BatchOutput)
        self.assertEqual(batch_output.num_examples, 5)
        self.assertAllEqual(batch_output.labels,
                            tf.ones(shape=[5, 1], dtype=tf.int32))
 def test_has_only_global_variables_true(self):
     keras_model = _create_keras_model()
     input_spec = _create_input_spec()
     model = keras_utils.from_keras_model(keras_model=keras_model,
                                          global_layers=keras_model.layers,
                                          local_layers=[],
                                          input_spec=input_spec)
     self.assertTrue(reconstruction_utils.has_only_global_variables(model))
def local_recon_model_fn():
    """Keras MNIST model with final dense layer local."""
    keras_model = _create_keras_model()
    input_spec = _create_input_spec()
    return keras_utils.from_keras_model(keras_model=keras_model,
                                        global_layers=keras_model.layers[:-1],
                                        local_layers=keras_model.layers[-1:],
                                        input_spec=input_spec)
def global_recon_model_fn():
    """Keras MNIST model with no local variables."""
    keras_model = _create_keras_model()
    input_spec = _create_input_spec()
    return keras_utils.from_keras_model(keras_model=keras_model,
                                        global_layers=keras_model.layers,
                                        local_layers=[],
                                        input_spec=input_spec)
Beispiel #9
0
def keras_linear_model_fn():
    """Should produce the same results as `LinearModel`."""
    inputs = tf.keras.layers.Input(shape=[1])
    scaled_input = tf.keras.layers.Dense(1,
                                         use_bias=False,
                                         kernel_initializer='zeros')(inputs)
    outputs = BiasLayer()(scaled_input)
    keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
    input_spec = _create_input_spec()
    return keras_utils.from_keras_model(keras_model=keras_model,
                                        global_layers=keras_model.layers[:-1],
                                        local_layers=keras_model.layers[-1:],
                                        input_spec=input_spec)
Beispiel #10
0
    def test_from_keras_model_forward_pass_fails_bad_input_keys(self):
        keras_model = _create_keras_model()
        input_spec = _create_input_spec()

        recon_model = keras_utils.from_keras_model(
            keras_model=keras_model,
            global_layers=keras_model.layers,
            local_layers=[],
            input_spec=input_spec)

        batch_input = collections.namedtuple('Batch', ['a', 'b'])(
            a=tf.ones(shape=[10, 784], dtype=tf.float32),
            b=tf.zeros(shape=[10, 1], dtype=tf.int32))

        with self.assertRaisesRegex(KeyError, 'keys'):
            recon_model.forward_pass(batch_input)
    def test_get_local_variables(self):
        keras_model = _create_keras_model()
        input_spec = _create_input_spec()
        model = keras_utils.from_keras_model(
            keras_model=keras_model,
            global_layers=keras_model.layers[:-1],
            local_layers=keras_model.layers[-1:],
            input_spec=input_spec)

        local_weights = reconstruction_utils.get_local_variables(model)

        self.assertIsInstance(local_weights, model_utils.ModelWeights)
        # The last layer of the Keras model, which is a local Dense layer, contains
        # 2 trainable variables for the weights and bias.
        self.assertEqual(local_weights.trainable,
                         keras_model.trainable_variables[-2:])
        self.assertEmpty(local_weights.non_trainable)
Beispiel #12
0
    def test_from_keras_model_properties(self):
        keras_model = _create_keras_model()
        input_spec = _create_input_spec()

        recon_model = keras_utils.from_keras_model(
            keras_model=keras_model,
            global_layers=keras_model.layers,
            local_layers=[],
            input_spec=input_spec)

        # Global trainable/non_trainable should include all the variables, and
        # local should be empty.
        self.assertEqual(recon_model.global_trainable_variables,
                         keras_model.trainable_variables)
        self.assertEqual(recon_model.global_non_trainable_variables,
                         keras_model.non_trainable_variables)
        self.assertEmpty(recon_model.local_trainable_variables)
        self.assertEmpty(recon_model.local_non_trainable_variables)
        self.assertEqual(input_spec, recon_model.input_spec)
Beispiel #13
0
    def test_from_keras_model_local_layers_properties(self):
        keras_model = _create_keras_model()
        input_spec = _create_input_spec()

        recon_model = keras_utils.from_keras_model(
            keras_model=keras_model,
            global_layers=keras_model.
            layers[:-1],  # Last Dense layer is local.
            local_layers=keras_model.layers[-1:],
            input_spec=input_spec)

        # Expect last two variables, the weights and bias for the final Dense layer,
        # to be local trainable, and the rest global.
        self.assertEqual(recon_model.global_trainable_variables,
                         keras_model.trainable_variables[:-2])
        self.assertEqual(recon_model.global_non_trainable_variables,
                         keras_model.non_trainable_variables)
        self.assertEqual(recon_model.local_trainable_variables,
                         keras_model.trainable_variables[-2:])
        self.assertEmpty(recon_model.local_non_trainable_variables)
        self.assertEqual(input_spec, recon_model.input_spec)
Beispiel #14
0
    def test_from_keras_model_forward_pass_list_input(self):
        """Forward pass still works with a 2-element list batch input."""
        keras_model = _create_keras_model()
        input_spec = _create_input_spec()

        recon_model = keras_utils.from_keras_model(
            keras_model=keras_model,
            global_layers=keras_model.layers[:-1],
            local_layers=keras_model.layers[-1:],
            input_spec=input_spec)

        batch_input = [
            tf.ones(shape=[10, 784], dtype=tf.float32),
            tf.zeros(shape=[10, 1], dtype=tf.int32)
        ]

        batch_output = recon_model.forward_pass(batch_input)

        self.assertIsInstance(batch_output, model_lib.BatchOutput)
        self.assertEqual(batch_output.num_examples, 10)
        self.assertAllEqual(batch_output.labels,
                            tf.zeros(shape=[10, 1], dtype=tf.int32))