def create_initial_state(): return reconstruction_utils.ServerState( model=reconstruction_utils.get_global_variables(model_fn()), optimizer_state=(), round_num=tf.constant(0, dtype=tf.int64), aggregator_state=(), )
def client_computation(incoming_model_weights, client_dataset): """Reconstructs and evaluates with `incoming_model_weights`.""" client_model = model_fn() client_global_weights = reconstruction_utils.get_global_variables( client_model) client_local_weights = reconstruction_utils.get_local_variables( client_model) metrics = [keras_utils.MeanLossMetric(loss_fn())] if metrics_fn is not None: metrics.extend(metrics_fn()) batch_loss_fn = loss_fn() reconstruction_optimizer = reconstruction_optimizer_fn() @tf.function def reconstruction_reduce_fn(num_examples_sum, batch): """Runs reconstruction training on local client batch.""" with tf.GradientTape() as tape: output = client_model.forward_pass(batch, training=True) batch_loss = batch_loss_fn( y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, client_local_weights.trainable) reconstruction_optimizer.apply_gradients( zip(gradients, client_local_weights.trainable)) return num_examples_sum + output.num_examples @tf.function def evaluation_reduce_fn(num_examples_sum, batch): """Runs evaluation on client batch without training.""" output = client_model.forward_pass(batch, training=False) # Update each metric. for metric in metrics: metric.update_state(y_true=output.labels, y_pred=output.predictions) return num_examples_sum + output.num_examples @tf.function def tf_client_computation(incoming_model_weights, client_dataset): """Reconstructs and evaluates with `incoming_model_weights`.""" # Pass in fixed 0 round number during evaluation, since global variables # aren't being iteratively updated as in training. recon_dataset, eval_dataset = dataset_split_fn( client_dataset, tf.constant(0, dtype=tf.int64)) # Assign incoming global weights to `client_model` before reconstruction. tf.nest.map_structure(lambda v, t: v.assign(t), client_global_weights, incoming_model_weights) recon_dataset.reduce(tf.constant(0), reconstruction_reduce_fn) eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn) eval_local_outputs = keras_utils.read_metric_variables(metrics) return eval_local_outputs return tf_client_computation(incoming_model_weights, client_dataset)
def server_init_tf(): """Initialize the TensorFlow-only portions of the server state.""" # Round number can be used to parameterize client behavior, e.g. clients # can do more local iterations for reconstruction for later rounds. round_num = tf.constant(1, dtype=tf.int64) model = model_fn() server_optimizer = server_optimizer_fn() # Create optimizer variables so we have a place to assign the optimizer's # state. server_optimizer_vars = reconstruction_utils.create_optimizer_vars( model, server_optimizer) return reconstruction_utils.get_global_variables( model), server_optimizer_vars, round_num
def server_update(model, server_optimizer, server_optimizer_vars, server_state, weights_delta, aggregator_state): """Updates `server_state` based on `weights_delta`. Args: model: A `ReconstructionModel`. server_optimizer: A `tf.keras.optimizers.Optimizer`. server_optimizer_vars: A list of variables of server_optimizer. server_state: A `ServerState`, the state to be updated. weights_delta: An update to the trainable variables of the model. aggregator_state: The state of the aggregator after performing aggregation. Returns: An updated `ServerState`. """ global_model_weights = reconstruction_utils.get_global_variables(model) # Initialize the model with the current state. tf.nest.map_structure( lambda a, b: a.assign(b), (global_model_weights, server_optimizer_vars), (server_state.model, server_state.optimizer_state)) weights_delta, has_non_finite_weight = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) # We ignore the update if the weights_delta is non finite. if has_non_finite_weight > 0: return tff.utils.update_state( server_state, model=global_model_weights, optimizer_state=server_optimizer_vars, round_num=server_state.round_num + 1, aggregator_state=aggregator_state) # Apply the update to the model. grads_and_vars = tf.nest.map_structure( lambda x, v: (-1.0 * x, v), tf.nest.flatten(weights_delta), tf.nest.flatten(global_model_weights.trainable)) server_optimizer.apply_gradients(grads_and_vars, name='server_update') # Create a new state based on the updated model. return tff.utils.update_state( server_state, model=global_model_weights, optimizer_state=server_optimizer_vars, round_num=server_state.round_num + 1, aggregator_state=aggregator_state, )
def test_get_global_variables(self): keras_model = tff.simulation.models.mnist.create_keras_model( compile_model=False) input_spec = _create_input_spec() model = keras_utils.from_keras_model( keras_model=keras_model, global_layers=keras_model.layers[:-1], local_layers=keras_model.layers[-1:], input_spec=input_spec) global_weights = reconstruction_utils.get_global_variables(model) self.assertIsInstance(global_weights, tff.learning.ModelWeights) # The last layer of the Keras model, which is a local Dense layer, contains # 2 trainable variables for the weights and bias. self.assertEqual(global_weights.trainable, keras_model.trainable_variables[:-2]) self.assertEmpty(global_weights.non_trainable)
def build_federated_reconstruction_evaluation( model_fn: ModelFn, *, # Callers pass below args by name. loss_fn: LossFn, metrics_fn: Optional[MetricsFn], reconstruction_optimizer_fn: OptimizerFn = functools.partial( tf.keras.optimizers.SGD, 0.1), dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None ) -> tff.Computation: """Builds a `tff.Computation` for evaluation of a `ReconstructionModel`. The returned computation proceeds in two stages: (1) reconstruction and (2) evaluation. During the reconstruction stage, local variables are reconstructed by freezing global variables and training using reconstruction_optimizer_fn. During the evaluation stage, the reconstructed local variables and global variables are evaluated using the provided loss_fn and metrics_fn. Usage of returned computation: eval_comp = build_federated_reconstruction_evaluation(...) metrics = eval_comp(reconstruction_utils.get_global_variables(model), federated_data) Args: model_fn: A no-arg function that returns a `ReconstructionModel`. This method must *not* capture Tensorflow tensors or variables and use them. Must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to evaluate the model. The loss will be applied to the model's outputs during the evaluation stage. The final loss metric is the example-weighted mean loss across batches (and across clients). metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s to evaluate the model. The metrics will be applied to the model's outputs during the evaluation stage. Final metric values are the example-weighted mean of metric values across batches (and across clients). If None, no metrics are applied. reconstruction_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` used to reconstruct the local variables with the global ones frozen. dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a client dataset and round number (always 0 for evaluation) and producing two TF datasets. The first is iterated over during reconstruction, and the second is iterated over during evaluation. This can be used to preprocess datasets to e.g. iterate over them for multiple epochs or use disjoint data for reconstruction and evaluation. If None, split client data in half for each user, using one half for reconstruction and the other for evaluation. See `reconstruction_utils.build_dataset_split_fn` for options. Raises: ValueError: if both `loss_fn` and `metrics_fn` are None. Returns: A `tff.Computation` that accepts model parameters and federated data and returns example-weighted evaluation loss and metrics. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. with tf.Graph().as_default(): model = model_fn() global_weights = reconstruction_utils.get_global_variables(model) model_weights_type = tff.framework.type_from_tensors(global_weights) batch_type = tff.to_type(model.input_spec) metrics = [keras_utils.MeanLossMetric(loss_fn())] if metrics_fn is not None: metrics.extend(metrics_fn()) if not metrics: raise ValueError( 'One or both of metrics_fn and loss_fn should be provided.') federated_output_computation = ( keras_utils.federated_output_computation_from_metrics(metrics)) # Remove unneeded variables to avoid polluting namespace. del model del global_weights del metrics if dataset_split_fn is None: dataset_split_fn = reconstruction_utils.build_dataset_split_fn( split_dataset=True) @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type)) def client_computation(incoming_model_weights, client_dataset): """Reconstructs and evaluates with `incoming_model_weights`.""" client_model = model_fn() client_global_weights = reconstruction_utils.get_global_variables( client_model) client_local_weights = reconstruction_utils.get_local_variables( client_model) metrics = [keras_utils.MeanLossMetric(loss_fn())] if metrics_fn is not None: metrics.extend(metrics_fn()) batch_loss_fn = loss_fn() reconstruction_optimizer = reconstruction_optimizer_fn() @tf.function def reconstruction_reduce_fn(num_examples_sum, batch): """Runs reconstruction training on local client batch.""" with tf.GradientTape() as tape: output = client_model.forward_pass(batch, training=True) batch_loss = batch_loss_fn( y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, client_local_weights.trainable) reconstruction_optimizer.apply_gradients( zip(gradients, client_local_weights.trainable)) return num_examples_sum + output.num_examples @tf.function def evaluation_reduce_fn(num_examples_sum, batch): """Runs evaluation on client batch without training.""" output = client_model.forward_pass(batch, training=False) # Update each metric. for metric in metrics: metric.update_state(y_true=output.labels, y_pred=output.predictions) return num_examples_sum + output.num_examples @tf.function def tf_client_computation(incoming_model_weights, client_dataset): """Reconstructs and evaluates with `incoming_model_weights`.""" # Pass in fixed 0 round number during evaluation, since global variables # aren't being iteratively updated as in training. recon_dataset, eval_dataset = dataset_split_fn( client_dataset, tf.constant(0, dtype=tf.int64)) # Assign incoming global weights to `client_model` before reconstruction. tf.nest.map_structure(lambda v, t: v.assign(t), client_global_weights, incoming_model_weights) recon_dataset.reduce(tf.constant(0), reconstruction_reduce_fn) eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn) eval_local_outputs = keras_utils.read_metric_variables(metrics) return eval_local_outputs return tf_client_computation(incoming_model_weights, client_dataset) @tff.federated_computation( tff.type_at_server(model_weights_type), tff.type_at_clients(tff.SequenceType(batch_type))) def server_eval(server_model_weights, federated_dataset): client_outputs = tff.federated_map( client_computation, [tff.federated_broadcast(server_model_weights), federated_dataset]) return federated_output_computation(client_outputs) return server_eval
def build_federated_reconstruction_process( model_fn: ModelFn, *, # Callers pass below args by name. loss_fn: LossFn, metrics_fn: Optional[MetricsFn] = None, server_optimizer_fn: OptimizerFn = functools.partial( tf.keras.optimizers.SGD, 1.0), client_optimizer_fn: OptimizerFn = functools.partial( tf.keras.optimizers.SGD, 0.1), reconstruction_optimizer_fn: OptimizerFn = functools.partial( tf.keras.optimizers.SGD, 0.1), dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None, evaluate_reconstruction: bool = False, jointly_train_variables: bool = False, client_weight_fn: Optional[ClientWeightFn] = None, aggregation_factory: Optional[ tff.aggregators.WeightedAggregationFactory] = None, ) -> tff.templates.IterativeProcess: """Builds the IterativeProcess for optimization using FedRecon. Returns a `tff.templates.IterativeProcess` for Federated Reconstruction. On the client, computation can be divided into two stages: (1) reconstruction of local variables and (2) training of global variables (possibly jointly with reconstructed local variables). Args: model_fn: A no-arg function that returns a `ReconstructionModel`. This method must *not* capture Tensorflow tensors or variables and use them. must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to compute local model updates during reconstruction and post-reconstruction and evaluate the model during training. The final loss metric is the example-weighted mean loss across batches and across clients. Depending on whether `evaluate_reconstruction` is True, the loss metric may or may not include reconstruction batches in the loss. metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s to evaluate the model. Metrics results are computed locally as described by the metric, and are aggregated across clients as in `federated_aggregate_keras_metric`. If None, no metrics are applied. Depending on whether evaluate_reconstruction is True, metrics may or may not be computed on reconstruction batches as well. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for applying updates to the global model on the server. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for local client training after reconstruction. reconstruction_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` used to reconstruct the local variables, with the global ones frozen, or the first stage described above. dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a client dataset and training round number (1-indexed) and producing two TF datasets. The first is iterated over during reconstruction, and the second is iterated over post-reconstruction. This can be used to preprocess datasets to e.g. iterate over them for multiple epochs or use disjoint data for reconstruction and post-reconstruction. If None, `reconstruction_utils.simple_dataset_split_fn` is used, which results in iterating over the original client data for both phases of training. See `reconstruction_utils.build_dataset_split_fn` for options. evaluate_reconstruction: If True, metrics (including loss) are computed on batches during reconstruction and post-reconstruction. If False, metrics are computed on batches only post-reconstruction, when global weights are being updated. Note that metrics are aggregated across batches as given by the metric (example-weighted mean for the loss). Setting this to True includes all local batches in metric calculations. Setting this to False brings the interpretation of these metrics closer to the interpretation of metrics in FedAvg. Note that this does not affect training at all: losses for individual batches are calculated and used to update variables regardless. jointly_train_variables: Whether to train local variables during the second stage described above. If True, global and local variables are trained jointly after reconstruction of local variables using the optimizer given by client_optimizer_fn. If False, only global variables are trained during the second stage with local variables frozen, similar to alternating minimization. client_weight_fn: Optional function that takes the local model's output, and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device during post-reconstruction phase. aggregation_factory: An optional instance of `tff.aggregators.WeightedAggregationFactory` determining the method of aggregation to perform. If unspecified, uses a default `tff.aggregators.MeanFactory` which computes a stateless weighted mean across clients. Returns: A `tff.templates.IterativeProcess`. """ with tf.Graph().as_default(): throwaway_model_for_metadata = model_fn() model_weights_type = tff.framework.type_from_tensors( reconstruction_utils.get_global_variables( throwaway_model_for_metadata)) aggregation_process = _instantiate_aggregation_process( aggregation_factory, model_weights_type, client_weight_fn) aggregator_state_type = ( aggregation_process.initialize.type_signature.result.member) server_init_tff = build_server_init_fn(model_fn, server_optimizer_fn, aggregation_process) server_state_type = server_init_tff.type_signature.result.member server_update_fn = build_server_update_fn( model_fn, server_optimizer_fn, server_state_type, server_state_type.model, aggregator_state_type=aggregator_state_type) tf_dataset_type = tff.SequenceType(throwaway_model_for_metadata.input_spec) if dataset_split_fn is None: dataset_split_fn = reconstruction_utils.simple_dataset_split_fn client_update_fn = build_client_update_fn( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, tf_dataset_type=tf_dataset_type, model_weights_type=server_state_type.model, client_optimizer_fn=client_optimizer_fn, reconstruction_optimizer_fn=reconstruction_optimizer_fn, dataset_split_fn=dataset_split_fn, evaluate_reconstruction=evaluate_reconstruction, jointly_train_variables=jointly_train_variables, client_weight_fn=client_weight_fn) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(tf_dataset_type) # Create placeholder metrics to produce a corresponding federated output # computation. metrics = [] if metrics_fn is not None: metrics.extend(metrics_fn()) metrics.append(keras_utils.MeanLossMetric(loss_fn())) federated_output_computation = ( keras_utils.federated_output_computation_from_metrics(metrics)) run_one_round_tff = build_run_one_round_fn( server_update_fn, client_update_fn, federated_output_computation, federated_server_state_type, federated_dataset_type, aggregation_process=aggregation_process, ) iterative_process = tff.templates.IterativeProcess( initialize_fn=server_init_tff, next_fn=run_one_round_tff) @tff.tf_computation(server_state_type) def get_model_weights(server_state): return server_state.model iterative_process.get_model_weights = get_model_weights return iterative_process
def client_update(model, metrics, batch_loss_fn, dataset, initial_weights, client_optimizer, reconstruction_optimizer, round_num): """Updates client model. Outputted weight deltas represent the difference between final global variables and initial ones. The client weight (used in aggregation across clients) is the sum of the number of examples across all batches post-reconstruction (that is, only the local steps that involve updating global variables). Args: model: A `ReconstructionModel`. metrics: A List of `tf.keras.metrics.Metric`s containing metrics to be computed and aggregated across clients. batch_loss_fn: A `tf.keras.losses.Loss` used to compute batch loss on `BatchOutput.predictions` (y_pred) and `BatchOutput.labels` (y_true) for each batch during and after reconstruction. dataset: A 'tf.data.Dataset'. initial_weights: A `tff.learning.ModelWeights` containing global trainable and non-trainable weights from the server. client_optimizer: a `tf.keras.optimizers.Optimizer` for training after the reconstruction step. reconstruction_optimizer: a `tf.keras.optimizers.Optimizer` for reconstruction of local trainable variables. round_num: the federated training round number, 1-indexed. Returns: A 'reconstruction_utils.ClientOutput`. """ global_model_weights = reconstruction_utils.get_global_variables(model) local_model_weights = reconstruction_utils.get_local_variables(model) tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights, initial_weights) @tf.function def reconstruction_reduce_fn(num_examples_sum, batch): """Runs reconstruction training on local client batch.""" with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) batch_loss = batch_loss_fn(y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, local_model_weights.trainable) reconstruction_optimizer.apply_gradients( zip(gradients, local_model_weights.trainable)) # Update metrics if needed. if evaluate_reconstruction: for metric in metrics: metric.update_state(y_true=output.labels, y_pred=output.predictions) return num_examples_sum + output.num_examples @tf.function def train_reduce_fn(num_examples_sum, batch): """Runs one step of client optimizer on local client batch.""" if jointly_train_variables: all_trainable_variables = (global_model_weights.trainable + local_model_weights.trainable) else: all_trainable_variables = global_model_weights.trainable with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) batch_loss = batch_loss_fn(y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, all_trainable_variables) client_optimizer.apply_gradients( zip(gradients, all_trainable_variables)) # Update each metric. for metric in metrics: metric.update_state(y_true=output.labels, y_pred=output.predictions) return num_examples_sum + output.num_examples recon_dataset, post_recon_dataset = dataset_split_fn( dataset, round_num) # If needed, do reconstruction, training the local variables while keeping # the global ones frozen. if local_model_weights.trainable: # Ignore output number of examples used in reconstruction, since this # isn't included in `client_weight`. recon_dataset.reduce(initial_state=tf.constant(0), reduce_func=reconstruction_reduce_fn) # Train the global variables, possibly jointly with local variables if # jointly_train_variables is True. num_examples_sum = post_recon_dataset.reduce( initial_state=tf.constant(0), reduce_func=train_reduce_fn) weights_delta = tf.nest.map_structure(lambda a, b: a - b, global_model_weights.trainable, initial_weights.trainable) # We ignore the update if the weights_delta is non finite. weights_delta, has_non_finite_weight = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) model_local_outputs = keras_utils.read_metric_variables(metrics) if has_non_finite_weight > 0: client_weight = tf.constant(0, dtype=tf.float32) elif client_weight_fn is None: client_weight = tf.cast(num_examples_sum, dtype=tf.float32) else: client_weight = client_weight_fn(model_local_outputs) return reconstruction_utils.ClientOutput(weights_delta, client_weight, model_local_outputs)