def __call__(self, dataset, initial_weights): # TODO(b/123898430): The control dependencies below have been inserted as a # temporary workaround. These control dependencies need to be removed, and # defuns and datasets supported together fully. model = self._model # TODO(b/113112108): Remove this temporary workaround and restore check for # `tf.data.Dataset` after subclassing the currently used custom data set # representation from it. if 'Dataset' not in str(type(dataset)): raise TypeError('Expected a data set, found {}.'.format( py_typecheck.type_string(type(dataset)))) # TODO(b/120801384): We should initialize model.local_variables here. # Or, we may just need a convention that TFF initializes all variables # before invoking the TF function. # We must assign to a variable here in order to use control_dependencies. dummy_weights = nest.map_structure(tf.assign, model.weights, initial_weights) with tf.control_dependencies(list(dummy_weights.trainable.values())): def reduce_fn(dummy_state, batch): """Runs `tff.learning.Model.train_on_batch` on local client batch.""" output = model.train_on_batch(batch) tf.assign_add(self._num_examples, tf.shape(output.predictions)[0]) return dummy_state # TODO(b/124477598): Remove dummy_output when b/121400757 fixed. dummy_output = dataset.reduce(initial_state=tf.constant(0.0), reduce_func=reduce_fn) with tf.control_dependencies([dummy_output]): weights_delta = nest.map_structure(tf.subtract, model.weights.trainable, initial_weights.trainable) aggregated_outputs = model.report_local_outputs() weights_delta_weight = self._client_weight_fn(aggregated_outputs) # pylint:disable=not-callable # TODO(b/122071074): Consider moving this functionality into # tff.federated_average? weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0), lambda: weights_delta_weight, lambda: tf.constant(0)) return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, aggregated_outputs, tensor_utils.to_odict({ 'num_examples': self._num_examples.value(), 'has_non_finite_delta': has_non_finite_delta, 'workaround for b/121400757': dummy_output, }))
def __call__(self, dataset, initial_weights): """Dummy client delta which simply returns 1.0 for all parameters.""" client_weight = tf.constant(1.0) return optimizer_utils.ClientOutput( tf.nest.map_structure(tf.ones_like, initial_weights.trainable), weights_delta_weight=client_weight, model_output=self._model.report_local_outputs(), optimizer_output={'client_weight': client_weight})
def __call__(self, dataset, initial_weights): # TODO(b/113112108): Remove this temporary workaround and restore check for # `tf.data.Dataset` after subclassing the currently used custom data set # representation from it. if 'Dataset' not in str(type(dataset)): raise TypeError('Expected a data set, found {}.'.format( py_typecheck.type_string(type(dataset)))) model = self._model dummy_weights = nest.map_structure(tf.assign, model.weights, initial_weights) def reduce_fn(accumulated_grads, batch): """Runs forward_pass on batch.""" with tf.contrib.eager.GradientTape() as tape: output = model.forward_pass(batch) with tf.control_dependencies(list(output)): flat_vars = nest.flatten(model.weights.trainable) grads = nest.pack_sequence_as( accumulated_grads, tape.gradient(output.loss, flat_vars)) if self._batch_weight_fn is not None: batch_weight = self._batch_weight_fn(batch) else: batch_weight = tf.cast( tf.shape(output.predictions)[0], tf.float32) tf.assign_add(self._batch_weight_sum, batch_weight) return nest.map_structure( lambda accumulator, grad: accumulator + batch_weight * grad, accumulated_grads, grads) with tf.control_dependencies(list(dummy_weights.trainable.values())): self._grad_sum_vars = dataset.reduce( initial_state=self._grad_sum_vars, reduce_func=reduce_fn) with tf.control_dependencies( [tf.identity(v) for v in self._grad_sum_vars.values()]): # For SGD, the delta is just the negative of the average gradient: weights_delta = nest.map_structure( lambda gradient: -1.0 * gradient / self._batch_weight_sum, self._grad_sum_vars) weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0), lambda: self._batch_weight_sum, lambda: tf.constant(0.0)) return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, model.report_local_outputs(), tensor_utils.to_odict({ 'client_weight': weights_delta_weight, 'has_non_finite_delta': has_non_finite_delta, }))
def __call__(self, dataset, initial_weights): trainable_weights_delta = nest.map_structure(lambda x: -1.0 * x, initial_weights.trainable) client_weight = tf.constant(3.0) return optimizer_utils.ClientOutput( trainable_weights_delta, weights_delta_weight=client_weight, model_output=self._model.report_local_outputs(), optimizer_output={'client_weight': client_weight})
def __call__(self, dataset, initial_weights): # N.B. When not in eager mode, this code must be wrapped as a defun # as it uses program-order semantics to avoid adding many explicit # control dependencies. model = self._model py_typecheck.check_type(dataset, tf.data.Dataset) nest.map_structure(tf.assign, model.weights, initial_weights) @tf.contrib.eager.function(autograph=False) def reduce_fn(dummy_state, batch): """Runs forward_pass on batch.""" with tf.contrib.eager.GradientTape() as tape: output = model.forward_pass(batch) flat_vars = nest.flatten(model.weights.trainable) grads = nest.pack_sequence_as( self._grad_sum_vars, tape.gradient(output.loss, flat_vars)) if self._batch_weight_fn is not None: batch_weight = self._batch_weight_fn(batch) else: batch_weight = tf.cast( tf.shape(output.predictions)[0], tf.float32) tf.assign_add(self._batch_weight_sum, batch_weight) nest.map_structure( lambda v, g: # pylint:disable=g-long-lambda tf.assign_add(v, batch_weight * g), self._grad_sum_vars, grads) return dummy_state # TODO(b/121400757): Remove dummy_output when bug fixed. dummy_output = dataset.reduce(initial_state=tf.constant(0.0), reduce_func=reduce_fn) # For SGD, the delta is just the negative of the average gradient: # TODO(b/109733734): Might be better to send the weighted grad sums # and the denominator separately? weights_delta = nest.map_structure( lambda g: -1.0 * g / self._batch_weight_sum, self._grad_sum_vars) weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0), lambda: self._batch_weight_sum, lambda: tf.constant(0.0)) return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, model.report_local_outputs(), tensor_utils.to_odict({ 'client_weight': weights_delta_weight, 'has_non_finite_delta': has_non_finite_delta, 'workaround for b/121400757': dummy_output, }))
def __call__(self, dataset, initial_weights): model = self._model optimizer = self._optimizer tf.nest.map_structure(lambda a, b: a.assign(b), model.weights, initial_weights) def reduce_fn(num_examples_sum, batch): """Train `tff.learning.Model` on local client batch.""" with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model.weights.trainable) optimizer.apply_gradients(zip(gradients, model.weights.trainable)) if output.num_examples is None: return num_examples_sum + tf.shape(output.predictions, out_type=tf.int64)[0] else: return num_examples_sum + tf.cast(output.num_examples, tf.int64) num_examples_sum = self._dataset_reduce_fn( reduce_fn, dataset, initial_state_fn=lambda: tf.zeros(shape=[], dtype=tf.int64)) weights_delta = tf.nest.map_structure(tf.subtract, model.weights.trainable, initial_weights.trainable) model_output = model.report_local_outputs() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) # Zero out the weight if there are any non-finite values. if has_non_finite_delta > 0: # TODO(b/176171842): Zeroing has no effect with unweighted aggregation. weights_delta_weight = tf.constant(0.0) elif self._client_weighting is ClientWeighting.NUM_EXAMPLES: weights_delta_weight = tf.cast(num_examples_sum, tf.float32) elif self._client_weighting is ClientWeighting.UNIFORM: weights_delta_weight = tf.constant(1.0) else: weights_delta_weight = self._client_weighting(model_output) # TODO(b/176245976): TFF `ClientOutput` structure names are confusing. optimizer_output = collections.OrderedDict( num_examples=num_examples_sum) return optimizer_utils.ClientOutput(weights_delta, weights_delta_weight, model_output, optimizer_output)
def __call__(self, dataset, initial_weights): # TODO(b/113112108): Remove this temporary workaround and restore check for # `tf.data.Dataset` after subclassing the currently used custom data set # representation from it. if 'Dataset' not in str(type(dataset)): raise TypeError('Expected a data set, found {}.'.format( py_typecheck.type_string(type(dataset)))) model = self._model tf.nest.map_structure(lambda a, b: a.assign(b), model.weights, initial_weights) @tf.function def reduce_fn(num_examples_sum, batch): """Runs `tff.learning.Model.train_on_batch` on local client batch.""" output = model.train_on_batch(batch) if output.num_examples is None: return num_examples_sum + tf.shape(output.predictions)[0] else: return num_examples_sum + output.num_examples num_examples_sum = dataset.reduce(initial_state=tf.constant(0), reduce_func=reduce_fn) weights_delta = tf.nest.map_structure(tf.subtract, model.weights.trainable, initial_weights.trainable) aggregated_outputs = model.report_local_outputs() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) if self._client_weight_fn is None: weights_delta_weight = tf.cast(num_examples_sum, tf.float32) else: weights_delta_weight = self._client_weight_fn(aggregated_outputs) # Zero out the weight if there are any non-finite values. if has_non_finite_delta > 0: weights_delta_weight = tf.constant(0.0) return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, aggregated_outputs, tensor_utils.to_odict({ 'num_examples': num_examples_sum, 'has_non_finite_delta': has_non_finite_delta, }))
def __call__(self, dataset, initial_weights): model = self._model optimizer = self._optimizer tf.nest.map_structure(lambda a, b: a.assign(b), model.weights, initial_weights) @tf.function def reduce_fn(num_examples_sum, batch): """Runs `tff.learning.Model.train_on_batch` on local client batch.""" with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model.weights.trainable) optimizer.apply_gradients(zip(gradients, model.weights.trainable)) if output.num_examples is None: return num_examples_sum + tf.shape(output.predictions)[0] else: return num_examples_sum + output.num_examples num_examples_sum = dataset.reduce(initial_state=tf.constant(0), reduce_func=reduce_fn) weights_delta = tf.nest.map_structure(tf.subtract, model.weights.trainable, initial_weights.trainable) aggregated_outputs = model.report_local_outputs() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) # Zero out the weight if there are any non-finite values. if has_non_finite_delta > 0: weights_delta_weight = tf.constant(0.0) elif self._client_weight_fn is None: weights_delta_weight = tf.cast(num_examples_sum, tf.float32) else: weights_delta_weight = self._client_weight_fn(aggregated_outputs) return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, aggregated_outputs, collections.OrderedDict( num_examples=num_examples_sum, has_non_finite_delta=has_non_finite_delta, ))
def __call__(self, dataset, initial_weights): # Iterate over the dataset to get new metric values. def reduce_fn(dummy, batch): self._model.train_on_batch(batch) return dummy dataset.reduce(tf.constant(0.0), reduce_fn) # Create some fake weight deltas to send back. trainable_weights_delta = tf.nest.map_structure( lambda x: -tf.ones_like(x), initial_weights.trainable) client_weight = tf.constant(1.0) return optimizer_utils.ClientOutput( trainable_weights_delta, weights_delta_weight=client_weight, model_output=self._model.report_local_outputs(), optimizer_output={'client_weight': client_weight})
def __call__(self, dataset, initial_weights): del initial_weights model = self._model @tf.function def reduce_fn_num_examples(num_examples_sum, batch): """Count number of examples.""" num_examples_in_batch = tf.shape(batch['x'])[0] return num_examples_sum + num_examples_in_batch @tf.function def reduce_fn_dataset_mean(sum_vector, batch): """Sum all the examples in the local dataset.""" sum_batch = tf.reshape(tf.reduce_sum(batch['x'], [0]), (-1, 1)) return sum_vector + sum_batch num_examples_sum = dataset.reduce(initial_state=tf.constant(0), reduce_func=reduce_fn_num_examples) example_vector_sum = dataset.reduce(initial_state=tf.zeros((DIM, 1)), reduce_func=reduce_fn_dataset_mean) # create an ordered dictionary with the same type as model.trainable # containing a mean of all the examples in the local dataset # Note: this works for a linear model only (as in the example above) key = list(model.weights.trainable.keys())[0] weights_delta = collections.OrderedDict( {key: example_vector_sum / tf.cast(num_examples_sum, tf.float32)}) aggregated_outputs = model.report_local_outputs() weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) weights_delta_weight = tf.cast(num_examples_sum, tf.float32) return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, aggregated_outputs, collections.OrderedDict([ ('num_examples', num_examples_sum), ('has_non_finite_delta', has_non_finite_delta), ]))
def __call__(self, dataset, initial_weights): # Iterate over the dataset to get new metric values. def reduce_fn(dummy, batch): with tf.GradientTape() as tape: output = self._model.forward_pass(batch) gradients = tape.gradient(output.loss, self._model.trainable_variables) self._optimizer.apply_gradients( zip(gradients, self._model.trainable_variables)) return dummy dataset.reduce(tf.constant(0.0), reduce_fn) # Create some fake weight deltas to send back. trainable_weights_delta = tf.nest.map_structure(lambda x: -tf.ones_like(x), initial_weights.trainable) client_weight = tf.constant(1.0) return optimizer_utils.ClientOutput( trainable_weights_delta, weights_delta_weight=client_weight, model_output=self._model.report_local_outputs(), optimizer_output=collections.OrderedDict([('client_weight', client_weight)]))
def __call__(self, dataset, initial_weights): # Iterate over the dataset to get new metric values. def reduce_fn(whimsy, batch): with tf.GradientTape() as tape: output = self._model.forward_pass(batch) gradients = tape.gradient(output.loss, self._model.trainable_variables) self._optimizer.apply_gradients( zip(gradients, self._model.trainable_variables)) return whimsy dataset.reduce(tf.constant(0.0), reduce_fn) # Create some fake weight deltas to send back. trainable_weights_delta = tf.nest.map_structure( lambda x: -tf.ones_like(x), initial_weights.trainable) client_weight = tf.constant(1.0) model_output = self._model.report_local_unfinalized_metrics() return optimizer_utils.ClientOutput( weights_delta=trainable_weights_delta, weights_delta_weight=client_weight, model_output=model_output, # We avoid using a `None` as it is unsupported in graph serialization optimizer_output=())
def __call__(self, dataset, initial_weights): model = self._model # TODO(b/113112108): Remove this temporary workaround and restore check for # `tf.data.Dataset` after subclassing the currently used custom data set # representation from it. if 'Dataset' not in str(type(dataset)): raise TypeError('Expected a data set, found {}.'.format( py_typecheck.type_string(type(dataset)))) tf.nest.map_structure(lambda a, b: a.assign(b), model.weights, initial_weights) flat_trainable_weights = tuple(tf.nest.flatten( model.weights.trainable)) @tf.function def reduce_fn(state, batch): """Runs forward_pass on batch and sums the weighted gradients.""" flat_accumulated_grads, batch_weight_sum = state with tf.GradientTape() as tape: output = model.forward_pass(batch) flat_grads = tape.gradient(output.loss, flat_trainable_weights) if self._batch_weight_fn is not None: batch_weight = self._batch_weight_fn(batch) else: batch_weight = tf.cast( tf.shape(output.predictions)[0], tf.float32) flat_accumulated_grads = tuple( accumulator + batch_weight * grad for accumulator, grad in zip( flat_accumulated_grads, flat_grads)) # The TF team is aware of an optimization in the reduce state to avoid # doubling the number of required variables here (e.g. keeping two copies # of all gradients). If you're looking to optimize memory usage this might # be a place to look. return (flat_accumulated_grads, batch_weight_sum + batch_weight) def _zero_initial_state(): """Create a tuple of (tuple of gradient accumulators, batch weight sum).""" return (tuple(tf.zeros_like(w) for w in flat_trainable_weights), tf.constant(0.0)) flat_grad_sums, batch_weight_sum = self._dataset_reduce_fn( reduce_fn=reduce_fn, dataset=dataset, initial_state_fn=_zero_initial_state) grad_sums = tf.nest.pack_sequence_as(model.weights.trainable, flat_grad_sums) # For SGD, the delta is just the negative of the average gradient: weights_delta = tf.nest.map_structure( lambda gradient: -1.0 * gradient / batch_weight_sum, grad_sums) weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) if has_non_finite_delta > 0: weights_delta_weight = tf.constant(0.0) else: weights_delta_weight = batch_weight_sum return optimizer_utils.ClientOutput( weights_delta, weights_delta_weight, model.report_local_outputs(), tensor_utils.to_odict({ 'client_weight': weights_delta_weight, 'has_non_finite_delta': has_non_finite_delta, }))
def __call__(self, dataset, initial_weights): model = self._model optimizer = self._optimizer model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def reduce_fn(state, batch): """Train `tff.learning.Model` on local client batch.""" num_examples_sum, optimizer_state = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model_weights.trainable) optimizer_state, updated_weights = optimizer.next( optimizer_state, model_weights.trainable, gradients) if not isinstance(optimizer, keras_optimizer.KerasOptimizer): # Keras optimizer mutates model variables within the `next` step. tf.nest.map_structure(lambda a, b: a.assign(b), model_weights.trainable, updated_weights) if output.num_examples is None: num_examples_sum += tf.shape(output.predictions, out_type=tf.int64)[0] else: num_examples_sum += tf.cast(output.num_examples, tf.int64) return num_examples_sum, optimizer_state def initial_state_for_reduce_fn(): trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights.trainable) return tf.zeros( shape=[], dtype=tf.int64), optimizer.initialize(trainable_tensor_specs) num_examples_sum, _ = self._dataset_reduce_fn( reduce_fn, dataset, initial_state_fn=initial_state_for_reduce_fn) weights_delta = tf.nest.map_structure(tf.subtract, model_weights.trainable, initial_weights.trainable) model_output = model.report_local_outputs() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) # Zero out the weight if there are any non-finite values. if has_non_finite_delta > 0: # TODO(b/176171842): Zeroing has no effect with unweighted aggregation. weights_delta_weight = tf.constant(0.0) elif self._client_weighting is client_weight_lib.ClientWeighting.NUM_EXAMPLES: weights_delta_weight = tf.cast(num_examples_sum, tf.float32) elif self._client_weighting is client_weight_lib.ClientWeighting.UNIFORM: weights_delta_weight = tf.constant(1.0) else: weights_delta_weight = self._client_weighting(model_output) # TODO(b/176245976): TFF `ClientOutput` structure names are confusing. optimizer_output = collections.OrderedDict( num_examples=num_examples_sum) return optimizer_utils.ClientOutput(weights_delta, weights_delta_weight, model_output, optimizer_output)