def test_keras_model_federated_output_computation(self): feature_dims = 3 num_train_steps = 3 def _make_keras_model(): keras_model = model_examples.build_linear_regression_keras_functional_model( feature_dims) return keras_model def _model_fn(): return keras_utils.from_keras_model( keras_model=_make_keras_model(), input_spec=_create_dummy_types(feature_dims), loss=tf.keras.losses.MeanSquaredError(), metrics=[NumBatchesCounter(), NumExamplesCounter()]) @tff.tf_computation() def _train(): # Create variables outside the tf.function. tff_model = _model_fn() optimizer = tf.keras.optimizers.SGD(0.1) @tf.function def _train_loop(): for _ in range(num_train_steps): with tf.GradientTape() as tape: batch_output = tff_model.forward_pass( collections.OrderedDict( x=np.ones([2, feature_dims], dtype=np.float32), y=np.ones([2, 1], dtype=np.float32))) gradients = tape.gradient(batch_output.loss, tff_model.trainable_variables) optimizer.apply_gradients( zip(gradients, tff_model.trainable_variables)) return tff_model.report_local_outputs(), tff_model.weights return _train_loop() # Simulate 'CLIENT' local training. client_local_outputs, tff_weights = _train() # Simulate entering the 'SERVER' context. tf.keras.backend.clear_session() aggregated_outputs = _model_fn().federated_output_computation( [client_local_outputs]) aggregated_outputs = collections.OrderedDict( anonymous_tuple.to_elements(aggregated_outputs)) self.assertEqual(aggregated_outputs['num_batches'], num_train_steps) self.assertEqual(aggregated_outputs['num_examples'], 2 * num_train_steps) self.assertGreater(aggregated_outputs['loss'], 0.0) keras_model = _make_keras_model() keras_utils.assign_weights_to_keras_model(keras_model, tff_weights)
def test_keras_model_federated_output_computation(self): feature_dims = 3 def _make_keras_model(): keras_model = model_examples.build_linear_regression_keras_functional_model( feature_dims) keras_model.compile( optimizer=tf.keras.optimizers.SGD(learning_rate=0.01), loss=tf.keras.losses.MeanSquaredError(), metrics=[NumBatchesCounter(), NumExamplesCounter()]) return keras_model def _model_fn(): return keras_utils.from_compiled_keras_model( keras_model=_make_keras_model(), dummy_batch=_create_dummy_batch(feature_dims)) num_iterations = 3 # TODO(b/122081673): This should be a @tf.function and the control # dependencies can go away (probably nothing blocking this, but it # just needs to be done and tested). @tff.tf_computation() def _train_loop(): tff_model = _model_fn() ops = [] for _ in range(num_iterations): with tf.control_dependencies(ops): batch_output = tff_model.train_on_batch({ 'x': np.ones([2, feature_dims], dtype=np.float32), 'y': np.ones([2, 1], dtype=np.float32) }) ops = list(batch_output) with tf.control_dependencies(ops): return (tff_model.report_local_outputs(), tff_model.weights) client_local_outputs, tff_weights = _train_loop() # Simulate entering the 'SERVER' context with a new graph. tf.keras.backend.clear_session() aggregated_outputs = _model_fn().federated_output_computation( [client_local_outputs]) aggregated_outputs = collections.OrderedDict( anonymous_tuple.to_elements(aggregated_outputs)) self.assertEqual(aggregated_outputs['num_batches'], num_iterations) self.assertEqual(aggregated_outputs['num_examples'], 2 * num_iterations) self.assertGreater(aggregated_outputs['loss'], 0.0) keras_model = _make_keras_model() keras_utils.assign_weights_to_keras_model(keras_model, tff_weights)
def test_assign_weights_from_odict(self): keras_model = model_examples.build_linear_regression_keras_functional_model( feature_dims=1) weight_odict = collections.OrderedDict( trainable=keras_model.trainable_weights, non_trainable=keras_model.non_trainable_weights) weight_odict_plus_1 = tf.nest.map_structure(lambda x: x + 1, weight_odict) keras_utils.assign_weights_to_keras_model(keras_model, weight_odict_plus_1) self.assertAllClose(self.evaluate(keras_model.trainable_weights), weight_odict_plus_1['trainable']) self.assertAllClose(self.evaluate(keras_model.non_trainable_weights), weight_odict_plus_1['non_trainable'])