def test_raises_on_invalid_distributor(self): model_weights_type = type_conversions.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression())) distributor = distributors.build_broadcast_process(model_weights_type) invalid_distributor = iterative_process.IterativeProcess( distributor.initialize, distributor.next) with self.assertRaises(TypeError): fed_avg.build_weighted_fed_avg( model_fn=model_examples.LinearRegression, client_optimizer_fn=sgdm.build_sgdm(1.0), model_distributor=invalid_distributor)
def test_raises_on_invalid_distributor(self): model_weights_type = type_conversions.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression())) distributor = distributors.build_broadcast_process(model_weights_type) invalid_distributor = iterative_process.IterativeProcess( distributor.initialize, distributor.next) with self.assertRaises(TypeError): mime.build_weighted_mime_lite( model_fn=model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), model_distributor=invalid_distributor)
def build_basic_fedavg_process(model_fn: Callable[[], model_lib.Model], client_learning_rate: float): """Builds vanilla Federated Averaging process. The created process is the basic form of the Federated Averaging algorithm as proposed by http://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf in Algorithm 1, for training the model created by `model_fn`. The following is the algorithm in pseudo-code: ``` # Inputs: m: Initial model weights; eta: Client learning rate for i in num_rounds: for c in available_clients_indices: delta_m_c, w_c = client_update(m, eta) aggregate_model_delta = sum_c(model_delta_c * w_c) / sum_c(w_c) m = m - aggregate_model_delta return m # Final trained model. def client_udpate(m, eta): initial_m = m for batch in client_dataset: m = m - eta * grad(m, b) delta_m = initial_m - m return delta_m, size(dataset) ``` The other algorithm hyper parameters (batch size, number of local epochs) are controlled by the data provided to the built process. An example usage of the returned `LearningProcess` in simulation: ``` fedavg = build_basic_fedavg_process(model_fn, 0.1) # Create a `LearningAlgorithmState` containing the initial model weights for # the model returned from `model_fn`. state = fedavg.initialize() for _ in range(num_rounds): client_data = ... # Preprocessed client datasets output = fedavg.next(state, client_data) write_round_metrics(outpus.metrics) # The new state contains the updated model weights after this round. state = output.state ``` Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_learning_rate: A float. Learning rate for the SGD at clients. Returns: A `LearningProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_type(client_learning_rate, float) @tensorflow_computation.tf_computation() def initial_model_weights_fn(): return model_utils.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result distributor = distributors.build_broadcast_process(model_weights_type) client_work = model_delta_client_work.build_model_delta_client_work( model_fn, sgdm.build_sgdm(client_learning_rate), client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES) aggregator = mean.MeanFactory().create( client_work.next.type_signature.result.result.member.update, client_work.next.type_signature.result.result.member.update_weight) finalizer = finalizers.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0), model_weights_type) return compose_learning_process(initial_model_weights_fn, distributor, client_work, aggregator, finalizer)
def build_weighted_mime_lite( model_fn: Callable[[], model_lib.Model], base_optimizer: optimizer_base.Optimizer, server_optimizer: optimizer_base.Optimizer = sgdm.build_sgdm(1.0), client_weighting: Optional[ client_weight_lib. ClientWeighting] = client_weight_lib.ClientWeighting.NUM_EXAMPLES, model_distributor: Optional[distributors.DistributionProcess] = None, model_aggregator: Optional[factory.WeightedAggregationFactory] = None, full_gradient_aggregator: Optional[ factory.WeightedAggregationFactory] = None, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, use_experimental_simulation_loop: bool = False ) -> learning_process.LearningProcess: """Builds a learning process that performs Mime Lite. This function creates a `tff.learning.templates.LearningProcess` that performs Mime Lite algorithm on client models. The iterative process has the following methods inherited from `tff.learning.templates.LearningProcess`: * `initialize`: A `tff.Computation` with the functional type signature `( -> S@SERVER)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` representing the initial state of the server. * `next`: A `tff.Computation` with the functional type signature `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `{B*}@CLIENTS` represents the client datasets. The output `L` contains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes. * `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `next`, and `M` represents the type of the model weights used during training. * `set_model_weights`: A `tff.Computation` with type signature `(<S, M> -> S)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `M` represents the type of the model weights used during training. Each time the `next` method is called, the server model is communicated to each client using the provided `model_distributor`. For each client, local training is performed using `optimizer`, where its state is communicated by the server, and kept intact during local training. The state is updated only at the server based on the full gradient evaluated by the clients based on the current server model state. The client full gradients are aggregated by weighted `full_gradient_aggregator`. Each client computes the difference between the client model after training and its initial model. These model deltas are then aggregated by weighted `model_aggregator`. Both of the aggregations are weighted, according to `client_weighting`. The aggregate model delta is added to the existing server model state. The Mime Lite algorithm is based on the paper "Breaking the centralized barrier for cross-device federated learning." Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank Reddi, Sebastian U. Stich, and Ananda Theertha Suresh. Advances in Neural Information Processing Systems 34 (2021). https://proceedings.neurips.cc/paper/2021/file/f0e6be4ce76ccfa73c5a540d992d0756-Paper.pdf Note that Keras optimizers are not supported. This is due to the Mime Lite algorithm applying the optimizer without changing it state at clients (optimizer's `tf.Variable`s in the case of Keras), which is not possible with Keras optimizers without reaching into private implementation details and incurring additional computation and memory cost at clients. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. base_optimizer: A `tff.learning.optimizers.Optimizer` which will be used for both creating and updating a global optimizer state, as well as optimization at clients given the global state, which is fixed during the optimization. server_optimizer: A `tff.learning.optimizers.Optimizer` which will be used for applying the aggregate model update to the global model weights. client_weighting: A member of `tff.learning.ClientWeighting` that specifies a built-in weighting method. By default, weighting by number of examples is used. model_distributor: An optional `DistributionProcess` that distributes the model weights on the server to the clients. If set to `None`, the distributor is constructed via `distributors.build_broadcast_process`. model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory` used to aggregate client updates on the server. If `None`, this is set to `tff.aggregators.MeanFactory`. full_gradient_aggregator: An optional `tff.aggregators.WeightedAggregationFactory` used to aggregate the full gradients on client datasets. If `None`, this is set to `tff.aggregators.MeanFactory`. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. If `None`, this is set to `tff.learning.metrics.sum_then_finalize`. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. Returns: A `tff.learning.templates.LearningProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_type(base_optimizer, optimizer_base.Optimizer) py_typecheck.check_type(server_optimizer, optimizer_base.Optimizer) py_typecheck.check_type(client_weighting, client_weight_lib.ClientWeighting) @tensorflow_computation.tf_computation() def initial_model_weights_fn(): return model_utils.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result if model_distributor is None: model_distributor = distributors.build_broadcast_process( model_weights_type) if model_aggregator is None: model_aggregator = mean.MeanFactory() py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory) model_aggregator = model_aggregator.create( model_weights_type.trainable, computation_types.TensorType(tf.float32)) if full_gradient_aggregator is None: full_gradient_aggregator = mean.MeanFactory() py_typecheck.check_type(full_gradient_aggregator, factory.WeightedAggregationFactory) client_work = _build_mime_lite_client_work( model_fn=model_fn, optimizer=base_optimizer, client_weighting=client_weighting, full_gradient_aggregator=full_gradient_aggregator, metrics_aggregator=metrics_aggregator, use_experimental_simulation_loop=use_experimental_simulation_loop) finalizer = finalizers.build_apply_optimizer_finalizer( server_optimizer, model_weights_type) return composers.compose_learning_process(initial_model_weights_fn, model_distributor, client_work, model_aggregator, finalizer)
def build_weighted_fed_avg_with_optimizer_schedule( model_fn: Callable[[], model_lib.Model], client_learning_rate_fn: Callable[[int], float], client_optimizer_fn: Callable[[float], TFFOrKerasOptimizer], server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[ [], tf.keras.optimizers.Optimizer]] = fed_avg.DEFAULT_SERVER_OPTIMIZER_FN, model_distributor: Optional[distributors.DistributionProcess] = None, model_aggregator: Optional[factory.WeightedAggregationFactory] = None, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, use_experimental_simulation_loop: bool = False ) -> learning_process.LearningProcess: """Builds a learning process for FedAvg with client optimizer scheduling. This function creates a `LearningProcess` that performs federated averaging on client models. The iterative process has the following methods inherited from `LearningProcess`: * `initialize`: A `tff.Computation` with the functional type signature `( -> S@SERVER)`, where `S` is a `LearningAlgorithmState` representing the initial state of the server. * `next`: A `tff.Computation` with the functional type signature `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `{B*}@CLIENTS` represents the client datasets. The output `L` contains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes. * `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `next`, and `M` represents the type of the model weights used during training. * `set_model_weights`: A `tff.Computation` with type signature `(<S, M> -> S)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `M` represents the type of the model weights used during training. Each time the `next` method is called, the server model is broadcast to each client using a broadcast function. For each client, local training is performed using `client_optimizer_fn`. Each client computes the difference between the client model after training and the initial broadcast model. These model deltas are then aggregated at the server using a weighted aggregation function. Clients weighted by the number of examples they see thoughout local training. The aggregate model delta is applied at the server using a server optimizer. The primary purpose of this implementation of FedAvg is that it allows for the client optimizer to be scheduled across rounds. The process keeps track of how many iterations of `.next` have occurred (starting at `0`), and for each such `round_num`, the clients will use `client_optimizer_fn(round_num)` to perform local optimization. This allows learning rate scheduling (eg. starting with a large learning rate and decaying it over time) as well as a small learning rate (eg. switching optimizers as learning progresses). Note: the default server optimizer function is `tf.keras.optimizers.SGD` with a learning rate of 1.0, which corresponds to adding the model delta to the current server model. This recovers the original FedAvg algorithm in [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). More sophisticated federated averaging procedures may use different learning rates or server optimizers. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. client_learning_rate_fn: A callable accepting an integer round number and returning a float to be used as a learning rate for the optimizer. The client work will call `optimizer_fn(learning_rate_fn(round_num))` where `round_num` is the integer round number. Note that the round numbers supplied will start at `0` and increment by one each time `.next` is called on the resulting process. Also note that this function must be serializable by TFF. client_optimizer_fn: A callable accepting a float learning rate, and returning a `tff.learning.optimizers.Optimizer` or a `tf.keras.Optimizer`. server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that returns a `tf.keras.Optimizer`. By default, this uses `tf.keras.optimizers.SGD` with a learning rate of 1.0. model_distributor: An optional `DistributionProcess` that distributes the model weights on the server to the clients. If set to `None`, the distributor is constructed via `distributors.build_broadcast_process`. model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory` used to aggregate client updates on the server. If `None`, this is set to `tff.aggregators.MeanFactory`. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. If `None`, this is set to `tff.learning.metrics.sum_then_finalize`. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. Returns: A `LearningProcess`. """ py_typecheck.check_callable(model_fn) @tensorflow_computation.tf_computation() def initial_model_weights_fn(): return model_utils.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result if model_distributor is None: model_distributor = distributors.build_broadcast_process(model_weights_type) if model_aggregator is None: model_aggregator = mean.MeanFactory() py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory) aggregator = model_aggregator.create(model_weights_type.trainable, computation_types.TensorType(tf.float32)) process_signature = aggregator.next.type_signature input_client_value_type = process_signature.parameter[1] result_server_value_type = process_signature.result[1] if input_client_value_type.member != result_server_value_type.member: raise TypeError('`model_update_aggregation_factory` does not produce a ' 'compatible `AggregationProcess`. The processes must ' 'retain the type structure of the inputs on the ' f'server, but got {input_client_value_type.member} != ' f'{result_server_value_type.member}.') if metrics_aggregator is None: metrics_aggregator = metric_aggregator.sum_then_finalize client_work = build_scheduled_client_work(model_fn, client_learning_rate_fn, client_optimizer_fn, metrics_aggregator, use_experimental_simulation_loop) finalizer = finalizers.build_apply_optimizer_finalizer( server_optimizer_fn, model_weights_type) return composers.compose_learning_process(initial_model_weights_fn, model_distributor, client_work, aggregator, finalizer)
def build_fed_kmeans( num_clusters: int, data_shape: Tuple[int, ...], random_seed: Optional[Tuple[int, int]] = None, distributor: Optional[distributors.DistributionProcess] = None, sum_aggregator: Optional[factory.UnweightedAggregationFactory] = None, ) -> learning_process.LearningProcess: """Builds a learning process for federated k-means clustering. This function creates a `tff.learning.templates.LearningProcess` that performs federated k-means clustering. Specifically, this performs mini-batch k-means clustering. Note that mini-batch k-means only processes a mini-batch of the data at each round, and updates clusters in a weighted manner based on how many points in the mini-batch were assigned to each cluster. In the federated version, clients do the assignment of each of their point locally, and the server updates the clusters. Conceptually, the "mini-batch" being used is the union of all client datasets involved in a given round. The learning process has the following methods inherited from `tff.learning.templates.LearningProcess`: * `initialize`: A `tff.Computation` with the functional type signature `( -> S@SERVER)`, where `S` is a `LearningAlgorithmState` representing the initial state of the server. * `next`: A `tff.Computation` with the functional type signature `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a `LearningAlgorithmState` whose type matches the output of `initialize` and `{B*}@CLIENTS` represents the client datasets. The output `L` is a `tff.learning.templates.LearningProcessOutput` containing the state `S` and metrics computed during training. * `get_model_weights`: A `tff.Computation` with type signature `(S -> W)`, where `W` represents the current k-means centroids. * `set_model_weights`: A `tff.Computation` with type signature `(<S, M> -> S)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `M` a new set of k-means centroids. Here, `S` is a `tff.learning.templates.LearningAlgorithmState`. The centroids `W` is a tensor representing the current centroids, and is of shape `(num_clusters,) + data_shape`. The datasets `{B*}` must have elements of shape `data_shape`, and not employ batching. The centroids are updated at each round by assigning all clients' points to the nearest centroid, and then summing these points according to these centroids. The centroids are then updated at the server based on these points. To do so, we keep track of how many points have been assigned to each centroid overall, as an integer tensor of shape `(num_clusters,)`. This information can be found in `state.finalizer`. Note that we begin with a "pseudo-count" of 1, in order to ensure that the centroids do not collapse to zero. Args: num_clusters: The number of clusters to use. data_shape: A tuple of integers specifying the shape of each data point. Note that this data shape should be unbatched, as this algorithm does not currently support batched data points. random_seed: A tuple of two integers used to seed the initialization phase. distributor: An optional `tff.learning.tekmplates.DistributionProcess` that broadcasts the centroids on the server to the clients. If set to `None`, the distributor is constructed via `tff.learning.templates.build_broadcast_process`. sum_aggregator: An optional `tff.aggregators.UnweightedAggregationFactory` used to sum updates across clients. If `None`, we use `tff.aggregators.SumFactory`. Returns: A `LearningProcess`. """ centroids_shape = (num_clusters, ) + data_shape if not random_seed: random_seed = (tf.cast(tf.timestamp() * _MILLIS_PER_SECOND, tf.int64).numpy(), 0) @tensorflow_computation.tf_computation def initialize_centers(): return tf.random.stateless_normal(centroids_shape, random_seed, dtype=_POINT_DTYPE) centroids_type = computation_types.TensorType(_POINT_DTYPE, centroids_shape) weights_type = computation_types.TensorType(_WEIGHT_DTYPE, shape=(num_clusters, )) point_type = computation_types.TensorType(_POINT_DTYPE, shape=data_shape) data_type = computation_types.SequenceType(point_type) if distributor is None: distributor = distributors.build_broadcast_process(centroids_type) client_work = _build_kmeans_client_work(centroids_type, data_type) if sum_aggregator is None: sum_aggregator = sum_factory.SumFactory() # We wrap the sum factory as a weighted aggregator for compatibility with # the learning process composer. weighted_aggregator = factory_utils.as_weighted_aggregator(sum_aggregator) value_type = computation_types.to_type((centroids_type, weights_type)) aggregator = weighted_aggregator.create(value_type, computation_types.to_type(())) finalizer = _build_kmeans_finalizer(centroids_type, num_clusters) return composers.compose_learning_process(initialize_centers, distributor, client_work, aggregator, finalizer)
def build_fed_sgd( model_fn: Callable[[], model_lib.Model], server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[ [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN, model_distributor: Optional[distributors.DistributionProcess] = None, model_aggregator: Optional[factory.WeightedAggregationFactory] = None, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, use_experimental_simulation_loop: bool = False, ) -> learning_process.LearningProcess: """Builds a learning process that performs federated SGD. This function creates a `tff.learning.templates.LearningProcess` that performs federated SGD on client models. The learning process has the following methods inherited from `tff.learning.templates.LearningProcess`: * `initialize`: A `tff.Computation` with type signature `( -> S@SERVER)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` representing the initial state of the server. * `next`: A `tff.Computation` with type signature `(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)` where `S` is a `LearningAlgorithmState` whose type matches that of the output of `initialize`, and `{B*}@CLIENTS` represents the client datasets, where `B` is the type of a single batch. This computation returns a `LearningAlgorithmState` representing the updated server state and the metrics during client training and any other metrics from broadcast and aggregation processes. * `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `next`, and `M` represents the type of the model weights used during training. * `set_model_weights`: A `tff.Computation` with type signature `(<S, M> -> S)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `M` represents the type of the model weights used during training. Each time `next` is called, the server model is broadcast to each client using a distributor. Each client sums the gradients for each batch in its local dataset (without updating its model) to calculate, and averages the gradients based on their number of examples. These average gradients are then aggregated at the server, and are applied at the server using a `tf.keras.optimizers.Optimizer`. This implements the original FedSGD algorithm in [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that returns a `tf.keras.Optimizer`. The optimizer is used to apply client updates to the server model. model_distributor: An optional `DistributionProcess` that distributes the model weights on the server to the clients. If set to `None`, the distributor is constructed via `distributors.build_broadcast_process`. model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory` used to aggregate client updates on the server. If `None`, this is set to `tff.aggregators.MeanFactory`. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. Returns: A `tff.learning.templates.LearningProcess`. """ py_typecheck.check_callable(model_fn) @tensorflow_computation.tf_computation() def initial_model_weights_fn(): return model_utils.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result if model_distributor is None: model_distributor = distributors.build_broadcast_process( model_weights_type) if model_aggregator is None: model_aggregator = mean.MeanFactory() aggregator = model_aggregator.create( model_weights_type.trainable, computation_types.TensorType(tf.float32)) if metrics_aggregator is None: metrics_aggregator = metric_aggregator.sum_then_finalize client_work = _build_fed_sgd_client_work( model_fn, metrics_aggregator, use_experimental_simulation_loop=use_experimental_simulation_loop) finalizer = finalizers.build_apply_optimizer_finalizer( server_optimizer_fn, model_weights_type) return composers.compose_learning_process(initial_model_weights_fn, model_distributor, client_work, aggregator, finalizer)
def build_weighted_fed_prox( model_fn: Callable[[], model_lib.Model], proximal_strength: float, client_optimizer_fn: Union[optimizer_base.Optimizer, Callable[[], tf.keras.optimizers.Optimizer]], server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[ [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN, client_weighting: Optional[ client_weight_lib.ClientWeighting] = client_weight_lib.ClientWeighting. NUM_EXAMPLES, model_distributor: Optional[distributors.DistributionProcess] = None, model_aggregator: Optional[factory.WeightedAggregationFactory] = None, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, use_experimental_simulation_loop: bool = False ) -> learning_process.LearningProcess: """Builds a learning process that performs the FedProx algorithm. This function creates a `tff.learning.templates.LearningProcess` that performs example-weighted FedProx on client models. This algorithm behaves the same as federated averaging, except that it uses a proximal regularization term that encourages clients to not drift too far from the server model. The iterative process has the following methods inherited from `tff.learning.templates.LearningProcess`: * `initialize`: A `tff.Computation` with the functional type signature `( -> S@SERVER)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` representing the initial state of the server. * `next`: A `tff.Computation` with the functional type signature `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize`and `{B*}@CLIENTS` represents the client datasets. The output `L` contains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes. * `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `next`, and `M` represents the type of the model weights used during training. * `set_model_weights`: A `tff.Computation` with type signature `(<S, M> -> S)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `M` represents the type of the model weights used during training. Each time the `next` method is called, the server model is communicated to each client using the provided `model_distributor`. For each client, local training is performed using `client_optimizer_fn`. Each client computes the difference between the client model after training and the initial model. These model deltas are then aggregated at the server using a weighted aggregation function, according to `client_weighting`. The aggregate model delta is applied at the server using a server optimizer, as in the FedOpt framework proposed in [Reddi et al., 2021](https://arxiv.org/abs/2003.00295). Note: The default server optimizer function is `tf.keras.optimizers.SGD` with a learning rate of 1.0, which corresponds to adding the model delta to the current server model. This recovers the original FedProx algorithm in [Li et al., 2020](https://arxiv.org/abs/1812.06127). More sophisticated federated averaging procedures may use different learning rates or server optimizers. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. proximal_strength: A nonnegative float representing the parameter of FedProx's regularization term. When set to `0`, the algorithm reduces to FedAvg. Higher values prevent clients from moving too far from the server model during local training. client_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that returns a `tf.keras.Optimizer`. server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that returns a `tf.keras.Optimizer`. By default, this uses `tf.keras.optimizers.SGD` with a learning rate of 1.0. client_weighting: A member of `tff.learning.ClientWeighting` that specifies a built-in weighting method. By default, weighting by number of examples is used. model_distributor: An optional `DistributionProcess` that broadcasts the model weights on the server to the clients. If set to `None`, the distributor is constructed via `distributors.build_broadcast_process`. model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory` used to aggregate client updates on the server. If `None`, this is set to `tff.aggregators.MeanFactory`. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. If `None`, this is set to `tff.learning.metrics.sum_then_finalize`. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. Returns: A `tff.learning.templates.LearningProcess`. Raises: ValueError: If `proximal_parameter` is not a nonnegative float. """ if not isinstance(proximal_strength, float) or proximal_strength < 0.0: raise ValueError( 'proximal_strength must be a nonnegative float, found {}'.format( proximal_strength)) py_typecheck.check_callable(model_fn) @tensorflow_computation.tf_computation() def initial_model_weights_fn(): return model_utils.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result if model_distributor is None: model_distributor = distributors.build_broadcast_process( model_weights_type) if model_aggregator is None: model_aggregator = mean.MeanFactory() py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory) aggregator = model_aggregator.create( model_weights_type.trainable, computation_types.TensorType(tf.float32)) process_signature = aggregator.next.type_signature input_client_value_type = process_signature.parameter[1] result_server_value_type = process_signature.result[1] if input_client_value_type.member != result_server_value_type.member: raise TypeError( '`model_update_aggregation_factory` does not produce a ' 'compatible `AggregationProcess`. The processes must ' 'retain the type structure of the inputs on the ' f'server, but got {input_client_value_type.member} != ' f'{result_server_value_type.member}.') if metrics_aggregator is None: metrics_aggregator = metric_aggregator.sum_then_finalize client_work = model_delta_client_work.build_model_delta_client_work( model_fn=model_fn, optimizer=client_optimizer_fn, client_weighting=client_weighting, delta_l2_regularizer=proximal_strength, metrics_aggregator=metrics_aggregator, use_experimental_simulation_loop=use_experimental_simulation_loop) finalizer = finalizers.build_apply_optimizer_finalizer( server_optimizer_fn, model_weights_type) return composers.compose_learning_process(initial_model_weights_fn, model_distributor, client_work, aggregator, finalizer)
def build_weighted_fed_avg( model_fn: Callable[[], model_lib.Model], client_optimizer_fn: Union[optimizer_base.Optimizer, Callable[[], tf.keras.optimizers.Optimizer]], server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[ [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN, *, client_weighting: Optional[ client_weight_lib.ClientWeighting] = client_weight_lib.ClientWeighting. NUM_EXAMPLES, model_distributor: Optional[distributors.DistributionProcess] = None, model_aggregator: Optional[factory.WeightedAggregationFactory] = None, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, use_experimental_simulation_loop: bool = False, model: Optional[functional.FunctionalModel] = None, ) -> learning_process.LearningProcess: """Builds a learning process that performs federated averaging. This function creates a `tff.learning.templates.LearningProcess` that performs federated averaging on client models. The iterative process has the following methods inherited from `tff.learning.templates.LearningProcess`: * `initialize`: A `tff.Computation` with the functional type signature `( -> S@SERVER)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` representing the initial state of the server. * `next`: A `tff.Computation` with the functional type signature `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `{B*}@CLIENTS` represents the client datasets. The output `L` contains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes. * `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `next`, and `M` represents the type of the model weights used during training. * `set_model_weights`: A `tff.Computation` with type signature `(<S, M> -> S)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `M` represents the type of the model weights used during training. Each time the `next` method is called, the server model is communicated to each client using the provided `model_distributor`. For each client, local training is performed using `client_optimizer_fn`. Each client computes the difference between the client model after training and its initial model. These model deltas are then aggregated at the server using a weighted aggregation function, according to `client_weighting`. The aggregate model delta is applied at the server using a server optimizer. Note: the default server optimizer function is `tf.keras.optimizers.SGD` with a learning rate of 1.0, which corresponds to adding the model delta to the current server model. This recovers the original FedAvg algorithm in [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). More sophisticated federated averaging procedures may use different learning rates or server optimizers (this generalized FedAvg algorithm is described in [Reddi et al., 2021](https://arxiv.org/abs/2003.00295)). Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. Cannot be used with `model` argument. client_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that returns a `tf.keras.Optimizer`. server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that returns a `tf.keras.Optimizer`. By default, this uses `tf.keras.optimizers.SGD` with a learning rate of 1.0. client_weighting: A member of `tff.learning.ClientWeighting` that specifies a built-in weighting method. By default, weighting by number of examples is used. model_distributor: An optional `DistributionProcess` that distributes the model weights on the server to the clients. If set to `None`, the distributor is constructed via `tff.learning.templates.build_broadcast_process`. model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory` used to aggregate client updates on the server. If `None`, this is set to `tff.aggregators.MeanFactory`. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. If `None`, this is set to `tff.learning.metrics.sum_then_finalize`. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. model: A `tff.learning.models.FunctionalModel` to train. Cannot be used with `model_fn` and must be `None` if `model_fn` is not `None`. Returns: A `tff.learning.templates.LearningProcess`. """ if model is not None and model_fn is not None: raise ValueError( 'Must specify only one of `model` and `model_fn`, both ' 'were not `None`.') elif model is None and model_fn is None: raise ValueError( 'Must specify one of `model` and `model_fn`, both were ' '`None`.') if model_fn is not None: py_typecheck.check_callable(model_fn) else: py_typecheck.check_type(model, functional.FunctionalModel) py_typecheck.check_type(client_weighting, client_weight_lib.ClientWeighting) if model is not None: @tensorflow_computation.tf_computation() def initial_model_weights_fn(): return model_utils.ModelWeights( tuple( tf.convert_to_tensor(w) for w in model.initial_weights[0]), tuple( tf.convert_to_tensor(w) for w in model.initial_weights[1])) else: @tensorflow_computation.tf_computation() def initial_model_weights_fn(): return model_utils.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result if model_distributor is None: model_distributor = distributors.build_broadcast_process( model_weights_type) if model_aggregator is None: model_aggregator = mean.MeanFactory() py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory) if model is not None: trainable_weights_type, _ = model_weights_type model_update_type = trainable_weights_type else: model_update_type = model_weights_type.trainable aggregator = model_aggregator.create( model_update_type, computation_types.TensorType(tf.float32)) process_signature = aggregator.next.type_signature input_client_value_type = process_signature.parameter[1] result_server_value_type = process_signature.result[1] if input_client_value_type.member != result_server_value_type.member: raise TypeError( '`model_aggregator` does not produce a compatible ' '`AggregationProcess`. The processes must retain the type ' 'structure of the inputs on the server, but got ' f'{input_client_value_type.member} != ' f'{result_server_value_type.member}.') if metrics_aggregator is None: metrics_aggregator = metric_aggregator.sum_then_finalize if model is not None: client_work = model_delta_client_work.build_functional_model_delta_client_work( model=model, optimizer=client_optimizer_fn, client_weighting=client_weighting, metrics_aggregator=metrics_aggregator) else: client_work = model_delta_client_work.build_model_delta_client_work( model_fn=model_fn, optimizer=client_optimizer_fn, client_weighting=client_weighting, metrics_aggregator=metrics_aggregator, use_experimental_simulation_loop=use_experimental_simulation_loop) finalizer = finalizers.build_apply_optimizer_finalizer( server_optimizer_fn, model_weights_type) return composers.compose_learning_process(initial_model_weights_fn, model_distributor, client_work, aggregator, finalizer)