Exemple #1
0
 def test_input_spec_batch_types_type_errors(self, input_spec, error_message):
   keras_model = model_examples.build_linear_regression_keras_functional_model(
       feature_dims=1)
   with self.assertRaisesRegex(TypeError, error_message):
     keras_utils.from_keras_model(
         keras_model=keras_model,
         input_spec=input_spec,
         loss=tf.keras.losses.MeanSquaredError())
 def test_input_spec_batch_types(self, input_spec):
     keras_model = model_examples.build_linear_regression_keras_functional_model(
         feature_dims=1)
     tff_model = keras_utils.from_keras_model(
         keras_model=keras_model,
         input_spec=input_spec,
         loss=tf.keras.losses.MeanSquaredError())
     self.assertIsInstance(tff_model, model_utils.EnhancedModel)
Exemple #3
0
  def test_from_compiled_keras_model_fails_on_uncompiled_model(self):
    keras_model = model_examples.build_linear_regression_keras_functional_model(
        feature_dims=1)

    with self.assertRaisesRegex(ValueError, '`keras_model` must be compiled'):
      keras_utils.from_compiled_keras_model(
          keras_model=keras_model,
          dummy_batch=_create_dummy_batch(feature_dims=1))
Exemple #4
0
 def model_fn():
     keras_model = model_examples.build_linear_regression_keras_functional_model(
         feature_dims=2)
     loss_fn = tf.keras.losses.MeanSquaredError()
     input_spec = dataset.element_spec
     return keras_utils.from_keras_model(keras_model,
                                         loss=loss_fn,
                                         input_spec=input_spec)
 def model_fn():
     keras_model = model_examples.build_linear_regression_keras_functional_model(
     )
     return keras_utils.from_keras_model(
         keras_model,
         loss=tf.keras.losses.MeanSquaredError(),
         input_spec=collections.OrderedDict(
             x=tf.TensorSpec(shape=[None, 2]),
             y=tf.TensorSpec(shape=[None, 1])))
Exemple #6
0
 def _make_keras_model():
     keras_model = model_examples.build_linear_regression_keras_functional_model(
         feature_dims)
     keras_model.compile(
         optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
         loss=tf.keras.losses.MeanSquaredError(),
         metrics=[NumBatchesCounter(),
                  NumExamplesCounter()])
     return keras_model
Exemple #7
0
 def test_input_spec_python_container(self, input_spec):
   keras_model = model_examples.build_linear_regression_keras_functional_model(
       feature_dims=1)
   tff_model = keras_utils.from_keras_model(
       keras_model=keras_model,
       input_spec=input_spec,
       loss=tf.keras.losses.MeanSquaredError())
   self.assertIsInstance(tff_model, model_utils.EnhancedModel)
   tf.nest.map_structure(lambda x: self.assertIsInstance(x, tf.TensorSpec),
                         tff_model.input_spec)
Exemple #8
0
 def test_keras_model_and_optimizer(self):
     # Expect TFF to compile the keras model if given an optimizer.
     keras_model = model_examples.build_linear_regression_keras_functional_model(
         feature_dims=1)
     tff_model = keras_utils.from_keras_model(
         keras_model=keras_model,
         dummy_batch=_create_dummy_batch(1),
         loss=tf.keras.losses.MeanSquaredError(),
         optimizer=tf.keras.optimizers.SGD(learning_rate=0.01))
     self.assertIsInstance(tff_model, model_utils.EnhancedTrainableModel)
     # pylint: disable=internal-access
     self.assertTrue(hasattr(tff_model._model._keras_model, 'optimizer'))
Exemple #9
0
    def test_from_compiled_keras_model_fails_on_uncompiled_model(self):
        keras_model = model_examples.build_linear_regression_keras_functional_model(
            feature_dims=1)

        with warnings.catch_warnings(record=True) as w:
            with self.assertRaisesRegex(ValueError,
                                        '`keras_model` must be compiled'):
                keras_utils.from_compiled_keras_model(
                    keras_model=keras_model,
                    dummy_batch=_create_dummy_batch(feature_dims=1))
            self.assertLen(w, 1)
            self.assertTrue(issubclass(w[0].category, DeprecationWarning))
    def test_keras_model_fails_compiled(self):
        feature_dims = 3
        keras_model = model_examples.build_linear_regression_keras_functional_model(
            feature_dims)

        keras_model.compile(loss=tf.keras.losses.MeanSquaredError())

        with self.assertRaisesRegex(ValueError, 'compile'):
            keras_utils.from_keras_model(
                keras_model=keras_model,
                input_spec=_create_dummy_types(feature_dims),
                loss=tf.keras.losses.MeanSquaredError(),
                metrics=[NumBatchesCounter(),
                         NumExamplesCounter()])
 def test_assign_weights_from_odict(self):
     keras_model = model_examples.build_linear_regression_keras_functional_model(
         feature_dims=1)
     weight_odict = collections.OrderedDict(
         trainable=keras_model.trainable_weights,
         non_trainable=keras_model.non_trainable_weights)
     weight_odict_plus_1 = tf.nest.map_structure(lambda x: x + 1,
                                                 weight_odict)
     keras_utils.assign_weights_to_keras_model(keras_model,
                                               weight_odict_plus_1)
     self.assertAllClose(self.evaluate(keras_model.trainable_weights),
                         weight_odict_plus_1['trainable'])
     self.assertAllClose(self.evaluate(keras_model.non_trainable_weights),
                         weight_odict_plus_1['non_trainable'])
Exemple #12
0
 def test_input_spec_struct(self):
   keras_model = model_examples.build_linear_regression_keras_functional_model(
       feature_dims=1)
   input_spec = computation_types.StructType(
       collections.OrderedDict(
           x=tf.TensorSpec(shape=[None, 1], dtype=tf.float32),
           y=tf.TensorSpec(shape=[None, 1], dtype=tf.float32)))
   tff_model = keras_utils.from_keras_model(
       keras_model=keras_model,
       input_spec=input_spec,
       loss=tf.keras.losses.MeanSquaredError())
   self.assertIsInstance(tff_model, model_utils.EnhancedModel)
   self.assertIsInstance(tff_model.input_spec, structure.Struct)
   structure.map_structure(lambda x: self.assertIsInstance(x, tf.TensorSpec),
                           tff_model.input_spec)
    def test_get_set_model_weights_keras_model(self):
        def model_fn():
            keras_model = model_examples.build_linear_regression_keras_functional_model(
            )
            return keras_utils.from_keras_model(
                keras_model,
                loss=tf.keras.losses.MeanSquaredError(),
                input_spec=collections.OrderedDict(
                    x=tf.TensorSpec(shape=[None, 2]),
                    y=tf.TensorSpec(shape=[None, 1])))

        fedavg = composers.build_basic_fedavg_process(model_fn=model_fn,
                                                      client_learning_rate=0.1)
        state = fedavg.initialize()
        # Create a local model and perform some pretraining.
        keras_model = model_examples.build_linear_regression_keras_functional_model(
        )
        keras_model.compile(optimizer='adam', loss='mse')
        keras_model.fit(self._test_data().map(lambda d: (d['x'], d['y'])))
        pretrained_weights = model_utils.ModelWeights.from_model(keras_model)
        # Assert the initial state weights are not the same as the pretrained model.
        initial_weights = fedavg.get_model_weights(state)
        self.assertNotAllClose(tf.nest.flatten(pretrained_weights),
                               tf.nest.flatten(initial_weights))
        # Change the state weights to those from our pretrained model.
        state = fedavg.set_model_weights(state, pretrained_weights)
        self.assertAllClose(tf.nest.flatten(pretrained_weights),
                            tf.nest.flatten(fedavg.get_model_weights(state)))
        # Run some FedAvg.
        client_data = [self._test_data()] * 3  # 3 clients with identical data.
        for _ in range(3):
            fedavg_result = fedavg.next(state, client_data)
            state = fedavg_result.state
        # Weights should be different after training.
        self.assertNotAllClose(
            tf.nest.flatten(pretrained_weights),
            tf.nest.flatten(fedavg.get_model_weights(state)))
        # We should be able to assign the back to the keras model without raising
        # an error.
        fedavg.get_model_weights(state).assign_weights_to(keras_model)
 def _make_keras_model():
     keras_model = model_examples.build_linear_regression_keras_functional_model(
         feature_dims)
     return keras_model
Exemple #15
0
    def test_functional_model_matches_model_fn(self, weighting):
        dataset = create_test_dataset()

        # Build a FunctionalModel based client_model_update procedure. This will
        # be compared to a model_fn based implementation built below.
        keras_model = model_examples.build_linear_regression_keras_functional_model(
            feature_dims=2)
        loss_fn = tf.keras.losses.MeanSquaredError()
        input_spec = dataset.element_spec
        functional_model = functional.functional_model_from_keras(
            keras_model, loss_fn=loss_fn, input_spec=input_spec)

        # Note: we must wrap in a `tf_computation` for the correct graph-context
        # processing of Keras models wrapped as FunctionalModel.
        @tensorflow_computation.tf_computation
        def client_update_functional_model(model_weights, dataset):
            model_delta_fn = model_delta_client_work.build_functional_model_delta_update(
                model=functional_model, weighting=weighting)
            return model_delta_fn(sgdm.build_sgdm(learning_rate=0.1),
                                  model_weights, dataset)

        # Build a model_fn based client_model_update procedure. This will be
        # comapred to the FunctionalModel variant built above to ensure they
        # procduce the same results.
        def model_fn():
            keras_model = model_examples.build_linear_regression_keras_functional_model(
                feature_dims=2)
            loss_fn = tf.keras.losses.MeanSquaredError()
            input_spec = dataset.element_spec
            return keras_utils.from_keras_model(keras_model,
                                                loss=loss_fn,
                                                input_spec=input_spec)

        client_update_model_fn = model_delta_client_work.build_model_delta_update_with_tff_optimizer(
            model_fn=model_fn, weighting=weighting)
        model_fn_optimizer = sgdm.build_sgdm(learning_rate=0.1)
        model_fn_weights = model_utils.ModelWeights.from_model(model_fn())

        functional_model_weights = functional_model.initial_weights
        for _ in range(10):
            # pylint: disable=cell-var-from-loop
            model_fn_output, _ = client_update_model_fn(
                model_fn_optimizer, model_fn_weights, dataset)
            functional_model_output, _ = client_update_functional_model(
                functional_model_weights, dataset)
            self.assertAllClose(model_fn_output.update,
                                functional_model_output.update)
            self.assertAllClose(model_fn_output.update_weight,
                                functional_model_output.update_weight)
            model_fn_weights = attr.evolve(
                model_fn_weights,
                trainable=tf.nest.map_structure(
                    lambda u, v: u + v * model_fn_output.update_weight,
                    model_fn_weights.trainable, model_fn_output.update))
            functional_model_weights = (tf.nest.map_structure(
                lambda u, v: u + v * functional_model_output.update_weight,
                functional_model_weights[0],
                functional_model_output.update), functional_model_weights[1])
            # pylint: enable=cell-var-from-loop
        self.assertAllClose(attr.astuple(model_fn_weights),
                            functional_model_weights)