def test_evaluation_computation_custom_stateless_broadcaster( self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [counters.NumExamplesCounter(), NumOverCounter(5.0)] model_weights_type = type_conversions.type_from_tensors( reconstruction_utils.get_global_variables(model_fn())) def build_custom_stateless_broadcaster( model_weights_type) -> measured_process_lib.MeasuredProcess: """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`.""" @federated_computation.federated_computation() def test_server_initialization(): return intrinsics.federated_value((), placements.SERVER) @federated_computation.federated_computation( computation_types.FederatedType((), placements.SERVER), computation_types.FederatedType(model_weights_type, placements.SERVER), ) def stateless_broadcast(state, value): test_metrics = intrinsics.federated_value( 3.0, placements.SERVER) return measured_process_lib.MeasuredProcessOutput( state=state, result=intrinsics.federated_broadcast(value), measurements=test_metrics) return measured_process_lib.MeasuredProcess( initialize_fn=test_server_initialization, next_fn=stateless_broadcast) evaluate = evaluation_computation.build_federated_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1), broadcast_process=build_custom_stateless_broadcaster( model_weights_type=model_weights_type)) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> <broadcast=float32,eval=' '<loss=float32,num_examples=int64,num_over=float32>>@SERVER)') result = evaluate( collections.OrderedDict([ ('trainable', [[[1.0]]]), ('non_trainable', []), ]), create_client_data()) self.assertEqual(result['broadcast'], 3.0)
def client_computation(incoming_model_weights: computation_types.Type, client_dataset: computation_types.SequenceType): """Reconstructs and evaluates with `incoming_model_weights`.""" client_model = model_fn() client_global_weights = reconstruction_utils.get_global_variables( client_model) client_local_weights = reconstruction_utils.get_local_variables( client_model) metrics = [keras_utils.MeanLossMetric(loss_fn())] if metrics_fn is not None: metrics.extend(metrics_fn()) client_loss = loss_fn() reconstruction_optimizer = reconstruction_optimizer_fn() @tf.function def reconstruction_reduce_fn(num_examples_sum, batch): """Runs reconstruction training on local client batch.""" with tf.GradientTape() as tape: output = client_model.forward_pass(batch, training=True) batch_loss = client_loss(y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, client_local_weights.trainable) reconstruction_optimizer.apply_gradients( zip(gradients, client_local_weights.trainable)) return num_examples_sum + output.num_examples @tf.function def evaluation_reduce_fn(num_examples_sum, batch): """Runs evaluation on client batch without training.""" output = client_model.forward_pass(batch, training=False) # Update each metric. for metric in metrics: metric.update_state(y_true=output.labels, y_pred=output.predictions) return num_examples_sum + output.num_examples @tf.function def tf_client_computation(incoming_model_weights, client_dataset): """Reconstructs and evaluates with `incoming_model_weights`.""" recon_dataset, eval_dataset = dataset_split_fn(client_dataset) # Assign incoming global weights to `client_model` before reconstruction. tf.nest.map_structure(lambda v, t: v.assign(t), client_global_weights, incoming_model_weights) recon_dataset.reduce(tf.constant(0), reconstruction_reduce_fn) eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn) eval_local_outputs = keras_utils.read_metric_variables(metrics) return eval_local_outputs return tf_client_computation(incoming_model_weights, client_dataset)
def server_init_tf(): """Initialize the TensorFlow-only portions of the server state.""" model_weights = reconstruction_utils.get_global_variables(model_fn()) optimizer = keras_optimizer.build_or_verify_tff_optimizer( server_optimizer_fn, model_weights.trainable, disjoint_init_and_next=True) trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights.trainable) optimizer_state = optimizer.initialize(trainable_tensor_specs) return model_weights, optimizer_state
def server_update(server_state, weights_delta, aggregator_state, broadcaster_state): """Updates the `server_state` based on `weights_delta`. Args: server_state: A `tff.learning.framework.ServerState`, the state to be updated. weights_delta: The model delta in global trainable variables from clients. aggregator_state: The state of the aggregator after performing aggregation. broadcaster_state: The state of the broadcaster after broadcasting. Returns: The updated `tff.learning.framework.ServerState`. """ with tf.init_scope(): model = model_fn() global_model_weights = reconstruction_utils.get_global_variables(model) optimizer = keras_optimizer.build_or_verify_tff_optimizer( server_optimizer_fn, global_model_weights.trainable, disjoint_init_and_next=True) optimizer_state = server_state.optimizer_state # Initialize the model with the current state. tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights, server_state.model) weights_delta, has_non_finite_weight = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) # We ignore the update if the weights_delta is non finite. if tf.equal(has_non_finite_weight, 0): negative_weights_delta = tf.nest.map_structure( lambda w: -1.0 * w, weights_delta) optimizer_state, updated_weights = optimizer.next( optimizer_state, global_model_weights.trainable, negative_weights_delta) if not isinstance(optimizer, keras_optimizer.KerasOptimizer): # Keras optimizer mutates model variables within the `next` step. tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights.trainable, updated_weights) # Create a new state based on the updated model. return structure.update_struct( server_state, model=global_model_weights, optimizer_state=optimizer_state, model_broadcast_state=broadcaster_state, delta_aggregate_state=aggregator_state, )
def test_get_global_variables(self): keras_model = _create_keras_model() input_spec = _create_input_spec() model = keras_utils.from_keras_model( keras_model=keras_model, global_layers=keras_model.layers[:-1], local_layers=keras_model.layers[-1:], input_spec=input_spec) global_weights = reconstruction_utils.get_global_variables(model) self.assertIsInstance(global_weights, model_utils.ModelWeights) # The last layer of the Keras model, which is a local Dense layer, contains # 2 trainable variables for the weights and bias. self.assertEqual(global_weights.trainable, keras_model.trainable_variables[:-2]) self.assertEmpty(global_weights.non_trainable)
def test_evaluation_computation_custom_stateful_broadcaster_fails( self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [counters.NumExamplesCounter(), NumOverCounter(5.0)] model_weights_type = type_conversions.type_from_tensors( reconstruction_utils.get_global_variables(model_fn())) def build_custom_stateful_broadcaster( model_weights_type) -> measured_process_lib.MeasuredProcess: """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`.""" @federated_computation.federated_computation() def test_server_initialization(): return intrinsics.federated_value(2.0, placements.SERVER) @federated_computation.federated_computation( computation_types.FederatedType(tf.float32, placements.SERVER), computation_types.FederatedType(model_weights_type, placements.SERVER), ) def stateful_broadcast(state, value): test_metrics = intrinsics.federated_value( 3.0, placements.SERVER) return measured_process_lib.MeasuredProcessOutput( state=state, result=intrinsics.federated_broadcast(value), measurements=test_metrics) return measured_process_lib.MeasuredProcess( initialize_fn=test_server_initialization, next_fn=stateful_broadcast) with self.assertRaisesRegex(TypeError, 'must be stateless'): evaluation_computation.build_federated_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1 ), broadcast_process=build_custom_stateful_broadcaster( model_weights_type=model_weights_type))
def build_federated_evaluation( model_fn: training_process.ModelFn, *, # Callers pass below args by name. loss_fn: training_process.LossFn, metrics_fn: Optional[training_process.MetricsFn] = None, reconstruction_optimizer_fn: training_process.OptimizerFn = functools. partial(tf.keras.optimizers.SGD, 0.1), dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None, broadcast_process: Optional[measured_process_lib.MeasuredProcess] = None, ) -> computation_base.Computation: """Builds a `tff.Computation` for evaluating a reconstruction `Model`. The returned computation proceeds in two stages: (1) reconstruction and (2) evaluation. During the reconstruction stage, local variables are reconstructed by freezing global variables and training using `reconstruction_optimizer_fn`. During the evaluation stage, the reconstructed local variables and global variables are evaluated using the provided `loss_fn` and `metrics_fn`. Usage of returned computation: eval_comp = build_federated_evaluation(...) metrics = eval_comp(tff.learning.reconstruction.get_global_variables(model), federated_data) Args: model_fn: A no-arg function that returns a `tff.learning.reconstruction.Model`. This method must *not* capture Tensorflow tensors or variables and use them. Must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to reconstruct and evaluate the model. The loss will be applied to the model's outputs during the evaluation stage. The final loss metric is the example-weighted mean loss across batches (and across clients). metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s to evaluate the model. The metrics will be applied to the model's outputs during the evaluation stage. Final metric values are the example-weighted mean of metric values across batches (and across clients). If None, no metrics are applied. reconstruction_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` used to reconstruct the local variables with the global ones frozen. dataset_split_fn: A `tff.learning.reconstruction.DatasetSplitFn` taking in a single TF dataset and producing two TF datasets. The first is iterated over during reconstruction, and the second is iterated over during evaluation. This can be used to preprocess datasets to e.g. iterate over them for multiple epochs or use disjoint data for reconstruction and evaluation. If None, split client data in half for each user, using one half for reconstruction and the other for evaluation. See `tff.learning.reconstruction.build_dataset_split_fn` for options. broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the model weights on the server to the clients. It must support the signature `(input_values@SERVER -> output_values@CLIENT)` and have empty state. If set to default None, the server model is broadcast to the clients using the default `tff.federated_broadcast`. Raises: TypeError: if `broadcast_process` does not have the expected signature or has non-empty state. Returns: A `tff.Computation` that accepts global model parameters and federated data and returns example-weighted evaluation loss and metrics. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. with tf.Graph().as_default(): model = model_fn() global_weights = reconstruction_utils.get_global_variables(model) model_weights_type = type_conversions.type_from_tensors(global_weights) batch_type = computation_types.to_type(model.input_spec) metrics = [keras_utils.MeanLossMetric(loss_fn())] if metrics_fn is not None: metrics.extend(metrics_fn()) federated_output_computation = ( keras_utils.federated_output_computation_from_metrics(metrics)) # Remove unneeded variables to avoid polluting namespace. del model del global_weights del metrics if dataset_split_fn is None: dataset_split_fn = reconstruction_utils.build_dataset_split_fn( split_dataset=True) if broadcast_process is None: broadcast_process = optimizer_utils.build_stateless_broadcaster( model_weights_type=model_weights_type) if not optimizer_utils.is_valid_broadcast_process(broadcast_process): raise TypeError( 'broadcast_process type signature does not conform to expected ' 'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).' ' Got: {t}'.format(t=broadcast_process.next.type_signature)) if iterative_process.is_stateful(broadcast_process): raise TypeError( f'Eval broadcast_process must be stateless (have an empty ' 'state), has state ' f'{broadcast_process.initialize.type_signature.result!r}') @tensorflow_computation.tf_computation( model_weights_type, computation_types.SequenceType(batch_type)) def client_computation(incoming_model_weights: computation_types.Type, client_dataset: computation_types.SequenceType): """Reconstructs and evaluates with `incoming_model_weights`.""" client_model = model_fn() client_global_weights = reconstruction_utils.get_global_variables( client_model) client_local_weights = reconstruction_utils.get_local_variables( client_model) metrics = [keras_utils.MeanLossMetric(loss_fn())] if metrics_fn is not None: metrics.extend(metrics_fn()) client_loss = loss_fn() reconstruction_optimizer = reconstruction_optimizer_fn() @tf.function def reconstruction_reduce_fn(num_examples_sum, batch): """Runs reconstruction training on local client batch.""" with tf.GradientTape() as tape: output = client_model.forward_pass(batch, training=True) batch_loss = client_loss(y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, client_local_weights.trainable) reconstruction_optimizer.apply_gradients( zip(gradients, client_local_weights.trainable)) return num_examples_sum + output.num_examples @tf.function def evaluation_reduce_fn(num_examples_sum, batch): """Runs evaluation on client batch without training.""" output = client_model.forward_pass(batch, training=False) # Update each metric. for metric in metrics: metric.update_state(y_true=output.labels, y_pred=output.predictions) return num_examples_sum + output.num_examples @tf.function def tf_client_computation(incoming_model_weights, client_dataset): """Reconstructs and evaluates with `incoming_model_weights`.""" recon_dataset, eval_dataset = dataset_split_fn(client_dataset) # Assign incoming global weights to `client_model` before reconstruction. tf.nest.map_structure(lambda v, t: v.assign(t), client_global_weights, incoming_model_weights) recon_dataset.reduce(tf.constant(0), reconstruction_reduce_fn) eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn) eval_local_outputs = keras_utils.read_metric_variables(metrics) return eval_local_outputs return tf_client_computation(incoming_model_weights, client_dataset) @federated_computation.federated_computation( computation_types.at_server(model_weights_type), computation_types.at_clients( computation_types.SequenceType(batch_type))) def server_eval(server_model_weights: computation_types.FederatedType, federated_dataset: computation_types.FederatedType): broadcast_output = broadcast_process.next( broadcast_process.initialize(), server_model_weights) client_outputs = intrinsics.federated_map( client_computation, [broadcast_output.result, federated_dataset]) aggregated_client_outputs = federated_output_computation( client_outputs) measurements = intrinsics.federated_zip( collections.OrderedDict(broadcast=broadcast_output.measurements, eval=aggregated_client_outputs)) return measurements return server_eval
def build_training_process( model_fn: ModelFn, *, # Callers pass below args by name. loss_fn: LossFn, metrics_fn: Optional[MetricsFn] = None, server_optimizer_fn: OptimizerFn = functools.partial( tf.keras.optimizers.SGD, 1.0), client_optimizer_fn: OptimizerFn = functools.partial( tf.keras.optimizers.SGD, 0.1), reconstruction_optimizer_fn: OptimizerFn = functools.partial( tf.keras.optimizers.SGD, 0.1), dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None, client_weighting: Optional[client_weight_lib.ClientWeightType] = None, broadcast_process: Optional[measured_process_lib.MeasuredProcess] = None, aggregation_factory: Optional[AggregationFactory] = None, ) -> iterative_process_lib.IterativeProcess: """Builds the IterativeProcess for optimization using FedRecon. Returns a `tff.templates.IterativeProcess` for Federated Reconstruction. On the client, computation can be divided into two stages: (1) reconstruction of local variables and (2) training of global variables. Args: model_fn: A no-arg function that returns a `tff.learning.reconstruction.Model`. This method must *not* capture Tensorflow tensors or variables and use them. must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to compute local model updates during reconstruction and post-reconstruction and evaluate the model during training. The final loss metric is the example-weighted mean loss across batches and across clients. The loss metric does not include reconstruction batches in the loss. metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s to evaluate the model. Metrics results are computed locally as described by the metric, and are aggregated across clients as in `federated_aggregate_keras_metric`. If None, no metrics are applied. Metrics are not computed on reconstruction batches. server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg function that returns a `tf.keras.optimizers.Optimizer` for applying updates to the global model on the server. client_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg function that returns a `tf.keras.optimizers.Optimizer` for local client training after reconstruction. reconstruction_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg function that returns a `tf.keras.optimizers.Optimizer` used to reconstruct the local variables, with the global ones frozen, or the first stage described above. dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a single TF dataset and producing two TF datasets. The first is iterated over during reconstruction, and the second is iterated over post-reconstruction. This can be used to preprocess datasets to e.g. iterate over them for multiple epochs or use disjoint data for reconstruction and post-reconstruction. If None, split client data in half for each user, using one half for reconstruction and the other for evaluation. See `reconstruction_utils.build_dataset_split_fn` for options. client_weighting: A value of `tff.learning.ClientWeighting` that specifies a built-in weighting method, or a callable that takes the local metrics of the model and returns a tensor that provides the weight in the federated average of model deltas. If None, defaults to weighting by number of examples. broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the model weights on the server to the clients. It must support the signature `(input_values@SERVER -> output_values@CLIENT)`. If set to default None, the server model is broadcast to the clients using the default `tff.federated_broadcast`. aggregation_factory: An optional instance of `tff.aggregators.WeightedAggregationFactory` or `tff.aggregators.UnweightedAggregationFactory` determining the method of aggregation to perform. If unspecified, uses a default `tff.aggregators.MeanFactory` which computes a stateless mean across clients (weighted depending on `client_weighting`). Raises: TypeError: If `broadcast_process` does not have the expected signature. TypeError: If `aggregation_factory` does not have the expected signature. ValueError: If `aggregation_factory` is not a `tff.aggregators.WeightedAggregationFactory` or a `tff.aggregators.UnweightedAggregationFactory`. ValueError: If `aggregation_factory` is a `tff.aggregators.UnweightedAggregationFactory` but `client_weighting` is not `tff.learning.ClientWeighting.UNIFORM`. Returns: A `tff.templates.IterativeProcess`. """ with tf.Graph().as_default(): throwaway_model_for_metadata = model_fn() model_weights_type = type_conversions.type_from_tensors( reconstruction_utils.get_global_variables( throwaway_model_for_metadata)) if client_weighting is None: client_weighting = client_weight_lib.ClientWeighting.NUM_EXAMPLES if (isinstance(aggregation_factory, factory.UnweightedAggregationFactory) and client_weighting is not client_weight_lib.ClientWeighting.UNIFORM): raise ValueError( f'Expected `tff.learning.ClientWeighting.UNIFORM` client ' f'weighting with unweighted aggregator, instead got ' f'{client_weighting}') if broadcast_process is None: broadcast_process = optimizer_utils.build_stateless_broadcaster( model_weights_type=model_weights_type) if not _is_valid_broadcast_process(broadcast_process): raise TypeError( 'broadcast_process type signature does not conform to expected ' 'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).' ' Got: {t}'.format(t=broadcast_process.next.type_signature)) broadcaster_state_type = ( broadcast_process.initialize.type_signature.result.member) aggregation_process = _instantiate_aggregation_process( aggregation_factory, model_weights_type) aggregator_state_type = ( aggregation_process.initialize.type_signature.result.member) server_init_tff = _build_server_init_fn(model_fn, server_optimizer_fn, aggregation_process.initialize, broadcast_process.initialize) server_state_type = server_init_tff.type_signature.result.member server_update_fn = _build_server_update_fn( model_fn, server_optimizer_fn, server_state_type, server_state_type.model, aggregator_state_type=aggregator_state_type, broadcaster_state_type=broadcaster_state_type) dataset_type = computation_types.SequenceType( throwaway_model_for_metadata.input_spec) if dataset_split_fn is None: dataset_split_fn = reconstruction_utils.build_dataset_split_fn( split_dataset=True) client_update_fn = _build_client_update_fn( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, dataset_type=dataset_type, model_weights_type=server_state_type.model, client_optimizer_fn=client_optimizer_fn, reconstruction_optimizer_fn=reconstruction_optimizer_fn, dataset_split_fn=dataset_split_fn, client_weighting=client_weighting) federated_server_state_type = computation_types.at_server( server_state_type) federated_dataset_type = computation_types.at_clients(dataset_type) # Create placeholder metrics to produce a corresponding federated output # computation. metrics = [] if metrics_fn is not None: metrics.extend(metrics_fn()) metrics.append(keras_utils.MeanLossMetric(loss_fn())) federated_output_computation = ( keras_utils.federated_output_computation_from_metrics(metrics)) run_one_round_tff = _build_run_one_round_fn( server_update_fn, client_update_fn, federated_output_computation, federated_server_state_type, federated_dataset_type, aggregation_process=aggregation_process, broadcast_process=broadcast_process, ) process = iterative_process_lib.IterativeProcess( initialize_fn=server_init_tff, next_fn=run_one_round_tff) @computations.tf_computation(server_state_type) def get_model_weights(server_state): return server_state.model process.get_model_weights = get_model_weights return process
def client_update(dataset, initial_model_weights): """Performs client local model optimization. Args: dataset: A `tf.data.Dataset` that provides training examples. initial_model_weights: A `tff.learning.ModelWeights` containing the starting global trainable and non-trainable weights. Returns: A `ClientOutput`. """ with tf.init_scope(): model = model_fn() metrics = [] if metrics_fn is not None: metrics.extend(metrics_fn()) # To be used to calculate example-weighted mean across batches and # clients. metrics.append(keras_utils.MeanLossMetric(loss_fn())) # To be used to calculate batch loss for model updates. client_loss = loss_fn() global_model_weights = reconstruction_utils.get_global_variables(model) local_model_weights = reconstruction_utils.get_local_variables(model) tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights, initial_model_weights) client_optimizer = keras_optimizer.build_or_verify_tff_optimizer( client_optimizer_fn, global_model_weights.trainable, disjoint_init_and_next=False) reconstruction_optimizer = keras_optimizer.build_or_verify_tff_optimizer( reconstruction_optimizer_fn, local_model_weights.trainable, disjoint_init_and_next=False) @tf.function def reconstruction_reduce_fn(state, batch): """Runs reconstruction training on local client batch.""" num_examples_sum, optimizer_state = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) batch_loss = client_loss(y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, local_model_weights.trainable) optimizer_state, updated_weights = reconstruction_optimizer.next( optimizer_state, local_model_weights.trainable, gradients) if not isinstance(reconstruction_optimizer, keras_optimizer.KerasOptimizer): # Keras optimizer mutates model variables within the `next` step. tf.nest.map_structure(lambda a, b: a.assign(b), local_model_weights.trainable, updated_weights) return num_examples_sum + output.num_examples @tf.function def train_reduce_fn(state, batch): """Runs one step of client optimizer on local client batch.""" num_examples_sum, optimizer_state = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) batch_loss = client_loss(y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, global_model_weights.trainable) optimizer_state, updated_weights = client_optimizer.next( optimizer_state, global_model_weights.trainable, gradients) if not isinstance(client_optimizer, keras_optimizer.KerasOptimizer): # Keras optimizer mutates model variables within the `next` step. tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights.trainable, updated_weights) # Update each metric. for metric in metrics: metric.update_state(y_true=output.labels, y_pred=output.predictions) return num_examples_sum + output.num_examples recon_dataset, post_recon_dataset = dataset_split_fn(dataset) # If needed, do reconstruction, training the local variables while keeping # the global ones frozen. if local_model_weights.trainable: # Ignore output number of examples used in reconstruction, since this # isn't included in `client_weight`. def initial_state_reconstruction_reduce(): trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), local_model_weights.trainable) return tf.constant(0), reconstruction_optimizer.initialize( trainable_tensor_specs) recon_dataset.reduce( initial_state=initial_state_reconstruction_reduce(), reduce_func=reconstruction_reduce_fn) # Train the global variables, keeping local variables frozen. def initial_state_train_reduce(): trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), global_model_weights.trainable) return tf.constant(0), client_optimizer.initialize( trainable_tensor_specs) num_examples_sum, _ = post_recon_dataset.reduce( initial_state=initial_state_train_reduce(), reduce_func=train_reduce_fn) weights_delta = tf.nest.map_structure(lambda a, b: a - b, global_model_weights.trainable, initial_model_weights.trainable) # We ignore the update if the weights_delta is non finite. weights_delta, has_non_finite_weight = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) model_local_outputs = keras_utils.read_metric_variables(metrics) if has_non_finite_weight > 0: client_weight = tf.constant(0.0, dtype=tf.float32) elif client_weighting is client_weight_lib.ClientWeighting.NUM_EXAMPLES: client_weight = tf.cast(num_examples_sum, dtype=tf.float32) elif client_weighting is client_weight_lib.ClientWeighting.UNIFORM: client_weight = tf.constant(1.0, dtype=tf.float32) else: client_weight = client_weighting(model_local_outputs) return ClientOutput(weights_delta, client_weight, model_local_outputs)
def test_keras_local_layer_custom_broadcaster(self): def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ NumExamplesCounter(), NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] model_weights_type = type_conversions.type_from_tensors( reconstruction_utils.get_global_variables(local_recon_model_fn())) def build_custom_stateful_broadcaster( model_weights_type) -> measured_process_lib.MeasuredProcess: """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`.""" @computations.federated_computation() def test_server_initialization(): return intrinsics.federated_value(2.0, placements.SERVER) @computations.federated_computation( computation_types.FederatedType(tf.float32, placements.SERVER), computation_types.FederatedType(model_weights_type, placements.SERVER), ) def stateful_broadcast(state, value): empty_metrics = intrinsics.federated_value(1.0, placements.SERVER) return measured_process_lib.MeasuredProcessOutput( state=state, result=intrinsics.federated_broadcast(value), measurements=empty_metrics) return measured_process_lib.MeasuredProcess( initialize_fn=test_server_initialization, next_fn=stateful_broadcast) it_process = training_process.build_training_process( local_recon_model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.001), dataset_split_fn=reconstruction_utils.simple_dataset_split_fn, broadcast_process=build_custom_stateful_broadcaster( model_weights_type=model_weights_type)) server_state = it_process.initialize() # Ensure initialization of broadcaster produces expected metric. self.assertEqual(server_state.model_broadcast_state, 2.0) client_data = create_emnist_client_data() federated_data = [client_data(), client_data()] server_state, output = it_process.next(server_state, federated_data) expected_keys = ['broadcast', 'aggregation', 'train'] self.assertCountEqual(output.keys(), expected_keys) expected_train_keys = [ 'sparse_categorical_accuracy', 'loss', 'num_examples_total', 'num_batches_total' ] self.assertCountEqual(output['train'].keys(), expected_train_keys) self.assertEqual(output['broadcast'], 1.0)