Esempio n. 1
0
    def test_keras_model_federated_output_computation(self):
        feature_dims = 3
        num_train_steps = 3

        def _make_keras_model():
            keras_model = model_examples.build_linear_regression_keras_functional_model(
                feature_dims)
            return keras_model

        def _model_fn():
            return keras_utils.from_keras_model(
                keras_model=_make_keras_model(),
                input_spec=_create_dummy_types(feature_dims),
                loss=tf.keras.losses.MeanSquaredError(),
                metrics=[NumBatchesCounter(),
                         NumExamplesCounter()])

        @tff.tf_computation()
        def _train():
            # Create variables outside the tf.function.
            tff_model = _model_fn()
            optimizer = tf.keras.optimizers.SGD(0.1)

            @tf.function
            def _train_loop():
                for _ in range(num_train_steps):
                    with tf.GradientTape() as tape:
                        batch_output = tff_model.forward_pass(
                            collections.OrderedDict(
                                x=np.ones([2, feature_dims], dtype=np.float32),
                                y=np.ones([2, 1], dtype=np.float32)))
                    gradients = tape.gradient(batch_output.loss,
                                              tff_model.trainable_variables)
                    optimizer.apply_gradients(
                        zip(gradients, tff_model.trainable_variables))
                return tff_model.report_local_outputs(), tff_model.weights

            return _train_loop()

        # Simulate 'CLIENT' local training.
        client_local_outputs, tff_weights = _train()

        # Simulate entering the 'SERVER' context.
        tf.keras.backend.clear_session()

        aggregated_outputs = _model_fn().federated_output_computation(
            [client_local_outputs])
        aggregated_outputs = collections.OrderedDict(
            anonymous_tuple.to_elements(aggregated_outputs))
        self.assertEqual(aggregated_outputs['num_batches'], num_train_steps)
        self.assertEqual(aggregated_outputs['num_examples'],
                         2 * num_train_steps)
        self.assertGreater(aggregated_outputs['loss'], 0.0)

        keras_model = _make_keras_model()
        keras_utils.assign_weights_to_keras_model(keras_model, tff_weights)
Esempio n. 2
0
    def test_keras_model_federated_output_computation(self):
        feature_dims = 3

        def _make_keras_model():
            keras_model = model_examples.build_linear_regression_keras_functional_model(
                feature_dims)
            keras_model.compile(
                optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
                loss=tf.keras.losses.MeanSquaredError(),
                metrics=[NumBatchesCounter(),
                         NumExamplesCounter()])
            return keras_model

        def _model_fn():
            return keras_utils.from_compiled_keras_model(
                keras_model=_make_keras_model(),
                dummy_batch=_create_dummy_batch(feature_dims))

        num_iterations = 3
        # TODO(b/122081673): This should be a @tf.function and the control
        # dependencies can go away (probably nothing blocking this, but it
        # just needs to be done and tested).
        @tff.tf_computation()
        def _train_loop():
            tff_model = _model_fn()
            ops = []
            for _ in range(num_iterations):
                with tf.control_dependencies(ops):
                    batch_output = tff_model.train_on_batch({
                        'x':
                        np.ones([2, feature_dims], dtype=np.float32),
                        'y':
                        np.ones([2, 1], dtype=np.float32)
                    })
                    ops = list(batch_output)
            with tf.control_dependencies(ops):
                return (tff_model.report_local_outputs(), tff_model.weights)

        client_local_outputs, tff_weights = _train_loop()

        # Simulate entering the 'SERVER' context with a new graph.
        tf.keras.backend.clear_session()
        aggregated_outputs = _model_fn().federated_output_computation(
            [client_local_outputs])
        aggregated_outputs = collections.OrderedDict(
            anonymous_tuple.to_elements(aggregated_outputs))
        self.assertEqual(aggregated_outputs['num_batches'], num_iterations)
        self.assertEqual(aggregated_outputs['num_examples'],
                         2 * num_iterations)
        self.assertGreater(aggregated_outputs['loss'], 0.0)

        keras_model = _make_keras_model()
        keras_utils.assign_weights_to_keras_model(keras_model, tff_weights)
Esempio n. 3
0
 def test_assign_weights_from_odict(self):
     keras_model = model_examples.build_linear_regression_keras_functional_model(
         feature_dims=1)
     weight_odict = collections.OrderedDict(
         trainable=keras_model.trainable_weights,
         non_trainable=keras_model.non_trainable_weights)
     weight_odict_plus_1 = tf.nest.map_structure(lambda x: x + 1,
                                                 weight_odict)
     keras_utils.assign_weights_to_keras_model(keras_model,
                                               weight_odict_plus_1)
     self.assertAllClose(self.evaluate(keras_model.trainable_weights),
                         weight_odict_plus_1['trainable'])
     self.assertAllClose(self.evaluate(keras_model.non_trainable_weights),
                         weight_odict_plus_1['non_trainable'])