def test_clients_placed(self): x = _mock_data_of_type( computation_types.at_clients( computation_types.SequenceType(tf.int32))) val = intrinsics.sequence_map(self.over_ten_fn(), x) self.assert_value(val, '{bool*}@CLIENTS')
def create( self, value_type: factory.ValueType) -> aggregation_process.AggregationProcess: # Validate input args and value_type and parse out the TF dtypes. if value_type.is_tensor(): tf_dtype = value_type.dtype elif (value_type.is_struct_with_python() and type_analysis.is_structure_of_tensors(value_type)): tf_dtype = type_conversions.structure_from_tensor_type_tree( lambda x: x.dtype, value_type) else: raise TypeError('Expected `value_type` to be `TensorType` or ' '`StructWithPythonType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') # Check that all values are floats. if not type_analysis.is_structure_of_floats(value_type): raise TypeError('Component dtypes of `value_type` must all be floats. ' f'Found {repr(value_type)}.') if self._distortion_aggregation_factory is not None: distortion_aggregation_process = self._distortion_aggregation_factory.create( computation_types.to_type(tf.float32)) @tensorflow_computation.tf_computation(value_type, tf.float32) def discretize_fn(value, step_size): return _discretize_struct(value, step_size) @tensorflow_computation.tf_computation(discretize_fn.type_signature.result, tf.float32) def undiscretize_fn(value, step_size): return _undiscretize_struct(value, step_size, tf_dtype) @tensorflow_computation.tf_computation(value_type, tf.float32) def distortion_measurement_fn(value, step_size): reconstructed_value = undiscretize_fn( discretize_fn(value, step_size), step_size) err = tf.nest.map_structure(tf.subtract, reconstructed_value, value) squared_err = tf.nest.map_structure(tf.square, err) flat_squared_errs = [ tf.cast(tf.reshape(t, [-1]), tf.float32) for t in tf.nest.flatten(squared_err) ] all_squared_errs = tf.concat(flat_squared_errs, axis=0) mean_squared_err = tf.reduce_mean(all_squared_errs) return mean_squared_err inner_agg_process = self._inner_agg_factory.create( discretize_fn.type_signature.result) @federated_computation.federated_computation() def init_fn(): state = collections.OrderedDict( step_size=intrinsics.federated_value(self._step_size, placements.SERVER), inner_agg_process=inner_agg_process.initialize()) return intrinsics.federated_zip(state) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): server_step_size = state['step_size'] client_step_size = intrinsics.federated_broadcast(server_step_size) discretized_value = intrinsics.federated_map(discretize_fn, (value, client_step_size)) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) undiscretized_agg_value = intrinsics.federated_map( undiscretize_fn, (inner_agg_output.result, server_step_size)) new_state = collections.OrderedDict( step_size=server_step_size, inner_agg_process=inner_agg_output.state) measurements = collections.OrderedDict( deterministic_discretization=inner_agg_output.measurements) if self._distortion_aggregation_factory is not None: distortions = intrinsics.federated_map(distortion_measurement_fn, (value, client_step_size)) aggregate_distortion = distortion_aggregation_process.next( distortion_aggregation_process.initialize(), distortions).result measurements['distortion'] = aggregate_distortion return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=undiscretized_agg_value, measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn)
def create(self, value_type): # Validate input args and value_type and parse out the TF dtypes. if value_type.is_tensor(): tf_dtype = value_type.dtype elif (value_type.is_struct_with_python() and type_analysis.is_structure_of_tensors(value_type)): if self._prior_norm_bound: raise TypeError( 'If `prior_norm_bound` is specified, `value_type` must ' f'be `TensorType`. Found type: {repr(value_type)}.') tf_dtype = type_conversions.structure_from_tensor_type_tree( lambda x: x.dtype, value_type) else: raise TypeError( 'Expected `value_type` to be `TensorType` or ' '`StructWithPythonType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') # Check that all values are floats. if not type_analysis.is_structure_of_floats(value_type): raise TypeError( 'Component dtypes of `value_type` must all be floats. ' f'Found {repr(value_type)}.') discretize_fn = _build_discretize_fn(value_type, self._stochastic, self._beta) @tensorflow_computation.tf_computation( discretize_fn.type_signature.result, tf.float32) def undiscretize_fn(value, scale_factor): return _undiscretize_struct(value, scale_factor, tf_dtype) inner_value_type = discretize_fn.type_signature.result inner_agg_process = self._inner_agg_factory.create(inner_value_type) @federated_computation.federated_computation() def init_fn(): state = collections.OrderedDict( scale_factor=intrinsics.federated_value( self._scale_factor, placements.SERVER), prior_norm_bound=intrinsics.federated_value( self._prior_norm_bound, placements.SERVER), inner_agg_process=inner_agg_process.initialize()) return intrinsics.federated_zip(state) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): server_scale_factor = state['scale_factor'] client_scale_factor = intrinsics.federated_broadcast( server_scale_factor) server_prior_norm_bound = state['prior_norm_bound'] prior_norm_bound = intrinsics.federated_broadcast( server_prior_norm_bound) discretized_value = intrinsics.federated_map( discretize_fn, (value, client_scale_factor, prior_norm_bound)) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) undiscretized_agg_value = intrinsics.federated_map( undiscretize_fn, (inner_agg_output.result, server_scale_factor)) new_state = collections.OrderedDict( scale_factor=server_scale_factor, prior_norm_bound=server_prior_norm_bound, inner_agg_process=inner_agg_output.state) measurements = collections.OrderedDict( discretize=inner_agg_output.measurements) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=undiscretized_agg_value, measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn)
def create_whimsy_intrinsic_def_federated_sum(): value = intrinsic_defs.FEDERATED_SUM type_signature = computation_types.FunctionType( computation_types.at_clients(tf.float32), computation_types.at_server(tf.float32)) return value, type_signature
def create_whimsy_value_at_clients(number_of_clients: int = 3): """Returns a Python value and federated type at clients.""" value = [float(x) for x in range(10, number_of_clients + 10)] type_signature = computation_types.at_clients(tf.float32) return value, type_signature
def test_roundtrip_with_nonempty_tuple_clients_argument(self): value = tuple(range(10)) type_signature = computation_types.at_clients(tf.int32) self.assertRoundTripEqual(value, type_signature, value)
def create_whimsy_intrinsic_def_federated_broadcast(): value = intrinsic_defs.FEDERATED_BROADCAST type_signature = computation_types.FunctionType( computation_types.at_server(tf.float32), computation_types.at_clients(tf.float32, all_equal=True)) return value, type_signature
def build_federated_evaluation( model_fn: Callable[[], model_lib.Model], broadcast_process: Optional[measured_process.MeasuredProcess] = None, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, use_experimental_simulation_loop: bool = False, ) -> computation_base.Computation: """Builds the TFF computation for federated evaluation of the given model. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the model weights on the server to the clients. It must support the signature `(input_values@SERVER -> output_values@CLIENTS)` and have empty state. If set to default None, the server model is broadcast to the clients using the default tff.federated_broadcast. metrics_aggregator: An optional function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a federated TFF computation of the following type signature `local_unfinalized_metrics@CLIENTS -> aggregated_metrics@SERVER`. If `None`, uses `tff.learning.metrics.sum_then_finalize`, which returns a federated TFF computation that sums the unfinalized metrics from `CLIENTS`, and then applies the corresponding metric finalizers at `SERVER`. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. Returns: A federated computation (an instance of `tff.Computation`) that accepts model parameters and federated data, and returns the evaluation metrics. """ if broadcast_process is not None: if not isinstance(broadcast_process, measured_process.MeasuredProcess): raise ValueError( '`broadcast_process` must be a `MeasuredProcess`, got ' f'{type(broadcast_process)}.') if iterative_process.is_stateful(broadcast_process): raise ValueError( 'Cannot create a federated evaluation with a stateful ' 'broadcast process, must be stateless (have empty state), has state: ' f'{broadcast_process.initialize.type_signature.result!r}') # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. # TODO(b/124477628): Ideally replace the need for stamping throwaway models # with some other mechanism. with tf.Graph().as_default(): model = model_fn() model_weights_type = model_utils.weights_type_from_model(model) batch_type = computation_types.to_type(model.input_spec) unfinalized_metrics_type = type_conversions.type_from_tensors( model.report_local_unfinalized_metrics()) if metrics_aggregator is not None: metrics_aggregation_computation = metrics_aggregator( model.metric_finalizers(), unfinalized_metrics_type) else: metrics_aggregation_computation = aggregator.sum_then_finalize( model.metric_finalizers(), unfinalized_metrics_type) @federated_computation.federated_computation( computation_types.at_server(model_weights_type), computation_types.at_clients(SequenceType(batch_type))) def server_eval(server_model_weights, federated_dataset): client_eval = build_local_evaluation(model_fn, model_weights_type, batch_type, use_experimental_simulation_loop) if broadcast_process is not None: # TODO(b/179091838): Zip the measurements from the broadcast_process with # the result of `model_metrics` below to avoid dropping these metrics. broadcast_output = broadcast_process.next( broadcast_process.initialize(), server_model_weights) client_outputs = intrinsics.federated_map( client_eval, (broadcast_output.result, federated_dataset)) else: client_outputs = intrinsics.federated_map(client_eval, [ intrinsics.federated_broadcast(server_model_weights), federated_dataset ]) model_metrics = metrics_aggregation_computation( client_outputs.local_outputs) return intrinsics.federated_zip( collections.OrderedDict(eval=model_metrics)) return server_eval
# # @federated_computation # def federated_aggregate(x, zero, accumulate, merge, report): # a = generic_partial_reduce(x, zero, accumulate, INTERMEDIATE_AGGREGATORS) # b = generic_reduce(a, zero, merge, SERVER) # c = generic_map(report, b) # return c # # Actual implementations might vary. # # Type signature: <{T}@CLIENTS,U,(<U,T>->U),(<U,U>->U),(U->R)> -> R@SERVER FEDERATED_AGGREGATE = IntrinsicDef( 'FEDERATED_AGGREGATE', 'federated_aggregate', computation_types.FunctionType(parameter=[ computation_types.at_clients(computation_types.AbstractType('T')), computation_types.AbstractType('U'), type_factory.reduction_op(computation_types.AbstractType('U'), computation_types.AbstractType('T')), type_factory.binary_op(computation_types.AbstractType('U')), computation_types.FunctionType(computation_types.AbstractType('U'), computation_types.AbstractType('R')) ], result=computation_types.at_server( computation_types.AbstractType('R'))), aggregation_kind=AggregationKind.DEFAULT) # Applies a given function to a value on the server. # # Type signature: <(T->U),T@SERVER> -> U@SERVER FEDERATED_APPLY = IntrinsicDef(
def build_federated_evaluation( model_fn: training_process.ModelFn, *, # Callers pass below args by name. loss_fn: training_process.LossFn, metrics_fn: Optional[training_process.MetricsFn] = None, reconstruction_optimizer_fn: training_process.OptimizerFn = functools. partial(tf.keras.optimizers.SGD, 0.1), dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None, broadcast_process: Optional[measured_process_lib.MeasuredProcess] = None, ) -> computation_base.Computation: """Builds a `tff.Computation` for evaluating a reconstruction `Model`. The returned computation proceeds in two stages: (1) reconstruction and (2) evaluation. During the reconstruction stage, local variables are reconstructed by freezing global variables and training using `reconstruction_optimizer_fn`. During the evaluation stage, the reconstructed local variables and global variables are evaluated using the provided `loss_fn` and `metrics_fn`. Usage of returned computation: eval_comp = build_federated_evaluation(...) metrics = eval_comp(tff.learning.reconstruction.get_global_variables(model), federated_data) Args: model_fn: A no-arg function that returns a `tff.learning.reconstruction.Model`. This method must *not* capture Tensorflow tensors or variables and use them. Must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to reconstruct and evaluate the model. The loss will be applied to the model's outputs during the evaluation stage. The final loss metric is the example-weighted mean loss across batches (and across clients). metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s to evaluate the model. The metrics will be applied to the model's outputs during the evaluation stage. Final metric values are the example-weighted mean of metric values across batches (and across clients). If None, no metrics are applied. reconstruction_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` used to reconstruct the local variables with the global ones frozen. dataset_split_fn: A `tff.learning.reconstruction.DatasetSplitFn` taking in a single TF dataset and producing two TF datasets. The first is iterated over during reconstruction, and the second is iterated over during evaluation. This can be used to preprocess datasets to e.g. iterate over them for multiple epochs or use disjoint data for reconstruction and evaluation. If None, split client data in half for each user, using one half for reconstruction and the other for evaluation. See `tff.learning.reconstruction.build_dataset_split_fn` for options. broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the model weights on the server to the clients. It must support the signature `(input_values@SERVER -> output_values@CLIENT)` and have empty state. If set to default None, the server model is broadcast to the clients using the default `tff.federated_broadcast`. Raises: TypeError: if `broadcast_process` does not have the expected signature or has non-empty state. Returns: A `tff.Computation` that accepts global model parameters and federated data and returns example-weighted evaluation loss and metrics. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. with tf.Graph().as_default(): model = model_fn() global_weights = reconstruction_utils.get_global_variables(model) model_weights_type = type_conversions.type_from_tensors(global_weights) batch_type = computation_types.to_type(model.input_spec) metrics = [keras_utils.MeanLossMetric(loss_fn())] if metrics_fn is not None: metrics.extend(metrics_fn()) federated_output_computation = ( keras_utils.federated_output_computation_from_metrics(metrics)) # Remove unneeded variables to avoid polluting namespace. del model del global_weights del metrics if dataset_split_fn is None: dataset_split_fn = reconstruction_utils.build_dataset_split_fn( split_dataset=True) if broadcast_process is None: broadcast_process = optimizer_utils.build_stateless_broadcaster( model_weights_type=model_weights_type) if not optimizer_utils.is_valid_broadcast_process(broadcast_process): raise TypeError( 'broadcast_process type signature does not conform to expected ' 'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).' ' Got: {t}'.format(t=broadcast_process.next.type_signature)) if iterative_process.is_stateful(broadcast_process): raise TypeError( f'Eval broadcast_process must be stateless (have an empty ' 'state), has state ' f'{broadcast_process.initialize.type_signature.result!r}') @tensorflow_computation.tf_computation( model_weights_type, computation_types.SequenceType(batch_type)) def client_computation(incoming_model_weights: computation_types.Type, client_dataset: computation_types.SequenceType): """Reconstructs and evaluates with `incoming_model_weights`.""" client_model = model_fn() client_global_weights = reconstruction_utils.get_global_variables( client_model) client_local_weights = reconstruction_utils.get_local_variables( client_model) metrics = [keras_utils.MeanLossMetric(loss_fn())] if metrics_fn is not None: metrics.extend(metrics_fn()) client_loss = loss_fn() reconstruction_optimizer = reconstruction_optimizer_fn() @tf.function def reconstruction_reduce_fn(num_examples_sum, batch): """Runs reconstruction training on local client batch.""" with tf.GradientTape() as tape: output = client_model.forward_pass(batch, training=True) batch_loss = client_loss(y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, client_local_weights.trainable) reconstruction_optimizer.apply_gradients( zip(gradients, client_local_weights.trainable)) return num_examples_sum + output.num_examples @tf.function def evaluation_reduce_fn(num_examples_sum, batch): """Runs evaluation on client batch without training.""" output = client_model.forward_pass(batch, training=False) # Update each metric. for metric in metrics: metric.update_state(y_true=output.labels, y_pred=output.predictions) return num_examples_sum + output.num_examples @tf.function def tf_client_computation(incoming_model_weights, client_dataset): """Reconstructs and evaluates with `incoming_model_weights`.""" recon_dataset, eval_dataset = dataset_split_fn(client_dataset) # Assign incoming global weights to `client_model` before reconstruction. tf.nest.map_structure(lambda v, t: v.assign(t), client_global_weights, incoming_model_weights) recon_dataset.reduce(tf.constant(0), reconstruction_reduce_fn) eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn) eval_local_outputs = keras_utils.read_metric_variables(metrics) return eval_local_outputs return tf_client_computation(incoming_model_weights, client_dataset) @federated_computation.federated_computation( computation_types.at_server(model_weights_type), computation_types.at_clients( computation_types.SequenceType(batch_type))) def server_eval(server_model_weights: computation_types.FederatedType, federated_dataset: computation_types.FederatedType): broadcast_output = broadcast_process.next( broadcast_process.initialize(), server_model_weights) client_outputs = intrinsics.federated_map( client_computation, [broadcast_output.result, federated_dataset]) aggregated_client_outputs = federated_output_computation( client_outputs) measurements = intrinsics.federated_zip( collections.OrderedDict(broadcast=broadcast_output.measurements, eval=aggregated_client_outputs)) return measurements return server_eval
def _build_mime_lite_client_work( model_fn: Callable[[], model_lib.Model], optimizer: optimizer_base.Optimizer, client_weighting: client_weight_lib.ClientWeighting, full_gradient_aggregator: Optional[ factory.WeightedAggregationFactory] = None, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, use_experimental_simulation_loop: bool = False ) -> client_works.ClientWorkProcess: """Creates a `ClientWorkProcess` for Mime Lite. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. optimizer: A `tff.learning.optimizers.Optimizer` which will be used for both creating and updating a global optimizer state, as well as optimization at clients given the global state, which is fixed during the optimization. client_weighting: A member of `tff.learning.ClientWeighting` that specifies a built-in weighting method. full_gradient_aggregator: An optional `tff.aggregators.WeightedAggregationFactory` used to aggregate the full gradients on client datasets. If `None`, this is set to `tff.aggregators.MeanFactory`. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. If `None`, this is set to `tff.learning.metrics.sum_then_finalize`. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. Returns: A `ClientWorkProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_type(optimizer, optimizer_base.Optimizer) py_typecheck.check_type(client_weighting, client_weight_lib.ClientWeighting) if full_gradient_aggregator is None: full_gradient_aggregator = mean.MeanFactory() py_typecheck.check_type(full_gradient_aggregator, factory.WeightedAggregationFactory) if metrics_aggregator is None: metrics_aggregator = metric_aggregator.sum_then_finalize with tf.Graph().as_default(): # Wrap model construction in a graph to avoid polluting the global context # with variables created for this model. model = model_fn() unfinalized_metrics_type = type_conversions.type_from_tensors( model.report_local_unfinalized_metrics()) metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(), unfinalized_metrics_type) data_type = computation_types.SequenceType(model.input_spec) weights_type = model_utils.weights_type_from_model(model) weight_tensor_specs = type_conversions.type_to_tf_tensor_specs( weights_type) full_gradient_aggregator = full_gradient_aggregator.create( weights_type.trainable, computation_types.TensorType(tf.float32)) @federated_computation.federated_computation def init_fn(): specs = weight_tensor_specs.trainable optimizer_state = intrinsics.federated_eval( tensorflow_computation.tf_computation( lambda: optimizer.initialize(specs)), placements.SERVER) aggregator_state = full_gradient_aggregator.initialize() return intrinsics.federated_zip((optimizer_state, aggregator_state)) client_update_fn = _build_client_update_fn_for_mime_lite( model_fn, optimizer, client_weighting, use_experimental_simulation_loop) @tensorflow_computation.tf_computation( init_fn.type_signature.result.member[0], weights_type.trainable) def update_optimizer_state(state, aggregate_gradient): whimsy_weights = tf.nest.map_structure( lambda g: tf.zeros(g.shape, g.dtype), aggregate_gradient) updated_state, _ = optimizer.next(state, whimsy_weights, aggregate_gradient) return updated_state @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(weights_type), computation_types.at_clients(data_type)) def next_fn(state, weights, client_data): optimizer_state, aggregator_state = state optimizer_state_at_clients = intrinsics.federated_broadcast( optimizer_state) client_result, model_outputs, full_gradient = ( intrinsics.federated_map( client_update_fn, (optimizer_state_at_clients, weights, client_data))) full_gradient_agg_output = full_gradient_aggregator.next( aggregator_state, full_gradient, client_result.update_weight) updated_optimizer_state = intrinsics.federated_map( update_optimizer_state, (optimizer_state, full_gradient_agg_output.result)) new_state = intrinsics.federated_zip( (updated_optimizer_state, full_gradient_agg_output.state)) train_metrics = metrics_aggregation_fn(model_outputs) measurements = intrinsics.federated_zip( collections.OrderedDict(train=train_metrics)) return measured_process.MeasuredProcessOutput(new_state, client_result, measurements) return client_works.ClientWorkProcess(init_fn, next_fn)
def build_scheduled_client_work( model_fn: Callable[[], model_lib.Model], learning_rate_fn: Callable[[int], float], optimizer_fn: Callable[[float], TFFOrKerasOptimizer], metrics_aggregator: Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation], use_experimental_simulation_loop: bool = False ) -> client_works.ClientWorkProcess: """Creates a `ClientWorkProcess` for federated averaging. This `ClientWorkProcess` creates a state containing the current round number, which is incremented at each call to `ClientWorkProcess.next`. This integer round number is used to call `optimizer_fn(round_num)`, in order to construct the proper optimizer. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. learning_rate_fn: A callable accepting an integer round number and returning a float to be used as a learning rate for the optimizer. That is, the client work will call `optimizer_fn(learning_rate_fn(round_num))` where `round_num` is the integer round number. optimizer_fn: A callable accepting a float learning rate, and returning a `tff.learning.optimizers.Optimizer` or a `tf.keras.Optimizer`. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. Returns: A `ClientWorkProcess`. """ with tf.Graph().as_default(): # Wrap model construction in a graph to avoid polluting the global context # with variables created for this model. whimsy_model = model_fn() whimsy_optimizer = optimizer_fn(1.0) unfinalized_metrics_type = type_conversions.type_from_tensors( whimsy_model.report_local_unfinalized_metrics()) metrics_aggregation_fn = metrics_aggregator( whimsy_model.metric_finalizers(), unfinalized_metrics_type) data_type = computation_types.SequenceType(whimsy_model.input_spec) weights_type = model_utils.weights_type_from_model(whimsy_model) if isinstance(whimsy_optimizer, optimizer_base.Optimizer): build_client_update_fn = model_delta_client_work.build_model_delta_update_with_tff_optimizer else: build_client_update_fn = model_delta_client_work.build_model_delta_update_with_keras_optimizer @tensorflow_computation.tf_computation(weights_type, data_type, tf.int32) def client_update_computation(initial_model_weights, dataset, round_num): learning_rate = learning_rate_fn(round_num) optimizer = optimizer_fn(learning_rate) client_update = build_client_update_fn( model_fn=model_fn, weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES, use_experimental_simulation_loop=use_experimental_simulation_loop) return client_update(optimizer, initial_model_weights, dataset) @federated_computation.federated_computation def init_fn(): return intrinsics.federated_value(0, placements.SERVER) @tensorflow_computation.tf_computation(tf.int32) @tf.function def add_one(x): return x + 1 @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(weights_type), computation_types.at_clients(data_type)) def next_fn(state, weights, client_data): round_num_at_clients = intrinsics.federated_broadcast(state) client_result, model_outputs = intrinsics.federated_map( client_update_computation, (weights, client_data, round_num_at_clients)) updated_state = intrinsics.federated_map(add_one, state) train_metrics = metrics_aggregation_fn(model_outputs) measurements = intrinsics.federated_zip( collections.OrderedDict(train=train_metrics)) return measured_process.MeasuredProcessOutput(updated_state, client_result, measurements) return client_works.ClientWorkProcess(init_fn, next_fn)
def create( self, metric_finalizers: model_lib.MetricFinalizersType, local_unfinalized_metrics_type: computation_types.StructWithPythonType, initial_unfinalized_metrics: Optional[OrderedDict[str, Any]] = None ) -> aggregation_process.AggregationProcess: """Creates a `tff.templates.AggregationProcess` for metrics aggregation. Args: metric_finalizers: An `OrderedDict` of metric names to finalizers, should have same keys as the unfinalized metrics. A finalizer is a function (typically a `tf.function` decorated callable or a `tff.tf_computation` decoreated TFF Computation) that takes in a metric's unfinalized values, and returns the finalized metric values. This can be obtained from `tff.learning.Model.metric_finalizers()`. local_unfinalized_metrics_type: A `tff.types.StructWithPythonType` (with `OrderedDict` as the Python container) of a client's local unfinalized metrics. Let `local_unfinalized_metrics` be the output of `tff.learning.Model.report_local_unfinalized_metrics()`, its type can be obtained by `tff.framework.type_from_tensors(local_unfinalized_metrics)`. initial_unfinalized_metrics: Optional. An `OrderedDict` of metric names to the initial values of local unfinalized metrics, its structure should match that of `local_unfinalized_metrics_type`. If not specified, defaults to zero. Returns: An instance of `tff.templates.AggregationProcess`. Raises: TypeError: If any argument type mismatches; if the metric finalizers mismatch the type of local unfinalized metrics; if the initial unfinalized metrics mismatch the type of local unfinalized metrics. """ aggregator.check_metric_finalizers(metric_finalizers) aggregator.check_local_unfinalzied_metrics_type( local_unfinalized_metrics_type) aggregator.check_finalizers_matches_unfinalized_metrics( metric_finalizers, local_unfinalized_metrics_type) inner_summation_process = sum_factory_lib.SumFactory().create( local_unfinalized_metrics_type) @federated_computation.federated_computation def init_fn(): unfinalized_metrics_accumulators = ( _intialize_unfinalized_metrics_accumulators( local_unfinalized_metrics_type, initial_unfinalized_metrics)) return intrinsics.federated_zip((inner_summation_process.initialize(), unfinalized_metrics_accumulators)) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(local_unfinalized_metrics_type)) def next_fn(state, unfinalized_metrics) -> measured_process.MeasuredProcessOutput: inner_summation_state, unfinalized_metrics_accumulators = state inner_summation_output = inner_summation_process.next( inner_summation_state, unfinalized_metrics) summed_unfinalized_metrics = inner_summation_output.result inner_summation_state = inner_summation_output.state @tensorflow_computation.tf_computation(local_unfinalized_metrics_type, local_unfinalized_metrics_type) def add_unfinalized_metrics(unfinalized_metrics, summed_unfinalized_metrics): return tf.nest.map_structure(tf.add, unfinalized_metrics, summed_unfinalized_metrics) unfinalized_metrics_accumulators = intrinsics.federated_map( add_unfinalized_metrics, (unfinalized_metrics_accumulators, summed_unfinalized_metrics)) finalizer_computation = _build_finalizer_computation( metric_finalizers, local_unfinalized_metrics_type) current_round_metrics = intrinsics.federated_map( finalizer_computation, summed_unfinalized_metrics) total_rounds_metrics = intrinsics.federated_map( finalizer_computation, unfinalized_metrics_accumulators) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip( (inner_summation_state, unfinalized_metrics_accumulators)), result=intrinsics.federated_zip( (current_round_metrics, total_rounds_metrics)), measurements=inner_summation_output.measurements) return aggregation_process.AggregationProcess(init_fn, next_fn)
def test_clients_placed(self): x = _mock_data_of_type( computation_types.at_clients( computation_types.SequenceType(tf.int32))) val = intrinsics.sequence_sum(x) self.assert_value(val, '{int32}@CLIENTS')
async def compute_federated_select( self, arg: FederatedResolvingStrategyValue ) -> FederatedResolvingStrategyValue: client_keys_type, max_key_type, server_val_type, select_fn_type = ( arg.type_signature) py_typecheck.check_type(arg.internal_representation, structure.Struct) client_keys, max_key, server_val, select_fn = arg.internal_representation # We slice up the value as-needed, so `max_key` is not used. del max_key, max_key_type del server_val_type # unused py_typecheck.check_type(client_keys, list) py_typecheck.check_type(server_val, list) server_val_at_server = server_val[0] py_typecheck.check_type(server_val_at_server, executor_value_base.ExecutorValue) py_typecheck.check_type(select_fn, pb.Computation) server = self._target_executors[placements.SERVER][0] clients = self._target_executors[placements.CLIENTS] single_key_type = computation_types.TensorType(tf.int32) client_keys_type.member.check_tensor() if (client_keys_type.member.dtype != tf.int32 or client_keys_type.member.shape.rank != 1): raise TypeError( f'Unexpected `client_keys_type`: {client_keys_type}') num_keys_per_client: int = client_keys_type.member.shape.dims[0].value unplaced_result_type = computation_types.SequenceType( select_fn_type.result) select_fn_at_server = await server.create_value( select_fn, select_fn_type) index_fn_at_server = await executor_utils.embed_indexing_operator( server, client_keys_type.member, single_key_type) async def select_single_key(keys_at_server, key_index): # Grab the `key_index`th key from the keys tensor. index_arg = await server.create_struct( structure.Struct([ (None, keys_at_server), (None, await server.create_value(key_index, single_key_type)), ])) key_at_server = await server.create_call(index_fn_at_server, index_arg) select_fn_arg = await server.create_struct( structure.Struct([ (None, server_val_at_server), (None, key_at_server), ])) selected = await server.create_call(select_fn_at_server, select_fn_arg) return await selected.compute() async def select_single_client(client, keys_at_client): keys_at_server = await server.create_value( await keys_at_client.compute(), client_keys_type.member) unplaced_values = await asyncio.gather(*[ select_single_key(keys_at_server, i) for i in range(num_keys_per_client) ]) return await client.create_value(unplaced_values, unplaced_result_type) return FederatedResolvingStrategyValue( list(await asyncio.gather(*[ select_single_client(client, keys_at_client) for client, keys_at_client in zip(clients, client_keys) ])), computation_types.at_clients(unplaced_result_type))
def build_model_delta_client_work( model_fn: Callable[[], model_lib.Model], optimizer: Union[optimizer_base.Optimizer, Callable[[], tf.keras.optimizers.Optimizer]], client_weighting: client_weight_lib.ClientWeighting, delta_l2_regularizer: float = 0.0, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, *, use_experimental_simulation_loop: bool = False ) -> client_works.ClientWorkProcess: """Creates a `ClientWorkProcess` for federated averaging. This client work is constructed in slightly different manners depending on whether `optimizer` is a `tff.learning.optimizers.Optimizer`, or a no-arg callable returning a `tf.keras.optimizers.Optimizer`. If it is a `tff.learning.optimizers.Optimizer`, we avoid creating `tf.Variable`s associated with the optimizer state within the scope of the client work, as they are not necessary. This also means that the client's model weights are updated by computing `optimizer.next` and then assigning the result to the model weights (while a `tf.keras.optimizers.Optimizer` will modify the model weight in place using `optimizer.apply_gradients`). Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. optimizer: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that returns a `tf.keras.Optimizer`. client_weighting: A `tff.learning.ClientWeighting` value. delta_l2_regularizer: A nonnegative float representing the parameter of the L2-regularization term applied to the delta from initial model weights during training. Values larger than 0.0 prevent clients from moving too far from the server model during local training. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. If `None`, this is set to `tff.learning.metrics.sum_then_finalize`. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. Returns: A `ClientWorkProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_type(client_weighting, client_weight_lib.ClientWeighting) py_typecheck.check_type(delta_l2_regularizer, float) if delta_l2_regularizer < 0.0: raise ValueError(f'Provided delta_l2_regularizer must be non-negative,' f'but found: {delta_l2_regularizer}') if not (isinstance(optimizer, optimizer_base.Optimizer) or callable(optimizer)): raise TypeError( 'Provided optimizer must a either a tff.learning.optimizers.Optimizer ' 'or a no-arg callable returning an tf.keras.optimizers.Optimizer.') if metrics_aggregator is None: metrics_aggregator = aggregator.sum_then_finalize with tf.Graph().as_default(): # Wrap model construction in a graph to avoid polluting the global context # with variables created for this model. model = model_fn() unfinalized_metrics_type = type_conversions.type_from_tensors( model.report_local_unfinalized_metrics()) metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(), unfinalized_metrics_type) data_type = computation_types.SequenceType(model.input_spec) weights_type = model_utils.weights_type_from_model(model) if isinstance(optimizer, optimizer_base.Optimizer): @tensorflow_computation.tf_computation(weights_type, data_type) def client_update_computation(initial_model_weights, dataset): client_update = build_model_delta_update_with_tff_optimizer( model_fn=model_fn, weighting=client_weighting, delta_l2_regularizer=delta_l2_regularizer, use_experimental_simulation_loop= use_experimental_simulation_loop) return client_update(optimizer, initial_model_weights, dataset) else: @tensorflow_computation.tf_computation(weights_type, data_type) def client_update_computation(initial_model_weights, dataset): keras_optimizer = optimizer() client_update = build_model_delta_update_with_keras_optimizer( model_fn=model_fn, weighting=client_weighting, delta_l2_regularizer=delta_l2_regularizer, use_experimental_simulation_loop= use_experimental_simulation_loop) return client_update(keras_optimizer, initial_model_weights, dataset) @federated_computation.federated_computation def init_fn(): return intrinsics.federated_value((), placements.SERVER) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(weights_type), computation_types.at_clients(data_type)) def next_fn(state, weights, client_data): client_result, model_outputs = intrinsics.federated_map( client_update_computation, (weights, client_data)) train_metrics = metrics_aggregation_fn(model_outputs) measurements = intrinsics.federated_zip( collections.OrderedDict(train=train_metrics)) return measured_process.MeasuredProcessOutput(state, client_result, measurements) return client_works.ClientWorkProcess(init_fn, next_fn)
def build_federated_evaluation( model_fn: Callable[[], model_lib.Model], broadcast_process: Optional[measured_process.MeasuredProcess] = None, use_experimental_simulation_loop: bool = False, ) -> computation_base.Computation: """Builds the TFF computation for federated evaluation of the given model. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the model weights on the server to the clients. It must support the signature `(input_values@SERVER -> output_values@CLIENTS)` and have empty state. If set to default None, the server model is broadcast to the clients using the default tff.federated_broadcast. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. Returns: A federated computation (an instance of `tff.Computation`) that accepts model parameters and federated data, and returns the evaluation metrics as aggregated by `tff.learning.Model.federated_output_computation`. """ if broadcast_process is not None: if not isinstance(broadcast_process, measured_process.MeasuredProcess): raise ValueError( '`broadcast_process` must be a `MeasuredProcess`, got ' f'{type(broadcast_process)}.') if optimizer_utils.is_stateful_process(broadcast_process): raise ValueError( 'Cannot create a federated evaluation with a stateful ' 'broadcast process, must be stateless, has state: ' f'{broadcast_process.initialize.type_signature.result!r}') # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. # TODO(b/124477628): Ideally replace the need for stamping throwaway models # with some other mechanism. with tf.Graph().as_default(): model = model_fn() model_weights_type = model_utils.weights_type_from_model(model) batch_type = computation_types.to_type(model.input_spec) @computations.tf_computation(model_weights_type, SequenceType(batch_type)) @tf.function def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" with tf.init_scope(): model = model_fn() model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda v, t: v.assign(t), model_weights, incoming_model_weights) def reduce_fn(num_examples, batch): model_output = model.forward_pass(batch, training=False) if model_output.num_examples is None: # Compute shape from the size of the predictions if model didn't use the # batch size. return num_examples + tf.shape(model_output.predictions, out_type=tf.int64)[0] else: return num_examples + tf.cast(model_output.num_examples, tf.int64) dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop) num_examples = dataset_reduce_fn( reduce_fn=reduce_fn, dataset=dataset, initial_state_fn=lambda: tf.zeros([], dtype=tf.int64)) return collections.OrderedDict( local_outputs=model.report_local_outputs(), num_examples=num_examples) @computations.federated_computation( computation_types.at_server(model_weights_type), computation_types.at_clients(SequenceType(batch_type))) def server_eval(server_model_weights, federated_dataset): if broadcast_process is not None: # TODO(b/179091838): Zip the measurements from the broadcast_process with # the result of `model.federated_output_computation` below to avoid # dropping these metrics. broadcast_output = broadcast_process.next( broadcast_process.initialize(), server_model_weights) client_outputs = intrinsics.federated_map( client_eval, (broadcast_output.result, federated_dataset)) else: client_outputs = intrinsics.federated_map(client_eval, [ intrinsics.federated_broadcast(server_model_weights), federated_dataset ]) model_metrics = model.federated_output_computation( client_outputs.local_outputs) statistics = collections.OrderedDict( num_examples=intrinsics.federated_sum(client_outputs.num_examples)) return intrinsics.federated_zip( collections.OrderedDict(eval=model_metrics, stat=statistics)) return server_eval
def build_functional_model_delta_client_work( *, model: functional.FunctionalModel, optimizer: optimizer_base.Optimizer, client_weighting: client_weight_lib.ClientWeighting, delta_l2_regularizer: float = 0.0, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, ) -> client_works.ClientWorkProcess: """Creates a `ClientWorkProcess` for federated averaging. This differs from `tff.learning.templates.build_model_delta_client_work` in that it only accepts `tff.learning.models.FunctionalModel` and `tff.learning.optimizers.Optimizer` type arguments, resulting in TensorFlow graphs that do not contain `tf.Variable` operations. Args: model: A `tff.learning.models.FunctionalModel` to train. optimizer: A `tff.learning.optimizers.Optimizer` to use for local, on-client optimization. client_weighting: A `tff.learning.ClientWeighting` value. delta_l2_regularizer: A nonnegative float representing the parameter of the L2-regularization term applied to the delta from initial model weights during training. Values larger than 0.0 prevent clients from moving too far from the server model during local training. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. If `None`, this is set to `tff.learning.metrics.sum_then_finalize`. Returns: A `ClientWorkProcess`. """ py_typecheck.check_type(model, functional.FunctionalModel) py_typecheck.check_type(optimizer, optimizer_base.Optimizer) py_typecheck.check_type(client_weighting, client_weight_lib.ClientWeighting) py_typecheck.check_type(delta_l2_regularizer, float) if delta_l2_regularizer < 0.0: raise ValueError(f'Provided delta_l2_regularizer must be non-negative,' f'but found: {delta_l2_regularizer}') if metrics_aggregator is None: metrics_aggregator = aggregator.sum_then_finalize # TODO(b/229612282): Add metrics implementation. data_type = computation_types.SequenceType(model.input_spec) def ndarray_to_tensorspec(ndarray): return tf.TensorSpec(shape=ndarray.shape, dtype=tf.dtypes.as_dtype(ndarray.dtype)) # Wrap in a `ModelWeights` structure that is required by the `finalizer.` weights_type = model_utils.ModelWeights( tuple(ndarray_to_tensorspec(w) for w in model.initial_weights[0]), tuple(ndarray_to_tensorspec(w) for w in model.initial_weights[1])) @tensorflow_computation.tf_computation(weights_type, data_type) def client_update_computation(initial_model_weights, dataset): # Switch to the tuple expected by FunctionalModel. initial_model_weights = (initial_model_weights.trainable, initial_model_weights.non_trainable) client_update = build_functional_model_delta_update( model=model, weighting=client_weighting, delta_l2_regularizer=delta_l2_regularizer) return client_update(optimizer, initial_model_weights, dataset) @federated_computation.federated_computation def init_fn(): # Empty tuple means "no state" / stateless. return intrinsics.federated_value((), placements.SERVER) @federated_computation.federated_computation( computation_types.at_server(()), computation_types.at_clients(weights_type), computation_types.at_clients(data_type)) def next_fn(state, weights, client_data): client_result, model_outputs = intrinsics.federated_map( client_update_computation, (weights, client_data)) # TODO(b/229612282): Add metrics computations del model_outputs measurements = intrinsics.federated_value((), placements.SERVER) return measured_process.MeasuredProcessOutput(state, client_result, measurements) return client_works.ClientWorkProcess(init_fn, next_fn)
type_signature = computation_types.FunctionType( computation_types.at_clients(tf.float32), computation_types.at_server(tf.float32)) return value, type_signature def create_whimsy_intrinsic_def_federated_secure_sum_bitwidth(): value = intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH type_signature = computation_types.FunctionType([ computation_types.at_clients(tf.int32), tf.int32, ], computation_types.at_server(tf.int32)) return value, type_signature _WHIMSY_SELECT_CLIENT_KEYS_TYPE = computation_types.at_clients( computation_types.TensorType(tf.int32, [3])) _WHIMSY_SELECT_MAX_KEY_TYPE = computation_types.at_server(tf.int32) _WHIMSY_SELECT_SERVER_STATE_TYPE = computation_types.at_server(tf.string) _WHIMSY_SELECTED_TYPE = computation_types.to_type((tf.string, tf.int32)) _WHIMSY_SELECT_SELECT_FN_TYPE = computation_types.FunctionType( (tf.string, tf.int32), _WHIMSY_SELECTED_TYPE) _WHIMSY_SELECT_RESULT_TYPE = computation_types.at_clients( computation_types.SequenceType(_WHIMSY_SELECTED_TYPE)) _WHIMSY_SELECT_TYPE = computation_types.FunctionType([ _WHIMSY_SELECT_CLIENT_KEYS_TYPE, _WHIMSY_SELECT_MAX_KEY_TYPE, _WHIMSY_SELECT_SERVER_STATE_TYPE, _WHIMSY_SELECT_SELECT_FN_TYPE, ], _WHIMSY_SELECT_RESULT_TYPE) _WHIMSY_SELECT_NUM_CLIENTS = 3
def test_type_properties(self, value_type, mechanism): ddp_factory = _make_test_factory(mechanism=mechanism) self.assertIsInstance(ddp_factory, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = ddp_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # The state is a nested object with component factory states. Construct # test factories directly and compare the signatures. modsum_f = secure.SecureModularSumFactory(2**15, True) if mechanism == 'distributed_dgauss': dp_query = tfp.DistributedDiscreteGaussianSumQuery( l2_norm_bound=10.0, local_stddev=10.0) else: dp_query = tfp.DistributedSkellamSumQuery(l1_norm_bound=10.0, l2_norm_bound=10.0, local_stddev=10.0) dp_f = differential_privacy.DifferentiallyPrivateFactory( dp_query, modsum_f) discrete_f = discretization.DiscretizationFactory(dp_f) l2clip_f = robust.clipping_factory(clipping_norm=10.0, inner_agg_factory=discrete_f) rot_f = rotation.HadamardTransformFactory(inner_agg_factory=l2clip_f) expected_process = concat.concat_factory(rot_f).create(value_type) # Check init_fn/state. expected_init_type = expected_process.initialize.type_signature expected_state_type = expected_init_type.result actual_init_type = process.initialize.type_signature self.assertTrue(actual_init_type.is_equivalent_to(expected_init_type)) # Check next_fn/measurements. tensor2type = type_conversions.type_from_tensors discrete_state = discrete_f.create( computation_types.to_type(tf.float32)).initialize() dp_query_state = dp_query.initial_global_state() dp_query_metrics_type = tensor2type( dp_query.derive_metrics(dp_query_state)) expected_measurements_type = collections.OrderedDict( l2_clip=robust.NORM_TF_TYPE, scale_factor=tensor2type(discrete_state['scale_factor']), scaled_inflated_l2=tensor2type(dp_query_state.l2_norm_bound), scaled_local_stddev=tensor2type(dp_query_state.local_stddev), actual_num_clients=tf.int32, padded_dim=tf.int32, dp_query_metrics=dp_query_metrics_type) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) actual_next_type = process.next.type_signature self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type)) try: static_assert.assert_not_contains_unsecure_aggregation( process.next) except: # pylint: disable=bare-except self.fail('Factory returned an AggregationProcess containing ' 'non-secure aggregation.')
def create_whimsy_intrinsic_def_federated_eval_at_clients(): value = intrinsic_defs.FEDERATED_EVAL_AT_CLIENTS type_signature = computation_types.FunctionType( computation_types.FunctionType(None, tf.float32), computation_types.at_clients(tf.float32)) return value, type_signature
def _clipped_sum(clip=2.0): return robust.clipping_factory(clip, sum_factory.SumFactory()) def _zeroed_mean(clip=2.0, norm_order=2.0): return robust.zeroing_factory(clip, mean.MeanFactory(), norm_order) def _zeroed_sum(clip=2.0, norm_order=2.0): return robust.zeroing_factory(clip, sum_factory.SumFactory(), norm_order) _float_at_server = computation_types.at_server(tf.float32) _float_at_clients = computation_types.at_clients(tf.float32) @computations.federated_computation() def _test_init_fn(): return intrinsics.federated_value(1., placements.SERVER) @computations.federated_computation(_float_at_server, _float_at_clients) def _test_next_fn(state, value): del value return intrinsics.federated_map( computations.tf_computation(lambda x: x + 1., tf.float32), state) @computations.federated_computation(_float_at_server)
def create_whimsy_intrinsic_def_federated_value_at_clients(): value = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS type_signature = computation_types.FunctionType( tf.float32, computation_types.at_clients(tf.float32, all_equal=True)) return value, type_signature
from tensorflow_federated.python.core.impl.federated_context import intrinsics from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.templates import client_works SERVER_INT = computation_types.FederatedType(tf.int32, placements.SERVER) SERVER_FLOAT = computation_types.FederatedType(tf.float32, placements.SERVER) CLIENTS_FLOAT_SEQUENCE = computation_types.FederatedType( computation_types.SequenceType(tf.float32), placements.CLIENTS) CLIENTS_FLOAT = computation_types.FederatedType(tf.float32, placements.CLIENTS) CLIENTS_INT = computation_types.FederatedType(tf.int32, placements.CLIENTS) MODEL_WEIGHTS_TYPE = computation_types.at_clients( computation_types.to_type(model_utils.ModelWeights(tf.float32, ()))) MeasuredProcessOutput = measured_process.MeasuredProcessOutput def server_zero(): return intrinsics.federated_value(0, placements.SERVER) def client_one(): return intrinsics.federated_value(1.0, placements.CLIENTS) def federated_add(a, b): return intrinsics.federated_map( tensorflow_computation.tf_computation(lambda x, y: x + y), (a, b))
def create_whimsy_value_at_clients_all_equal(): """Returns a Python value and federated type at clients and all equal.""" value = 10.0 type_signature = computation_types.at_clients(tf.float32, all_equal=True) return value, type_signature
def build_model_delta_client_work(model_fn: Callable[[], model_lib.Model], optimizer: optimizer_base.Optimizer): """Builds `ClientWorkProcess` returning change to the trained model weights. The created `ClientWorkProcess` expects model weights that can be assigned to the model created by `model_fn`, and will apply `optimizer` to optimize the model using the client data. The returned `ClientResult` will contain the difference between the trained and initial trainable model weights (aka "model delta") as update, and the update_weight will be the number of examples used in training. The type signature for client data is derived from the input spec of the model. This method is the recommended starting point for forking a custom implementation of the `ClientWorkProcess`. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. optimizer: A `tff.learning.optimizers.Optimizer`. Returns: A `ClientWorkProcess`. """ py_typecheck.check_callable(model_fn) # TODO(b/190334722): Include support for Keras optimizers via # tff.learning.optimizers.KerasOptimizer when ready. py_typecheck.check_type(optimizer, optimizer_base.Optimizer) weights_type, data_type = _weights_and_data_type_from_model_fn(model_fn) # TODO(b/161529310): We flatten and convert the trainable specs to tuple, as # "for batch in data:" pattern would try to stack the tensors in a list. optimizer_tensor_specs = _flat_tuple( type_conversions.type_to_tf_tensor_specs(weights_type.trainable)) @computations.tf_computation(weights_type, data_type) @tf.function def local_update(initial_weights, data): # TODO(b/190334722): Restructure so that model_fn only needs to be invoked # once. with tf.init_scope(): model = model_fn() model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda weight, value: weight.assign(value), model_weights, initial_weights) num_examples = tf.constant(0, tf.int32) optimizer_state = optimizer.initialize(optimizer_tensor_specs) # TODO(b/161529310): Different from creating an iterator using iter(data). for batch in data: with tf.GradientTape() as tape: outputs = model.forward_pass(batch) gradients = tape.gradient(outputs.loss, model_weights.trainable) num_examples += tf.shape(outputs.predictions)[0] optimizer_state, updated_weights = optimizer.next( optimizer_state, _flat_tuple(model_weights.trainable), _flat_tuple(gradients)) updated_weights = tf.nest.pack_sequence_as(model_weights.trainable, updated_weights) tf.nest.map_structure(lambda weight, value: weight.assign(value), model_weights.trainable, updated_weights) model_delta = tf.nest.map_structure(lambda x, y: x - y, initial_weights.trainable, model_weights.trainable) return ClientResult( update=model_delta, update_weight=tf.cast(num_examples, tf.float32)) @computations.federated_computation def init_fn(): return intrinsics.federated_value((), placements.SERVER) @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(weights_type), computation_types.at_clients(data_type)) def next_fn(state, weights, client_data): client_result = intrinsics.federated_map(local_update, (weights, client_data)) empty_measurements = intrinsics.federated_value((), placements.SERVER) return measured_process.MeasuredProcessOutput(state, client_result, empty_measurements) return ClientWorkProcess(init_fn, next_fn)
from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import keras_utils from tensorflow_federated.python.learning import model_examples from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.optimizers import sgdm from tensorflow_federated.python.learning.templates import client_works from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import distributors from tensorflow_federated.python.learning.templates import finalizers from tensorflow_federated.python.learning.templates import learning_process FLOAT_TYPE = computation_types.TensorType(tf.float32) MODEL_WEIGHTS_TYPE = computation_types.to_type( model_utils.ModelWeights(FLOAT_TYPE, ())) CLIENTS_SEQUENCE_FLOAT_TYPE = computation_types.at_clients( computation_types.SequenceType(FLOAT_TYPE)) def empty_at_server(): return intrinsics.federated_value((), placements.SERVER) @federated_computation.federated_computation() def empty_init_fn(): return empty_at_server() @tensorflow_computation.tf_computation() def test_init_model_weights_fn(): return model_utils.ModelWeights(trainable=tf.constant(1.0), non_trainable=())
def test_init_fn(): return intrinsics.federated_value(0, placements.SERVER) test_state_type = test_init_fn.type_signature.result @computations.tf_computation def sum_sequence(s): spec = s.element_spec return s.reduce( tf.zeros(spec.shape, spec.dtype), lambda s, t: tf.nest.map_structure(tf.add, s, t)) ClientIntSequenceType = computation_types.at_clients( computation_types.SequenceType(tf.int32)) def build_next_fn(server_init_fn): @computations.federated_computation(server_init_fn.type_signature.result, ClientIntSequenceType) def next_fn(state, client_values): metrics = intrinsics.federated_map(sum_sequence, client_values) metrics = intrinsics.federated_sum(metrics) return LearningProcessOutput(state, metrics) return next_fn def build_report_fn(server_init_fn):
async def compute_intrinsic_federated_weighted_mean( executor: executor_base.Executor, arg: executor_value_base.ExecutorValue, local_computation_factory: local_computation_factory_base. LocalComputationFactory = tensorflow_computation_factory. TensorFlowComputationFactory() ) -> executor_value_base.ExecutorValue: """Computes a federated weighted mean on the given `executor`. Args: executor: The executor to use. arg: The argument to embedded in `executor`. local_computation_factory: An instance of `LocalComputationFactory` to use to construct local computations used as parameters in certain federated operators (such as `tff.federated_sum`, etc.). Defaults to a TensorFlow computation factory that generates TensorFlow code. Returns: The result embedded in `executor`. """ type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( arg.type_signature) zip1_type = computation_types.FunctionType( computation_types.StructType([ computation_types.at_clients(arg.type_signature[0].member), computation_types.at_clients(arg.type_signature[1].member) ]), computation_types.at_clients( computation_types.StructType( [arg.type_signature[0].member, arg.type_signature[1].member]))) operand_type = zip1_type.result.member[0] scalar_type = zip1_type.result.member[1] multiply_comp_pb, multiply_comp_type = local_computation_factory.create_scalar_multiply_operator( operand_type, scalar_type) multiply_blk = building_blocks.CompiledComputation( multiply_comp_pb, type_signature=multiply_comp_type) map_type = computation_types.FunctionType( computation_types.StructType( [multiply_blk.type_signature, zip1_type.result]), computation_types.at_clients(multiply_blk.type_signature.result)) sum1_type = computation_types.FunctionType( computation_types.at_clients(map_type.result.member), computation_types.at_server(map_type.result.member)) sum2_type = computation_types.FunctionType( computation_types.at_clients(arg.type_signature[1].member), computation_types.at_server(arg.type_signature[1].member)) zip2_type = computation_types.FunctionType( computation_types.StructType([sum1_type.result, sum2_type.result]), computation_types.at_server( computation_types.StructType( [sum1_type.result.member, sum2_type.result.member]))) divide_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast( zip2_type.result.member, tf.divide) async def _compute_multiply_fn(): return await executor.create_value(multiply_blk.proto, multiply_blk.type_signature) async def _compute_multiply_arg(): zip1_comp = create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip1_type) zip_fn = await executor.create_value(zip1_comp, zip1_type) return await executor.create_call(zip_fn, arg) async def _compute_product_fn(): map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP, map_type) return await executor.create_value(map_comp, map_type) async def _compute_product_arg(): multiply_fn, multiply_arg = await asyncio.gather( _compute_multiply_fn(), _compute_multiply_arg()) return await executor.create_struct((multiply_fn, multiply_arg)) async def _compute_products(): product_fn, product_arg = await asyncio.gather(_compute_product_fn(), _compute_product_arg()) return await executor.create_call(product_fn, product_arg) async def _compute_total_weight(): sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum2_type) sum2_fn, sum2_arg = await asyncio.gather( executor.create_value(sum2_comp, sum2_type), executor.create_selection(arg, 1)) return await executor.create_call(sum2_fn, sum2_arg) async def _compute_sum_of_products(): sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM, sum1_type) sum1_fn, products = await asyncio.gather( executor.create_value(sum1_comp, sum1_type), _compute_products()) return await executor.create_call(sum1_fn, products) async def _compute_zip2_fn(): zip2_comp = create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_SERVER, zip2_type) return await executor.create_value(zip2_comp, zip2_type) async def _compute_zip2_arg(): sum_of_products, total_weight = await asyncio.gather( _compute_sum_of_products(), _compute_total_weight()) return await executor.create_struct([sum_of_products, total_weight]) async def _compute_divide_fn(): return await executor.create_value(divide_blk.proto, divide_blk.type_signature) async def _compute_divide_arg(): zip_fn, zip_arg = await asyncio.gather(_compute_zip2_fn(), _compute_zip2_arg()) return await executor.create_call(zip_fn, zip_arg) async def _compute_apply_fn(): apply_type = computation_types.FunctionType( computation_types.StructType( [divide_blk.type_signature, zip2_type.result]), computation_types.at_server(divide_blk.type_signature.result)) apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY, apply_type) return await executor.create_value(apply_comp, apply_type) async def _compute_apply_arg(): divide_fn, divide_arg = await asyncio.gather(_compute_divide_fn(), _compute_divide_arg()) return await executor.create_struct([divide_fn, divide_arg]) async def _compute_divided(): apply_fn, apply_arg = await asyncio.gather(_compute_apply_fn(), _compute_apply_arg()) return await executor.create_call(apply_fn, apply_arg) return await _compute_divided()
def test_errors_on_client_int(self): with self.assertRaises(TypeError): x = _mock_data_of_type( computation_types.at_clients(tf.int32, all_equal=True)) intrinsics.federated_broadcast(x)