def assign_and_compute(): tf.nest.map_structure(lambda v, t: v.assign(t), model_weights, initial_model_weights) py_typecheck.check_callable(baseline_evaluate_fn) return baseline_evaluate_fn(model, test_data)
def transform_type_postorder( type_signature: computation_types.Type, transform_fn: Callable[[computation_types.Type], Tuple[computation_types.Type, bool]]): """Walks type tree of `type_signature` postorder, calling `transform_fn`. Args: type_signature: Instance of `computation_types.Type` to transform recursively. transform_fn: Transformation function to apply to each node in the type tree of `type_signature`. Must be instance of Python function type. Returns: A possibly transformed version of `type_signature`, with each node in its tree the result of applying `transform_fn` to the corresponding node in `type_signature`. Raises: TypeError: If the types don't match the specification above. """ py_typecheck.check_type(type_signature, computation_types.Type) py_typecheck.check_callable(transform_fn) if type_signature.is_federated(): transformed_member, member_mutated = transform_type_postorder( type_signature.member, transform_fn) if member_mutated: type_signature = computation_types.FederatedType( transformed_member, type_signature.placement, type_signature.all_equal) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or member_mutated elif type_signature.is_sequence(): transformed_element, element_mutated = transform_type_postorder( type_signature.element, transform_fn) if element_mutated: type_signature = computation_types.SequenceType( transformed_element) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or element_mutated elif type_signature.is_function(): if type_signature.parameter is not None: transformed_parameter, parameter_mutated = transform_type_postorder( type_signature.parameter, transform_fn) else: transformed_parameter, parameter_mutated = (None, False) transformed_result, result_mutated = transform_type_postorder( type_signature.result, transform_fn) if parameter_mutated or result_mutated: type_signature = computation_types.FunctionType( transformed_parameter, transformed_result) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, (type_signature_mutated or parameter_mutated or result_mutated) elif type_signature.is_struct(): elements = [] elements_mutated = False for element in structure.iter_elements(type_signature): transformed_element, element_mutated = transform_type_postorder( element[1], transform_fn) elements_mutated = elements_mutated or element_mutated elements.append((element[0], transformed_element)) if elements_mutated: if type_signature.is_struct_with_python(): type_signature = computation_types.StructWithPythonType( elements, type_signature.python_container) else: type_signature = computation_types.StructType(elements) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or elements_mutated elif type_signature.is_abstract() or type_signature.is_placement( ) or type_signature.is_tensor(): return transform_fn(type_signature)
def add_measurements( inner_agg_factory: factory.AggregationFactory, measurement_fn: Callable[..., Dict[str, Any]], ) -> factory.AggregationFactory: """Wraps `AggregationFactory` to report additional measurements. The function `measurement_fn` is a python callable that will be called on `value` (if `inner_agg_factory` is an `UnweightedAggregationFactory`) or `(value, weight)` (if `inner_agg_factory` is a `WeightedAggregationFactory`) in the `next` function of the `AggregationProcess` produced by the returned factory to generate additional measurements. It must be traceable by TFF and expect `tff.Value` objects placed at `CLIENTS` as inputs, and return `collections.OrderedDicts` mapping string names to tensor values placed at `SERVER`, which will be added to the measurement dict produced by the `inner_agg_factory`. Args: inner_agg_factory: The factory to wrap and add measurements. measurement_fn: A python callable that will be called on `value` (and/or `weight`) provided to the `next` function to compute additional measurements. Returns: An `AggregationFactory` that reports additional measurements. """ py_typecheck.check_callable(measurement_fn) if isinstance(inner_agg_factory, factory.UnweightedAggregationFactory): if len(inspect.signature(measurement_fn).parameters) != 1: raise ValueError('`measurement_fn` must take a single parameter if ' '`inner_agg_factory` is unweighted.') elif isinstance(inner_agg_factory, factory.WeightedAggregationFactory): if len(inspect.signature(measurement_fn).parameters) != 2: raise ValueError('`measurement_fn` must take a two parameters if ' '`inner_agg_factory` is weighted.') else: raise TypeError( f'`inner_agg_factory` must be of type `UnweightedAggregationFactory` or' f'`WeightedAggregationFactory`. Found {type(inner_agg_factory)}.') @computations.tf_computation() def dict_update(orig_dict, new_values): if not orig_dict: return new_values orig_dict.update(new_values) return orig_dict if isinstance(inner_agg_factory, factory.WeightedAggregationFactory): class WeightedWrappedFactory(factory.WeightedAggregationFactory): """Wrapper for `WeightedAggregationFactory` that adds new measurements.""" def create( self, value_type: factory.ValueType, weight_type: factory.ValueType ) -> aggregation_process.AggregationProcess: py_typecheck.check_type(value_type, factory.ValueType.__args__) py_typecheck.check_type(weight_type, factory.ValueType.__args__) inner_agg_process = inner_agg_factory.create(value_type, weight_type) init_fn = inner_agg_process.initialize @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type), computation_types.at_clients(weight_type)) def next_fn(state, value, weight): inner_agg_output = inner_agg_process.next(state, value, weight) extra_measurements = measurement_fn(value, weight) measurements = intrinsics.federated_map( dict_update, (inner_agg_output.measurements, extra_measurements)) return measured_process.MeasuredProcessOutput( state=inner_agg_output.state, result=inner_agg_output.result, measurements=measurements) return aggregation_process.AggregationProcess(init_fn, next_fn) return WeightedWrappedFactory() else: class UnweightedWrappedFactory(factory.UnweightedAggregationFactory): """Wrapper for `UnweightedAggregationFactory` that adds new measurements.""" def create( self, value_type: factory.ValueType ) -> aggregation_process.AggregationProcess: py_typecheck.check_type(value_type, factory.ValueType.__args__) inner_agg_process = inner_agg_factory.create(value_type) init_fn = inner_agg_process.initialize @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): inner_agg_output = inner_agg_process.next(state, value) extra_measurements = measurement_fn(value) measurements = intrinsics.federated_map( dict_update, (inner_agg_output.measurements, extra_measurements)) return measured_process.MeasuredProcessOutput( state=inner_agg_output.state, result=inner_agg_output.result, measurements=measurements) return aggregation_process.AggregationProcess(init_fn, next_fn) return UnweightedWrappedFactory()
def _client_fn(model, initial_model_weights, train_data, test_data, personalize_fn_dict, baseline_evaluate_fn, context=None): """The main `tf.function` that runs on device. This function first evalautes the initial model and gets the baseline metrics. Then starting from the same initial model, this function iterates over the personalization strategies defined in `personalize_fn_dict`, trains and evaluates the personalized models, and returns the evaluation metrics. Args: model: A `tff.learning.Model`. initial_model_weights: A `tff.learning.framework.ModelWeights` containing `tf.Tensor`s that hold trainable and non-trainable weights. train_data: A `tf.data.Dataset` used for training. test_data: A `tf.data.Dataset` used for evaluation. personalize_fn_dict: This is the same argument specified in the function `build_personalization_eval` above; see its documentation for details. baseline_evaluate_fn: This is the same argument specified in the function `build_personalization_eval` above; see its documentation for details. context: An optional object used in `personalize_fn_dict`. If used, its `tff.Type` must be provided by passing the correct `context_tff_type` argument to the `build_personalization_eval` function. Returns: An `OrderedDict` that maps a string 'baseline_metrics' to the evaluation metrics of the initial model (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of the corresponding personalization strategies. Raises: TypeError: If arguments are of the wrong types. ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`. """ # Wrap the input model as an `EnhancedModel` for easy access of its weights. model = model_utils.enhance(model) final_metrics = collections.OrderedDict() tff.utils.assign(model.weights, initial_model_weights) py_typecheck.check_callable(baseline_evaluate_fn) final_metrics['baseline_metrics'] = baseline_evaluate_fn(model, test_data) py_typecheck.check_type(personalize_fn_dict, collections.OrderedDict) if 'baseline_metrics' in personalize_fn_dict: raise ValueError('baseline_metrics should not be used as a key in ' 'personalize_fn_dict.') for name, personalize_fn_builder in personalize_fn_dict.items(): py_typecheck.check_type(name, str) tff.utils.assign(model.weights, initial_model_weights) # Construct the `personalize_fn` (and the associated `tf.Variable`s) here. # Once `_client_fn` is decorated with `tff.tf_computation`, construction of # the new variables will happen in a scope controlled by TFF. Ensuring # `tf.Variable`s are created in the graphs that TFF controls is the reason # we need `personalize_fn_dict` to contain no-argument functions that build # the desired `tf.function`s, rather than already built `tf.function`s. py_typecheck.check_callable(personalize_fn_builder) personalize_fn = personalize_fn_builder() py_typecheck.check_callable(personalize_fn) final_metrics[name] = personalize_fn(model, train_data, test_data, context) return final_metrics
def serialize_jax_computation(traced_fn, arg_fn, parameter_type, context_stack): """Serializes a Python function containing JAX code as a TFF computation. Args: traced_fn: The Python function containing JAX code to be traced by JAX and serialized as a TFF computation containing XLA code. arg_fn: An unpacking function that takes a TFF argument, and returns a combo of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed by `function_utils.create_argument_unpacking_fn`). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if the function does not take any parameters. context_stack: The context stack to use during serialization. Returns: An instance of `pb.Computation` with the constructed computation. Raises: TypeError: if the arguments are of the wrong types. """ py_typecheck.check_callable(traced_fn) py_typecheck.check_callable(arg_fn) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: parameter_type = computation_types.to_type(parameter_type) packed_arg = _tff_type_to_xla_serializer_arg(parameter_type) else: packed_arg = None args, kwargs = arg_fn(packed_arg) # While the fake parameters are fed via args/kwargs during serialization, # it is possible for them to get reorderd in the actual generate XLA code. # We use here the same flatenning function as that one, which is used by # the JAX serializer to determine the orderding and allow it to be captured # in the parameter binding. We do not need to do anything special for the # results, since the results, if multiple, are always returned as a tuple. flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs)) tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj])) def _adjust_arg(x): if isinstance(x, structure.Struct): return type_conversions.type_to_py_container(x, x.type_signature) else: return x args = [_adjust_arg(x) for x in args] kwargs = {k: _adjust_arg(v) for k, v in kwargs.items()} context = jax_computation_context.JaxComputationContext() with context_stack.install(context): tracer_callable = jax.xla_computation(traced_fn, tuple_args=True, return_shape=True) compiled_xla, returned_shape = tracer_callable(*args, **kwargs) if isinstance(returned_shape, jax.ShapeDtypeStruct): returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor( returned_shape) else: returned_type_spec = computation_types.to_type( structure.map_structure( _jax_shape_dtype_struct_to_tff_tensor, structure.from_container(returned_shape, recursive=True))) computation_type = computation_types.FunctionType(parameter_type, returned_type_spec) return xla_serialization.create_xla_tff_computation( compiled_xla, tensor_indexes, computation_type)
def create_binary_operator_with_upcast( type_signature: computation_types.StructType, operator: Callable[[Any, Any], Any]) -> ProtoAndType: """Creates TF computation upcasting its argument and applying `operator`. Args: type_signature: A `computation_types.StructType` with two elements, both only containing structs or tensors in their type tree. The first and second element must match in structure, or the second element may be a single tensor type that is broadcasted (upcast) to the leaves of the structure of the first type. operator: Callable defining the operator. Returns: A `building_blocks.CompiledComputation` encapsulating a function which upcasts the second element of its argument and applies the binary operator. """ py_typecheck.check_type(type_signature, computation_types.StructType) py_typecheck.check_callable(operator) type_analysis.check_tensorflow_compatible_type(type_signature) if not type_signature.is_struct() or len(type_signature) != 2: raise TypeError( 'To apply a binary operator, we must by definition have an ' 'argument which is a `StructType` with 2 elements; ' 'asked to create a binary operator for type: {t}'.format( t=type_signature)) if type_analysis.contains(type_signature, lambda t: t.is_sequence()): raise TypeError('Applying binary operators in TensorFlow is only ' 'supported on Tensors and StructTypes; you ' 'passed {t} which contains a SequenceType.'.format( t=type_signature)) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if type_spec.is_struct(): elem_iter = structure.iter_elements(type_spec) return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elem_iter]) elif type_spec.is_tensor(): return tf.broadcast_to(to_pack, type_spec.shape) with tf.Graph().as_default() as graph: first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_signature[0], graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', type_signature[1], graph) if type_signature[0].is_struct() and type_signature[1].is_struct(): # If both the first and second arguments are structs with the same # structure, simply re-use operand_2_value as. `tf.nest.map_structure` # below will map the binary operator pointwise to the leaves of the # structure. if structure.is_same_structure(type_signature[0], type_signature[1]): second_arg = operand_2_value else: raise TypeError( 'Cannot upcast one structure to a different structure. ' '{x} -> {y}'.format(x=type_signature[1], y=type_signature[0])) elif type_signature[0].is_equivalent_to(type_signature[1]): second_arg = operand_2_value else: second_arg = _pack_into_type(operand_2_value, type_signature[0]) if type_signature[0].is_tensor(): result_value = operator(first_arg, second_arg) elif type_signature[0].is_struct(): result_value = structure.map_structure(operator, first_arg, second_arg) else: raise TypeError( 'Encountered unexpected type {t}; can only handle Tensor ' 'and StructTypes.'.format(t=type_signature[0])) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType(type_signature, result_type) parameter_binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def create_binary_operator_with_upcast( type_signature: computation_types.NamedTupleType, operator: Callable[[Any, Any], Any]) -> pb.Computation: """Creates TF computation upcasting its argument and applying `operator`. Args: type_signature: A `computation_types.NamedTupleType` with two elements, both of the same type or the second able to be upcast to the first, as explained in `apply_binary_operator_with_upcast`, and both containing only tuples and tensors in their type tree. operator: Callable defining the operator. Returns: A `building_blocks.CompiledComputation` encapsulating a function which upcasts the second element of its argument and applies the binary operator. """ py_typecheck.check_type(type_signature, computation_types.NamedTupleType) py_typecheck.check_callable(operator) type_analysis.check_tensorflow_compatible_type(type_signature) if not type_signature.is_tuple() or len(type_signature) != 2: raise TypeError('To apply a binary operator, we must by definition have an ' 'argument which is a `NamedTupleType` with 2 elements; ' 'asked to create a binary operator for type: {t}'.format( t=type_signature)) if type_analysis.contains(type_signature, lambda t: t.is_sequence()): raise TypeError( 'Applying binary operators in TensorFlow is only ' 'supported on Tensors and NamedTupleTypes; you ' 'passed {t} which contains a SequenceType.'.format(t=type_signature)) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if type_spec.is_tuple(): elem_iter = anonymous_tuple.iter_elements(type_spec) return anonymous_tuple.AnonymousTuple([ (elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elem_iter ]) elif type_spec.is_tensor(): return tf.broadcast_to(to_pack, type_spec.shape) with tf.Graph().as_default() as graph: first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_signature[0], graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', type_signature[1], graph) if type_signature[0].is_equivalent_to(type_signature[1]): second_arg = operand_2_value else: second_arg = _pack_into_type(operand_2_value, type_signature[0]) if type_signature[0].is_tensor(): result_value = operator(first_arg, second_arg) elif type_signature[0].is_tuple(): result_value = anonymous_tuple.map_structure(operator, first_arg, second_arg) else: raise TypeError('Encountered unexpected type {t}; can only handle Tensor ' 'and NamedTupleTypes.'.format(t=type_signature[0])) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType(type_signature, result_type) parameter_binding = pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def build_weighted_mime_lite( model_fn: Callable[[], model_lib.Model], base_optimizer: optimizer_base.Optimizer, server_optimizer: optimizer_base.Optimizer = sgdm.build_sgdm(1.0), client_weighting: Optional[ client_weight_lib. ClientWeighting] = client_weight_lib.ClientWeighting.NUM_EXAMPLES, model_distributor: Optional[distributors.DistributionProcess] = None, model_aggregator: Optional[factory.WeightedAggregationFactory] = None, full_gradient_aggregator: Optional[ factory.WeightedAggregationFactory] = None, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, use_experimental_simulation_loop: bool = False ) -> learning_process.LearningProcess: """Builds a learning process that performs Mime Lite. This function creates a `tff.learning.templates.LearningProcess` that performs Mime Lite algorithm on client models. The iterative process has the following methods inherited from `tff.learning.templates.LearningProcess`: * `initialize`: A `tff.Computation` with the functional type signature `( -> S@SERVER)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` representing the initial state of the server. * `next`: A `tff.Computation` with the functional type signature `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `{B*}@CLIENTS` represents the client datasets. The output `L` contains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes. * `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `next`, and `M` represents the type of the model weights used during training. * `set_model_weights`: A `tff.Computation` with type signature `(<S, M> -> S)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `M` represents the type of the model weights used during training. Each time the `next` method is called, the server model is communicated to each client using the provided `model_distributor`. For each client, local training is performed using `optimizer`, where its state is communicated by the server, and kept intact during local training. The state is updated only at the server based on the full gradient evaluated by the clients based on the current server model state. The client full gradients are aggregated by weighted `full_gradient_aggregator`. Each client computes the difference between the client model after training and its initial model. These model deltas are then aggregated by weighted `model_aggregator`. Both of the aggregations are weighted, according to `client_weighting`. The aggregate model delta is added to the existing server model state. The Mime Lite algorithm is based on the paper "Breaking the centralized barrier for cross-device federated learning." Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank Reddi, Sebastian U. Stich, and Ananda Theertha Suresh. Advances in Neural Information Processing Systems 34 (2021). https://proceedings.neurips.cc/paper/2021/file/f0e6be4ce76ccfa73c5a540d992d0756-Paper.pdf Note that Keras optimizers are not supported. This is due to the Mime Lite algorithm applying the optimizer without changing it state at clients (optimizer's `tf.Variable`s in the case of Keras), which is not possible with Keras optimizers without reaching into private implementation details and incurring additional computation and memory cost at clients. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. base_optimizer: A `tff.learning.optimizers.Optimizer` which will be used for both creating and updating a global optimizer state, as well as optimization at clients given the global state, which is fixed during the optimization. server_optimizer: A `tff.learning.optimizers.Optimizer` which will be used for applying the aggregate model update to the global model weights. client_weighting: A member of `tff.learning.ClientWeighting` that specifies a built-in weighting method. By default, weighting by number of examples is used. model_distributor: An optional `DistributionProcess` that distributes the model weights on the server to the clients. If set to `None`, the distributor is constructed via `distributors.build_broadcast_process`. model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory` used to aggregate client updates on the server. If `None`, this is set to `tff.aggregators.MeanFactory`. full_gradient_aggregator: An optional `tff.aggregators.WeightedAggregationFactory` used to aggregate the full gradients on client datasets. If `None`, this is set to `tff.aggregators.MeanFactory`. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. If `None`, this is set to `tff.learning.metrics.sum_then_finalize`. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. Returns: A `tff.learning.templates.LearningProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_type(base_optimizer, optimizer_base.Optimizer) py_typecheck.check_type(server_optimizer, optimizer_base.Optimizer) py_typecheck.check_type(client_weighting, client_weight_lib.ClientWeighting) @tensorflow_computation.tf_computation() def initial_model_weights_fn(): return model_utils.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result if model_distributor is None: model_distributor = distributors.build_broadcast_process( model_weights_type) if model_aggregator is None: model_aggregator = mean.MeanFactory() py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory) model_aggregator = model_aggregator.create( model_weights_type.trainable, computation_types.TensorType(tf.float32)) if full_gradient_aggregator is None: full_gradient_aggregator = mean.MeanFactory() py_typecheck.check_type(full_gradient_aggregator, factory.WeightedAggregationFactory) client_work = _build_mime_lite_client_work( model_fn=model_fn, optimizer=base_optimizer, client_weighting=client_weighting, full_gradient_aggregator=full_gradient_aggregator, metrics_aggregator=metrics_aggregator, use_experimental_simulation_loop=use_experimental_simulation_loop) finalizer = finalizers.build_apply_optimizer_finalizer( server_optimizer, model_weights_type) return composers.compose_learning_process(initial_model_weights_fn, model_distributor, client_work, model_aggregator, finalizer)
def build_model_delta_optimizer_process( model_fn: _ModelConstructor, model_to_client_delta_fn: Callable[[model_lib.Model], ClientDeltaFn], server_optimizer_fn: _OptimizerConstructor, stateful_delta_aggregate_fn: tff.utils. StatefulAggregateFn = build_stateless_mean(), stateful_model_broadcast_fn: tff.utils. StatefulBroadcastFn = build_stateless_broadcaster(), ) -> tff.templates.IterativeProcess: """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD. This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a `ClientDeltaFn` that specifies how `weight_deltas` are computed on device. Note: We pass in functions rather than constructed objects so we can ensure any variables or ops created in constructors are placed in the correct graph. Args: model_fn: a no-arg function that returns a `tff.learning.Model`. model_to_client_delta_fn: a function from a `model_fn` to a `ClientDeltaFn`. server_optimizer_fn: a no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. stateful_delta_aggregate_fn: a `tff.utils.StatefulAggregateFn` where the `next_fn` performs a federated aggregation and upates state. That is, it has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER)`, where the `value` type is `tff.learning.framework.ModelWeights.trainable` corresponding to the object returned by `model_fn`. stateful_model_broadcast_fn: a `tff.utils.StatefulBroadcastFn` where the `next_fn` performs a federated broadcast and upates state. That is, it has TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`, where the `value` type is `tff.learning.framework.ModelWeights` corresponding to the object returned by `model_fn`. Returns: A `tff.templates.IterativeProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(model_to_client_delta_fn) py_typecheck.check_callable(server_optimizer_fn) py_typecheck.check_type(stateful_delta_aggregate_fn, tff.utils.StatefulAggregateFn) py_typecheck.check_type(stateful_model_broadcast_fn, tff.utils.StatefulBroadcastFn) initialize_computation = _build_initialize_computaiton( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, delta_aggregate_fn=stateful_delta_aggregate_fn, model_broadcast_fn=stateful_model_broadcast_fn) delta_aggregate_state_type = initialize_computation.type_signature.result.member.delta_aggregate_state model_broadcast_state_type = initialize_computation.type_signature.result.member.model_broadcast_state run_one_round_computation = _build_one_round_computation( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, model_to_client_delta_fn=model_to_client_delta_fn, delta_aggregate_fn=stateful_delta_aggregate_fn, model_broadcast_fn=stateful_model_broadcast_fn, delta_aggregate_state_type=delta_aggregate_state_type, model_broadcast_state_type=model_broadcast_state_type) return tff.templates.IterativeProcess(initialize_fn=initialize_computation, next_fn=run_one_round_computation)
def build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn, server_optimizer_fn): """Constructs `tff.utils.IterativeProcess` for Federated Averaging or SGD. This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device. Note: We pass in functions rather than constructed objects so we can ensure any variables or ops created in constructors are placed in the correct graph. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. model_to_client_delta_fn: A function from a model_fn to a `ClientDeltaFn`. server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. Returns: A `tff.utils.IterativeProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(model_to_client_delta_fn) py_typecheck.check_callable(server_optimizer_fn) # TODO(b/122081673): would be nice not to have the construct a throwaway model # here just to get the types. After fully moving to TF2.0 and eager-mode, we # should re-evaluate what happens here and where `g` is used below. with tf.Graph().as_default() as g: dummy_model_for_metadata = model_utils.enhance(model_fn()) @tff.federated_computation def server_init_tff(): """Orchestration logic for server model initialization.""" no_arg_server_init_fn = lambda: server_init(model_fn, server_optimizer_fn) server_init_tf = tff.tf_computation(no_arg_server_init_fn) return tff.federated_value(server_init_tf(), tff.SERVER) federated_server_state_type = server_init_tff.type_signature.result server_state_type = federated_server_state_type.member tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS, all_equal=False) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ model_weights_type = federated_server_state_type.member.model @tff.tf_computation(tf_dataset_type, model_weights_type) def client_delta_tf(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) # TODO(b/123092620): this can be removed once AnonymousTuple works with # tf.contrib.framework.nest, or the following behavior is moved to # anonymous_tuple module. if isinstance(initial_model_weights, anonymous_tuple.AnonymousTuple): initial_model_weights = model_utils.ModelWeights.from_tff_value( initial_model_weights) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output client_outputs = tff.federated_map( client_delta_tf, (federated_dataset, tff.federated_broadcast(server_state.model))) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_model_tf(server_state, model_delta): """Converts args to correct python types and calls server_update_model.""" # We need to convert TFF types to the types server_update_model expects. # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not # pretty, fold this into anonymous_tuple module or get working with # tf.contrib.framework.nest. py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple) model_delta = anonymous_tuple.to_odict(model_delta) py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple) server_state = ServerState( model=model_utils.ModelWeights.from_tff_value( server_state.model), optimizer_state=list(server_state.optimizer_state)) return server_update_model(server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF fed_weight_type = client_outputs.weights_delta_weight.type_signature.member py_typecheck.check_type(fed_weight_type, tff.TensorType) if fed_weight_type.dtype.is_integer: @tff.tf_computation(fed_weight_type) def _cast_to_float(x): return tf.cast(x, tf.float32) weight_denom = tff.federated_map( _cast_to_float, client_outputs.weights_delta_weight) else: weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_average(client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # server_update_model_tf directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply(server_update_model_tf, (server_state, round_model_delta)) # Re-use graph used to construct `model`, since it has the variables, which # need to be read in federated_output_computation to get the correct shapes # and types for the federated aggregation. with g.as_default(): aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs return tff.utils.IterativeProcess(initialize_fn=server_init_tff, next_fn=run_one_round_tff)
def _build_mime_lite_client_work( model_fn: Callable[[], model_lib.Model], optimizer: optimizer_base.Optimizer, client_weighting: client_weight_lib.ClientWeighting, full_gradient_aggregator: Optional[ factory.WeightedAggregationFactory] = None, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, use_experimental_simulation_loop: bool = False ) -> client_works.ClientWorkProcess: """Creates a `ClientWorkProcess` for Mime Lite. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. optimizer: A `tff.learning.optimizers.Optimizer` which will be used for both creating and updating a global optimizer state, as well as optimization at clients given the global state, which is fixed during the optimization. client_weighting: A member of `tff.learning.ClientWeighting` that specifies a built-in weighting method. full_gradient_aggregator: An optional `tff.aggregators.WeightedAggregationFactory` used to aggregate the full gradients on client datasets. If `None`, this is set to `tff.aggregators.MeanFactory`. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. If `None`, this is set to `tff.learning.metrics.sum_then_finalize`. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. Returns: A `ClientWorkProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_type(optimizer, optimizer_base.Optimizer) py_typecheck.check_type(client_weighting, client_weight_lib.ClientWeighting) if full_gradient_aggregator is None: full_gradient_aggregator = mean.MeanFactory() py_typecheck.check_type(full_gradient_aggregator, factory.WeightedAggregationFactory) if metrics_aggregator is None: metrics_aggregator = metric_aggregator.sum_then_finalize with tf.Graph().as_default(): # Wrap model construction in a graph to avoid polluting the global context # with variables created for this model. model = model_fn() unfinalized_metrics_type = type_conversions.type_from_tensors( model.report_local_unfinalized_metrics()) metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(), unfinalized_metrics_type) data_type = computation_types.SequenceType(model.input_spec) weights_type = model_utils.weights_type_from_model(model) weight_tensor_specs = type_conversions.type_to_tf_tensor_specs( weights_type) full_gradient_aggregator = full_gradient_aggregator.create( weights_type.trainable, computation_types.TensorType(tf.float32)) @federated_computation.federated_computation def init_fn(): specs = weight_tensor_specs.trainable optimizer_state = intrinsics.federated_eval( tensorflow_computation.tf_computation( lambda: optimizer.initialize(specs)), placements.SERVER) aggregator_state = full_gradient_aggregator.initialize() return intrinsics.federated_zip((optimizer_state, aggregator_state)) client_update_fn = _build_client_update_fn_for_mime_lite( model_fn, optimizer, client_weighting, use_experimental_simulation_loop) @tensorflow_computation.tf_computation( init_fn.type_signature.result.member[0], weights_type.trainable) def update_optimizer_state(state, aggregate_gradient): whimsy_weights = tf.nest.map_structure( lambda g: tf.zeros(g.shape, g.dtype), aggregate_gradient) updated_state, _ = optimizer.next(state, whimsy_weights, aggregate_gradient) return updated_state @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(weights_type), computation_types.at_clients(data_type)) def next_fn(state, weights, client_data): optimizer_state, aggregator_state = state optimizer_state_at_clients = intrinsics.federated_broadcast( optimizer_state) client_result, model_outputs, full_gradient = ( intrinsics.federated_map( client_update_fn, (optimizer_state_at_clients, weights, client_data))) full_gradient_agg_output = full_gradient_aggregator.next( aggregator_state, full_gradient, client_result.update_weight) updated_optimizer_state = intrinsics.federated_map( update_optimizer_state, (optimizer_state, full_gradient_agg_output.result)) new_state = intrinsics.federated_zip( (updated_optimizer_state, full_gradient_agg_output.state)) train_metrics = metrics_aggregation_fn(model_outputs) measurements = intrinsics.federated_zip( collections.OrderedDict(train=train_metrics)) return measured_process.MeasuredProcessOutput(new_state, client_result, measurements) return client_works.ClientWorkProcess(init_fn, next_fn)
def build_weighted_fed_avg_with_optimizer_schedule( model_fn: Callable[[], model_lib.Model], client_learning_rate_fn: Callable[[int], float], client_optimizer_fn: Callable[[float], TFFOrKerasOptimizer], server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[ [], tf.keras.optimizers.Optimizer]] = fed_avg.DEFAULT_SERVER_OPTIMIZER_FN, model_distributor: Optional[distributors.DistributionProcess] = None, model_aggregator: Optional[factory.WeightedAggregationFactory] = None, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, use_experimental_simulation_loop: bool = False ) -> learning_process.LearningProcess: """Builds a learning process for FedAvg with client optimizer scheduling. This function creates a `LearningProcess` that performs federated averaging on client models. The iterative process has the following methods inherited from `LearningProcess`: * `initialize`: A `tff.Computation` with the functional type signature `( -> S@SERVER)`, where `S` is a `LearningAlgorithmState` representing the initial state of the server. * `next`: A `tff.Computation` with the functional type signature `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `{B*}@CLIENTS` represents the client datasets. The output `L` contains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes. * `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `next`, and `M` represents the type of the model weights used during training. * `set_model_weights`: A `tff.Computation` with type signature `(<S, M> -> S)`, where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type matches the output of `initialize` and `M` represents the type of the model weights used during training. Each time the `next` method is called, the server model is broadcast to each client using a broadcast function. For each client, local training is performed using `client_optimizer_fn`. Each client computes the difference between the client model after training and the initial broadcast model. These model deltas are then aggregated at the server using a weighted aggregation function. Clients weighted by the number of examples they see thoughout local training. The aggregate model delta is applied at the server using a server optimizer. The primary purpose of this implementation of FedAvg is that it allows for the client optimizer to be scheduled across rounds. The process keeps track of how many iterations of `.next` have occurred (starting at `0`), and for each such `round_num`, the clients will use `client_optimizer_fn(round_num)` to perform local optimization. This allows learning rate scheduling (eg. starting with a large learning rate and decaying it over time) as well as a small learning rate (eg. switching optimizers as learning progresses). Note: the default server optimizer function is `tf.keras.optimizers.SGD` with a learning rate of 1.0, which corresponds to adding the model delta to the current server model. This recovers the original FedAvg algorithm in [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). More sophisticated federated averaging procedures may use different learning rates or server optimizers. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. client_learning_rate_fn: A callable accepting an integer round number and returning a float to be used as a learning rate for the optimizer. The client work will call `optimizer_fn(learning_rate_fn(round_num))` where `round_num` is the integer round number. Note that the round numbers supplied will start at `0` and increment by one each time `.next` is called on the resulting process. Also note that this function must be serializable by TFF. client_optimizer_fn: A callable accepting a float learning rate, and returning a `tff.learning.optimizers.Optimizer` or a `tf.keras.Optimizer`. server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that returns a `tf.keras.Optimizer`. By default, this uses `tf.keras.optimizers.SGD` with a learning rate of 1.0. model_distributor: An optional `DistributionProcess` that distributes the model weights on the server to the clients. If set to `None`, the distributor is constructed via `distributors.build_broadcast_process`. model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory` used to aggregate client updates on the server. If `None`, this is set to `tff.aggregators.MeanFactory`. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. If `None`, this is set to `tff.learning.metrics.sum_then_finalize`. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. Returns: A `LearningProcess`. """ py_typecheck.check_callable(model_fn) @tensorflow_computation.tf_computation() def initial_model_weights_fn(): return model_utils.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result if model_distributor is None: model_distributor = distributors.build_broadcast_process(model_weights_type) if model_aggregator is None: model_aggregator = mean.MeanFactory() py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory) aggregator = model_aggregator.create(model_weights_type.trainable, computation_types.TensorType(tf.float32)) process_signature = aggregator.next.type_signature input_client_value_type = process_signature.parameter[1] result_server_value_type = process_signature.result[1] if input_client_value_type.member != result_server_value_type.member: raise TypeError('`model_update_aggregation_factory` does not produce a ' 'compatible `AggregationProcess`. The processes must ' 'retain the type structure of the inputs on the ' f'server, but got {input_client_value_type.member} != ' f'{result_server_value_type.member}.') if metrics_aggregator is None: metrics_aggregator = metric_aggregator.sum_then_finalize client_work = build_scheduled_client_work(model_fn, client_learning_rate_fn, client_optimizer_fn, metrics_aggregator, use_experimental_simulation_loop) finalizer = finalizers.build_apply_optimizer_finalizer( server_optimizer_fn, model_weights_type) return composers.compose_learning_process(initial_model_weights_fn, model_distributor, client_work, aggregator, finalizer)
def extract_nodes_consuming(tree, predicate): """Returns the set of AST nodes which consume nodes matching `predicate`. Notice we adopt the convention that a node which itself satisfies the predicate is in this set. Args: tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an abstract syntax tree, and construct the set of nodes in this tree having a dependency on nodes matching `predicate`; that is, the set of nodes whose value depends on evaluating nodes matching `predicate`. predicate: One-arg callable, accepting arguments of type `building_blocks.ComputationBuildingBlock` and returning a `bool` indicating match or mismatch with the desired pattern. Returns: A `set` of `building_blocks.ComputationBuildingBlock` instances representing the nodes in `tree` dependent on nodes matching `predicate`. """ py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) py_typecheck.check_callable(predicate) class _NodeSet: def __init__(self): self.mapping = {} def add(self, comp): self.mapping[id(comp)] = comp def to_set(self): return set(self.mapping.values()) dependent_nodes = _NodeSet() def _are_children_in_dependent_set(comp, symbol_tree): """Checks if the dependencies of `comp` are present in `dependent_nodes`.""" if (comp.is_intrinsic() or comp.is_data() or comp.is_placement() or comp.is_compiled_computation()): return False elif comp.is_lambda(): return id(comp.result) in dependent_nodes.mapping elif comp.is_block(): return any( id(x[1]) in dependent_nodes.mapping for x in comp.locals) or id( comp.result) in dependent_nodes.mapping elif comp.is_struct(): return any(id(x) in dependent_nodes.mapping for x in comp) elif comp.is_selection(): return id(comp.source) in dependent_nodes.mapping elif comp.is_call(): return id(comp.function) in dependent_nodes.mapping or id( comp.argument) in dependent_nodes.mapping elif comp.is_reference(): return _is_reference_dependent(comp, symbol_tree) def _is_reference_dependent(comp, symbol_tree): payload = symbol_tree.get_payload_with_name(comp.name) if payload is None: return False # The postorder traversal ensures that we process any # bindings before we process the reference to those bindings return id(payload.value) in dependent_nodes.mapping def _populate_dependent_set(comp, symbol_tree): """Populates `dependent_nodes` with all nodes dependent on `predicate`.""" if predicate(comp): dependent_nodes.add(comp) elif _are_children_in_dependent_set(comp, symbol_tree): dependent_nodes.add(comp) return comp, False symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) transformation_utils.transform_postorder_with_symbol_bindings( tree, _populate_dependent_set, symbol_tree) return dependent_nodes.to_set()
def build_personalization_eval(model_fn, personalize_fn_dict, baseline_evaluate_fn, max_num_clients=100, context_tff_type=None): """Builds the TFF computation for evaluating personalization strategies. The returned TFF computation broadcasts model weights from `tff.SERVER` to `tff.CLIENTS`. Each client evaluates the personalization strategies given in `personalize_fn_dict`. Evaluation metrics from at most `max_num_clients` participating clients are collected to the server. NOTE: The functions in `personalize_fn_dict` and `baseline_evaluate_fn` are expected to take as input *unbatched* datasets, and are responsible for applying batching, if any, to the provided input datasets. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a strategy name) to a no-argument function that returns a `tf.function`. Each `tf.function` represents a personalization strategy - it accepts a `tff.learning.Model` (with weights already initialized to the given model weights when users invoke the returned TFF computation), an unbatched `tf.data.Dataset` for train, an unbatched `tf.data.Dataset` for test, and an arbitrary context object (which is used to hold any extra information that a personalization strategy may use), trains a personalized model, and returns the evaluation metrics. The evaluation metrics are represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), and an unbatched `tf.data.Dataset`, evaluates the model on the dataset, and returns the evaluation metrics. The evaluation metrics are represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This function is *only* used to compute the baseline metrics of the initial model. max_num_clients: A positive `int` specifying the maximum number of clients to collect metrics in a round (default is 100). The clients are sampled without replacement. For each sampled client, all the personalization metrics from this client will be collected. If the number of participating clients in a round is smaller than this value, then metrics from all clients will be collected. context_tff_type: A `tff.Type` of the optional context object used by the personalization strategies defined in `personalization_fn_dict`. We use a context object to hold any extra information (in addition to the training dataset) that personalization may use. If context is used in `personalization_fn_dict`, its `tff.Type` must be provided here. Returns: A federated `tff.Computation` with the functional type signature `(<model_weights@SERVER, input@CLIENTS> -> personalization_metrics@SERVER)`: * `model_weights` is a `tff.learning.ModelWeights`. * Each client's input is an `OrderedDict` of two required keys `train_data` and `test_data`; each key is mapped to an unbatched `tf.data.Dataset`. If extra context (e.g., extra datasets) is used in `personalize_fn_dict`, then client input has a third key `context` that is mapped to a object whose `tff.Type` is provided by the `context_tff_type` argument. * `personazliation_metrics` is an `OrderedDict` that maps a key 'baseline_metrics' to the evaluation metrics of the initial model (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of the corresponding personalization strategies. * Note: only metrics from at most `max_num_clients` participating clients (sampled without replacement) are collected to the SERVER. All collected metrics are stored in a single `OrderedDict` (`personalization_metrics` shown above), where each metric is mapped to a list of scalars (each scalar comes from one client). Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all come from the same client. Raises: TypeError: If arguments are of the wrong types. ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`. ValueError: If `max_num_clients` is not positive. """ # Obtain the types by constructing the model first. # TODO(b/124477628): Replace it with other ways of handling metadata. with tf.Graph().as_default(): py_typecheck.check_callable(model_fn) model = model_fn() model_weights_type = model_utils.weights_type_from_model(model) batch_tff_type = computation_types.to_type(model.input_spec) # Define the `tff.Type` of each client's input. Since batching (as well as # other preprocessing of datasets) is done within each personalization # strategy (i.e., by functions in `personalize_fn_dict`), the client-side # input should contain unbatched elements. element_tff_type = _remove_batch_dim(batch_tff_type) client_input_type = collections.OrderedDict( train_data=computation_types.SequenceType(element_tff_type), test_data=computation_types.SequenceType(element_tff_type)) if context_tff_type is not None: py_typecheck.check_type(context_tff_type, computation_types.Type) client_input_type['context'] = context_tff_type client_input_type = computation_types.to_type(client_input_type) @computations.tf_computation(model_weights_type, client_input_type) def _client_computation(initial_model_weights, client_input): """TFF computation that runs on each client.""" train_data = client_input['train_data'] test_data = client_input['test_data'] context = client_input.get('context', None) final_metrics = collections.OrderedDict() # Compute the evaluation metrics of the initial model. final_metrics['baseline_metrics'] = _compute_baseline_metrics( model_fn, initial_model_weights, test_data, baseline_evaluate_fn) py_typecheck.check_type(personalize_fn_dict, collections.OrderedDict) if 'baseline_metrics' in personalize_fn_dict: raise ValueError('baseline_metrics should not be used as a key in ' 'personalize_fn_dict.') # Compute the evaluation metrics of the personalized models. The returned # `p13n_metrics` is an `OrderedDict` that maps keys (strategy names) in # `personalize_fn_dict` to the evaluation metrics of the corresponding # personalization strategies. p13n_metrics = _compute_p13n_metrics(model_fn, initial_model_weights, train_data, test_data, personalize_fn_dict, context) final_metrics.update(p13n_metrics) return final_metrics py_typecheck.check_type(max_num_clients, int) if max_num_clients <= 0: raise ValueError('max_num_clients must be a positive integer.') reservoir_sampling_factory = sampling.UnweightedReservoirSamplingFactory( sample_size=max_num_clients) aggregation_process = reservoir_sampling_factory.create( _client_computation.type_signature.result) @computations.federated_computation( computation_types.FederatedType(model_weights_type, placements.SERVER), computation_types.FederatedType(client_input_type, placements.CLIENTS)) def personalization_eval(server_model_weights, federated_client_input): """TFF orchestration logic.""" client_init_weights = intrinsics.federated_broadcast(server_model_weights) client_final_metrics = intrinsics.federated_map( _client_computation, (client_init_weights, federated_client_input)) # WARNING: Collecting information from clients can be risky. Users have to # make sure that it is proper to collect those metrics from clients. # TODO(b/147889283): Add a link to the TFF doc once it exists. sampling_output = aggregation_process.next( aggregation_process.initialize(), # No state. client_final_metrics) # In the future we may want to output `sampling_output.measurements` also # but currently it is empty. return sampling_output.result return personalization_eval
def zero_or_one_arg_fn_to_building_block( fn, parameter_name: Optional[str], parameter_type: Optional[computation_types.Type], context_stack: context_stack_base.ContextStack, suggested_name: Optional[str] = None, ) -> Tuple[building_blocks.ComputationBuildingBlock, computation_types.Type]: """Converts a zero- or one-argument `fn` into a computation building block. Args: fn: A function with 0 or 1 arguments that contains orchestration logic, i.e., that expects zero or one `values_base.Value` and returns a result convertible to the same. parameter_name: The name of the parameter, or `None` if there is't any. parameter_type: The `tff.Type` of the parameter, or `None` if there's none. context_stack: The context stack to use. suggested_name: The optional suggested name to use for the federated context that will be used to serialize this function's body (ideally the name of the underlying Python function). It might be modified to avoid conflicts. Returns: A tuple of `(building_blocks.ComputationBuildingBlock, computation_types.Type)`, where the first element contains the logic from `fn`, and the second element contains potentially annotated type information for the result of `fn`. Raises: ValueError: if `fn` is incompatible with `parameter_type`. """ py_typecheck.check_callable(fn) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if suggested_name is not None: py_typecheck.check_type(suggested_name, str) if isinstance(context_stack.current, federated_computation_context.FederatedComputationContext): parent_context = context_stack.current else: parent_context = None context = federated_computation_context.FederatedComputationContext( context_stack, suggested_name=suggested_name, parent=parent_context) if parameter_name is not None: py_typecheck.check_type(parameter_name, str) parameter_name = '{}_{}'.format(context.name, str(parameter_name)) with context_stack.install(context): if parameter_type is not None: result = fn( value_impl.ValueImpl( building_blocks.Reference(parameter_name, parameter_type), context_stack)) else: result = fn() if result is None: raise ValueError( 'The function defined on line {} of file {} has returned a ' '`NoneType`, but all TFF functions must return some non-`None` ' 'value.'.format(fn.__code__.co_firstlineno, fn.__code__.co_filename)) annotated_result_type = type_conversions.infer_type(result) result = value_impl.to_value(result, annotated_result_type, context_stack) result_comp = value_impl.ValueImpl.get_comp(result) symbols_bound_in_context = context_stack.current.symbol_bindings if symbols_bound_in_context: result_comp = building_blocks.Block( local_symbols=symbols_bound_in_context, result=result_comp) annotated_type = computation_types.FunctionType( parameter_type, annotated_result_type) return building_blocks.Lambda(parameter_name, parameter_type, result_comp), annotated_type
def create_executor_factory( executor_stack_fn: Callable[[CardinalitiesType], executor_base.Executor] ) -> ExecutorFactory: """Create an `ExecutorFactory` for a given executor stack function.""" py_typecheck.check_callable(executor_stack_fn) return ExecutorFactoryImpl(executor_stack_fn)
def build_jax_federated_averaging_process(batch_type, model_type, loss_fn, step_size): """Constructs an iterative process that implements simple federated averaging. Args: batch_type: An instance of `tff.Type` that represents the type of a single batch of data to use for training. This type should be constructed with standard Python containers (such as `collections.OrderedDict`) of the sort that are expected as parameters to `loss_fn`. model_type: An instance of `tff.Type` that represents the type of the model. Similarly to `batch_size`, this type should be constructed with standard Python containers (such as `collections.OrderedDict`) of the sort that are expected as parameters to `loss_fn`. loss_fn: A loss function for the model. Must be a Python function that takes two parameters, one of them being the model, and the other being a single batch of data (with types matching `batch_type` and `model_type`). step_size: The step size to use during training (an `np.float32`). Returns: An instance of `tff.templates.IterativeProcess` that implements federated training in JAX. """ batch_type = computation_types.to_type(batch_type) model_type = computation_types.to_type(model_type) py_typecheck.check_type(batch_type, computation_types.Type) py_typecheck.check_type(model_type, computation_types.Type) py_typecheck.check_callable(loss_fn) py_typecheck.check_type(step_size, np.float) def _tensor_zeros(tensor_type): return jax.numpy.zeros(tensor_type.shape.dims, dtype=tensor_type.dtype.as_numpy_dtype) @experimental_computations.jax_computation def _create_zero_model(): model_zeros = structure.map_structure(_tensor_zeros, model_type) return type_conversions.type_to_py_container(model_zeros, model_type) @computations.federated_computation def _create_zero_model_on_server(): return intrinsics.federated_eval(_create_zero_model, placements.SERVER) def _apply_update(model_param, param_delta): return model_param - step_size * param_delta @experimental_computations.jax_computation(model_type, batch_type) def _train_on_one_batch(model, batch): params = structure.flatten( structure.from_container(model, recursive=True)) grads = structure.flatten( structure.from_container(jax.api.grad(loss_fn)(model, batch))) updated_params = [_apply_update(x, y) for (x, y) in zip(params, grads)] trained_model = structure.pack_sequence_as(model_type, updated_params) return type_conversions.type_to_py_container(trained_model, model_type) local_dataset_type = computation_types.SequenceType(batch_type) @computations.federated_computation(model_type, local_dataset_type) def _train_on_one_client(model, batches): return intrinsics.sequence_reduce(batches, model, _train_on_one_batch) @computations.federated_computation( computation_types.FederatedType(model_type, placements.SERVER), computation_types.FederatedType(local_dataset_type, placements.CLIENTS)) def _train_one_round(model, federated_data): locally_trained_models = intrinsics.federated_map( _train_on_one_client, collections.OrderedDict([('model', intrinsics.federated_broadcast(model)), ('batches', federated_data)])) return intrinsics.federated_mean(locally_trained_models) return iterative_process.IterativeProcess( initialize_fn=_create_zero_model_on_server, next_fn=_train_one_round)
def __init__(self, raw_client_data: client_data.ClientData, make_transform_fn: Callable[[str, int], Callable[[Any], Any]], num_transformed_clients: Optional[int] = None): """Initializes the TransformingClientData. Args: raw_client_data: A ClientData to expand. make_transform_fn: A function that returns a callable that maps datapoint x to a new datapoint x'. make_transform_fn will be called as make_transform_fn(raw_client_id, i) where i is an integer index, and should return a function fn(x)->x. For example if x is an image, then make_transform_fn("client_a", 0)(x) might be the identity, while make_transform_fn("client_a", 1)(x) could be a random rotation of the image with the angle determined by a hash of "client_a" and "1". If transform_fn_cons returns `None`, no transformation is performed. Typically by convention the index 0 corresponds to the identity function if the identity is supported. num_transformed_clients: The total number of transformed clients to produce. If `None`, only the original clients will be transformed. If it is an integer multiple k of the number of real clients, there will be exactly k pseudo-clients per real client, with indices 0...k-1. Any remainder g will be generated from the first g real clients and will be given index k. """ py_typecheck.check_type(raw_client_data, client_data.ClientData) py_typecheck.check_callable(make_transform_fn) raw_client_ids = raw_client_data.client_ids if not raw_client_ids: raise ValueError('`raw_client_data` must be non-empty.') if num_transformed_clients is None: num_transformed_clients = len(raw_client_ids) else: py_typecheck.check_type(num_transformed_clients, int) if num_transformed_clients <= 0: raise ValueError('`num_transformed_clients` must be positive.') self._raw_client_data = raw_client_data self._make_transform_fn = make_transform_fn self._has_pseudo_clients = num_transformed_clients > len( raw_client_ids) if self._has_pseudo_clients: num_digits = len(str(num_transformed_clients - 1)) format_str = '{}_{:0' + str(num_digits) + '}' k = num_transformed_clients // len(raw_client_ids) self._client_ids = [] for raw_client_id in raw_client_ids: for i in range(k): self._client_ids.append(format_str.format( raw_client_id, i)) num_extra_client_ids = num_transformed_clients - k * len( raw_client_ids) for c in range(num_extra_client_ids): self._client_ids.append(format_str.format( raw_client_ids[c], k)) else: self._client_ids = raw_client_ids # Already sorted if raw_client_data.client_ids are, but just to be sure... self._client_ids = sorted(self._client_ids)
def create_binary_operator( operator, operand_type: computation_types.Type) -> pb.Computation: """Returns a tensorflow computation computing a binary operation. The returned computation has the type signature `(<T,T> -> U)`, where `T` is `operand_type` and `U` is the result of applying the `operator` to a tuple of type `<T,T>` Note: If `operand_type` is a `computation_types.NamedTupleType`, then `operator` will be applied pointwise. This places the burden on callers of this function to construct the correct values to pass into the returned function. For example, to divide `[2, 2]` by `2`, first `2` must be packed into the data structure `[x, x]`, before the division operator of the appropriate type is called. Args: operator: A callable taking two arguments representing the operation to encode For example: `tf.math.add`, `tf.math.multiply`, and `tf.math.divide`. operand_type: A `computation_types.Type` to use as the argument to the constructed binary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `operand_type` are violated or `operator` is not callable. """ if not type_analysis.is_generic_op_compatible_type(operand_type): raise TypeError( 'The type {} contains a type other than `computation_types.TensorType` ' 'and `computation_types.NamedTupleType`; this is disallowed in the ' 'generic operators.'.format(operand_type)) py_typecheck.check_callable(operator) with tf.Graph().as_default() as graph: operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', operand_type, graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', operand_type, graph) if operand_type is not None: if operand_type.is_tensor(): result_value = operator(operand_1_value, operand_2_value) elif operand_type.is_tuple(): result_value = anonymous_tuple.map_structure(operator, operand_1_value, operand_2_value) else: raise TypeError( 'Operand type {} cannot be used in generic operations. The call to ' '`type_analysis.is_generic_op_compatible_type` has allowed it to ' 'pass, and should be updated.'.format(operand_type)) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType( computation_types.NamedTupleType((operand_type, operand_type)), result_type) parameter_binding = pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def __init__(self, compilation_fn: Callable[[computation_base.Computation], Any]): py_typecheck.check_callable(compilation_fn) self._compilation_fn = compilation_fn
def extract_nodes_consuming(tree, predicate): """Returns the set of AST nodes which consume nodes matching `predicate`. Notice we adopt the convention that a node which itself satisfies the predicate is in this set. Args: tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an abstract syntax tree, and construct the set of nodes in this tree having a dependency on nodes matching `predicate`; that is, the set of nodes whose value depends on evaluating nodes matching `predicate`. predicate: One-arg callable, accepting arguments of type `building_blocks.ComputationBuildingBlock` and returning a `bool` indicating match or mismatch with the desired pattern. Returns: A `set` of `building_blocks.ComputationBuildingBlock` instances representing the nodes in `tree` dependent on nodes matching `predicate`. """ py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) py_typecheck.check_callable(predicate) dependent_nodes = set() def _are_children_in_dependent_set(comp, symbol_tree): """Checks if the dependencies of `comp` are present in `dependent_nodes`.""" if isinstance( comp, (building_blocks.Intrinsic, building_blocks.Data, building_blocks.Placement, building_blocks.CompiledComputation)): return False elif isinstance(comp, building_blocks.Lambda): return comp.result in dependent_nodes elif isinstance(comp, building_blocks.Block): return any(x[1] in dependent_nodes for x in comp.locals) or comp.result in dependent_nodes elif isinstance(comp, building_blocks.Tuple): return any(x in dependent_nodes for x in comp) elif isinstance(comp, building_blocks.Selection): return comp.source in dependent_nodes elif isinstance(comp, building_blocks.Call): return comp.function in dependent_nodes or comp.argument in dependent_nodes elif isinstance(comp, building_blocks.Reference): return _is_reference_dependent(comp, symbol_tree) def _is_reference_dependent(comp, symbol_tree): payload = symbol_tree.get_payload_with_name(comp.name) if payload is None: return False # The postorder traversal ensures that we process any # bindings before we process the reference to those bindings return payload.value in dependent_nodes def _populate_dependent_set(comp, symbol_tree): """Populates `dependent_nodes` with all nodes dependent on `predicate`.""" if predicate(comp): dependent_nodes.add(comp) elif _are_children_in_dependent_set(comp, symbol_tree): dependent_nodes.add(comp) return comp, False symbol_tree = transformation_utils.SymbolTree( transformation_utils.ReferenceCounter) transformation_utils.transform_postorder_with_symbol_bindings( tree, _populate_dependent_set, symbol_tree) return dependent_nodes
def transform_preorder( comp: building_blocks.ComputationBuildingBlock, transform: Callable[[building_blocks.ComputationBuildingBlock], TransformReturnType] ) -> TransformReturnType: """Walks the AST of `comp` preorder, calling `transform` on the way down. Notice that this function will stop walking the tree when its transform function modifies a node; this is to prevent the caller from unexpectedly kicking off an infinite recursion. For this purpose the transform function must identify when it has transformed the structure of a building block; if the structure of the building block is modified but `False` is returned as the second element of the tuple returned by `transform`, `transform_preorder` may result in an infinite recursion. Args: comp: Instance of `building_blocks.ComputationBuildingBlock` to be transformed in a preorder fashion. transform: Transform function to be applied to the nodes of `comp`. Must return a two-tuple whose first element is a `building_blocks.ComputationBuildingBlock` and whose second element is a Boolean. If the computation which is passed to `comp` is returned in a modified state, must return `True` for the second element. Returns: A two-tuple, whose first element is modified version of `comp`, and whose second element is a Boolean indicating whether `comp` was transformed during the walk. Raises: TypeError: If the argument types don't match those specified above. """ py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) py_typecheck.check_callable(transform) inner_comp, modified = transform(comp) if modified: return inner_comp, modified if (inner_comp.is_compiled_computation() or inner_comp.is_data() or inner_comp.is_intrinsic() or inner_comp.is_placement() or inner_comp.is_reference()): return inner_comp, modified elif inner_comp.is_lambda(): transformed_result, result_modified = transform_preorder( inner_comp.result, transform) if not (modified or result_modified): return inner_comp, False return building_blocks.Lambda(inner_comp.parameter_name, inner_comp.parameter_type, transformed_result), True elif inner_comp.is_tuple(): elements_modified = False elements = [] for name, val in anonymous_tuple.iter_elements(inner_comp): result, result_modified = transform_preorder(val, transform) elements_modified = elements_modified or result_modified elements.append((name, result)) if not (modified or elements_modified): return inner_comp, False return building_blocks.Tuple(elements), True elif inner_comp.is_selection(): transformed_source, source_modified = transform_preorder( inner_comp.source, transform) if not (modified or source_modified): return inner_comp, False return building_blocks.Selection(transformed_source, inner_comp.name, inner_comp.index), True elif inner_comp.is_call(): transformed_fn, fn_modified = transform_preorder(inner_comp.function, transform) if inner_comp.argument is not None: transformed_arg, arg_modified = transform_preorder( inner_comp.argument, transform) else: transformed_arg = None arg_modified = False if not (modified or fn_modified or arg_modified): return inner_comp, False return building_blocks.Call(transformed_fn, transformed_arg), True elif inner_comp.is_block(): transformed_variables = [] values_modified = False for key, value in inner_comp.locals: transformed_value, value_modified = transform_preorder(value, transform) transformed_variables.append((key, transformed_value)) values_modified = values_modified or value_modified transformed_result, result_modified = transform_preorder( inner_comp.result, transform) if not (modified or values_modified or result_modified): return inner_comp, False return building_blocks.Block(transformed_variables, transformed_result), True else: raise NotImplementedError( 'Unrecognized computation building block: {}'.format(str(inner_comp)))
def build_personalization_eval(model_fn, personalize_fn_dict, baseline_evaluate_fn, max_num_samples=100, context_tff_type=None): """Builds the TFF computation for evaluating personalization strategies. The returned TFF computation broadcasts model weights from SERVER to CLIENTS. Each client evaluates the personalization strategies given in `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples` participating clients are collected to the SERVER. Args: model_fn: A no-argument function that returns a `tff.learning.Model`. personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a strategy name) to a no-argument function that returns a `tf.function`. Each `tf.function` represents a personalization strategy: it accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), a training `tf.dataset.Dataset`, a test `tf.dataset.Dataset`, and an arbitrary context object (which is used to hold any extra information that a personalization strategy may use), trains a personalized model, and returns the evaluation metrics. The evaluation metrics are usually represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), and a `tf.dataset.Dataset`, evaluates the model on the dataset, and returns the evaluation metrics. The evaluation metrics are usually represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This function is *only* used to compute the baseline metrics of the initial model. max_num_samples: A positive `int` specifying the maximum number of metric samples to collect in a round. Each sample contains the personalization metrics from a single client. If the number of participating clients in a round is smaller than this value, all clients' metrics are collected. context_tff_type: A `tff.Type` of the optional context object used by the personalization strategies defined in `personalization_fn_dict`. We use a context object to hold any extra information (in addition to the training dataset) that personalization may use. If context is used in `personalization_fn_dict`, its `tff.Type` must be provided here. Returns: A federated `tff.Computation` that maps < model_weights@SERVER, input@CLIENTS > -> personalization_metrics@SERVER, where: - model_weights is a `tff.learning.framework.ModelWeights`. - each client's input is an `OrderedDict` of at least two keys `train_data` and `test_data`, and each key is mapped to a `tf.dataset.Dataset`. If context is used in `personalize_fn_dict`, then client input has a third key `context` that is mapped to a object whose `tff.Type` is provided by the `context_tff_type` argument. - personazliation_metrics is an `OrderedDict` that maps a key 'baseline_metrics' to the evaluation metrics of the initial model (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of the corresponding personalization strategies. - Note: only metrics from at most `max_num_samples` participating clients are collected to the SERVER. All collected metrics are stored in a single `OrderedDict` (the personalization_metrics shown above), where each metric is mapped to a list of scalars (each scalar comes from one client). Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all come from the same client. Raises: TypeError: If arguments are of the wrong types. ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`. ValueError: If `max_num_samples` is not positive. """ # Obtain the types by constructing the model first. # TODO(b/124477628): Replace it with other ways of handling metadata. with tf.Graph().as_default(): py_typecheck.check_callable(model_fn) model = model_utils.enhance(model_fn()) model_weights_type = tff.framework.type_from_tensors(model.weights) batch_type = tff.to_type(model.input_spec) # Define the `tff.Type` of each client's input. client_input_type = collections.OrderedDict([ ('train_data', tff.SequenceType(batch_type)), ('test_data', tff.SequenceType(batch_type)) ]) if context_tff_type is not None: py_typecheck.check_type(context_tff_type, tff.Type) client_input_type['context'] = context_tff_type client_input_type = tff.to_type(client_input_type) @tff.tf_computation(model_weights_type, client_input_type) def _client_computation(initial_model_weights, client_input): """TFF computation that runs on each client.""" model = model_fn() train_data = client_input['train_data'] test_data = client_input['test_data'] context = client_input.get('context', None) return _client_fn(model, initial_model_weights, train_data, test_data, personalize_fn_dict, baseline_evaluate_fn, context) py_typecheck.check_type(max_num_samples, int) if max_num_samples <= 0: raise ValueError('max_num_samples must be a positive integer.') @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(client_input_type, tff.CLIENTS)) def personalization_eval(server_model_weights, federated_client_input): """TFF orchestration logic.""" client_init_weights = tff.federated_broadcast(server_model_weights) client_final_metrics = tff.federated_map( _client_computation, (client_init_weights, federated_client_input)) # WARNING: Collecting information from clients can be risky. Users have to # make sure that it is proper to collect those metrics from clients. # TODO(b/147889283): Add a link to the TFF doc once it exists. results = tff.utils.federated_sample(client_final_metrics, max_num_samples) return results return personalization_eval
def build_model_delta_client_work( model_fn: Callable[[], model_lib.Model], optimizer: Union[optimizer_base.Optimizer, Callable[[], tf.keras.optimizers.Optimizer]], client_weighting: client_weight_lib.ClientWeighting, delta_l2_regularizer: float = 0.0, metrics_aggregator: Optional[Callable[[ model_lib.MetricFinalizersType, computation_types.StructWithPythonType ], computation_base.Computation]] = None, *, use_experimental_simulation_loop: bool = False ) -> client_works.ClientWorkProcess: """Creates a `ClientWorkProcess` for federated averaging. This client work is constructed in slightly different manners depending on whether `optimizer` is a `tff.learning.optimizers.Optimizer`, or a no-arg callable returning a `tf.keras.optimizers.Optimizer`. If it is a `tff.learning.optimizers.Optimizer`, we avoid creating `tf.Variable`s associated with the optimizer state within the scope of the client work, as they are not necessary. This also means that the client's model weights are updated by computing `optimizer.next` and then assigning the result to the model weights (while a `tf.keras.optimizers.Optimizer` will modify the model weight in place using `optimizer.apply_gradients`). Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. optimizer: A `tff.learning.optimizers.Optimizer`, or a no-arg callable that returns a `tf.keras.Optimizer`. client_weighting: A `tff.learning.ClientWeighting` value. delta_l2_regularizer: A nonnegative float representing the parameter of the L2-regularization term applied to the delta from initial model weights during training. Values larger than 0.0 prevent clients from moving too far from the server model during local training. metrics_aggregator: A function that takes in the metric finalizers (i.e., `tff.learning.Model.metric_finalizers()`) and a `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF type of `tff.learning.Model.report_local_unfinalized_metrics()`), and returns a `tff.Computation` for aggregating the unfinalized metrics. If `None`, this is set to `tff.learning.metrics.sum_then_finalize`. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations. Returns: A `ClientWorkProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_type(client_weighting, client_weight_lib.ClientWeighting) py_typecheck.check_type(delta_l2_regularizer, float) if delta_l2_regularizer < 0.0: raise ValueError(f'Provided delta_l2_regularizer must be non-negative,' f'but found: {delta_l2_regularizer}') if not (isinstance(optimizer, optimizer_base.Optimizer) or callable(optimizer)): raise TypeError( 'Provided optimizer must a either a tff.learning.optimizers.Optimizer ' 'or a no-arg callable returning an tf.keras.optimizers.Optimizer.') if metrics_aggregator is None: metrics_aggregator = aggregator.sum_then_finalize with tf.Graph().as_default(): # Wrap model construction in a graph to avoid polluting the global context # with variables created for this model. model = model_fn() unfinalized_metrics_type = type_conversions.type_from_tensors( model.report_local_unfinalized_metrics()) metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(), unfinalized_metrics_type) data_type = computation_types.SequenceType(model.input_spec) weights_type = model_utils.weights_type_from_model(model) if isinstance(optimizer, optimizer_base.Optimizer): @tensorflow_computation.tf_computation(weights_type, data_type) def client_update_computation(initial_model_weights, dataset): client_update = build_model_delta_update_with_tff_optimizer( model_fn=model_fn, weighting=client_weighting, delta_l2_regularizer=delta_l2_regularizer, use_experimental_simulation_loop= use_experimental_simulation_loop) return client_update(optimizer, initial_model_weights, dataset) else: @tensorflow_computation.tf_computation(weights_type, data_type) def client_update_computation(initial_model_weights, dataset): keras_optimizer = optimizer() client_update = build_model_delta_update_with_keras_optimizer( model_fn=model_fn, weighting=client_weighting, delta_l2_regularizer=delta_l2_regularizer, use_experimental_simulation_loop= use_experimental_simulation_loop) return client_update(keras_optimizer, initial_model_weights, dataset) @federated_computation.federated_computation def init_fn(): return intrinsics.federated_value((), placements.SERVER) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(weights_type), computation_types.at_clients(data_type)) def next_fn(state, weights, client_data): client_result, model_outputs = intrinsics.federated_map( client_update_computation, (weights, client_data)) train_metrics = metrics_aggregation_fn(model_outputs) measurements = intrinsics.federated_zip( collections.OrderedDict(train=train_metrics)) return measured_process.MeasuredProcessOutput(state, client_result, measurements) return client_works.ClientWorkProcess(init_fn, next_fn)
def build_model_delta_optimizer_process( model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn=build_stateless_mean(), stateful_model_broadcast_fn=build_stateless_broadcaster()): """Constructs `tff.utils.IterativeProcess` for Federated Averaging or SGD. This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device. Note: We pass in functions rather than constructed objects so we can ensure any variables or ops created in constructors are placed in the correct graph. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. model_to_client_delta_fn: A function from a model_fn to a `ClientDeltaFn`. server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the next_fn performs a federated aggregation and upates state. That is, it has TFF type (state@SERVER, value@CLIENTS) -> (state@SERVER, aggregate@SERVER). stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the next_fn performs a federated broadcast and upates state. That is, it has TFF type (state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS). Returns: A `tff.utils.IterativeProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(model_to_client_delta_fn) py_typecheck.check_callable(server_optimizer_fn) py_typecheck.check_type(stateful_delta_aggregate_fn, tff.utils.StatefulAggregateFn) py_typecheck.check_type(stateful_model_broadcast_fn, tff.utils.StatefulBroadcastFn) # TODO(b/122081673): would be nice not to have the construct a throwaway model # here just to get the types. After fully moving to TF2.0 and eager-mode, we # should re-evaluate what happens here. with tf.Graph().as_default(): dummy_model_for_metadata = model_utils.enhance(model_fn()) @tff.tf_computation def tf_init_fn(): return server_init(model_fn, server_optimizer_fn, stateful_delta_aggregate_fn.initialize(), stateful_model_broadcast_fn.initialize()) @tff.federated_computation def server_init_tff(): """Orchestration logic for server model initialization.""" return tff.federated_value(tf_init_fn(), tff.SERVER) federated_server_state_type = server_init_tff.type_signature.result server_state_type = tf_init_fn.type_signature.result tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS, all_equal=False) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ model_weights_type = server_state_type.model @tff.tf_computation(tf_dataset_type, model_weights_type) def client_delta_tf(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output new_broadcaster_state, client_model = stateful_model_broadcast_fn( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map(client_delta_tf, (federated_dataset, client_model)) @tff.tf_computation( server_state_type, model_weights_type.trainable, server_state.delta_aggregate_state.type_signature.member, server_state.model_broadcast_state.type_signature.member) def server_update_tf(server_state, model_delta, new_delta_aggregate_state, new_broadcaster_state): """Converts args to correct python types and calls server_update_model.""" py_typecheck.check_type(server_state, ServerState) server_state = ServerState( model=server_state.model, optimizer_state=list(server_state.optimizer_state), delta_aggregate_state=new_delta_aggregate_state, model_broadcast_state=new_broadcaster_state) return server_update_model(server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF fed_weight_type = client_outputs.weights_delta_weight.type_signature.member py_typecheck.check_type(fed_weight_type, tff.TensorType) if fed_weight_type.dtype.is_integer: @tff.tf_computation(fed_weight_type) def _cast_to_float(x): return tf.cast(x, tf.float32) weight_denom = tff.federated_map( _cast_to_float, client_outputs.weights_delta_weight) else: weight_denom = client_outputs.weights_delta_weight new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # server_update_tf directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply( server_update_tf, (server_state, round_model_delta, new_delta_aggregate_state, new_broadcaster_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs return tff.utils.IterativeProcess(initialize_fn=server_init_tff, next_fn=run_one_round_tff)
def serialize_tf2_as_tf_computation(target, parameter_type, unpack=None): """Serializes the 'target' as a TF computation with a given parameter type. Args: target: The entity to convert into and serialize as a TF computation. This can currently only be a Python function or `tf.function`, with arguments matching the 'parameter_type'. parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `types.Type`, or something that's convertible to it by `types.to_type()`. unpack: Whether to always unpack the parameter_type. Necessary for support of polymorphic tf2_computations. Returns: The constructed `pb.Computation` instance with the `pb.TensorFlow` variant set. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ py_typecheck.check_callable(target) parameter_type = computation_types.to_type(parameter_type) signature = function_utils.get_signature(target) if signature.parameters and parameter_type is None: raise ValueError( 'Expected the target to declare no parameters, found {!r}.'.format( signature.parameters)) # In the codepath for TF V1 based serialization (tff.tf_computation), # we get the "wrapped" function to serialize. Here, target is the # raw function to be wrapped; however, we still need to know if # the parameter_type should be unpacked into multiple args and kwargs # in order to construct the TensorSpecs to be passed in the call # to get_concrete_fn below. unpack = function_utils.infer_unpack_needed(target, parameter_type, unpack) arg_typespecs, kwarg_typespecs, parameter_binding = ( tensorflow_utils.get_tf_typespec_and_binding( parameter_type, arg_names=list(signature.parameters.keys()), unpack=unpack)) # Pseudo-global to be appended to once when target_poly below is traced. type_and_binding_slot = [] # N.B. To serialize a tf.function or eager python code, # the return type must be a flat list, tuple, or dict. However, the # tff.tf_computation must be able to handle structured inputs and outputs. # Thus, we intercept the result of calling the original target fn, introspect # its structure to create a result_type and bindings, and then return a # flat dict output. It is this new "unpacked" tf.function that we will # serialize using tf.saved_model.save. # # TODO(b/117428091): The return type limitation is primarily a limitation of # SignatureDefs and therefore of the signatures argument to # tf.saved_model.save. tf.functions attached to objects and loaded back with # tf.saved_model.load can take/return nests; this might offer a better # approach to the one taken here. @tf.function def target_poly(*args, **kwargs): result = target(*args, **kwargs) result_dict, result_type, result_binding = ( tensorflow_utils.get_tf2_result_dict_and_binding(result)) assert not type_and_binding_slot # A "side channel" python output. type_and_binding_slot.append((result_type, result_binding)) return result_dict # Triggers tracing so that type_and_binding_slot is filled. cc_fn = target_poly.get_concrete_function(*arg_typespecs, **kwarg_typespecs) assert len(type_and_binding_slot) == 1 result_type, result_binding = type_and_binding_slot[0] # N.B. Note that cc_fn does *not* accept the same args and kwargs as the # Python target_poly; instead, it must be called with **kwargs based on the # unique names embedded in the TensorSpecs inside arg_typespecs and # kwarg_typespecs. The (preliminary) parameter_binding tracks the mapping # between these tensor names and the components of the (possibly nested) TFF # input type. When cc_fn is serialized, concrete tensors for each input are # introduced, and the call finalize_binding(parameter_binding, # sigs['serving_default'].inputs) updates the bindings to reference these # concrete tensors. # Associate vars with unique names and explicitly attach to the Checkpoint: var_dict = { 'var{:02d}'.format(i): v for i, v in enumerate(cc_fn.graph.variables) } saveable = tf.train.Checkpoint(fn=target_poly, **var_dict) try: # TODO(b/122081673): All we really need is the meta graph def, we could # probably just load that directly, e.g., using parse_saved_model from # tensorflow/python/saved_model/loader_impl.py, but I'm not sure we want to # depend on that presumably non-public symbol. Perhaps TF can expose a way # to just get the MetaGraphDef directly without saving to a tempfile? This # looks like a small change to v2.saved_model.save(). outdir = tempfile.mkdtemp('savedmodel') tf.saved_model.save(saveable, outdir, signatures=cc_fn) graph = tf.Graph() with tf.compat.v1.Session(graph=graph) as sess: mgd = tf.compat.v1.saved_model.load( sess, tags=[tf.saved_model.SERVING], export_dir=outdir) finally: shutil.rmtree(outdir) sigs = mgd.signature_def # TODO(b/123102455): Figure out how to support the init_op. The meta graph def # contains sigs['__saved_model_init_op'].outputs['__saved_model_init_op']. It # probably won't do what we want, because it will want to read from # Checkpoints, not just run Variable initializerse (?). The right solution may # be to grab the target_poly.get_initialization_function(), and save a sig for # that. # Now, traverse the signature from the MetaGraphDef to find # find the actual tensor names and write them into the bindings. finalize_binding(parameter_binding, sigs['serving_default'].inputs) finalize_binding(result_binding, sigs['serving_default'].outputs) annotated_type = computation_types.FunctionType(parameter_type, result_type) return pb.Computation( type=pb.Type( function=pb.FunctionType( parameter=type_serialization.serialize_type(parameter_type), result=type_serialization.serialize_type(result_type))), tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(mgd.graph_def), parameter=parameter_binding, result=result_binding)), annotated_type
def build_model_delta_optimizer_process( model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn=build_stateless_mean(), stateful_model_broadcast_fn=build_stateless_broadcaster()): """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD. This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a `ClientDeltaFn` that specifies how `weight_deltas` are computed on device. Note: We pass in functions rather than constructed objects so we can ensure any variables or ops created in constructors are placed in the correct graph. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`. server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the `next_fn` performs a federated aggregation and upates state. That is, it has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER)`, where the `value` type is `tff.learning.framework.ModelWeights.trainable` corresponding to the object returned by `model_fn`. stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the `next_fn` performs a federated broadcast and upates state. That is, it has TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`, where the `value` type is `tff.learning.framework.ModelWeights` corresponding to the object returned by `model_fn`. Returns: A `tff.templates.IterativeProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(model_to_client_delta_fn) py_typecheck.check_callable(server_optimizer_fn) py_typecheck.check_type(stateful_delta_aggregate_fn, tff.utils.StatefulAggregateFn) py_typecheck.check_type(stateful_model_broadcast_fn, tff.utils.StatefulBroadcastFn) # TODO(b/124477628): would be nice not to have the construct a throwaway model # here just to get the types. After fully moving to TF2.0 and eager-mode, we # should re-evaluate what happens here. # TODO(b/144382142): Keras name uniquification is probably the main reason we # still need this. with tf.Graph().as_default(): dummy_model_for_metadata = model_utils.enhance(model_fn()) # =========================================================================== # TensorFlow Computations @tff.tf_computation def tf_init_fn(): return server_init(model_fn, server_optimizer_fn, stateful_delta_aggregate_fn.initialize(), stateful_model_broadcast_fn.initialize()) tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) server_state_type = tf_init_fn.type_signature.result @tff.tf_computation(tf_dataset_type, server_state_type.model) def tf_client_delta(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output @tff.tf_computation(server_state_type, server_state_type.model.trainable, server_state_type.delta_aggregate_state, server_state_type.model_broadcast_state) def tf_server_update(server_state, model_delta, new_delta_aggregate_state, new_broadcaster_state): """Converts args to correct python types and calls server_update_model.""" py_typecheck.check_type(server_state, ServerState) server_state = ServerState( model=server_state.model, optimizer_state=list(server_state.optimizer_state), delta_aggregate_state=new_delta_aggregate_state, model_broadcast_state=new_broadcaster_state) return server_update_model(server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) weight_type = tf_client_delta.type_signature.result.weights_delta_weight @tff.tf_computation(weight_type) def _cast_weight_to_float(x): return tf.cast(x, tf.float32) # =========================================================================== # Federated Computations @tff.federated_computation def server_init_tff(): """Orchestration logic for server model initialization.""" return tff.federated_value(tf_init_fn(), tff.SERVER) federated_server_state_type = server_init_tff.type_signature.result federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ new_broadcaster_state, client_model = stateful_model_broadcast_fn( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map(tf_client_delta, (federated_dataset, client_model)) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF weight_denom = tff.federated_map(_cast_weight_to_float, client_outputs.weights_delta_weight) new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) server_state = tff.federated_map( tf_server_update, (server_state, round_model_delta, new_delta_aggregate_state, new_broadcaster_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.NamedTupleType): # Promote the FederatedType outside the NamedTupleType. aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs return tff.templates.IterativeProcess(initialize_fn=server_init_tff, next_fn=run_one_round_tff)
def build_model_delta_optimizer_process( model_fn: _ModelConstructor, model_to_client_delta_fn: Callable[[model_lib.Model], ClientDeltaFn], server_optimizer_fn: _OptimizerConstructor, stateful_delta_aggregate_fn: Optional[ tff.utils.StatefulAggregateFn] = None, stateful_model_broadcast_fn: Optional[ tff.utils.StatefulBroadcastFn] = None, *, broadcast_process: Optional[tff.templates.MeasuredProcess] = None, aggregation_process: Optional[tff.templates.MeasuredProcess] = None, ) -> tff.templates.IterativeProcess: """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD. This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a `ClientDeltaFn` that specifies how `weight_deltas` are computed on device. Note: We pass in functions rather than constructed objects so we can ensure any variables or ops created in constructors are placed in the correct graph. Args: model_fn: a no-arg function that returns a `tff.learning.Model`. model_to_client_delta_fn: a function from a `model_fn` to a `ClientDeltaFn`. server_optimizer_fn: a no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. stateful_delta_aggregate_fn: a `tff.utils.StatefulAggregateFn` where the `next_fn` performs a federated aggregation and updates state. That is, it has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER)`, where the `value` type is `tff.learning.framework.ModelWeights.trainable` corresponding to the object returned by `model_fn`. stateful_model_broadcast_fn: a `tff.utils.StatefulBroadcastFn` where the `next_fn` performs a federated broadcast and updates state. That is, it has TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`, where the `value` type is `tff.learning.framework.ModelWeights` corresponding to the object returned by `model_fn`. broadcast_process: a `tff.templates.MeasuredProcess` that broadcasts the model weights on the server to the clients. It must support the signature `(input_values@SERVER -> output_values@CLIENT)`. aggregation_process: a `tff.templates.MeasuredProcess` that aggregates the model updates on the clients back to the server. It must support the signature `({input_values}@CLIENTS-> output_values@SERVER)`. Returns: A `tff.templates.IterativeProcess`. Raises: ProcessTypeError: if `broadcast_process` or `aggregation_process` do not conform to the signature of broadcast (SERVER->CLIENTS) or aggregation (CLIENTS->SERVER). """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(model_to_client_delta_fn) py_typecheck.check_callable(server_optimizer_fn) with tf.Graph().as_default(): dummy_model_for_metadata = model_fn() model_weights_type = tff.framework.type_from_tensors( model_utils.ModelWeights.from_model(dummy_model_for_metadata)) # TODO(b/159138779): remove the StatefulFn arguments and these validation # functions once all callers are migrated. def validate_disjoint_optional_arguments( stateful_fn: Optional[Union[tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn]], process: Optional[tff.templates.MeasuredProcess], process_input_type: Union[tff.NamedTupleType, tff.TensorType], ) -> Optional[tff.templates.MeasuredProcess]: """Validate that only one of two arguments is specified. This validates that only the `tff.templates.MeasuredProcess` or the `tff.utils.StatefulFn` is specified, and converts the `tff.utils.StatefulFn` to a `tff.templates.MeasuredProcess` if possible. This a bridge while transition to `tff.templates.MeasuredProcess`. Args: stateful_fn: an optional `tff.utils.StatefulFn` that will be wrapped if specified. process: an optional `tff.templates.MeasuredProcess` that will be returned as-is. process_input_type: the input type used when wrapping `stateful_fn`. Returns: `None` if neither argument is specified, otherwise a `tff.templates.MeasuredProcess`. Raises: DisjointArgumentError: if both `stateful_fn` and `process` are not `None`. """ if stateful_fn is not None: py_typecheck.check_type( stateful_fn, (tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn)) if process is not None: raise DisjointArgumentError( 'Specifying both arguments is an error. Only one may be used' ) return _wrap_in_measured_process(stateful_fn, input_type=process_input_type) return process try: aggregation_process = validate_disjoint_optional_arguments( stateful_delta_aggregate_fn, aggregation_process, model_weights_type.trainable) except DisjointArgumentError as e: raise DisjointArgumentError( 'Specifying both `stateful_delta_aggregate_fn` and ' '`aggregation_process` is an error. Only one may be used') from e try: broadcast_process = validate_disjoint_optional_arguments( stateful_model_broadcast_fn, broadcast_process, model_weights_type) except DisjointArgumentError as e: raise DisjointArgumentError( 'Specifying both `stateful_model_broadcast_fn` and ' '`broadcast_process` is an error. Only one may be used') from e if broadcast_process is None: broadcast_process = build_stateless_broadcaster( model_weights_type=model_weights_type) if not _is_valid_broadcast_process(broadcast_process): raise ProcessTypeError( 'broadcast_process type signature does not conform to expected ' 'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).' ' Got: {t}'.format(t=broadcast_process.next.type_signature)) if aggregation_process is None: aggregation_process = build_stateless_mean( model_delta_type=model_weights_type.trainable) if not _is_valid_aggregation_process(aggregation_process): raise ProcessTypeError( 'aggregation_process type signature does not conform to expected ' 'signature (<state@S, input@C> -> <state@S, result@S, measurements@S>).' ' Got: {t}'.format(t=aggregation_process.next.type_signature)) initialize_computation = _build_initialize_computation( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, broadcast_process=broadcast_process, aggregation_process=aggregation_process) run_one_round_computation = _build_one_round_computation( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, model_to_client_delta_fn=model_to_client_delta_fn, broadcast_process=broadcast_process, aggregation_process=aggregation_process) return tff.templates.IterativeProcess(initialize_fn=initialize_computation, next_fn=run_one_round_computation)
def build_model_delta_optimizer_process( model_fn: _ModelConstructor, model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]], ClientDeltaFn], server_optimizer_fn: _OptimizerConstructor, *, broadcast_process: Optional[measured_process.MeasuredProcess] = None, aggregation_process: Optional[measured_process.MeasuredProcess] = None, model_update_aggregation_factory: Optional[ factory.AggregationFactory] = None, ) -> iterative_process.IterativeProcess: """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD. This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a `ClientDeltaFn` that specifies how `weight_deltas` are computed on device. Note: We pass in functions rather than constructed objects so we can ensure any variables or ops created in constructors are placed in the correct graph. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`. server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the model weights on the server to the clients. It must support the signature `(input_values@SERVER -> output_values@CLIENT)`. aggregation_process: A `tff.templates.MeasuredProcess` that aggregates the model updates on the clients back to the server. It must support the signature `({input_values}@CLIENTS-> output_values@SERVER)`. Must be `None` if `model_update_aggregation_factory` is not `None.` model_update_aggregation_factory: An optional `tff.aggregators.WeightedAggregationFactory` that contstructs `tff.templates.AggregationProcess` for aggregating the client model updates on the server. If `None`, uses a default constructed `tff.aggregators.MeanFactory`, creating a stateless mean aggregation. Must be `None` if `aggregation_process` is not `None.` Returns: A `tff.templates.IterativeProcess`. Raises: ProcessTypeError: if `broadcast_process` or `aggregation_process` do not conform to the signature of broadcast (SERVER->CLIENTS) or aggregation (CLIENTS->SERVER). DisjointArgumentError: if both `aggregation_process` and `model_update_aggregation_factory` are not `None`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(model_to_client_delta_fn) py_typecheck.check_callable(server_optimizer_fn) model_weights_type = model_utils.weights_type_from_model(model_fn) if broadcast_process is None: broadcast_process = build_stateless_broadcaster( model_weights_type=model_weights_type) if not _is_valid_broadcast_process(broadcast_process): raise ProcessTypeError( 'broadcast_process type signature does not conform to expected ' 'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).' ' Got: {t}'.format(t=broadcast_process.next.type_signature)) if (model_update_aggregation_factory is not None and aggregation_process is not None): raise DisjointArgumentError( 'Must specify only one of `model_update_aggregation_factory` and ' '`AggregationProcess`.') if aggregation_process is None: if model_update_aggregation_factory is None: model_update_aggregation_factory = mean_factory.MeanFactory() py_typecheck.check_type(model_update_aggregation_factory, factory.AggregationFactory.__args__) if isinstance(model_update_aggregation_factory, factory.WeightedAggregationFactory): aggregation_process = model_update_aggregation_factory.create( model_weights_type.trainable, computation_types.TensorType(tf.float32)) else: aggregation_process = model_update_aggregation_factory.create( model_weights_type.trainable) else: next_num_args = len(aggregation_process.next.type_signature.parameter) if next_num_args not in [2, 3]: raise ValueError( f'`next` function of `aggregation_process` must take two (for ' f'unweighted aggregation) or three (for weighted aggregation) ' f'arguments. Found {next_num_args}.') if not _is_valid_aggregation_process(aggregation_process): raise ProcessTypeError( 'aggregation_process type signature does not conform to expected ' 'signature (<state@S, input@C> -> <state@S, result@S, measurements@S>).' ' Got: {t}'.format(t=aggregation_process.next.type_signature)) initialize_computation = _build_initialize_computation( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, broadcast_process=broadcast_process, aggregation_process=aggregation_process) run_one_round_computation = _build_one_round_computation( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, model_to_client_delta_fn=model_to_client_delta_fn, broadcast_process=broadcast_process, aggregation_process=aggregation_process) return iterative_process.IterativeProcess( initialize_fn=initialize_computation, next_fn=run_one_round_computation)
def to_representation_for_type(value, type_spec, callable_handler=None): """Verifies or converts the `value` representation to match `type_spec`. This method first tries to determine whether `value` is a valid representation of TFF type `type_spec`. If so, it is returned unchanged. If not, but if it can be converted into a valid representation, it is converted to such, and the valid representation is returned. If no conversion to a valid representation is possible, TypeError is raised. The accepted forms of `value` for various TFF types are as follows: * For TFF tensor types listed in `tensorflow_utils.TENSOR_REPRESENTATION_TYPES`. * For TFF named tuple types, instances of `anonymous_tuple.AnonymousTuple`. * For TFF sequences, Python lists. * For TFF functional types, Python callables that accept a single argument that is an instance of `ComputedValue` (if the function has a parameter) or `None` (otherwise), and return a `ComputedValue` instance as a result. This function only verifies that `value` is a callable. * For TFF abstract types, there is no valid representation. The reference executor requires all types in an executable computation to be concrete. * For TFF placement types, the valid representations are the placement literals (currently only `tff.SERVER` and `tff.CLIENTS`). * For TFF federated types with `all_equal` set to `True`, the representation is the same as the representation of the member constituent (thus, e.g., a valid representation of `int32@SERVER` is the same as that of `int32`). For those types that have `all_equal_` set to `False`, the representation is a Python list of member constituents. NOTE: This function does not attempt at validating that the sizes of lists that represent federated values match the corresponding placemenets. The cardinality analysis is a separate step, handled by the reference executor at a different point. As long as values can be packed into a Python list, they are accepted as they are. Args: value: The raw representation of a value to compare against `type_spec` and potentially to be converted into a canonical form for the given TFF type. type_spec: The TFF type, an instance of `tff.Type` or something convertible to it that determines what the valid representation should be. callable_handler: The function to invoke to handle TFF functional types. If this is `None`, functional types are not supported. The function must accept `value` and `type_spec` as arguments and return the converted valid representation, just as `to_representation_for_type`. Returns: Either `value` itself, or the `value` converted into a valid representation for `type_spec`. Raises: TypeError: If `value` is not a valid representation for given `type_spec`. NotImplementedError: If verification for `type_spec` is not supported. """ type_spec = computation_types.to_type(type_spec) py_typecheck.check_type(type_spec, computation_types.Type) if callable_handler is not None: py_typecheck.check_callable(callable_handler) # NOTE: We do not simply call `type_utils.infer_type()` on `value`, as the # representations of values in the reference executor are only a subset of # the Python types recognized by that helper function. if isinstance(type_spec, computation_types.TensorType): if tf.executing_eagerly() and isinstance(value, (tf.Tensor, tf.Variable)): value = value.numpy() py_typecheck.check_type(value, tensorflow_utils.TENSOR_REPRESENTATION_TYPES) inferred_type_spec = type_utils.infer_type(value) if not type_utils.is_assignable_from(type_spec, inferred_type_spec): raise TypeError( 'The tensor type {} of the value representation does not match ' 'the type spec {}.'.format(inferred_type_spec, type_spec)) return value elif isinstance(type_spec, computation_types.NamedTupleType): type_spec_elements = anonymous_tuple.to_elements(type_spec) # Special-casing unodered dictionaries to allow their elements to be fed in # the order in which they're defined in the named tuple type. if (isinstance(value, dict) and (set(value.keys()) == set(k for k, _ in type_spec_elements))): value = collections.OrderedDict([ (k, value[k]) for k, _ in type_spec_elements ]) value = anonymous_tuple.from_container(value) value_elements = anonymous_tuple.to_elements(value) if len(value_elements) != len(type_spec_elements): raise TypeError( 'The number of elements {} in the value tuple {} does not match the ' 'number of elements {} in the type spec {}.'.format( len(value_elements), value, len(type_spec_elements), type_spec)) result_elements = [] for index, (type_elem_name, type_elem) in enumerate(type_spec_elements): value_elem_name, value_elem = value_elements[index] if value_elem_name not in [type_elem_name, None]: raise TypeError( 'Found element named `{}` where `{}` was expected at position {} ' 'in the value tuple. Value: {}. Type: {}'.format( value_elem_name, type_elem_name, index, value, type_spec)) converted_value_elem = to_representation_for_type(value_elem, type_elem, callable_handler) result_elements.append((type_elem_name, converted_value_elem)) return anonymous_tuple.AnonymousTuple(result_elements) elif isinstance(type_spec, computation_types.SequenceType): if isinstance(value, tf.data.Dataset): inferred_type_spec = computation_types.SequenceType( computation_types.to_type(tf.data.experimental.get_structure(value))) if not type_utils.is_assignable_from(type_spec, inferred_type_spec): raise TypeError( 'Value of type {!s} not assignable to expected type {!s}'.format( inferred_type_spec, type_spec)) if tf.executing_eagerly(): return [ to_representation_for_type(v, type_spec.element, callable_handler) for v in value ] else: raise ValueError( 'Processing `tf.data.Datasets` outside of eager mode is not ' 'currently supported.') return [ to_representation_for_type(v, type_spec.element, callable_handler) for v in value ] elif isinstance(type_spec, computation_types.FunctionType): if callable_handler is not None: return callable_handler(value, type_spec) else: raise TypeError( 'Values that are callables have been explicitly disallowed ' 'in this context. If you would like to supply here a function ' 'as a parameter, please construct a computation that contains ' 'this call.') elif isinstance(type_spec, computation_types.AbstractType): raise TypeError( 'Abstract types are not supported by the reference executor.') elif isinstance(type_spec, computation_types.PlacementType): py_typecheck.check_type(value, placement_literals.PlacementLiteral) return value elif isinstance(type_spec, computation_types.FederatedType): if type_spec.all_equal: return to_representation_for_type(value, type_spec.member, callable_handler) elif type_spec.placement is not placements.CLIENTS: raise TypeError( 'Unable to determine a valid value representation for a federated ' 'type with non-equal members placed at {}.'.format( type_spec.placement)) elif not isinstance(value, (list, tuple)): raise ValueError('Please pass a list or tuple to any function that' ' expects a federated type placed at {};' ' you passed {}'.format(type_spec.placement, value)) else: return [ to_representation_for_type(v, type_spec.member, callable_handler) for v in value ] else: raise NotImplementedError( 'Unable to determine valid value representation for {} for what ' 'is currently an unsupported TFF type {}.'.format(value, type_spec))