def next_fn(state, weights, data): reduced_data = intrinsics.federated_map(tf_data_sum, data) return MeasuredProcessOutput( state, client_works.ClientResult( federated_add(weights.trainable, reduced_data), client_one()), server_zero())
def client_update(optimizer, initial_weights, data): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def reduce_fn(state, batch): """Trains a `tff.learning.Model` on a batch of data.""" num_examples_sum, optimizer_state = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model_weights.trainable) if delta_l2_regularizer > 0.0: proximal_term = tf.nest.map_structure( lambda x, y: delta_l2_regularizer * (y - x), model_weights.trainable, initial_weights.trainable) gradients = tf.nest.map_structure(tf.add, gradients, proximal_term) optimizer_state, updated_weights = optimizer.next( optimizer_state, tuple(tf.nest.flatten(model_weights.trainable)), tuple(tf.nest.flatten(gradients))) updated_weights = tf.nest.pack_sequence_as(model_weights.trainable, updated_weights) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights.trainable, updated_weights) if output.num_examples is None: num_examples_sum += tf.shape(output.predictions, out_type=tf.int64)[0] else: num_examples_sum += tf.cast(output.num_examples, tf.int64) return num_examples_sum, optimizer_state def initial_state_for_reduce_fn(): # TODO(b/161529310): We flatten and convert the trainable specs to tuple, # as "for batch in data:" pattern would try to stack the tensors in list. trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), tuple(tf.nest.flatten(model_weights.trainable))) return (tf.zeros(shape=[], dtype=tf.int64), optimizer.initialize(trainable_tensor_specs)) num_examples, _ = dataset_reduce_fn( reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn) client_update = tf.nest.map_structure(tf.subtract, initial_weights.trainable, model_weights.trainable) model_output = model.report_local_unfinalized_metrics() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? client_update, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(client_update)) client_weight = _choose_client_weight(weighting, has_non_finite_delta, num_examples) return client_works.ClientResult( update=client_update, update_weight=client_weight), model_output
def next_fn(state, weights, data): return MeasuredProcessOutput( state, intrinsics.federated_zip( client_works.ClientResult( federated_add(weights.trainable, data), client_one())), server_zero())
def client_update(initial_weights, dataset): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def reduce_fn(state, batch): """Runs forward_pass on batch and sums the weighted gradients.""" accumulated_gradients, num_examples_sum = state with tf.GradientTape() as tape: output = model.forward_pass(batch) gradients = tape.gradient(output.loss, model_weights.trainable) num_examples = tf.cast(output.num_examples, tf.float32) accumulated_gradients = tuple( accumulator + num_examples * gradient for accumulator, gradient in zip(accumulated_gradients, gradients)) # We may be able to optimize the reduce function to avoid doubling the # number of required variables here (e.g. keeping two copies of all # gradients). If you're looking to optimize memory usage this might be a # place to look. return (accumulated_gradients, num_examples_sum + num_examples) def _zero_initial_state(): """Create a tuple of (gradient accumulators, num examples).""" return tuple( tf.nest.map_structure(tf.zeros_like, model_weights.trainable)), tf.constant( 0, dtype=tf.float32) gradient_sums, num_examples_sum = dataset_reduce_fn( reduce_fn=reduce_fn, dataset=dataset, initial_state_fn=_zero_initial_state) # We now normalize to compute the average gradient over all examples. average_gradient = tf.nest.map_structure( lambda gradient: gradient / num_examples_sum, gradient_sums) model_output = model.report_local_unfinalized_metrics() average_gradient, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(average_gradient)) if has_non_finite_delta > 0: client_weight = tf.constant(0.0) else: client_weight = num_examples_sum return client_works.ClientResult( update=average_gradient, update_weight=client_weight), model_output
def test_type_properties(self, weighting): model_fn = model_examples.LinearRegression optimizer = sgdm.build_sgdm(learning_rate=0.1, momentum=0.9) client_work_process = mime._build_mime_lite_client_work( model_fn, optimizer, weighting) self.assertIsInstance(client_work_process, client_works.ClientWorkProcess) mw_type = model_utils.ModelWeights( trainable=computation_types.to_type([(tf.float32, (2, 1)), tf.float32]), non_trainable=computation_types.to_type([tf.float32])) expected_param_model_weights_type = computation_types.at_clients( mw_type) expected_param_data_type = computation_types.at_clients( computation_types.SequenceType( computation_types.to_type(model_fn().input_spec))) expected_result_type = computation_types.at_clients( client_works.ClientResult( update=mw_type.trainable, update_weight=computation_types.TensorType(tf.float32))) expected_optimizer_state_type = type_conversions.type_from_tensors( optimizer.initialize( type_conversions.type_to_tf_tensor_specs(mw_type.trainable))) expected_aggregator_type = computation_types.to_type( collections.OrderedDict(value_sum_process=(), weight_sum_process=())) expected_state_type = computation_types.at_server( (expected_optimizer_state_type, expected_aggregator_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(train=collections.OrderedDict( loss=tf.float32, num_examples=tf.int32))) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) expected_initialize_type.check_equivalent_to( client_work_process.initialize.type_signature) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_model_weights_type, client_data=expected_param_data_type), result=measured_process.MeasuredProcessOutput( expected_state_type, expected_result_type, expected_measurements_type)) expected_next_type.check_equivalent_to( client_work_process.next.type_signature)
def client_update(optimizer, initial_weights, data): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def reduce_fn(num_examples_sum, batch): """Trains a `tff.learning.Model` on a batch of data.""" with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model_weights.trainable) if delta_l2_regularizer > 0.0: proximal_term = tf.nest.map_structure( lambda x, y: delta_l2_regularizer * (y - x), model_weights.trainable, initial_weights.trainable) gradients = tf.nest.map_structure(tf.add, gradients, proximal_term) grads_and_vars = zip(gradients, model_weights.trainable) optimizer.apply_gradients(grads_and_vars) # TODO(b/199782787): Add a unit test for a model that does not compute # `num_examples` in its forward pass. if output.num_examples is None: num_examples_sum += tf.shape(output.predictions, out_type=tf.int64)[0] else: num_examples_sum += tf.cast(output.num_examples, tf.int64) return num_examples_sum def initial_state_for_reduce_fn(): return tf.zeros(shape=[], dtype=tf.int64) num_examples = dataset_reduce_fn( reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn) client_update = tf.nest.map_structure(tf.subtract, initial_weights.trainable, model_weights.trainable) model_output = model.report_local_unfinalized_metrics() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? client_update, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(client_update)) client_weight = _choose_client_weight(weighting, has_non_finite_delta, num_examples) return client_works.ClientResult( update=client_update, update_weight=client_weight), model_output
def client_update_fn(optimizer, initial_weights, dataset): initial_trainable_weights = initial_weights[0] trainable_tensor_specs = tf.nest.map_structure( tf.TensorSpec.from_tensor, tuple(tf.nest.flatten(initial_trainable_weights))) optimizer_state = optimizer.initialize(trainable_tensor_specs) # Autograph requires we define these variables once outside the loop. model_weights = initial_weights trainable_weights, non_trainable_weights = model_weights num_examples = tf.constant(0, tf.int64) for batch in iter(dataset): trainable_weights, non_trainable_weights = model_weights with tf.GradientTape() as tape: # Must explicitly watch non-variable tensors. tape.watch(trainable_weights) output = model.forward_pass(model_weights, batch, training=True) gradients = tape.gradient(output.loss, trainable_weights) if tf.greater(delta_l2_regularizer, 0.0): proximal_term = tf.nest.map_structure( lambda x, y: delta_l2_regularizer * (y - x), trainable_weights, initial_trainable_weights) gradients = tf.nest.map_structure(tf.add, gradients, proximal_term) optimizer_state, trainable_weights = optimizer.next( optimizer_state, trainable_weights, gradients) num_examples += tf.cast(output.num_examples, tf.int64) model_weights = (trainable_weights, non_trainable_weights) # After all local batches, compute the delta between the trained model # and the initial incoming model weights. client_model_update = tf.nest.map_structure(tf.subtract, initial_trainable_weights, trainable_weights) # TODO(b/229612282): Implement metrics. model_output = () client_model_update, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(client_model_update)) client_weight = _choose_client_weight(weighting, has_non_finite_delta, num_examples) return client_works.ClientResult( update=client_model_update, update_weight=client_weight), model_output
def _compute_kmeans_step(centroids: tf.Tensor, data: tf.data.Dataset): """Performs a k-means step on a dataset. This method finds, for each point in `data`, the closest centroid in `centroids`. It returns a structure `tff.learning.templates.ClientResult` whose `update` attribute is a tuple `(cluster_sums, cluster_weights)`. Here, `cluster_sums` is a tensor of shape matching `centroids`, where `cluster_sums[i, :]` is the sum of all points closest to the i-th centroid, and `cluster_weights` is a `(num_centroids,)` dimensional tensor whose i-th component is the number of points closest to the i-th centroid. The `ClientResult.update_weight` attribute is left empty. Args: centroids: A `tf.Tensor` of centroids, indexed by the first axis. data: A `tf.data.Dataset` of points, each of which has shape matching that of `centroids.shape[1:]`. Returns: A `tff.learning.templates.ClientResult`. """ cluster_sums = tf.zeros_like(centroids) cluster_weights = tf.zeros(shape=(centroids.shape[0], ), dtype=_WEIGHT_DTYPE) num_examples = tf.constant(0, dtype=_WEIGHT_DTYPE) def reduce_fn(state, point): cluster_sums, cluster_weights, num_examples = state closest_centroid = _find_closest_centroid(centroids, point) scatter_index = [[closest_centroid]] cluster_sums = tf.tensor_scatter_nd_add(cluster_sums, scatter_index, tf.expand_dims(point, axis=0)) cluster_weights = tf.tensor_scatter_nd_add(cluster_weights, scatter_index, [1]) num_examples += 1 return cluster_sums, cluster_weights, num_examples cluster_sums, cluster_weights, num_examples = data.reduce( initial_state=(cluster_sums, cluster_weights, num_examples), reduce_func=reduce_fn) stat_output = collections.OrderedDict(num_examples=num_examples) return client_works.ClientResult(update=(cluster_sums, cluster_weights), update_weight=()), stat_output
def make_result(value, data): return client_works.ClientResult(update=value.trainable, update_weight=data.reduce( 0.0, lambda x, y: x + y))
def test_client_result(weights, data): reduced_data = intrinsics.federated_map(tf_data_sum, data) return intrinsics.federated_zip( client_works.ClientResult(update=federated_add(weights.trainable, reduced_data), update_weight=client_one()))
def next_fn(state, unused_weights, unused_data): return MeasuredProcessOutput( state, intrinsics.federated_value(client_works.ClientResult((), ()), placements.CLIENTS), server_zero())
def next_fn(state, weights, data): return MeasuredProcessOutput( state, client_works.ClientResult( weights.trainable + tf_data_sum(data), ()), ())
def client_update(global_optimizer_state, initial_weights, data): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def full_gradient_reduce_fn(state, batch): """Sums individual gradients, to be later divided by num_examples.""" gradient_sum, num_examples_sum = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) if output.num_examples is None: num_examples = tf.shape(output.predictions, out_type=tf.int64)[0] else: num_examples = tf.cast(output.num_examples, tf.int64) # TODO(b/161529310): We flatten and convert to tuple, as tf.data # iterators would try to stack the tensors in list into a single tensor. gradients = tuple( tf.nest.flatten( tape.gradient(output.loss, model_weights.trainable))) gradient_sum = tf.nest.map_structure( lambda g_sum, g: g_sum + g * tf.cast( num_examples, g.dtype), gradient_sum, gradients) num_examples_sum += num_examples return gradient_sum, num_examples_sum def initial_state_for_full_gradient_reduce_fn(): initial_gradient_sum = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape, spec.dtype), tuple(tf.nest.flatten(weight_tensor_specs.trainable))) initial_num_examples_sum = tf.constant(0, tf.int64) return initial_gradient_sum, initial_num_examples_sum full_gradient, num_examples = dataset_reduce_fn( full_gradient_reduce_fn, data, initial_state_for_full_gradient_reduce_fn) # Compute the average gradient. full_gradient = tf.nest.map_structure( lambda g: tf.math.divide_no_nan( g, tf.cast(num_examples, g.dtype)), full_gradient) # Resets the local model variables, including metrics states, as we are # not interested in metrics based on the full gradient evaluation, only # from the subsequent training. model.reset_metrics() def train_reduce_fn(state, batch): with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model_weights.trainable) # Mime Lite keeps optimizer state unchanged during local training. _, updated_weights = optimizer.next(global_optimizer_state, model_weights.trainable, gradients) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights.trainable, updated_weights) return state # Performs local training, updating `tf.Variable`s in `model_weights`. dataset_reduce_fn(train_reduce_fn, data, initial_state_fn=lambda: tf.zeros(shape=[0])) client_weights_delta = tf.nest.map_structure( tf.subtract, initial_weights.trainable, model_weights.trainable) model_output = model.report_local_unfinalized_metrics() # TODO(b/122071074): Consider moving this functionality into aggregation. client_weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(client_weights_delta)) client_weight = _choose_client_weight(client_weighting, has_non_finite_delta, num_examples) return client_works.ClientResult( update=client_weights_delta, update_weight=client_weight), model_output, full_gradient