Example #1
0
  def test_keras_layer_capturing_other_layer_fails(self):

    class SharedLayer(tf.keras.layers.Layer):

      def __init__(self, dense_layer: tf.keras.layers.Dense, **kwargs):
        super().__init__(**kwargs)
        self._dense_layer = dense_layer
        self.kernel = dense_layer.kernel
        self.bias = dense_layer.bias

      def call(self, inputs):
        return inputs

      def get_config(self):
        config = super().get_config()
        config.update({'dense_layer': self._dense_layer})
        return config

    inputs = tf.keras.layers.Input(shape=[1])
    layer1 = tf.keras.layers.Dense(1)
    y = layer1(inputs)
    layer2 = SharedLayer(layer1)
    outputs = layer2(y)
    keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)

    with self.assertRaisesRegex(functional.KerasFunctionalModelError,
                                'sharing variables across layers'):
      functional.functional_model_from_keras(
          keras_model,
          tf.keras.losses.MeanSquaredError(),
          input_spec=(tf.TensorSpec(shape=[None, 1]),
                      tf.TensorSpec(shape=[None, 1])))
Example #2
0
  def test_keras_layer_input_other_layer_fails(self):
    # A variant of test_keras_layer_capturing_other_layer_fails, but
    # instead of passing the layer in the construction, it takes the other
    # layer as an input to `call`.

    class SharedLayer(tf.keras.layers.Layer):

      def call(self, inputs: tf.Tensor,
               dense_layer: tf.keras.layers.Dense) -> tf.Tensor:
        return inputs @ dense_layer.kernel + dense_layer.bias

    def create_test_model():
      inputs = tf.keras.layers.Input(shape=[1])
      layer1 = tf.keras.layers.Dense(1)
      y = layer1(inputs)
      layer2 = SharedLayer()
      outputs = layer2(y, layer1)
      return tf.keras.Model(inputs=inputs, outputs=outputs)

    with self.assertRaisesRegex(
        functional.KerasFunctionalModelError,
        'has a layer that receives inputs from other layers directly'):
      functional.functional_model_from_keras(
          create_test_model(),
          tf.keras.losses.MeanSquaredError(),
          input_spec=(tf.TensorSpec(shape=[None, 1]),
                      tf.TensorSpec(shape=[None, 1])))
    functional.functional_model_from_keras(
        create_test_model,
        tf.keras.losses.MeanSquaredError(),
        input_spec=(tf.TensorSpec(shape=[None, 1]),
                    tf.TensorSpec(shape=[None, 1])))
Example #3
0
 def test_keras_model_with_batch_normalization_fails(self):
   model = tf.keras.models.Sequential([
       tf.keras.layers.InputLayer(input_shape=[10]),
       tf.keras.layers.BatchNormalization(),
   ])
   with self.assertRaisesRegex(functional.KerasFunctionalModelError,
                               'batch normalization'):
     functional.functional_model_from_keras(
         model,
         tf.keras.losses.MeanSquaredError(),
         input_spec=(tf.TensorSpec(shape=[None, 10]),
                     tf.TensorSpec(shape=[None, 1])))
Example #4
0
 def test_keras_model_with_non_trainable_variables_fails(self):
   inputs = tf.keras.layers.Input(shape=[1])
   d = tf.keras.layers.Dense(1)
   d.trainable = False
   outputs = d(inputs)
   keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
   with self.assertRaisesRegex(functional.KerasFunctionalModelError,
                               'non-trainable variables'):
     functional.functional_model_from_keras(
         keras_model,
         tf.keras.losses.MeanSquaredError(),
         input_spec=(tf.TensorSpec(shape=[None, 1]),
                     tf.TensorSpec(shape=[None, 1])))
Example #5
0
 def test_predict_on_batch_keras_outside_graph_fails(self):
   dataset = create_test_dataset()
   example_batch = next(iter(dataset))
   functional_model = functional.functional_model_from_keras(
       keras_model=create_test_keras_model(),
       loss_fn=tf.keras.losses.MeanSquaredError(),
       input_spec=(tf.TensorSpec([None, 1], dtype=tf.float32),
                   tf.TensorSpec([None, 1], dtype=tf.int32)))
   with self.assertRaisesRegex(functional.KerasFunctionalModelError,
                               'only usable inside a tff.tf_computation'):
     functional_model.predict_on_batch(functional_model.initial_weights,
                                       example_batch[0])
def create_test_keras_functional_model(input_spec):
    # We must create the functional model that wraps a keras model in a graph
    # context (see IMPORTANT note in `functional_model_from_keras`), otherwise
    # we'll get non-model Variables.
    keras_model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=[3]),
        tf.keras.layers.Dense(1,
                              kernel_initializer='zeros',
                              bias_initializer='zeros')
    ])
    return functional.functional_model_from_keras(
        keras_model,
        loss_fn=tf.keras.losses.MeanSquaredError(),
        input_spec=input_spec)
Example #7
0
 def test_construct_from_keras(self):
   keras_model = create_test_keras_model()
   # Assign some variables after initialization so we can assert that they
   # were cloned into the FunctionalModel.
   tf.nest.map_structure(lambda v: v.assign(tf.ones_like(v)),
                         keras_model.variables)
   functional_model = functional.functional_model_from_keras(
       keras_model=keras_model,
       loss_fn=tf.keras.losses.MeanSquaredError(),
       input_spec=(tf.TensorSpec([None, 1], dtype=tf.float32),
                   tf.TensorSpec([None, 1], dtype=tf.int32)))
   self.assertIsInstance(functional_model, functional.FunctionalModel)
   # Assert all ones, instead of zeros from a newly initial model.
   tf.nest.map_structure(lambda v: self.assertAllClose(v, tf.ones_like(v)),
                         functional_model.initial_weights)
Example #8
0
  def test_construct_from_keras_converges(self):
    functional_model = functional.functional_model_from_keras(
        keras_model=create_test_keras_model(),
        loss_fn=tf.keras.losses.MeanSquaredError(),
        input_spec=(tf.TensorSpec([None, 1], dtype=tf.float32),
                    tf.TensorSpec([None, 1], dtype=tf.int32)))
    with tf.Graph().as_default() as test_graph:
      # Capture all the variables for later initialization in the session,
      # otherwise it's hard to get our hands on the Keras-owned variables.
      with variable_utils.record_variable_creation_scope(
      ) as captured_variables:
        # Create data satisfying y = 2*x + 1
        dataset = tf.data.Dataset.from_tensor_slices((
            # Features
            [[1.0], [2.0], [3.0]],
            # Labels.
            [[3.0], [5.0], [7.0]],
        )).batch(1)
        variables = tf.nest.map_structure(tf.Variable,
                                          functional_model.initial_weights)
        optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

        @tf.function
        def train():
          weights = tf.nest.map_structure(lambda v: v.read_value(), variables)
          initial_loss = loss = functional_model.forward_pass(
              weights, next(iter(dataset)), training=True).loss
          trainable = variables[0]
          for batch in dataset.repeat(30):
            with tf.GradientTape() as tape:
              weights = tf.nest.map_structure(lambda v: v.read_value(),
                                              variables)
              tape.watch(weights[0])
              batch_output = functional_model.forward_pass(
                  weights, batch, training=True)
            gradients = tape.gradient(batch_output.loss, weights[0])
            optimizer.apply_gradients(zip(gradients, trainable))
            loss = batch_output.loss
          return initial_loss, loss

        initial_loss, final_loss = train()
    with tf.compat.v1.Session(graph=test_graph) as sess:
      sess.run(tf.compat.v1.initializers.variables(captured_variables))
      initial_loss, final_loss = sess.run([initial_loss, final_loss])
    # Expect some amount of convergence after a few epochs of the dataset.
    self.assertGreater(initial_loss, 2.0)
    self.assertLess(final_loss, 0.2)
Example #9
0
    def test_functional_model_matches_model_fn(self, weighting):
        dataset = create_test_dataset()

        # Build a FunctionalModel based client_model_update procedure. This will
        # be compared to a model_fn based implementation built below.
        keras_model = model_examples.build_linear_regression_keras_functional_model(
            feature_dims=2)
        loss_fn = tf.keras.losses.MeanSquaredError()
        input_spec = dataset.element_spec
        functional_model = functional.functional_model_from_keras(
            keras_model, loss_fn=loss_fn, input_spec=input_spec)

        # Note: we must wrap in a `tf_computation` for the correct graph-context
        # processing of Keras models wrapped as FunctionalModel.
        @tensorflow_computation.tf_computation
        def client_update_functional_model(model_weights, dataset):
            model_delta_fn = model_delta_client_work.build_functional_model_delta_update(
                model=functional_model, weighting=weighting)
            return model_delta_fn(sgdm.build_sgdm(learning_rate=0.1),
                                  model_weights, dataset)

        # Build a model_fn based client_model_update procedure. This will be
        # comapred to the FunctionalModel variant built above to ensure they
        # procduce the same results.
        def model_fn():
            keras_model = model_examples.build_linear_regression_keras_functional_model(
                feature_dims=2)
            loss_fn = tf.keras.losses.MeanSquaredError()
            input_spec = dataset.element_spec
            return keras_utils.from_keras_model(keras_model,
                                                loss=loss_fn,
                                                input_spec=input_spec)

        client_update_model_fn = model_delta_client_work.build_model_delta_update_with_tff_optimizer(
            model_fn=model_fn, weighting=weighting)
        model_fn_optimizer = sgdm.build_sgdm(learning_rate=0.1)
        model_fn_weights = model_utils.ModelWeights.from_model(model_fn())

        functional_model_weights = functional_model.initial_weights
        for _ in range(10):
            # pylint: disable=cell-var-from-loop
            model_fn_output, _ = client_update_model_fn(
                model_fn_optimizer, model_fn_weights, dataset)
            functional_model_output, _ = client_update_functional_model(
                functional_model_weights, dataset)
            self.assertAllClose(model_fn_output.update,
                                functional_model_output.update)
            self.assertAllClose(model_fn_output.update_weight,
                                functional_model_output.update_weight)
            model_fn_weights = attr.evolve(
                model_fn_weights,
                trainable=tf.nest.map_structure(
                    lambda u, v: u + v * model_fn_output.update_weight,
                    model_fn_weights.trainable, model_fn_output.update))
            functional_model_weights = (tf.nest.map_structure(
                lambda u, v: u + v * functional_model_output.update_weight,
                functional_model_weights[0],
                functional_model_output.update), functional_model_weights[1])
            # pylint: enable=cell-var-from-loop
        self.assertAllClose(attr.astuple(model_fn_weights),
                            functional_model_weights)