def client_update(optimizer, initial_weights, data): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def reduce_fn(state, batch): """Trains a `tff.learning.Model` on a batch of data.""" num_examples_sum, optimizer_state = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model_weights.trainable) if delta_l2_regularizer > 0.0: proximal_term = tf.nest.map_structure( lambda x, y: delta_l2_regularizer * (y - x), model_weights.trainable, initial_weights.trainable) gradients = tf.nest.map_structure(tf.add, gradients, proximal_term) optimizer_state, updated_weights = optimizer.next( optimizer_state, tuple(tf.nest.flatten(model_weights.trainable)), tuple(tf.nest.flatten(gradients))) updated_weights = tf.nest.pack_sequence_as(model_weights.trainable, updated_weights) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights.trainable, updated_weights) if output.num_examples is None: num_examples_sum += tf.shape(output.predictions, out_type=tf.int64)[0] else: num_examples_sum += tf.cast(output.num_examples, tf.int64) return num_examples_sum, optimizer_state def initial_state_for_reduce_fn(): # TODO(b/161529310): We flatten and convert the trainable specs to tuple, # as "for batch in data:" pattern would try to stack the tensors in list. trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), tuple(tf.nest.flatten(model_weights.trainable))) return (tf.zeros(shape=[], dtype=tf.int64), optimizer.initialize(trainable_tensor_specs)) num_examples, _ = dataset_reduce_fn( reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn) client_update = tf.nest.map_structure(tf.subtract, initial_weights.trainable, model_weights.trainable) model_output = model.report_local_unfinalized_metrics() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? client_update, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(client_update)) client_weight = _choose_client_weight(weighting, has_non_finite_delta, num_examples) return client_works.ClientResult( update=client_update, update_weight=client_weight), model_output
def server_update(model, server_optimizer, server_state, weights_delta): """Updates `server_state` based on `weights_delta`, increase the round number. Args: model: A `tff.learning.Model`. server_optimizer: A `tf.keras.optimizers.Optimizer`. server_state: A `ServerState`, the state to be updated. weights_delta: An update to the trainable variables of the model. Returns: An updated `ServerState`. """ model_weights = _get_weights(model) tff.utils.assign(model_weights, server_state.model) # Server optimizer variables must be initialized prior to invoking this tff.utils.assign(server_optimizer.variables(), server_state.optimizer_state) weights_delta, has_non_finite_weight = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) if has_non_finite_weight > 0: return server_state # Apply the update to the model. We must multiply weights_delta by -1.0 to # view it as a gradient that should be applied to the server_optimizer. grads_and_vars = [(-1.0 * x, v) for x, v in zip(weights_delta, model_weights.trainable)] server_optimizer.apply_gradients(grads_and_vars) # Create a new state based on the updated model. return tff.utils.update_state(server_state, model=model_weights, optimizer_state=server_optimizer.variables(), round_num=server_state.round_num + 1.0)
def server_update(model, server_optimizer, server_optimizer_vars, server_state, weights_delta, grads_norm): """Updates `server_state` based on `weights_delta`. Args: model: A `tff.learning.Model`. server_optimizer: A `tf.keras.optimizers.Optimizer`. server_optimizer_vars: A list of previous variables of server_optimzer. server_state: A `ServerState` namedtuple, the state to be updated. weights_delta: An update to the trainable variables of the model. grads_norm: Summation of the norm of gradients from clients. Returns: An updated `ServerState`. """ model_weights = tff.learning.framework.ModelWeights.from_model(model) tf.nest.map_structure(lambda v, t: v.assign(t), (model_weights, server_optimizer_vars), (server_state.model, server_state.optimizer_state)) # Zero out the weight if there are any non-finite values. weights_delta, _ = (tensor_utils.zero_all_if_any_non_finite(weights_delta)) grads_and_vars = tf.nest.map_structure( lambda x, v: (-1.0 * x, v), tf.nest.flatten(weights_delta), tf.nest.flatten(model_weights.trainable)) server_optimizer.update_grads_norm( tf.nest.flatten(model_weights.trainable), grads_norm) server_optimizer.apply_gradients(grads_and_vars, name='server_update') return tff.utils.update_state(server_state, model=model_weights, optimizer_state=server_optimizer_vars)
def server_update(global_model, mean_model_delta, optimizer_state): """Updates the global model with the mean model update from clients.""" with tf.init_scope(): # Create a structure of variables that the server optimizer can update. model_variables = tf.nest.map_structure( lambda t: tf.Variable(initial_value=tf.zeros(t.shape, t.dtype) ), global_model) optimizer = keras_optimizer.build_or_verify_tff_optimizer( server_optimizer_fn, model_variables.trainable, disjoint_init_and_next=True) # Set the variables to the current global model, the optimizer will # update these variables. tf.nest.map_structure(lambda a, b: a.assign(b), model_variables, global_model) # We might have a NaN value e.g. if all of the clients processed had no # data, so the denominator in the federated_mean is zero. If we see any # NaNs, zero out the whole update. # TODO(b/124538167): We should increment a server counter to # track the fact a non-finite weights_delta was encountered. finite_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite( mean_model_delta) # Update the global model variables with the delta as a pseudo-gradient. negative_weights_delta = tf.nest.map_structure(lambda w: -1.0 * w, finite_weights_delta) optimizer_state, updated_weights = optimizer.next( optimizer_state, model_variables.trainable, negative_weights_delta) # Keras optimizers mutate model variables in with the `next` step above, so # we skip calling the assignment for those optimizers. if not isinstance(optimizer, keras_optimizer.KerasOptimizer): tf.nest.map_structure(lambda a, b: a.assign(b), model_variables.trainable, updated_weights) return model_variables, optimizer_state
def client_computation( # Tensor/Dataset arguments that will be supplied by TFF: gen_inputs_ds: tf.data.Dataset, real_data_ds: tf.data.Dataset, from_server: FromServer, # Python arguments bound to be bound at TFF computation construction time: generator: tf.keras.Model, discriminator: tf.keras.Model, train_discriminator_fn) -> ClientOutput: """The computation to run on the client, training the discriminator. Args: gen_inputs_ds: A `tf.data.Dataset` of generator_inputs. real_data_ds: A `tf.data.Dataset` of data from the real distribution. from_server: A `FromServer` object, including the current model weights. generator: The generator. discriminator: The discriminator. train_discriminator_fn: A function which takes the two networks, generator input, and real data and trains the discriminator. Returns: A `ClientOutput` object. """ tf.nest.map_structure(lambda a, b: a.assign(b), generator.weights, from_server.generator_weights) tf.nest.map_structure(lambda a, b: a.assign(b), discriminator.weights, from_server.discriminator_weights) num_examples = tf.constant(0) gen_inputs_and_real_data = tf.data.Dataset.zip( (gen_inputs_ds, real_data_ds)) for gen_inputs, real_data in gen_inputs_and_real_data: # It's possible that real_data and gen_inputs have different batch sizes. # For calculating the discriminator loss, it's desirable to have equal-sized # contributions from both the real and fake data. Also, it's necessary if # using the Wasserstein gradient penalty (where a difference is taken b/w # the real and fake data). So here we reduce to the min batch size. This # also ensures num_examples properly reflects the amount of data trained on. min_batch_size = tf.minimum( tf.shape(real_data)[0], tf.shape(gen_inputs)[0]) real_data = real_data[0:min_batch_size] gen_inputs = gen_inputs[0:min_batch_size] num_examples += train_discriminator_fn(generator, discriminator, gen_inputs, real_data) weights_delta = tf.nest.map_structure(tf.subtract, discriminator.weights, from_server.discriminator_weights) weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) update_weight = tf.cast(num_examples, tf.float32) # Zero out the weight if there are any non-finite values. # TODO(b/122071074): federated_mean might not do the right thing if # all clients have zero weight. update_weight = tf.cond(tf.equal(has_non_finite_delta, 0), lambda: update_weight, lambda: tf.constant(0.0)) return ClientOutput( discriminator_weights_delta=weights_delta, update_weight=update_weight, counters={'num_discriminator_train_examples': num_examples})
def server_update(global_model, mean_model_delta, optimizer_state): """Updates the global model with the mean model update from clients.""" with tf.init_scope(): model = model_fn() optimizer = server_optimizer_fn() # We must force variable creation for momentum and adaptive optimizers. _eagerly_create_optimizer_variables(model=model, optimizer=optimizer) model_variables = model_utils.ModelWeights.from_model(model) optimizer_variables = optimizer.variables() # Set the variables to the current global model, the optimizer will # update these variables. tf.nest.map_structure(lambda a, b: a.assign(b), (model_variables, optimizer_variables), (global_model, optimizer_state)) # We might have a NaN value e.g. if all of the clients processed had no # data, so the denominator in the federated_mean is zero. If we see any # NaNs, zero out the whole update. # TODO(b/124538167): We should increment a server counter to # track the fact a non-finite weights_delta was encountered. finite_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite( mean_model_delta) # Update the global model variables with the delta as a pseudo-gradient. _apply_delta(optimizer=optimizer, model=model, delta=finite_weights_delta) return model_variables, optimizer_variables
def __call__(self, dataset, initial_weights): # TODO(b/123898430): The control dependencies below have been inserted as a # temporary workaround. These control dependencies need to be removed, and # defuns and datasets supported together fully. model = self._model # TODO(b/113112108): Remove this temporary workaround and restore check for # `tf.data.Dataset` after subclassing the currently used custom data set # representation from it. if 'Dataset' not in str(type(dataset)): raise TypeError('Expected a data set, found {}.'.format( py_typecheck.type_string(type(dataset)))) # TODO(b/120801384): We should initialize model.local_variables here. # Or, we may just need a convention that TFF initializes all variables # before invoking the TF function. # We must assign to a variable here in order to use control_dependencies. dummy_weights = nest.map_structure(tf.assign, model.weights, initial_weights) with tf.control_dependencies(list(dummy_weights.trainable.values())): def reduce_fn(dummy_state, batch): """Runs `tff.learning.Model.train_on_batch` on local client batch.""" output = model.train_on_batch(batch) tf.assign_add(self._num_examples, tf.shape(output.predictions)[0]) return dummy_state # TODO(b/124477598): Remove dummy_output when b/121400757 fixed. dummy_output = dataset.reduce(initial_state=tf.constant(0.0), reduce_func=reduce_fn) with tf.control_dependencies([dummy_output]): weights_delta = nest.map_structure(tf.subtract, model.weights.trainable, initial_weights.trainable) aggregated_outputs = model.report_local_outputs() weights_delta_weight = self._client_weight_fn(aggregated_outputs) # pylint:disable=not-callable # TODO(b/122071074): Consider moving this functionality into # tff.federated_average? weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0), lambda: weights_delta_weight, lambda: tf.constant(0)) return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, aggregated_outputs, tensor_utils.to_odict({ 'num_examples': self._num_examples.value(), 'has_non_finite_delta': has_non_finite_delta, 'workaround for b/121400757': dummy_output, }))
def __call__(self, dataset, initial_weights): # TODO(b/113112108): Remove this temporary workaround and restore check for # `tf.data.Dataset` after subclassing the currently used custom data set # representation from it. if 'Dataset' not in str(type(dataset)): raise TypeError('Expected a data set, found {}.'.format( py_typecheck.type_string(type(dataset)))) model = self._model dummy_weights = nest.map_structure(tf.assign, model.weights, initial_weights) def reduce_fn(accumulated_grads, batch): """Runs forward_pass on batch.""" with tf.contrib.eager.GradientTape() as tape: output = model.forward_pass(batch) with tf.control_dependencies(list(output)): flat_vars = nest.flatten(model.weights.trainable) grads = nest.pack_sequence_as( accumulated_grads, tape.gradient(output.loss, flat_vars)) if self._batch_weight_fn is not None: batch_weight = self._batch_weight_fn(batch) else: batch_weight = tf.cast( tf.shape(output.predictions)[0], tf.float32) tf.assign_add(self._batch_weight_sum, batch_weight) return nest.map_structure( lambda accumulator, grad: accumulator + batch_weight * grad, accumulated_grads, grads) with tf.control_dependencies(list(dummy_weights.trainable.values())): self._grad_sum_vars = dataset.reduce( initial_state=self._grad_sum_vars, reduce_func=reduce_fn) with tf.control_dependencies( [tf.identity(v) for v in self._grad_sum_vars.values()]): # For SGD, the delta is just the negative of the average gradient: weights_delta = nest.map_structure( lambda gradient: -1.0 * gradient / self._batch_weight_sum, self._grad_sum_vars) weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0), lambda: self._batch_weight_sum, lambda: tf.constant(0.0)) return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, model.report_local_outputs(), tensor_utils.to_odict({ 'client_weight': weights_delta_weight, 'has_non_finite_delta': has_non_finite_delta, }))
def client_update(model, dataset, initial_weights, client_optimizer, client_id, client_weight_fn=None): """Updates client model. Args: model: A `tff.learning.Model`. dataset: A 'tf.data.Dataset'. initial_weights: A `tff.learning.ModelWeights` from server. client_optimizer: A `tf.keras.optimizer.Optimizer` object. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A 'ClientOutput`. """ model_weights = _get_weights(model) tff.utils.assign(model_weights, initial_weights) num_examples = tf.constant(0, dtype=tf.int32) for batch in dataset: with tf.GradientTape() as tape: output = model.forward_pass(batch) grads = tape.gradient(output.loss, model_weights.trainable) #grads = tf.nest.map_structure(lambda g: clip_ops.clip_by_norm(g,5.0), grads) grads, _ = tf.clip_by_global_norm(grads, 1.0) grads_and_vars = zip(grads, model_weights.trainable) client_optimizer.apply_gradients(grads_and_vars) num_examples += tf.shape(output.predictions)[0] aggregated_outputs = model.report_local_outputs() weights_delta = tf.nest.map_structure(lambda a, b: a - b, model_weights.trainable, initial_weights.trainable) weights_delta, has_non_finite_weight = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) if has_non_finite_weight > 0: client_weight = tf.constant([[0]], dtype=tf.float32) else: client_weight = tf.cast([[num_examples]], dtype=tf.float32) # else: # client_weight = client_weight_fn(aggregated_outputs) #weights_delta_encoded = tf.nest.map_structure(mean_encoder_fn, weights_delta) return ClientOutput( weights_delta, client_weight, aggregated_outputs, collections.OrderedDict([('num_examples', num_examples)]), client_id)
def __call__(self, dataset, initial_weights): # N.B. When not in eager mode, this code must be wrapped as a defun # as it uses program-order semantics to avoid adding many explicit # control dependencies. model = self._model py_typecheck.check_type(dataset, tf.data.Dataset) nest.map_structure(tf.assign, model.weights, initial_weights) @tf.contrib.eager.function(autograph=False) def reduce_fn(dummy_state, batch): """Runs forward_pass on batch.""" with tf.contrib.eager.GradientTape() as tape: output = model.forward_pass(batch) flat_vars = nest.flatten(model.weights.trainable) grads = nest.pack_sequence_as( self._grad_sum_vars, tape.gradient(output.loss, flat_vars)) if self._batch_weight_fn is not None: batch_weight = self._batch_weight_fn(batch) else: batch_weight = tf.cast( tf.shape(output.predictions)[0], tf.float32) tf.assign_add(self._batch_weight_sum, batch_weight) nest.map_structure( lambda v, g: # pylint:disable=g-long-lambda tf.assign_add(v, batch_weight * g), self._grad_sum_vars, grads) return dummy_state # TODO(b/121400757): Remove dummy_output when bug fixed. dummy_output = dataset.reduce(initial_state=tf.constant(0.0), reduce_func=reduce_fn) # For SGD, the delta is just the negative of the average gradient: # TODO(b/109733734): Might be better to send the weighted grad sums # and the denominator separately? weights_delta = nest.map_structure( lambda g: -1.0 * g / self._batch_weight_sum, self._grad_sum_vars) weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0), lambda: self._batch_weight_sum, lambda: tf.constant(0.0)) return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, model.report_local_outputs(), tensor_utils.to_odict({ 'client_weight': weights_delta_weight, 'has_non_finite_delta': has_non_finite_delta, 'workaround for b/121400757': dummy_output, }))
def expect_zeros(structure, expected): with tf.Graph().as_default(): result, error = tensor_utils.zero_all_if_any_non_finite(structure) with self.session() as sess: result, error = sess.run((result, error)) try: tf.nest.map_structure(np.testing.assert_allclose, result, expected) except AssertionError: self.fail('Expected to get zeros, but instead got {}'.format(result)) self.assertEqual(error, 1)
def server_update(server_state, weights_delta, aggregator_state, broadcaster_state): """Updates the `server_state` based on `weights_delta`. Args: server_state: A `tff.learning.framework.ServerState`, the state to be updated. weights_delta: The model delta in global trainable variables from clients. aggregator_state: The state of the aggregator after performing aggregation. broadcaster_state: The state of the broadcaster after broadcasting. Returns: The updated `tff.learning.framework.ServerState`. """ with tf.init_scope(): model = model_fn() global_model_weights = reconstruction_utils.get_global_variables(model) optimizer = keras_optimizer.build_or_verify_tff_optimizer( server_optimizer_fn, global_model_weights.trainable, disjoint_init_and_next=True) optimizer_state = server_state.optimizer_state # Initialize the model with the current state. tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights, server_state.model) weights_delta, has_non_finite_weight = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) # We ignore the update if the weights_delta is non finite. if tf.equal(has_non_finite_weight, 0): negative_weights_delta = tf.nest.map_structure( lambda w: -1.0 * w, weights_delta) optimizer_state, updated_weights = optimizer.next( optimizer_state, global_model_weights.trainable, negative_weights_delta) if not isinstance(optimizer, keras_optimizer.KerasOptimizer): # Keras optimizer mutates model variables within the `next` step. tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights.trainable, updated_weights) # Create a new state based on the updated model. return structure.update_struct( server_state, model=global_model_weights, optimizer_state=optimizer_state, model_broadcast_state=broadcaster_state, delta_aggregate_state=aggregator_state, )
def __call__(self, dataset, initial_weights): model = self._model optimizer = self._optimizer tf.nest.map_structure(lambda a, b: a.assign(b), model.weights, initial_weights) def reduce_fn(num_examples_sum, batch): """Train `tff.learning.Model` on local client batch.""" with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model.weights.trainable) optimizer.apply_gradients(zip(gradients, model.weights.trainable)) if output.num_examples is None: return num_examples_sum + tf.shape(output.predictions, out_type=tf.int64)[0] else: return num_examples_sum + tf.cast(output.num_examples, tf.int64) num_examples_sum = self._dataset_reduce_fn( reduce_fn, dataset, initial_state_fn=lambda: tf.zeros(shape=[], dtype=tf.int64)) weights_delta = tf.nest.map_structure(tf.subtract, model.weights.trainable, initial_weights.trainable) model_output = model.report_local_outputs() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) # Zero out the weight if there are any non-finite values. if has_non_finite_delta > 0: # TODO(b/176171842): Zeroing has no effect with unweighted aggregation. weights_delta_weight = tf.constant(0.0) elif self._client_weighting is ClientWeighting.NUM_EXAMPLES: weights_delta_weight = tf.cast(num_examples_sum, tf.float32) elif self._client_weighting is ClientWeighting.UNIFORM: weights_delta_weight = tf.constant(1.0) else: weights_delta_weight = self._client_weighting(model_output) # TODO(b/176245976): TFF `ClientOutput` structure names are confusing. optimizer_output = collections.OrderedDict( num_examples=num_examples_sum) return optimizer_utils.ClientOutput(weights_delta, weights_delta_weight, model_output, optimizer_output)
def server_update_model( server_state: ServerState, weights_delta: Collection[tf.Tensor], model_fn: _ModelConstructor, optimizer_fn: _OptimizerConstructor, ) -> ServerState: """Updates `server_state` based on `weights_delta`. Args: server_state: A `tff.learning.framework.ServerState` namedtuple, the state to be updated. weights_delta: An update to the trainable variables of the model. model_fn: A no-arg function that returns a `tff.learning.Model`. Passing in a function ensures any variables are created when server_update_model is called, so they can be captured in a specific graph or other context. optimizer_fn: A no-arg function that returns a `tf.train.Optimizer`. As with model_fn, we pass in a function to control when variables are created. Returns: An updated `tff.learning.framework.ServerState`. """ py_typecheck.check_type(server_state, ServerState) py_typecheck.check_type(weights_delta, collections.Collection) model = model_utils.enhance(model_fn()) optimizer = optimizer_fn() apply_delta_fn, optimizer_vars = _build_server_optimizer(model, optimizer) # We might have a NaN value e.g. if all of the clients processed # had no data, so the denominator in the federated_mean is zero. # If we see any NaNs, zero out the whole update. no_nan_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite( weights_delta) # TODO(b/124538167): We should increment a server counter to # track the fact a non-finite weights_delta was encountered. @tf.function def update_model_inner(): """Applies the update.""" tf.nest.map_structure( lambda a, b: a.assign(b), (model.weights, optimizer_vars), (server_state.model, server_state.optimizer_state)) apply_delta_fn(no_nan_weights_delta) return model.weights, optimizer_vars model_weights, optimizer_vars = update_model_inner() # TODO(b/123092620): We must do this outside of the above tf.function, because # there could be an AnonymousTuple hiding in server_state, # and tf.function's can't return AnonymousTuples. return tff.utils.update_state(server_state, model=model_weights, optimizer_state=optimizer_vars)
def client_update(initial_weights, dataset): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def reduce_fn(state, batch): """Runs forward_pass on batch and sums the weighted gradients.""" accumulated_gradients, num_examples_sum = state with tf.GradientTape() as tape: output = model.forward_pass(batch) gradients = tape.gradient(output.loss, model_weights.trainable) num_examples = tf.cast(output.num_examples, tf.float32) accumulated_gradients = tuple( accumulator + num_examples * gradient for accumulator, gradient in zip(accumulated_gradients, gradients)) # We may be able to optimize the reduce function to avoid doubling the # number of required variables here (e.g. keeping two copies of all # gradients). If you're looking to optimize memory usage this might be a # place to look. return (accumulated_gradients, num_examples_sum + num_examples) def _zero_initial_state(): """Create a tuple of (gradient accumulators, num examples).""" return tuple( tf.nest.map_structure(tf.zeros_like, model_weights.trainable)), tf.constant( 0, dtype=tf.float32) gradient_sums, num_examples_sum = dataset_reduce_fn( reduce_fn=reduce_fn, dataset=dataset, initial_state_fn=_zero_initial_state) # We now normalize to compute the average gradient over all examples. average_gradient = tf.nest.map_structure( lambda gradient: gradient / num_examples_sum, gradient_sums) model_output = model.report_local_unfinalized_metrics() average_gradient, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(average_gradient)) if has_non_finite_delta > 0: client_weight = tf.constant(0.0) else: client_weight = num_examples_sum return client_works.ClientResult( update=average_gradient, update_weight=client_weight), model_output
def client_update(optimizer, initial_weights, data): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def reduce_fn(num_examples_sum, batch): """Trains a `tff.learning.Model` on a batch of data.""" with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model_weights.trainable) if delta_l2_regularizer > 0.0: proximal_term = tf.nest.map_structure( lambda x, y: delta_l2_regularizer * (y - x), model_weights.trainable, initial_weights.trainable) gradients = tf.nest.map_structure(tf.add, gradients, proximal_term) grads_and_vars = zip(gradients, model_weights.trainable) optimizer.apply_gradients(grads_and_vars) # TODO(b/199782787): Add a unit test for a model that does not compute # `num_examples` in its forward pass. if output.num_examples is None: num_examples_sum += tf.shape(output.predictions, out_type=tf.int64)[0] else: num_examples_sum += tf.cast(output.num_examples, tf.int64) return num_examples_sum def initial_state_for_reduce_fn(): return tf.zeros(shape=[], dtype=tf.int64) num_examples = dataset_reduce_fn( reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn) client_update = tf.nest.map_structure(tf.subtract, initial_weights.trainable, model_weights.trainable) model_output = model.report_local_unfinalized_metrics() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? client_update, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(client_update)) client_weight = _choose_client_weight(weighting, has_non_finite_delta, num_examples) return client_works.ClientResult( update=client_update, update_weight=client_weight), model_output
def __call__(self, dataset, initial_weights): # TODO(b/113112108): Remove this temporary workaround and restore check for # `tf.data.Dataset` after subclassing the currently used custom data set # representation from it. if 'Dataset' not in str(type(dataset)): raise TypeError('Expected a data set, found {}.'.format( py_typecheck.type_string(type(dataset)))) model = self._model tf.nest.map_structure(lambda a, b: a.assign(b), model.weights, initial_weights) @tf.function def reduce_fn(num_examples_sum, batch): """Runs `tff.learning.Model.train_on_batch` on local client batch.""" output = model.train_on_batch(batch) if output.num_examples is None: return num_examples_sum + tf.shape(output.predictions)[0] else: return num_examples_sum + output.num_examples num_examples_sum = dataset.reduce(initial_state=tf.constant(0), reduce_func=reduce_fn) weights_delta = tf.nest.map_structure(tf.subtract, model.weights.trainable, initial_weights.trainable) aggregated_outputs = model.report_local_outputs() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) if self._client_weight_fn is None: weights_delta_weight = tf.cast(num_examples_sum, tf.float32) else: weights_delta_weight = self._client_weight_fn(aggregated_outputs) # Zero out the weight if there are any non-finite values. if has_non_finite_delta > 0: weights_delta_weight = tf.constant(0.0) return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, aggregated_outputs, tensor_utils.to_odict({ 'num_examples': num_examples_sum, 'has_non_finite_delta': has_non_finite_delta, }))
def __call__(self, dataset, initial_weights): model = self._model optimizer = self._optimizer tf.nest.map_structure(lambda a, b: a.assign(b), model.weights, initial_weights) @tf.function def reduce_fn(num_examples_sum, batch): """Runs `tff.learning.Model.train_on_batch` on local client batch.""" with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model.weights.trainable) optimizer.apply_gradients(zip(gradients, model.weights.trainable)) if output.num_examples is None: return num_examples_sum + tf.shape(output.predictions)[0] else: return num_examples_sum + output.num_examples num_examples_sum = dataset.reduce(initial_state=tf.constant(0), reduce_func=reduce_fn) weights_delta = tf.nest.map_structure(tf.subtract, model.weights.trainable, initial_weights.trainable) aggregated_outputs = model.report_local_outputs() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) # Zero out the weight if there are any non-finite values. if has_non_finite_delta > 0: weights_delta_weight = tf.constant(0.0) elif self._client_weight_fn is None: weights_delta_weight = tf.cast(num_examples_sum, tf.float32) else: weights_delta_weight = self._client_weight_fn(aggregated_outputs) return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, aggregated_outputs, collections.OrderedDict( num_examples=num_examples_sum, has_non_finite_delta=has_non_finite_delta, ))
def client_update(model, dataset, initial_weights, client_optimizer): """Updates client model. Args: model: A `tff.learning.Model`. dataset: A 'tf.data.Dataset'. initial_weights: A `tff.learning.ModelWeights` from server. client_optimizer: A `tf.keras.optimizer.Optimizer` object. Returns: A 'ClientOutput`. """ model_weights = _get_weights(model) tff.utils.assign(model_weights, initial_weights) num_examples = tf.constant(0, dtype=tf.int32) for batch in dataset: with tf.GradientTape() as tape: output = model.forward_pass(batch) grads = tape.gradient(output.loss, model_weights.trainable) grads_and_vars = zip(grads, model_weights.trainable) client_optimizer.apply_gradients(grads_and_vars) num_examples += tf.shape(output.predictions)[0] aggregated_outputs = model.report_local_outputs() weights_delta = tf.nest.map_structure(lambda a, b: a - b, model_weights.trainable, initial_weights.trainable) weights_delta, has_non_finite_weight = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) if has_non_finite_weight > 0: client_weight = tf.constant(0, dtype=tf.float32) else: client_weight = tf.constant(1, dtype=tf.float32) #elif client_weight_fn is None: #client_weight = tf.cast(num_examples, dtype=tf.float32) #else: #client_weight = client_weight_fn(aggregated_outputs) return ClientOutput( weights_delta, client_weight, aggregated_outputs, collections.OrderedDict([('num_examples', num_examples)]))
def client_update_fn(optimizer, initial_weights, dataset): initial_trainable_weights = initial_weights[0] trainable_tensor_specs = tf.nest.map_structure( tf.TensorSpec.from_tensor, tuple(tf.nest.flatten(initial_trainable_weights))) optimizer_state = optimizer.initialize(trainable_tensor_specs) # Autograph requires we define these variables once outside the loop. model_weights = initial_weights trainable_weights, non_trainable_weights = model_weights num_examples = tf.constant(0, tf.int64) for batch in iter(dataset): trainable_weights, non_trainable_weights = model_weights with tf.GradientTape() as tape: # Must explicitly watch non-variable tensors. tape.watch(trainable_weights) output = model.forward_pass(model_weights, batch, training=True) gradients = tape.gradient(output.loss, trainable_weights) if tf.greater(delta_l2_regularizer, 0.0): proximal_term = tf.nest.map_structure( lambda x, y: delta_l2_regularizer * (y - x), trainable_weights, initial_trainable_weights) gradients = tf.nest.map_structure(tf.add, gradients, proximal_term) optimizer_state, trainable_weights = optimizer.next( optimizer_state, trainable_weights, gradients) num_examples += tf.cast(output.num_examples, tf.int64) model_weights = (trainable_weights, non_trainable_weights) # After all local batches, compute the delta between the trained model # and the initial incoming model weights. client_model_update = tf.nest.map_structure(tf.subtract, initial_trainable_weights, trainable_weights) # TODO(b/229612282): Implement metrics. model_output = () client_model_update, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(client_model_update)) client_weight = _choose_client_weight(weighting, has_non_finite_delta, num_examples) return client_works.ClientResult( update=client_model_update, update_weight=client_weight), model_output
def update_model_inner(weights_delta): """Applies the update to the global model.""" model_variables = model_utils.ModelWeights.from_model(model) optimizer_variables = optimizer.variables() # We might have a NaN value e.g. if all of the clients processed # had no data, so the denominator in the federated_mean is zero. # If we see any NaNs, zero out the whole update. no_nan_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite( weights_delta) # TODO(b/124538167): We should increment a server counter to # track the fact a non-finite weights_delta was encountered. # Set the variables to the current global model (before update). tf.nest.map_structure(lambda a, b: a.assign(b), (model_variables, optimizer_variables), (global_model, optimizer_state)) # Update the variables with the delta, and return the new global model. _apply_delta(optimizer=optimizer, model=model, delta=no_nan_weights_delta) return model_variables, optimizer_variables
def __call__(self, dataset, initial_weights): del initial_weights model = self._model @tf.function def reduce_fn_num_examples(num_examples_sum, batch): """Count number of examples.""" num_examples_in_batch = tf.shape(batch['x'])[0] return num_examples_sum + num_examples_in_batch @tf.function def reduce_fn_dataset_mean(sum_vector, batch): """Sum all the examples in the local dataset.""" sum_batch = tf.reshape(tf.reduce_sum(batch['x'], [0]), (-1, 1)) return sum_vector + sum_batch num_examples_sum = dataset.reduce(initial_state=tf.constant(0), reduce_func=reduce_fn_num_examples) example_vector_sum = dataset.reduce(initial_state=tf.zeros((DIM, 1)), reduce_func=reduce_fn_dataset_mean) # create an ordered dictionary with the same type as model.trainable # containing a mean of all the examples in the local dataset # Note: this works for a linear model only (as in the example above) key = list(model.weights.trainable.keys())[0] weights_delta = collections.OrderedDict( {key: example_vector_sum / tf.cast(num_examples_sum, tf.float32)}) aggregated_outputs = model.report_local_outputs() weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) weights_delta_weight = tf.cast(num_examples_sum, tf.float32) return tff.learning.framework.ClientOutput( weights_delta, weights_delta_weight, aggregated_outputs, collections.OrderedDict([ ('num_examples', num_examples_sum), ('has_non_finite_delta', has_non_finite_delta), ]))
def server_update_model(current_server_state, weights_delta, model_fn, optimizer_fn): """Updates `server_state` based on `weights_delta`. Args: current_server_state: A `tff.learning.framework.ServerState` namedtuple. weights_delta: An update to the trainable variables of the model. model_fn: A no-arg function that returns a `tff.learning.Model`. Passing in a function ensures any variables are created when server_update_model is called, so they can be captured in a specific graph or other context. optimizer_fn: A no-arg function that returns a `tf.train.Optimizer`. As with model_fn, we pass in a function to control when variables are created. Returns: An updated `tff.learning.framework.ServerState`. """ py_typecheck.check_type(current_server_state, ServerState) py_typecheck.check_type(weights_delta, collections.OrderedDict) model = model_utils.enhance(model_fn()) optimizer = optimizer_fn() apply_delta_fn, server_state_vars = _create_optimizer_and_server_state( model, optimizer) # We might have a NaN value e.g. if all of the clients processed # had no data, so the denominator in the federated_mean is zero. # If we see any NaNs, zero out the whole update. no_nan_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite( weights_delta) # TODO(b/124538167): We should increment a server counter to # track the fact a non-finite weiths_delta was encountered. @tf.contrib.eager.function(autograph=False) def update_model_inner(): """Applies the update.""" nest.map_structure(tf.assign, server_state_vars, current_server_state) apply_delta_fn(no_nan_weights_delta) return server_state_vars return update_model_inner()
def __call__(self, dataset, initial_weights): model = self._model # TODO(b/113112108): Remove this temporary workaround and restore check for # `tf.data.Dataset` after subclassing the currently used custom data set # representation from it. if 'Dataset' not in str(type(dataset)): raise TypeError('Expected a data set, found {}.'.format( py_typecheck.type_string(type(dataset)))) tf.nest.map_structure(lambda a, b: a.assign(b), model.weights, initial_weights) flat_trainable_weights = tuple(tf.nest.flatten( model.weights.trainable)) @tf.function def reduce_fn(state, batch): """Runs forward_pass on batch and sums the weighted gradients.""" flat_accumulated_grads, batch_weight_sum = state with tf.GradientTape() as tape: output = model.forward_pass(batch) flat_grads = tape.gradient(output.loss, flat_trainable_weights) if self._batch_weight_fn is not None: batch_weight = self._batch_weight_fn(batch) else: batch_weight = tf.cast( tf.shape(output.predictions)[0], tf.float32) flat_accumulated_grads = tuple( accumulator + batch_weight * grad for accumulator, grad in zip( flat_accumulated_grads, flat_grads)) # The TF team is aware of an optimization in the reduce state to avoid # doubling the number of required variables here (e.g. keeping two copies # of all gradients). If you're looking to optimize memory usage this might # be a place to look. return (flat_accumulated_grads, batch_weight_sum + batch_weight) def _zero_initial_state(): """Create a tuple of (tuple of gradient accumulators, batch weight sum).""" return (tuple(tf.zeros_like(w) for w in flat_trainable_weights), tf.constant(0.0)) flat_grad_sums, batch_weight_sum = self._dataset_reduce_fn( reduce_fn=reduce_fn, dataset=dataset, initial_state_fn=_zero_initial_state) grad_sums = tf.nest.pack_sequence_as(model.weights.trainable, flat_grad_sums) # For SGD, the delta is just the negative of the average gradient: weights_delta = tf.nest.map_structure( lambda gradient: -1.0 * gradient / batch_weight_sum, grad_sums) weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) if has_non_finite_delta > 0: weights_delta_weight = tf.constant(0.0) else: weights_delta_weight = batch_weight_sum return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, model.report_local_outputs(), tensor_utils.to_odict({ 'client_weight': weights_delta_weight, 'has_non_finite_delta': has_non_finite_delta, }))
def client_update(dataset, initial_model_weights): """Performs client local model optimization. Args: dataset: A `tf.data.Dataset` that provides training examples. initial_model_weights: A `tff.learning.ModelWeights` containing the starting global trainable and non-trainable weights. Returns: A `ClientOutput`. """ with tf.init_scope(): model = model_fn() metrics = [] if metrics_fn is not None: metrics.extend(metrics_fn()) # To be used to calculate example-weighted mean across batches and # clients. metrics.append(keras_utils.MeanLossMetric(loss_fn())) # To be used to calculate batch loss for model updates. client_loss = loss_fn() global_model_weights = reconstruction_utils.get_global_variables(model) local_model_weights = reconstruction_utils.get_local_variables(model) tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights, initial_model_weights) client_optimizer = keras_optimizer.build_or_verify_tff_optimizer( client_optimizer_fn, global_model_weights.trainable, disjoint_init_and_next=False) reconstruction_optimizer = keras_optimizer.build_or_verify_tff_optimizer( reconstruction_optimizer_fn, local_model_weights.trainable, disjoint_init_and_next=False) @tf.function def reconstruction_reduce_fn(state, batch): """Runs reconstruction training on local client batch.""" num_examples_sum, optimizer_state = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) batch_loss = client_loss(y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, local_model_weights.trainable) optimizer_state, updated_weights = reconstruction_optimizer.next( optimizer_state, local_model_weights.trainable, gradients) if not isinstance(reconstruction_optimizer, keras_optimizer.KerasOptimizer): # Keras optimizer mutates model variables within the `next` step. tf.nest.map_structure(lambda a, b: a.assign(b), local_model_weights.trainable, updated_weights) return num_examples_sum + output.num_examples @tf.function def train_reduce_fn(state, batch): """Runs one step of client optimizer on local client batch.""" num_examples_sum, optimizer_state = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) batch_loss = client_loss(y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, global_model_weights.trainable) optimizer_state, updated_weights = client_optimizer.next( optimizer_state, global_model_weights.trainable, gradients) if not isinstance(client_optimizer, keras_optimizer.KerasOptimizer): # Keras optimizer mutates model variables within the `next` step. tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights.trainable, updated_weights) # Update each metric. for metric in metrics: metric.update_state(y_true=output.labels, y_pred=output.predictions) return num_examples_sum + output.num_examples recon_dataset, post_recon_dataset = dataset_split_fn(dataset) # If needed, do reconstruction, training the local variables while keeping # the global ones frozen. if local_model_weights.trainable: # Ignore output number of examples used in reconstruction, since this # isn't included in `client_weight`. def initial_state_reconstruction_reduce(): trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), local_model_weights.trainable) return tf.constant(0), reconstruction_optimizer.initialize( trainable_tensor_specs) recon_dataset.reduce( initial_state=initial_state_reconstruction_reduce(), reduce_func=reconstruction_reduce_fn) # Train the global variables, keeping local variables frozen. def initial_state_train_reduce(): trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), global_model_weights.trainable) return tf.constant(0), client_optimizer.initialize( trainable_tensor_specs) num_examples_sum, _ = post_recon_dataset.reduce( initial_state=initial_state_train_reduce(), reduce_func=train_reduce_fn) weights_delta = tf.nest.map_structure(lambda a, b: a - b, global_model_weights.trainable, initial_model_weights.trainable) # We ignore the update if the weights_delta is non finite. weights_delta, has_non_finite_weight = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) model_local_outputs = keras_utils.read_metric_variables(metrics) if has_non_finite_weight > 0: client_weight = tf.constant(0.0, dtype=tf.float32) elif client_weighting is client_weight_lib.ClientWeighting.NUM_EXAMPLES: client_weight = tf.cast(num_examples_sum, dtype=tf.float32) elif client_weighting is client_weight_lib.ClientWeighting.UNIFORM: client_weight = tf.constant(1.0, dtype=tf.float32) else: client_weight = client_weighting(model_local_outputs) return ClientOutput(weights_delta, client_weight, model_local_outputs)
def client_update(model, optimizer, dataset, initial_weights): """Updates client model. Args: model: A `tff.learning.Model`. optimizer: A `tf.keras.optimizers.Optimizer`. dataset: A 'tf.data.Dataset'. initial_weights: A `tff.learning.Model.weights` from server. Returns: A 'ClientOutput`. """ model_weights = tff.learning.framework.ModelWeights.from_model(model) tf.nest.map_structure(lambda v, t: v.assign(t), model_weights, initial_weights) flat_trainable_weights = tuple(tf.nest.flatten(model_weights.trainable)) @tf.function def reduce_fn(state, batch): """Train on local client batch, summing the gradients and gradients norm.""" flat_accumulated_grads, flat_accumulated_grads_norm, batch_weight_sum = state # Unliked the FedAvg client update, we need to capture the gradients during # training so we can send back the norms to the server. with tf.GradientTape() as tape: output = model.forward_pass(batch) flat_grads = tape.gradient(output.loss, flat_trainable_weights) optimizer.apply_gradients(zip(flat_grads, flat_trainable_weights)) batch_weight = tf.cast(tf.shape(output.predictions)[0], dtype=tf.float32) flat_accumulated_grads = tuple( accumulator + batch_weight * grad for accumulator, grad in zip(flat_accumulated_grads, flat_grads)) flat_accumulated_grads_norm = tuple( norm_accumulator + batch_weight * tf.norm(grad) for norm_accumulator, grad in zip(flat_accumulated_grads_norm, flat_grads)) return (flat_accumulated_grads, flat_accumulated_grads_norm, batch_weight_sum + batch_weight) def _zero_initial_state(): """Create a tuple of (tuple of gradient accumulators, batch weight sum).""" return ( tuple(tf.zeros_like(w) for w in flat_trainable_weights), tuple( tf.constant(0, dtype=w.dtype) for w in flat_trainable_weights), tf.constant(0, dtype=tf.float32), ) flat_grads_sum, flat_grads_norm_sum, batch_weight_sum = dataset.reduce( initial_state=_zero_initial_state(), reduce_func=reduce_fn) grads_sum = tf.nest.pack_sequence_as(model_weights.trainable, flat_grads_sum) weights_delta = tf.nest.map_structure( lambda gradient: -1.0 * gradient / batch_weight_sum, grads_sum) flat_grads_norm_sum = tf.nest.map_structure( lambda grad_norm: grad_norm / batch_weight_sum, flat_grads_norm_sum) weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) # Zero out the weight if there are any non-finite values. if has_non_finite_delta > 0: weights_delta_weight = tf.constant(0.0) else: weights_delta_weight = batch_weight_sum return ClientOutput(weights_delta, weights_delta_weight, model_output=model.report_local_outputs(), optimizer_output=collections.OrderedDict( num_examples=batch_weight_sum, flat_grads_norm_sum=flat_grads_norm_sum))
def client_update(global_optimizer_state, initial_weights, data): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def full_gradient_reduce_fn(state, batch): """Sums individual gradients, to be later divided by num_examples.""" gradient_sum, num_examples_sum = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) if output.num_examples is None: num_examples = tf.shape(output.predictions, out_type=tf.int64)[0] else: num_examples = tf.cast(output.num_examples, tf.int64) # TODO(b/161529310): We flatten and convert to tuple, as tf.data # iterators would try to stack the tensors in list into a single tensor. gradients = tuple( tf.nest.flatten( tape.gradient(output.loss, model_weights.trainable))) gradient_sum = tf.nest.map_structure( lambda g_sum, g: g_sum + g * tf.cast( num_examples, g.dtype), gradient_sum, gradients) num_examples_sum += num_examples return gradient_sum, num_examples_sum def initial_state_for_full_gradient_reduce_fn(): initial_gradient_sum = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape, spec.dtype), tuple(tf.nest.flatten(weight_tensor_specs.trainable))) initial_num_examples_sum = tf.constant(0, tf.int64) return initial_gradient_sum, initial_num_examples_sum full_gradient, num_examples = dataset_reduce_fn( full_gradient_reduce_fn, data, initial_state_for_full_gradient_reduce_fn) # Compute the average gradient. full_gradient = tf.nest.map_structure( lambda g: tf.math.divide_no_nan( g, tf.cast(num_examples, g.dtype)), full_gradient) # Resets the local model variables, including metrics states, as we are # not interested in metrics based on the full gradient evaluation, only # from the subsequent training. model.reset_metrics() def train_reduce_fn(state, batch): with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model_weights.trainable) # Mime Lite keeps optimizer state unchanged during local training. _, updated_weights = optimizer.next(global_optimizer_state, model_weights.trainable, gradients) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights.trainable, updated_weights) return state # Performs local training, updating `tf.Variable`s in `model_weights`. dataset_reduce_fn(train_reduce_fn, data, initial_state_fn=lambda: tf.zeros(shape=[0])) client_weights_delta = tf.nest.map_structure( tf.subtract, initial_weights.trainable, model_weights.trainable) model_output = model.report_local_unfinalized_metrics() # TODO(b/122071074): Consider moving this functionality into aggregation. client_weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(client_weights_delta)) client_weight = _choose_client_weight(client_weighting, has_non_finite_delta, num_examples) return client_works.ClientResult( update=client_weights_delta, update_weight=client_weight), model_output, full_gradient
def __call__(self, dataset, initial_weights): model = self._model optimizer = self._optimizer model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def reduce_fn(state, batch): """Train `tff.learning.Model` on local client batch.""" num_examples_sum, optimizer_state = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model_weights.trainable) optimizer_state, updated_weights = optimizer.next( optimizer_state, model_weights.trainable, gradients) if not isinstance(optimizer, keras_optimizer.KerasOptimizer): # Keras optimizer mutates model variables within the `next` step. tf.nest.map_structure(lambda a, b: a.assign(b), model_weights.trainable, updated_weights) if output.num_examples is None: num_examples_sum += tf.shape(output.predictions, out_type=tf.int64)[0] else: num_examples_sum += tf.cast(output.num_examples, tf.int64) return num_examples_sum, optimizer_state def initial_state_for_reduce_fn(): trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights.trainable) return tf.zeros( shape=[], dtype=tf.int64), optimizer.initialize(trainable_tensor_specs) num_examples_sum, _ = self._dataset_reduce_fn( reduce_fn, dataset, initial_state_fn=initial_state_for_reduce_fn) weights_delta = tf.nest.map_structure(tf.subtract, model_weights.trainable, initial_weights.trainable) model_output = model.report_local_outputs() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) # Zero out the weight if there are any non-finite values. if has_non_finite_delta > 0: # TODO(b/176171842): Zeroing has no effect with unweighted aggregation. weights_delta_weight = tf.constant(0.0) elif self._client_weighting is client_weight_lib.ClientWeighting.NUM_EXAMPLES: weights_delta_weight = tf.cast(num_examples_sum, tf.float32) elif self._client_weighting is client_weight_lib.ClientWeighting.UNIFORM: weights_delta_weight = tf.constant(1.0) else: weights_delta_weight = self._client_weighting(model_output) # TODO(b/176245976): TFF `ClientOutput` structure names are confusing. optimizer_output = collections.OrderedDict( num_examples=num_examples_sum) return optimizer_utils.ClientOutput(weights_delta, weights_delta_weight, model_output, optimizer_output)