def test_construction_with_aggregation_process(self): with tf.Graph().as_default(): model_update_type = tff.framework.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression()).trainable) aggregation_process = _build_test_measured_mean(model_update_type) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, aggregation_process=aggregation_process) aggregation_state_type = aggregation_process.initialize.type_signature.result initialize_type = iterative_process.initialize.type_signature self.assertEqual( tff.FederatedType( initialize_type.result.member.delta_aggregate_state, tff.SERVER), aggregation_state_type) next_type = iterative_process.next.type_signature self.assertEqual( tff.FederatedType( next_type.parameter[0].member.delta_aggregate_state, tff.SERVER), aggregation_state_type) self.assertEqual( tff.FederatedType(next_type.result[0].member.delta_aggregate_state, tff.SERVER), aggregation_state_type) aggregation_metrics_type = aggregation_process.next.type_signature.result.measurements self.assertEqual( tff.FederatedType(next_type.result[1].member.aggregation, tff.SERVER), aggregation_metrics_type)
def test_construction_with_broadcast_process(self): with tf.Graph().as_default(): model_weights_type = tff.framework.type_from_tensors( model_utils.ModelWeights.from_model( model_examples.LinearRegression())) broadcast_process = _build_test_measured_broadcast(model_weights_type) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, broadcast_process=broadcast_process) expected_broadcast_state_type = broadcast_process.initialize.type_signature.result initialize_type = iterative_process.initialize.type_signature self.assertEqual( tff.FederatedType( initialize_type.result.member.model_broadcast_state, tff.SERVER), expected_broadcast_state_type) next_type = iterative_process.next.type_signature self.assertEqual( tff.FederatedType( next_type.parameter[0].member.model_broadcast_state, tff.SERVER), expected_broadcast_state_type) self.assertEqual( tff.FederatedType(next_type.result[0].member.model_broadcast_state, tff.SERVER), expected_broadcast_state_type)
def build_federated_evaluation(model_fn): """Builds the TFF computation for federated evaluation of the given model. Args: model_fn: A no-argument function that returns a `tff.learning.Model`. Returns: A federated computation (an instance of `tff.Computation`) that accepts model parameters and federated data, and returns the evaluation metrics as aggregated by `tff.learning.Model.federated_output_computation`. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. # TODO(b/124477628): Ideally replace the need for stamping throwaway models # with some other mechanism. with tf.Graph().as_default(): model = model_utils.enhance(model_fn()) model_weights_type = tff.to_type( tf.nest.map_structure( lambda v: tff.TensorType(v.dtype.base_dtype, v.shape), model.weights)) batch_type = tff.to_type(model.input_spec) @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type)) def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" model = model_utils.enhance(model_fn()) # TODO(b/124477598): Remove dummy when b/121400757 has been fixed. @tf.function def reduce_fn(dummy, batch): model_output = model.forward_pass(batch, training=False) return dummy + tf.cast(model_output.loss, tf.float64) # TODO(b/123898430): The control dependencies below have been inserted as a # temporary workaround. These control dependencies need to be removed, and # defuns and datasets supported together fully. with tf.control_dependencies( [tff.utils.assign(model.weights, incoming_model_weights)]): dummy = dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn) with tf.control_dependencies([dummy]): return collections.OrderedDict([ ('local_outputs', model.report_local_outputs()), ('workaround for b/121400757', dummy) ]) @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS)) def server_eval(server_model_weights, federated_dataset): client_outputs = tff.federated_map( client_eval, [tff.federated_broadcast(server_model_weights), federated_dataset]) return model.federated_output_computation(client_outputs.local_outputs) return server_eval
def build_federated_evaluation(model_fn): """Builds the TFF computation for federated evaluation of the given model. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. Returns: A federated computation (an instance of `tff.Computation`) that accepts model parameters and federated data, and returns the evaluation metrics as aggregated by `tff.learning.Model.federated_output_computation`. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. # TODO(b/124477628): Ideally replace the need for stamping throwaway models # with some other mechanism. with tf.Graph().as_default(): model = model_utils.enhance(model_fn()) model_weights_type = tff.framework.type_from_tensors(model.weights) batch_type = tff.to_type(model.input_spec) @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type)) def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" model = model_utils.enhance(model_fn()) @tf.function def _tf_client_eval(incoming_model_weights, dataset): """Evaluation TF work.""" tff.utils.assign(model.weights, incoming_model_weights) def reduce_fn(prev_loss, batch): model_output = model.forward_pass(batch, training=False) return prev_loss + tf.cast(model_output.loss, tf.float64) dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn) return collections.OrderedDict([('local_outputs', model.report_local_outputs())]) return _tf_client_eval(incoming_model_weights, dataset) @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS)) def server_eval(server_model_weights, federated_dataset): client_outputs = tff.federated_map( client_eval, [tff.federated_broadcast(server_model_weights), federated_dataset]) return model.federated_output_computation(client_outputs.local_outputs) return server_eval
def test_mutates_iterproc_accepting_dataset_in_second_index_of_next(self): iterproc = _create_stateless_int_dataset_reduction_iterative_process() expected_new_next_type_signature = tff.FunctionType([ tff.FederatedType(tf.int64, tff.SERVER), tff.FederatedType(tf.string, tff.CLIENTS) ], tff.FederatedType(tf.int64, tff.SERVER)) new_iterproc = iterative_process_compositions.compose_dataset_computation( int_dataset_computation, iterproc) self.assertTrue( expected_new_next_type_signature.is_equivalent_to( new_iterproc.next.type_signature))
def test_orchestration_type_signature(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.TrainableLinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0 )) expected_model_weights_type = model_utils.ModelWeights( collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])), ('b', tf.float32)]), collections.OrderedDict([('c', tf.float32)])) # ServerState consists of a model and optimizer_state. The optimizer_state # is provided by TensorFlow, TFF doesn't care what the actual value is. expected_federated_server_state_type = tff.FederatedType( optimizer_utils.ServerState(expected_model_weights_type, test.AnyType(), test.AnyType(), test.AnyType()), placement=tff.SERVER, all_equal=True) expected_federated_dataset_type = tff.FederatedType(tff.SequenceType( model_examples.TrainableLinearRegression().input_spec), tff.CLIENTS, all_equal=False) expected_model_output_types = tff.FederatedType( collections.OrderedDict([ ('loss', tff.TensorType(tf.float32, [])), ('num_examples', tff.TensorType(tf.int32, [])), ]), tff.SERVER, all_equal=True) # `initialize` is expected to be a funcion of no arguments to a ServerState. self.assertEqual( tff.FunctionType(parameter=None, result=expected_federated_server_state_type), iterative_process.initialize.type_signature) # `next` is expected be a function of (ServerState, Datasets) to # ServerState. self.assertEqual( tff.FunctionType(parameter=[ expected_federated_server_state_type, expected_federated_dataset_type ], result=(expected_federated_server_state_type, expected_model_output_types)), iterative_process.next.type_signature)
def build_stateless_robust_aggregation(model_type, num_communication_passes=5, tolerance=1e-6): """Create TFF function for robust aggregation. The robust aggregate is an approximate geometric median computed via the smoothed Weiszfeld algorithm. Args: model_type: tff typespec of quantity to be aggregated. num_communication_passes: number of communication rounds in the smoothed Weiszfeld algorithm (min. 1). tolerance: smoothing parameter of smoothed Weiszfeld algorithm. Default 1e-6. Returns: An instance of `tff.utils.StatefulAggregateFn` which implements a (stateless) robust aggregate. """ py_typecheck.check_type(num_communication_passes, int) if num_communication_passes < 1: raise ValueError('Aggregation requires num_communication_passes >= 1') # client weights have been hardcoded as float32, this needs to be # parameterized. @tff.tf_computation(tf.float32, model_type, model_type) def update_weight_fn(weight, server_model, client_model): sqnorms = tf.nest.map_structure(lambda a, b: tf.norm(a - b)**2, server_model, client_model) sqnorm = tf.reduce_sum(list(six.itervalues(sqnorms))) return weight / tf.math.maximum(tolerance, tf.math.sqrt(sqnorm)) client_model_type = tff.FederatedType(model_type, tff.CLIENTS) client_weight_type = tff.FederatedType(tf.float32, tff.CLIENTS) @tff.federated_computation(client_model_type, client_weight_type) def robust_aggregation_fn(value, weight): aggregate = tff.federated_mean(value, weight=weight) for _ in range(num_communication_passes - 1): aggregate_at_client = tff.federated_broadcast(aggregate) updated_weight = tff.federated_map( update_weight_fn, (weight, aggregate_at_client, value)) aggregate = tff.federated_mean(value, weight=updated_weight) return aggregate def _stateless_next(state, value, weight): return state, robust_aggregation_fn(value, weight) return tff.utils.StatefulAggregateFn(initialize_fn=lambda: (), next_fn=_stateless_next)
def test_orchestration_typecheck(self): iterative_process = federated_sgd.build_federated_sgd_process( model_fn=model_examples.LinearRegression) expected_model_weights_type = model_utils.ModelWeights( collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])), ('b', tf.float32)]), collections.OrderedDict([('c', tf.float32)])) # ServerState consists of a model and optimizer_state. The optimizer_state # is provided by TensorFlow, TFF doesn't care what the actual value is. expected_federated_server_state_type = tff.FederatedType( optimizer_utils.ServerState(expected_model_weights_type, test.AnyType()), placement=tff.SERVER, all_equal=True) expected_federated_dataset_type = tff.FederatedType( tff.SequenceType( model_examples.LinearRegression.make_batch( tff.TensorType(tf.float32, [None, 2]), tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS, all_equal=False) expected_model_output_types = tff.FederatedType( collections.OrderedDict([ ('loss', tff.TensorType(tf.float32, [])), ('num_examples', tff.TensorType(tf.int32, [])), ]), tff.SERVER, all_equal=True) # `initialize` is expected to be a funcion of no arguments to a ServerState. self.assertEqual( tff.FunctionType( parameter=None, result=expected_federated_server_state_type), iterative_process.initialize.type_signature) # `next` is expected be a function of (ServerState, Datasets) to # ServerState. self.assertEqual( tff.FunctionType( parameter=[ expected_federated_server_state_type, expected_federated_dataset_type ], result=(expected_federated_server_state_type, expected_model_output_types)), iterative_process.next.type_signature)
def _wrap_in_measured_process( stateful_fn: Union[tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn], input_type: tff.Type) -> tff.templates.MeasuredProcess: """Converts a `tff.utils.StatefulFn` to a `tff.templates.MeasuredProcess`.""" py_typecheck.check_type( stateful_fn, (tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn)) @tff.federated_computation() def initialize_comp(): if not isinstance(stateful_fn.initialize, tff.Computation): initialize = tff.tf_computation(stateful_fn.initialize) else: initialize = stateful_fn.initialize return tff.federated_eval(initialize, tff.SERVER) state_type = initialize_comp.type_signature.result if isinstance(stateful_fn, tff.utils.StatefulBroadcastFn): @tff.federated_computation(state_type, tff.FederatedType(input_type, tff.SERVER)) def next_comp(state, value): empty_metrics = tff.federated_value((), tff.SERVER) state, result = stateful_fn(state, value) return collections.OrderedDict(state=state, result=result, measurements=empty_metrics) elif isinstance(stateful_fn, tff.utils.StatefulAggregateFn): @tff.federated_computation(state_type, tff.FederatedType(input_type, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS)) def next_comp(state, value, weight): empty_metrics = tff.federated_value((), tff.SERVER) state, result = stateful_fn(state, value, weight) return collections.OrderedDict(state=state, result=result, measurements=empty_metrics) else: raise TypeError( 'Received a {t}, expected either a tff.utils.StatefulAggregateFn or a ' 'tff.utils.StatefulBroadcastFn.'.format(t=type(stateful_fn))) return tff.templates.MeasuredProcess(initialize_fn=initialize_comp, next_fn=next_comp)
def build_stateless_mean( *, model_delta_type: Union[tff.NamedTupleType, tff.TensorType] ) -> tff.templates.MeasuredProcess: """Builds a `MeasuredProcess` that wraps` tff.federated_mean`.""" @tff.federated_computation(NONE_SERVER_TYPE, tff.FederatedType(model_delta_type, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS)) def stateless_mean(state, value, weight): empty_metrics = tff.federated_value((), tff.SERVER) return collections.OrderedDict( state=state, result=tff.federated_mean(value, weight=weight), measurements=empty_metrics) return tff.templates.MeasuredProcess( initialize_fn=_empty_server_initialization, next_fn=stateless_mean)
def test_construction(self): iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD) server_state_type = tff.FederatedType( optimizer_utils.ServerState( model=model_utils.ModelWeights( trainable=[ tff.TensorType(tf.float32, [2, 1]), tff.TensorType(tf.float32) ], non_trainable=[tff.TensorType(tf.float32)]), optimizer_state=[tf.int64], delta_aggregate_state=(), model_broadcast_state=()), tff.SERVER) self.assertEqual( str(iterative_process.initialize.type_signature), str(tff.FunctionType(parameter=None, result=server_state_type))) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 2]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict( broadcast=(), aggregation=(), train=collections.OrderedDict( loss=tff.TensorType(tf.float32), num_examples=tff.TensorType(tf.int32))), tff.SERVER) self.assertEqual( str(iterative_process.next.type_signature), str( tff.FunctionType( parameter=(server_state_type, dataset_type), result=(server_state_type, metrics_type))))
def _build_test_measured_mean( model_update_type: tff.NamedTupleType ) -> tff.templates.MeasuredProcess: """Builds a test `MeasuredProcess` that has state and metrics.""" @tff.federated_computation() def initialize_comp(): return tff.federated_value(0, tff.SERVER) @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER), tff.FederatedType(model_update_type, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS)) def next_comp(state, value, weight): return collections.OrderedDict( state=tff.federated_map(_add_one, state), result=tff.federated_mean(value, weight), measurements=tff.federated_zip( collections.OrderedDict(num_clients=tff.federated_sum( tff.federated_value(1, tff.CLIENTS))))) return tff.templates.MeasuredProcess(initialize_fn=initialize_comp, next_fn=next_comp)
def _build_test_measured_broadcast( model_weights_type: tff.NamedTupleType ) -> tff.templates.MeasuredProcess: """Builds a test `MeasuredProcess` that has state and metrics.""" @tff.federated_computation() def initialize_comp(): return tff.federated_value(0, tff.SERVER) @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER), tff.FederatedType(model_weights_type, tff.SERVER)) def next_comp(state, value): return collections.OrderedDict( state=tff.federated_map(_add_one, state), result=tff.federated_broadcast(value), # Arbitrary metrics for testing. measurements=tff.federated_map( tff.tf_computation( lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0), value)) return tff.templates.MeasuredProcess(initialize_fn=initialize_comp, next_fn=next_comp)
def test_returns_iterproc_accepting_dataset_in_third_index_of_next(self): iterproc = _create_stateless_int_dataset_reduction_iterative_process() old_param_type = iterproc.next.type_signature.parameter new_param_elements = [old_param_type[0], tf.int32, old_param_type[1]] @tff.federated_computation(tff.StructType(new_param_elements)) def new_next(param): return iterproc.next([param[0], param[2]]) iterproc_with_dataset_as_third_elem = tff.templates.IterativeProcess( iterproc.initialize, new_next) expected_new_next_type_signature = tff.FunctionType([ tff.FederatedType(tf.int64, tff.SERVER), tf.int32, tff.FederatedType(tf.string, tff.CLIENTS) ], tff.FederatedType(tf.int64, tff.SERVER)) new_iterproc = iterative_process_compositions.compose_dataset_computation( int_dataset_computation, iterproc_with_dataset_as_third_elem) self.assertTrue( expected_new_next_type_signature.is_equivalent_to( new_iterproc.next.type_signature))
def federated_output_computation(self): metric_variable_type_dict = nest.map_structure( tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType( metric_variable_type_dict, tff.CLIENTS, all_equal=False) @tff.federated_computation(federated_local_outputs_type) def federated_output(local_outputs): results = collections.OrderedDict() for metric, variables in zip(self.get_metrics(), local_outputs): results[metric.name] = federated_aggregate_keras_metric( type(metric), metric.get_config(), variables) return results return federated_output
def build_stateless_broadcaster( *, model_weights_type: Union[tff.NamedTupleType, tff.TensorType] ) -> tff.templates.MeasuredProcess: """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`.""" @tff.federated_computation(NONE_SERVER_TYPE, tff.FederatedType(model_weights_type, tff.SERVER)) def stateless_broadcast(state, value): empty_metrics = tff.federated_value((), tff.SERVER) return collections.OrderedDict( state=state, result=tff.federated_broadcast(value), measurements=empty_metrics) return tff.templates.MeasuredProcess( initialize_fn=_empty_server_initialization, next_fn=stateless_broadcast)
def __init__(self, keras_model: tf.keras.Model, input_spec, loss_fns: List[tf.keras.losses.Loss], loss_weights: List[float], metrics: List[tf.keras.metrics.Metric]): self._keras_model = keras_model self._input_spec = input_spec self._loss_fns = loss_fns self._loss_weights = loss_weights self._metrics = metrics # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(tf.keras.metrics.Mean): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super().__init__(name, dtype) self._loss_fns = loss_fns self._loss_weights = loss_weights def update_state(self, y_true, y_pred, sample_weight=None): if len(self._loss_fns) == 1: batch_size = tf.shape(y_pred)[0] batch_loss = self._loss_fns[0](y_true, y_pred) else: batch_size = tf.shape(y_pred[0])[0] batch_loss = tf.zeros(()) for i in range(len(self._loss_fns)): batch_loss += self._loss_weights[i] * self._loss_fns[ i](y_true[i], y_pred[i]) return super().update_state(batch_loss, batch_size) self._loss_metric = _WeightedMeanLossMetric() metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType( metric_variable_type_dict, tff.CLIENTS) def federated_output(local_outputs): return federated_aggregate_keras_metric(self.get_metrics(), local_outputs) self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type)
def _create_stateless_int_dataset_reduction_iterative_process(): @tff.tf_computation() def make_zero(): return tf.cast(0, tf.int64) @tff.federated_computation() def init(): return tff.federated_eval(make_zero, tff.SERVER) @tff.tf_computation(tff.SequenceType(tf.int64)) def reduce_dataset(x): return x.reduce(tf.cast(0, tf.int64), lambda x, y: x + y) @tff.federated_computation((init.type_signature.result, tff.FederatedType(tff.SequenceType(tf.int64), tff.CLIENTS))) def next_fn(empty_tup, x): del empty_tup # Unused return tff.federated_sum(tff.federated_map(reduce_dataset, x)) return tff.templates.IterativeProcess(initialize_fn=init, next_fn=next_fn)
def __init__(self, inner_model, dummy_batch, loss_fns, loss_weights=None, metrics=None): # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized # variables until they are called on input. We forced that here. if isinstance(dummy_batch, collections.Mapping): inner_model(dummy_batch['x']) else: inner_model(dummy_batch[0]) def _tensor_spec_with_undefined_batch_dim(tensor): # Remove the batch dimension and leave it unspecified. spec = tf.TensorSpec(shape=[None] + tensor.shape.dims[1:], dtype=tensor.dtype) return spec self._input_spec = tf.nest.map_structure( _tensor_spec_with_undefined_batch_dim, dummy_batch) self._keras_model = inner_model self._loss_fns = loss_fns if isinstance(loss_weights, collections.Mapping): self._loss_weights = [] for name in inner_model.output_names: if name not in loss_weights: raise KeyError( 'Output missing from loss_weights dictionary' '\nloss_weights: {}\noutputs: {}'.format( list(loss_weights.keys()), inner_model.output_names)) else: self._loss_weights.append(loss_weights[name]) else: if loss_weights is None: self._loss_weights = [1.0 for _ in range(len(loss_fns))] else: self._loss_weights = loss_weights loss_weights = self._loss_weights self._metrics = metrics if metrics is not None else [] # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(tf.keras.metrics.Mean): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super(_WeightedMeanLossMetric, self).__init__(name, dtype) self._loss_fns = loss_fns self._loss_weights = loss_weights def update_state(self, y_true, y_pred, sample_weight=None): if len(self._loss_fns) == 1: batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype) y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) batch_loss = self._loss_fns[0](y_true, y_pred) else: batch_loss = tf.zeros(()) for i in range(len(self._loss_fns)): y_t = tf.cast(y_true[i], self._dtype) y_p = tf.cast(y_pred[i], self._dtype) batch_loss += self._loss_weights[i] * self._loss_fns[ i](y_t, y_p) batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype) return super(_WeightedMeanLossMetric, self).update_state(batch_loss, batch_size) class _TrainingTimeHistory(tf.keras.metrics.Sum): def update_state(self, y_true, y_pred, sample_weight=None): pass def log_time(self, time_value): return super(_TrainingTimeHistory, self).update_state(values=time_value) self._loss_metric = _WeightedMeanLossMetric() self._training_timing = _TrainingTimeHistory(name='training_time_sec') metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType( metric_variable_type_dict, tff.CLIENTS) def federated_output(local_outputs): results = collections.OrderedDict() for metric, variables in zip(self.get_metrics(), local_outputs): results[metric.name] = federated_aggregate_keras_metric( type(metric), metric.get_config(), variables) return results self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type) # Keras creates variables that are not added to any collection, making it # impossible for TFF to extract them and create the appropriate initializer # before call a tff.Computation. Here we store them in a TFF specific # collection so that they can be retrieved later. # TODO(b/122081673): this likely goes away in TF2.0 for variable in itertools.chain(self.trainable_variables, self.non_trainable_variables, self.local_variables): tf.compat.v1.add_to_collection( graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
Args: process: A measured process to validate. Returns: `True` iff the process is a validate aggregation process, otherwise `False`. """ next_type = process.next.type_signature return (isinstance(process, tff.templates.MeasuredProcess) and _is_valid_stateful_process(process) and next_type.parameter[1].placement is tff.CLIENTS and next_type.result.result.placement is tff.SERVER) # ============================================================================ NONE_SERVER_TYPE = tff.FederatedType((), tff.SERVER) def _wrap_in_measured_process( stateful_fn: Union[tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn], input_type: tff.Type) -> tff.templates.MeasuredProcess: """Converts a `tff.utils.StatefulFn` to a `tff.templates.MeasuredProcess`.""" py_typecheck.check_type( stateful_fn, (tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn)) @tff.federated_computation() def initialize_comp(): if not isinstance(stateful_fn.initialize, tff.Computation): initialize = tff.tf_computation(stateful_fn.initialize)
def benchmark_fc_api_mnist(self): """Code adapted from FC API tutorial ipynb.""" n_rounds = 10 batch_type = tff.NamedTupleType([ ("x", tff.TensorType(tf.float32, [None, 784])), ("y", tff.TensorType(tf.int32, [None])) ]) model_type = tff.NamedTupleType([ ("weights", tff.TensorType(tf.float32, [784, 10])), ("bias", tff.TensorType(tf.float32, [10])) ]) local_data_type = tff.SequenceType(batch_type) server_model_type = tff.FederatedType(model_type, tff.SERVER, all_equal=True) client_data_type = tff.FederatedType(local_data_type, tff.CLIENTS) server_float_type = tff.FederatedType(tf.float32, tff.SERVER, all_equal=True) computation_building_start = time.time() # pylint: disable=missing-docstring @tff.tf_computation(model_type, batch_type) def batch_loss(model, batch): predicted_y = tf.nn.softmax( tf.matmul(batch.x, model.weights) + model.bias) return -tf.reduce_mean( tf.reduce_sum(tf.one_hot(batch.y, 10) * tf.log(predicted_y), reduction_indices=[1])) initial_model = { "weights": np.zeros([784, 10], dtype=np.float32), "bias": np.zeros([10], dtype=np.float32) } @tff.tf_computation(model_type, batch_type, tf.float32) def batch_train(initial_model, batch, learning_rate): model_vars = tff.utils.get_variables("v", model_type) init_model = tff.utils.assign(model_vars, initial_model) optimizer = tf.train.GradientDescentOptimizer(learning_rate) with tf.control_dependencies([init_model]): train_model = optimizer.minimize(batch_loss(model_vars, batch)) with tf.control_dependencies([train_model]): return tff.utils.identity(model_vars) @tff.federated_computation(model_type, tf.float32, local_data_type) def local_train(initial_model, learning_rate, all_batches): @tff.federated_computation(model_type, batch_type) def batch_fn(model, batch): return batch_train(model, batch, learning_rate) return tff.sequence_reduce(all_batches, initial_model, batch_fn) @tff.federated_computation(server_model_type, server_float_type, client_data_type) def federated_train(model, learning_rate, data): return tff.federated_average( tff.federated_map(local_train, [ tff.federated_broadcast(model), tff.federated_broadcast(learning_rate), data ])) computation_building_stop = time.time() building_time = computation_building_stop - computation_building_start self.report_benchmark(name="computation_building_time, FC API", wall_time=building_time, iters=1) model = initial_model learning_rate = 0.1 federated_data = generate_fake_mnist_data() execution_array = [] for _ in range(n_rounds): execution_start = time.time() model = federated_train(model, learning_rate, federated_data) execution_stop = time.time() execution_array.append(execution_stop - execution_start) self.report_benchmark(name="Average per round execution time, FC API", wall_time=np.mean(execution_array), iters=n_rounds, extras={"std_dev": np.std(execution_array)})
def __init__(self, inner_model, dummy_batch, loss_fn, metrics): # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized # variables until they are called on input. We forced that here. inner_model(dummy_batch['x']) def _tensor_spec_with_undefined_batch_dim(tensor): # Remove the batch dimension and leave it unspecified. spec = tf.TensorSpec( shape=[None] + tensor.shape.dims[1:], dtype=tensor.dtype) return spec self._input_spec = tf.nest.map_structure( _tensor_spec_with_undefined_batch_dim, dummy_batch) self._keras_model = inner_model self._loss_fn = loss_fn self._metrics = metrics if metrics is not None else [] # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(tf.keras.metrics.Mean): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super(_WeightedMeanLossMetric, self).__init__(name, dtype) self._loss_fn = loss_fn def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype) batch_loss = self._loss_fn(y_true, y_pred) return super(_WeightedMeanLossMetric, self).update_state(batch_loss, batch_size) self._loss_metric = _WeightedMeanLossMetric() metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType(metric_variable_type_dict, tff.CLIENTS) def federated_output(local_outputs): results = collections.OrderedDict() for metric, variables in zip(self.get_metrics(), local_outputs): results[metric.name] = federated_aggregate_keras_metric( type(metric), metric.get_config(), variables) return results self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type) # Keras creates variables that are not added to any collection, making it # impossible for TFF to extract them and create the appropriate initializer # before call a tff.Computation. Here we store them in a TFF specific # collection so that they can be retrieved later. # TODO(b/122081673): this likely goes away in TF2.0 for variable in itertools.chain(self.trainable_variables, self.non_trainable_variables, self.local_variables): tf.add_to_collection(graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
def __init__(self, inner_model, input_spec, loss_fns, loss_weights=None, metrics=None): self._input_spec = input_spec if not loss_fns: raise ValueError( 'Must specify at least one loss_fns, got: {l}'.format( l=loss_fns)) if (bool(len(loss_fns) == 1) != tf.is_tensor(inner_model.output) or (isinstance(inner_model.output, list) and len(loss_fns) != len(inner_model.output))): raise ValueError( 'Must specify the same number of loss_fns as model ' 'outputs.\nloss_fns: {l}\nmodel outputs: {o}'.format( l=loss_fns, o=inner_model.output)) self._loss_fns = loss_fns if loss_weights is None: loss_weights = [1.0] * len(loss_fns) else: py_typecheck.check_type(loss_weights, collections.Sequence) if len(loss_weights) != len(loss_fns): raise ValueError( 'Must specify the same number of ' 'loss_weights (got {llw}) as loss_fns (got {llf}).\n' 'loss_weights: {lw}\nloss_fns: {lf}'.format( lw=loss_weights, llw=len(loss_weights), lf=loss_fns, llf=len(loss_fns))) self._loss_weights = loss_weights self._keras_model = inner_model self._metrics = metrics if metrics is not None else [] # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(tf.keras.metrics.Mean): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super().__init__(name, dtype) self._loss_fns = loss_fns self._loss_weights = loss_weights def update_state(self, y_true, y_pred, sample_weight=None): if len(self._loss_fns) == 1: batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype) y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) batch_loss = self._loss_fns[0](y_true, y_pred) else: batch_loss = tf.zeros(()) for i in range(len(self._loss_fns)): y_t = tf.cast(y_true[i], self._dtype) y_p = tf.cast(y_pred[i], self._dtype) batch_loss += self._loss_weights[i] * self._loss_fns[ i](y_t, y_p) batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype) return super().update_state(batch_loss, batch_size) self._loss_metric = _WeightedMeanLossMetric() metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType( metric_variable_type_dict, tff.CLIENTS) def federated_output(local_outputs): results = collections.OrderedDict() for metric, variables in zip(self.get_metrics(), local_outputs): results[metric.name] = federated_aggregate_keras_metric( type(metric), metric.get_config(), variables) return results self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type) # Keras creates variables that are not added to any collection, making it # impossible for TFF to extract them and create the appropriate initializer # before call a tff.Computation. Here we store them in a TFF specific # collection so that they can be retrieved later. # TODO(b/122081673): this likely goes away in TF2.0 for variable in itertools.chain(self.trainable_variables, self.non_trainable_variables, self.local_variables): tf.compat.v1.add_to_collection( graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
def build_personalization_eval(model_fn, personalize_fn_dict, baseline_evaluate_fn, max_num_samples=100, context_tff_type=None): """Builds the TFF computation for evaluating personalization strategies. The returned TFF computation broadcasts model weights from SERVER to CLIENTS. Each client evaluates the personalization strategies given in `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples` participating clients are collected to the SERVER. Args: model_fn: A no-argument function that returns a `tff.learning.Model`. personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a strategy name) to a no-argument function that returns a `tf.function`. Each `tf.function` represents a personalization strategy: it accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), a training `tf.dataset.Dataset`, a test `tf.dataset.Dataset`, and an arbitrary context object (which is used to hold any extra information that a personalization strategy may use), trains a personalized model, and returns the evaluation metrics. The evaluation metrics are usually represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), and a `tf.dataset.Dataset`, evaluates the model on the dataset, and returns the evaluation metrics. The evaluation metrics are usually represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This function is *only* used to compute the baseline metrics of the initial model. max_num_samples: A positive `int` specifying the maximum number of metric samples to collect in a round. Each sample contains the personalization metrics from a single client. If the number of participating clients in a round is smaller than this value, all clients' metrics are collected. context_tff_type: A `tff.Type` of the optional context object used by the personalization strategies defined in `personalization_fn_dict`. We use a context object to hold any extra information (in addition to the training dataset) that personalization may use. If context is used in `personalization_fn_dict`, its `tff.Type` must be provided here. Returns: A federated `tff.Computation` that maps < model_weights@SERVER, input@CLIENTS > -> personalization_metrics@SERVER, where: - model_weights is a `tff.learning.framework.ModelWeights`. - each client's input is an `OrderedDict` of at least two keys `train_data` and `test_data`, and each key is mapped to a `tf.dataset.Dataset`. If context is used in `personalize_fn_dict`, then client input has a third key `context` that is mapped to a object whose `tff.Type` is provided by the `context_tff_type` argument. - personazliation_metrics is an `OrderedDict` that maps a key 'baseline_metrics' to the evaluation metrics of the initial model (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of the corresponding personalization strategies. - Note: only metrics from at most `max_num_samples` participating clients are collected to the SERVER. All collected metrics are stored in a single `OrderedDict` (the personalization_metrics shown above), where each metric is mapped to a list of scalars (each scalar comes from one client). Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all come from the same client. Raises: TypeError: If arguments are of the wrong types. ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`. ValueError: If `max_num_samples` is not positive. """ # Obtain the types by constructing the model first. # TODO(b/124477628): Replace it with other ways of handling metadata. with tf.Graph().as_default(): py_typecheck.check_callable(model_fn) model = model_utils.enhance(model_fn()) model_weights_type = tff.framework.type_from_tensors(model.weights) batch_type = tff.to_type(model.input_spec) # Define the `tff.Type` of each client's input. client_input_type = collections.OrderedDict([ ('train_data', tff.SequenceType(batch_type)), ('test_data', tff.SequenceType(batch_type)) ]) if context_tff_type is not None: py_typecheck.check_type(context_tff_type, tff.Type) client_input_type['context'] = context_tff_type client_input_type = tff.to_type(client_input_type) @tff.tf_computation(model_weights_type, client_input_type) def _client_computation(initial_model_weights, client_input): """TFF computation that runs on each client.""" model = model_fn() train_data = client_input['train_data'] test_data = client_input['test_data'] context = client_input.get('context', None) return _client_fn(model, initial_model_weights, train_data, test_data, personalize_fn_dict, baseline_evaluate_fn, context) py_typecheck.check_type(max_num_samples, int) if max_num_samples <= 0: raise ValueError('max_num_samples must be a positive integer.') @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(client_input_type, tff.CLIENTS)) def personalization_eval(server_model_weights, federated_client_input): """TFF orchestration logic.""" client_init_weights = tff.federated_broadcast(server_model_weights) client_final_metrics = tff.federated_map( _client_computation, (client_init_weights, federated_client_input)) # WARNING: Collecting information from clients can be risky. Users have to # make sure that it is proper to collect those metrics from clients. # TODO(b/147889283): Add a link to the TFF doc once it exists. results = tff.utils.federated_sample(client_final_metrics, max_num_samples) return results return personalization_eval
def __init__(self, inner_model, dummy_batch, loss_fn, metrics): # TODO(b/124477598): the following set_session() should be removed in the # future. This is a workaround for Keras' caching sessions in a way that # isn't compatible with TFF. This is already fixed in TF master, but not as # of v1.13.1. # # We do not use .clear_session() because it blows away the graph stack by # resetting the default graph. tf.keras.backend.set_session(None) if hasattr(dummy_batch, '_asdict'): dummy_batch = dummy_batch._asdict() # Convert input to tensors, possibly from nested lists that need to be # converted to a single top-level tensor. dummy_tensors = collections.OrderedDict([ (k, tf.convert_to_tensor_or_sparse_tensor(v)) for k, v in six.iteritems(dummy_batch) ]) # NOTE: sub-classed `tf.keras.Model`s do not have fully initialized # variables until they are called on input. We forced that here. inner_model(dummy_tensors['x']) def _tensor_spec_with_undefined_batch_dim(tensor): # Remove the batch dimension and leave it unspecified. spec = tf.TensorSpec( shape=[None] + tensor.shape.dims[1:], dtype=tensor.dtype) return spec self._input_spec = nest.map_structure(_tensor_spec_with_undefined_batch_dim, dummy_tensors) self._keras_model = inner_model self._loss_fn = loss_fn self._metrics = metrics if metrics is not None else [] # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(keras_metrics.Metric): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super(_WeightedMeanLossMetric, self).__init__(name, dtype) self._total_loss = self.add_weight('total_loss', initializer='zeros') self._total_weight = self.add_weight( 'total_weight', initializer='zeros') self._loss_fn = loss_fn def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) # _loss_fn is expected to return the scalar mean loss, so we multiply by # the batch_size to get back to total loss. batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype) batch_total_loss = self._loss_fn(y_true, y_pred) * batch_size op = self._total_loss.assign_add(batch_total_loss) with tf.control_dependencies([op]): return self._total_weight.assign_add(batch_size) def result(self): return tf.div_no_nan(self._total_loss, self._total_weight) self._loss_metric = _WeightedMeanLossMetric() metric_variable_type_dict = nest.map_structure(tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType( metric_variable_type_dict, tff.CLIENTS, all_equal=False) def federated_output(local_outputs): results = collections.OrderedDict() for metric, variables in zip(self.get_metrics(), local_outputs): results[metric.name] = federated_aggregate_keras_metric( type(metric), metric.get_config(), variables) return results self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type) # Keras creates variables that are not added to any collection, making it # impossible for TFF to extract them and create the appropriate initializer # before call a tff.Computation. Here we store them in a TFF specific # collection so that they can be retrieved later. # TODO(b/122081673): this likely goes away in TF2.0 for variable in itertools.chain(self.trainable_variables, self.non_trainable_variables, self.local_variables): tf.add_to_collection(graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE, variable)
Args: process: A measured process to validate. Returns: `True` iff the process is a validate aggregation process, otherwise `False`. """ next_type = process.next.type_signature return (isinstance(process, tff.templates.MeasuredProcess) and _is_valid_stateful_process(process) and next_type.parameter[1].placement is tff.CLIENTS and next_type.result.result.placement is tff.SERVER) # ============================================================================ NONE_SERVER_TYPE = tff.FederatedType(tff.NamedTupleType([]), tff.SERVER) def _wrap_in_measured_process( stateful_fn: Union[tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn], input_type: tff.Type) -> tff.templates.MeasuredProcess: """Converts a `tff.utils.StatefulFn` to a `tff.templates.MeasuredProcess`.""" py_typecheck.check_type( stateful_fn, (tff.utils.StatefulBroadcastFn, tff.utils.StatefulAggregateFn)) @tff.federated_computation() def initialize_comp(): if not isinstance(stateful_fn.initialize, tff.Computation): initialize = tff.tf_computation(stateful_fn.initialize)
def build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn, server_optimizer_fn): """Constructs `tff.utils.IterativeProcess` for Federated Averaging or SGD. This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device. Note: We pass in functions rather than constructed objects so we can ensure any variables or ops created in constructors are placed in the correct graph. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. model_to_client_delta_fn: A function from a model_fn to a `ClientDeltaFn`. server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. Returns: A `tff.utils.IterativeProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(model_to_client_delta_fn) py_typecheck.check_callable(server_optimizer_fn) # TODO(b/122081673): would be nice not to have the construct a throwaway model # here just to get the types. After fully moving to TF2.0 and eager-mode, we # should re-evaluate what happens here and where `g` is used below. with tf.Graph().as_default() as g: dummy_model_for_metadata = model_utils.enhance(model_fn()) @tff.federated_computation def server_init_tff(): """Orchestration logic for server model initialization.""" no_arg_server_init_fn = lambda: server_init(model_fn, server_optimizer_fn) server_init_tf = tff.tf_computation(no_arg_server_init_fn) return tff.federated_value(server_init_tf(), tff.SERVER) federated_server_state_type = server_init_tff.type_signature.result server_state_type = federated_server_state_type.member tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) federated_dataset_type = tff.FederatedType( tf_dataset_type, tff.CLIENTS, all_equal=False) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ model_weights_type = federated_server_state_type.member.model @tff.tf_computation(tf_dataset_type, model_weights_type) def client_delta_tf(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) # TODO(b/123092620): this can be removed once AnonymousTuple works with # tf.contrib.framework.nest, or the following behavior is moved to # anonymous_tuple module. if isinstance(initial_model_weights, anonymous_tuple.AnonymousTuple): initial_model_weights = model_utils.ModelWeights.from_tff_value( initial_model_weights) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output client_outputs = tff.federated_map( client_delta_tf, (federated_dataset, tff.federated_broadcast(server_state.model))) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_model_tf(server_state, model_delta): """Converts args to correct python types and calls server_update_model.""" # We need to convert TFF types to the types server_update_model expects. # TODO(b/123092620): Mixing AnonymousTuple with other nested types is not # pretty, fold this into anonymous_tuple module or get working with # tf.contrib.framework.nest. py_typecheck.check_type(model_delta, anonymous_tuple.AnonymousTuple) model_delta = anonymous_tuple.to_odict(model_delta) py_typecheck.check_type(server_state, anonymous_tuple.AnonymousTuple) server_state = ServerState( model=model_utils.ModelWeights.from_tff_value(server_state.model), optimizer_state=list(server_state.optimizer_state)) return server_update_model( server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF fed_weight_type = client_outputs.weights_delta_weight.type_signature.member py_typecheck.check_type(fed_weight_type, tff.TensorType) if fed_weight_type.dtype.is_integer: @tff.tf_computation(fed_weight_type) def _cast_to_float(x): return tf.cast(x, tf.float32) weight_denom = tff.federated_map(_cast_to_float, client_outputs.weights_delta_weight) else: weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean( client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # server_update_model_tf directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply(server_update_model_tf, (server_state, round_model_delta)) # Re-use graph used to construct `model`, since it has the variables, which # need to be read in federated_output_computation to get the correct shapes # and types for the federated aggregation. with g.as_default(): aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs return tff.utils.IterativeProcess( initialize_fn=server_init_tff, next_fn=run_one_round_tff)
def _build_one_round_computation( *, model_fn: _ModelConstructor, server_optimizer_fn: _OptimizerConstructor, model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]], ClientDeltaFn], broadcast_process: tff.templates.MeasuredProcess, aggregation_process: tff.templates.MeasuredProcess, ) -> tff.Computation: """Builds the `next` computation for a model delta averaging process. Args: model_fn: a no-argument callable that constructs and returns a `tff.learning.Model`. *Must* construct and return a new model when called. Returning captured models from other scopes will raise errors. server_optimizer_fn: a no-argument callable that constructs and returns a `tf.keras.optimizers.Optimizer`. *Must* construct and return a new optimizer when called. Returning captured optimizers from other scopes will raise errors. model_to_client_delta_fn: a callable that takes a single no-arg callable that returns `tff.learning.Model` as an argument and returns a `ClientDeltaFn` which performs the local training loop and model delta computation. broadcast_process: a `tff.templates.MeasuredProcess` to broadcast the global model to the clients. aggregation_process: a `tff.templates.MeasuredProcess` to aggregate client model deltas. Returns: A `tff.Computation` that initializes the process. The computation takes a tuple of `(ServerState@SERVER, tf.data.Dataset@CLIENTS)` argument, and returns a tuple of `(ServerState@SERVER, metrics@SERVER)`. """ # TODO(b/124477628): would be nice not to have the construct a throwaway model # here just to get the types. After fully moving to TF2.0 and eager-mode, we # should re-evaluate what happens here. # TODO(b/144382142): Keras name uniquification is probably the main reason we # still need this. with tf.Graph().as_default(): dummy_model_for_metadata = model_fn() model_weights_type = tff.framework.type_from_tensors( model_utils.ModelWeights.from_model(dummy_model_for_metadata)) dummy_optimizer = server_optimizer_fn() # We must force variable creation for momentum and adaptive optimizers. _eagerly_create_optimizer_variables( model=dummy_model_for_metadata, optimizer=dummy_optimizer) optimizer_variable_type = tff.framework.type_from_tensors( dummy_optimizer.variables()) @tff.tf_computation(model_weights_type, model_weights_type.trainable, optimizer_variable_type) def server_update(global_model, model_delta, optimizer_state): """Converts args to correct python types and calls server_update_model.""" # Construct variables first. model = model_fn() optimizer = server_optimizer_fn() # We must force variable creation for momentum and adaptive optimizers. _eagerly_create_optimizer_variables(model=model, optimizer=optimizer) @tf.function def update_model_inner(weights_delta): """Applies the update to the global model.""" model_variables = model_utils.ModelWeights.from_model(model) optimizer_variables = optimizer.variables() # We might have a NaN value e.g. if all of the clients processed # had no data, so the denominator in the federated_mean is zero. # If we see any NaNs, zero out the whole update. no_nan_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite( weights_delta) # TODO(b/124538167): We should increment a server counter to # track the fact a non-finite weights_delta was encountered. # Set the variables to the current global model (before update). tf.nest.map_structure(lambda a, b: a.assign(b), (model_variables, optimizer_variables), (global_model, optimizer_state)) # Update the variables with the delta, and return the new global model. _apply_delta(optimizer=optimizer, model=model, delta=no_nan_weights_delta) return model_variables, optimizer_variables return update_model_inner(model_delta) dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) @tff.tf_computation(dataset_type, model_weights_type) def _compute_local_training_and_client_delta(dataset, initial_model_weights): """Performs client local model optimization. Args: dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) client_output = client_delta_fn(dataset, initial_model_weights) return client_output broadcast_state = broadcast_process.initialize.type_signature.result.member aggregation_state = aggregation_process.initialize.type_signature.result.member server_state_type = ServerState( model=model_weights_type, optimizer_state=optimizer_variable_type, delta_aggregate_state=aggregation_state, model_broadcast_state=broadcast_state) @tff.federated_computation( tff.FederatedType(server_state_type, tff.SERVER), tff.FederatedType(dataset_type, tff.CLIENTS)) def one_round_computation(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`, both having `tff.SERVER` placement. """ broadcast_output = broadcast_process.next( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map( _compute_local_training_and_client_delta, (federated_dataset, broadcast_output.result)) aggregation_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, client_outputs.weights_delta_weight) new_global_model, new_optimizer_state = tff.federated_map( server_update, (server_state.model, aggregation_output.result, server_state.optimizer_state)) new_server_state = tff.federated_zip( ServerState(new_global_model, new_optimizer_state, aggregation_output.state, broadcast_output.state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) measurements = tff.federated_zip( collections.OrderedDict( broadcast=broadcast_output.measurements, aggregation=aggregation_output.measurements, train=aggregated_outputs)) return new_server_state, measurements return one_round_computation
def build_model_delta_optimizer_process( model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn=build_stateless_mean(), stateful_model_broadcast_fn=build_stateless_broadcaster()): """Constructs `tff.utils.IterativeProcess` for Federated Averaging or SGD. This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a `ClientDeltaFn` that specifies how `weight_deltas` are computed on device. Note: We pass in functions rather than constructed objects so we can ensure any variables or ops created in constructors are placed in the correct graph. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`. server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The `apply_gradients` method of this optimizer is used to apply client updates to the server model. stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the `next_fn` performs a federated aggregation and upates state. That is, it has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER)`, where the `value` type is `tff.learning.framework.ModelWeights.trainable` corresponding to the object returned by `model_fn`. stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the `next_fn` performs a federated broadcast and upates state. That is, it has TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`, where the `value` type is `tff.learning.framework.ModelWeights` corresponding to the object returned by `model_fn`. Returns: A `tff.utils.IterativeProcess`. """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(model_to_client_delta_fn) py_typecheck.check_callable(server_optimizer_fn) py_typecheck.check_type(stateful_delta_aggregate_fn, tff.utils.StatefulAggregateFn) py_typecheck.check_type(stateful_model_broadcast_fn, tff.utils.StatefulBroadcastFn) # TODO(b/122081673): would be nice not to have the construct a throwaway model # here just to get the types. After fully moving to TF2.0 and eager-mode, we # should re-evaluate what happens here. with tf.Graph().as_default(): dummy_model_for_metadata = model_utils.enhance(model_fn()) # =========================================================================== # TensorFlow Computations @tff.tf_computation def tf_init_fn(): return server_init(model_fn, server_optimizer_fn, stateful_delta_aggregate_fn.initialize(), stateful_model_broadcast_fn.initialize()) tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) server_state_type = tf_init_fn.type_signature.result @tff.tf_computation(tf_dataset_type, server_state_type.model) def tf_client_delta(tf_dataset, initial_model_weights): """Performs client local model optimization. Args: tf_dataset: a `tf.data.Dataset` that provides training examples. initial_model_weights: a `model_utils.ModelWeights` containing the starting weights. Returns: A `ClientOutput` structure. """ client_delta_fn = model_to_client_delta_fn(model_fn) client_output = client_delta_fn(tf_dataset, initial_model_weights) return client_output @tff.tf_computation(server_state_type, server_state_type.model.trainable, server_state_type.delta_aggregate_state, server_state_type.model_broadcast_state) def tf_server_update(server_state, model_delta, new_delta_aggregate_state, new_broadcaster_state): """Converts args to correct python types and calls server_update_model.""" py_typecheck.check_type(server_state, ServerState) server_state = ServerState( model=server_state.model, optimizer_state=list(server_state.optimizer_state), delta_aggregate_state=new_delta_aggregate_state, model_broadcast_state=new_broadcaster_state) return server_update_model(server_state, model_delta, model_fn=model_fn, optimizer_fn=server_optimizer_fn) weight_type = tf_client_delta.type_signature.result.weights_delta_weight @tff.tf_computation(weight_type) def _cast_weight_to_float(x): return tf.cast(x, tf.float32) # =========================================================================== # Federated Computations @tff.federated_computation def server_init_tff(): """Orchestration logic for server model initialization.""" return tff.federated_value(tf_init_fn(), tff.SERVER) federated_server_state_type = server_init_tff.type_signature.result federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round_tff(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ new_broadcaster_state, client_model = stateful_model_broadcast_fn( server_state.model_broadcast_state, server_state.model) client_outputs = tff.federated_map(tf_client_delta, (federated_dataset, client_model)) # TODO(b/124070381): We hope to remove this explicit cast once we have a # full solution for type analysis in multiplications and divisions # inside TFF weight_denom = tff.federated_map(_cast_weight_to_float, client_outputs.weights_delta_weight) new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) # TODO(b/123408447): remove tff.federated_apply and call # tf_server_update directly once T <-> T@SERVER isomorphism is # supported. server_state = tff.federated_apply( tf_server_update, (server_state, round_model_delta, new_delta_aggregate_state, new_broadcaster_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) # Promote the FederatedType outside the NamedTupleType aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs return tff.utils.IterativeProcess(initialize_fn=server_init_tff, next_fn=run_one_round_tff)
def __init__(self, inner_model, input_spec, loss_fns, loss_weights=None, metrics=None): self._input_spec = input_spec if not loss_fns: raise ValueError( 'Must specify at least one loss_fns, got: {l}'.format( l=loss_fns)) if (len(tf.nest.flatten(loss_fns)) != len( tf.nest.flatten(inner_model.output))): raise ValueError( 'Must specify the same number of loss_fns as model ' 'outputs.\nloss_fns: {l}\nmodel outputs: {o}'.format( l=loss_fns, o=inner_model.output)) self._loss_fns = loss_fns if loss_weights is None: loss_weights = [1.0] * len(loss_fns) else: py_typecheck.check_type(loss_weights, collections.Sequence) if len(loss_weights) != len(loss_fns): raise ValueError( 'Must specify the same number of ' 'loss_weights (got {llw}) as loss_fns (got {llf}).\n' 'loss_weights: {lw}\nloss_fns: {lf}'.format( lw=loss_weights, llw=len(loss_weights), lf=loss_fns, llf=len(loss_fns))) self._loss_weights = loss_weights self._keras_model = inner_model self._metrics = metrics if metrics is not None else [] # This is defined here so that it closes over the `loss_fn`. class _WeightedMeanLossMetric(tf.keras.metrics.Mean): """A `tf.keras.metrics.Metric` wrapper for the loss function.""" def __init__(self, name='loss', dtype=tf.float32): super().__init__(name, dtype) self._loss_fns = loss_fns self._loss_weights = loss_weights def update_state(self, y_true, y_pred, sample_weight=None): if len(self._loss_fns) == 1: batch_size = tf.cast(tf.shape(y_pred)[0], self._dtype) y_true = tf.cast(y_true, self._dtype) y_pred = tf.cast(y_pred, self._dtype) batch_loss = self._loss_fns[0](y_true, y_pred) else: batch_loss = tf.zeros(()) for i in range(len(self._loss_fns)): y_t = tf.cast(y_true[i], self._dtype) y_p = tf.cast(y_pred[i], self._dtype) batch_loss += self._loss_weights[i] * self._loss_fns[ i](y_t, y_p) batch_size = tf.cast(tf.shape(y_pred[0])[0], self._dtype) return super().update_state(batch_loss, batch_size) self._loss_metric = _WeightedMeanLossMetric() metric_variable_type_dict = tf.nest.map_structure( tf.TensorSpec.from_tensor, self.report_local_outputs()) federated_local_outputs_type = tff.FederatedType( metric_variable_type_dict, tff.CLIENTS) def federated_output(local_outputs): return federated_aggregate_keras_metric(self.get_metrics(), local_outputs) self._federated_output_computation = tff.federated_computation( federated_output, federated_local_outputs_type)