Ejemplo n.º 1
0
 def next_fn(state, weights, data):
     reduced_data = intrinsics.federated_map(tf_data_sum, data)
     return MeasuredProcessOutput(
         state,
         client_works.ClientResult(
             federated_add(weights.trainable, reduced_data),
             client_one()), server_zero())
Ejemplo n.º 2
0
    def client_update(optimizer, initial_weights, data):
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                              initial_weights)

        def reduce_fn(state, batch):
            """Trains a `tff.learning.Model` on a batch of data."""
            num_examples_sum, optimizer_state = state
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model_weights.trainable)
            if delta_l2_regularizer > 0.0:
                proximal_term = tf.nest.map_structure(
                    lambda x, y: delta_l2_regularizer * (y - x),
                    model_weights.trainable, initial_weights.trainable)
                gradients = tf.nest.map_structure(tf.add, gradients,
                                                  proximal_term)
            optimizer_state, updated_weights = optimizer.next(
                optimizer_state,
                tuple(tf.nest.flatten(model_weights.trainable)),
                tuple(tf.nest.flatten(gradients)))
            updated_weights = tf.nest.pack_sequence_as(model_weights.trainable,
                                                       updated_weights)
            tf.nest.map_structure(lambda a, b: a.assign(b),
                                  model_weights.trainable, updated_weights)

            if output.num_examples is None:
                num_examples_sum += tf.shape(output.predictions,
                                             out_type=tf.int64)[0]
            else:
                num_examples_sum += tf.cast(output.num_examples, tf.int64)

            return num_examples_sum, optimizer_state

        def initial_state_for_reduce_fn():
            # TODO(b/161529310): We flatten and convert the trainable specs to tuple,
            # as "for batch in data:" pattern would try to stack the tensors in list.
            trainable_tensor_specs = tf.nest.map_structure(
                lambda v: tf.TensorSpec(v.shape, v.dtype),
                tuple(tf.nest.flatten(model_weights.trainable)))
            return (tf.zeros(shape=[], dtype=tf.int64),
                    optimizer.initialize(trainable_tensor_specs))

        num_examples, _ = dataset_reduce_fn(
            reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn)
        client_update = tf.nest.map_structure(tf.subtract,
                                              initial_weights.trainable,
                                              model_weights.trainable)
        model_output = model.report_local_unfinalized_metrics()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        client_update, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(client_update))
        client_weight = _choose_client_weight(weighting, has_non_finite_delta,
                                              num_examples)

        return client_works.ClientResult(
            update=client_update, update_weight=client_weight), model_output
Ejemplo n.º 3
0
 def next_fn(state, weights, data):
     return MeasuredProcessOutput(
         state,
         intrinsics.federated_zip(
             client_works.ClientResult(
                 federated_add(weights.trainable, data), client_one())),
         server_zero())
Ejemplo n.º 4
0
    def client_update(initial_weights, dataset):
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                              initial_weights)

        def reduce_fn(state, batch):
            """Runs forward_pass on batch and sums the weighted gradients."""
            accumulated_gradients, num_examples_sum = state

            with tf.GradientTape() as tape:
                output = model.forward_pass(batch)
            gradients = tape.gradient(output.loss, model_weights.trainable)
            num_examples = tf.cast(output.num_examples, tf.float32)
            accumulated_gradients = tuple(
                accumulator + num_examples * gradient
                for accumulator, gradient in zip(accumulated_gradients,
                                                 gradients))

            # We may be able to optimize the reduce function to avoid doubling the
            # number of required variables here (e.g. keeping two copies of all
            # gradients). If you're looking to optimize memory usage this might be a
            # place to look.
            return (accumulated_gradients, num_examples_sum + num_examples)

        def _zero_initial_state():
            """Create a tuple of (gradient accumulators, num examples)."""
            return tuple(
                tf.nest.map_structure(tf.zeros_like,
                                      model_weights.trainable)), tf.constant(
                                          0, dtype=tf.float32)

        gradient_sums, num_examples_sum = dataset_reduce_fn(
            reduce_fn=reduce_fn,
            dataset=dataset,
            initial_state_fn=_zero_initial_state)

        # We now normalize to compute the average gradient over all examples.
        average_gradient = tf.nest.map_structure(
            lambda gradient: gradient / num_examples_sum, gradient_sums)

        model_output = model.report_local_unfinalized_metrics()
        average_gradient, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(average_gradient))
        if has_non_finite_delta > 0:
            client_weight = tf.constant(0.0)
        else:
            client_weight = num_examples_sum

        return client_works.ClientResult(
            update=average_gradient, update_weight=client_weight), model_output
Ejemplo n.º 5
0
    def test_type_properties(self, weighting):
        model_fn = model_examples.LinearRegression
        optimizer = sgdm.build_sgdm(learning_rate=0.1, momentum=0.9)
        client_work_process = mime._build_mime_lite_client_work(
            model_fn, optimizer, weighting)
        self.assertIsInstance(client_work_process,
                              client_works.ClientWorkProcess)

        mw_type = model_utils.ModelWeights(
            trainable=computation_types.to_type([(tf.float32, (2, 1)),
                                                 tf.float32]),
            non_trainable=computation_types.to_type([tf.float32]))
        expected_param_model_weights_type = computation_types.at_clients(
            mw_type)
        expected_param_data_type = computation_types.at_clients(
            computation_types.SequenceType(
                computation_types.to_type(model_fn().input_spec)))
        expected_result_type = computation_types.at_clients(
            client_works.ClientResult(
                update=mw_type.trainable,
                update_weight=computation_types.TensorType(tf.float32)))
        expected_optimizer_state_type = type_conversions.type_from_tensors(
            optimizer.initialize(
                type_conversions.type_to_tf_tensor_specs(mw_type.trainable)))
        expected_aggregator_type = computation_types.to_type(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))
        expected_state_type = computation_types.at_server(
            (expected_optimizer_state_type, expected_aggregator_type))
        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(train=collections.OrderedDict(
                loss=tf.float32, num_examples=tf.int32)))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        expected_initialize_type.check_equivalent_to(
            client_work_process.initialize.type_signature)

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                weights=expected_param_model_weights_type,
                client_data=expected_param_data_type),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, expected_result_type,
                expected_measurements_type))
        expected_next_type.check_equivalent_to(
            client_work_process.next.type_signature)
Ejemplo n.º 6
0
    def client_update(optimizer, initial_weights, data):
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                              initial_weights)

        def reduce_fn(num_examples_sum, batch):
            """Trains a `tff.learning.Model` on a batch of data."""
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model_weights.trainable)
            if delta_l2_regularizer > 0.0:
                proximal_term = tf.nest.map_structure(
                    lambda x, y: delta_l2_regularizer * (y - x),
                    model_weights.trainable, initial_weights.trainable)
                gradients = tf.nest.map_structure(tf.add, gradients,
                                                  proximal_term)
            grads_and_vars = zip(gradients, model_weights.trainable)
            optimizer.apply_gradients(grads_and_vars)

            # TODO(b/199782787): Add a unit test for a model that does not compute
            # `num_examples` in its forward pass.
            if output.num_examples is None:
                num_examples_sum += tf.shape(output.predictions,
                                             out_type=tf.int64)[0]
            else:
                num_examples_sum += tf.cast(output.num_examples, tf.int64)

            return num_examples_sum

        def initial_state_for_reduce_fn():
            return tf.zeros(shape=[], dtype=tf.int64)

        num_examples = dataset_reduce_fn(
            reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn)
        client_update = tf.nest.map_structure(tf.subtract,
                                              initial_weights.trainable,
                                              model_weights.trainable)
        model_output = model.report_local_unfinalized_metrics()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        client_update, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(client_update))
        client_weight = _choose_client_weight(weighting, has_non_finite_delta,
                                              num_examples)
        return client_works.ClientResult(
            update=client_update, update_weight=client_weight), model_output
Ejemplo n.º 7
0
    def client_update_fn(optimizer, initial_weights, dataset):
        initial_trainable_weights = initial_weights[0]
        trainable_tensor_specs = tf.nest.map_structure(
            tf.TensorSpec.from_tensor,
            tuple(tf.nest.flatten(initial_trainable_weights)))
        optimizer_state = optimizer.initialize(trainable_tensor_specs)

        # Autograph requires we define these variables once outside the loop.
        model_weights = initial_weights
        trainable_weights, non_trainable_weights = model_weights
        num_examples = tf.constant(0, tf.int64)
        for batch in iter(dataset):
            trainable_weights, non_trainable_weights = model_weights
            with tf.GradientTape() as tape:
                # Must explicitly watch non-variable tensors.
                tape.watch(trainable_weights)
                output = model.forward_pass(model_weights,
                                            batch,
                                            training=True)
            gradients = tape.gradient(output.loss, trainable_weights)
            if tf.greater(delta_l2_regularizer, 0.0):
                proximal_term = tf.nest.map_structure(
                    lambda x, y: delta_l2_regularizer * (y - x),
                    trainable_weights, initial_trainable_weights)
                gradients = tf.nest.map_structure(tf.add, gradients,
                                                  proximal_term)
            optimizer_state, trainable_weights = optimizer.next(
                optimizer_state, trainable_weights, gradients)
            num_examples += tf.cast(output.num_examples, tf.int64)
            model_weights = (trainable_weights, non_trainable_weights)
        # After all local batches, compute the delta between the trained model
        # and the initial incoming model weights.
        client_model_update = tf.nest.map_structure(tf.subtract,
                                                    initial_trainable_weights,
                                                    trainable_weights)
        # TODO(b/229612282): Implement metrics.
        model_output = ()
        client_model_update, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(client_model_update))
        client_weight = _choose_client_weight(weighting, has_non_finite_delta,
                                              num_examples)
        return client_works.ClientResult(
            update=client_model_update,
            update_weight=client_weight), model_output
Ejemplo n.º 8
0
def _compute_kmeans_step(centroids: tf.Tensor, data: tf.data.Dataset):
    """Performs a k-means step on a dataset.

  This method finds, for each point in `data`, the closest centroid in
  `centroids`. It returns a structure `tff.learning.templates.ClientResult`
  whose `update` attribute is a tuple `(cluster_sums, cluster_weights)`. Here,
  `cluster_sums` is a tensor of shape matching `centroids`, where
  `cluster_sums[i, :]` is the sum of all points closest to the i-th centroid,
  and `cluster_weights` is a `(num_centroids,)` dimensional tensor whose i-th
  component is the number of points closest to the i-th centroid. The
  `ClientResult.update_weight` attribute is left empty.

  Args:
    centroids: A `tf.Tensor` of centroids, indexed by the first axis.
    data: A `tf.data.Dataset` of points, each of which has shape matching that
      of `centroids.shape[1:]`.

  Returns:
   A `tff.learning.templates.ClientResult`.
  """
    cluster_sums = tf.zeros_like(centroids)
    cluster_weights = tf.zeros(shape=(centroids.shape[0], ),
                               dtype=_WEIGHT_DTYPE)
    num_examples = tf.constant(0, dtype=_WEIGHT_DTYPE)

    def reduce_fn(state, point):
        cluster_sums, cluster_weights, num_examples = state
        closest_centroid = _find_closest_centroid(centroids, point)
        scatter_index = [[closest_centroid]]
        cluster_sums = tf.tensor_scatter_nd_add(cluster_sums, scatter_index,
                                                tf.expand_dims(point, axis=0))
        cluster_weights = tf.tensor_scatter_nd_add(cluster_weights,
                                                   scatter_index, [1])
        num_examples += 1
        return cluster_sums, cluster_weights, num_examples

    cluster_sums, cluster_weights, num_examples = data.reduce(
        initial_state=(cluster_sums, cluster_weights, num_examples),
        reduce_func=reduce_fn)

    stat_output = collections.OrderedDict(num_examples=num_examples)
    return client_works.ClientResult(update=(cluster_sums, cluster_weights),
                                     update_weight=()), stat_output
Ejemplo n.º 9
0
 def make_result(value, data):
     return client_works.ClientResult(update=value.trainable,
                                      update_weight=data.reduce(
                                          0.0, lambda x, y: x + y))
Ejemplo n.º 10
0
def test_client_result(weights, data):
    reduced_data = intrinsics.federated_map(tf_data_sum, data)
    return intrinsics.federated_zip(
        client_works.ClientResult(update=federated_add(weights.trainable,
                                                       reduced_data),
                                  update_weight=client_one()))
Ejemplo n.º 11
0
 def next_fn(state, unused_weights, unused_data):
     return MeasuredProcessOutput(
         state,
         intrinsics.federated_value(client_works.ClientResult((), ()),
                                    placements.CLIENTS), server_zero())
Ejemplo n.º 12
0
 def next_fn(state, weights, data):
     return MeasuredProcessOutput(
         state,
         client_works.ClientResult(
             weights.trainable + tf_data_sum(data), ()), ())
Ejemplo n.º 13
0
        def client_update(global_optimizer_state, initial_weights, data):
            model_weights = model_utils.ModelWeights.from_model(model)
            tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                                  initial_weights)

            def full_gradient_reduce_fn(state, batch):
                """Sums individual gradients, to be later divided by num_examples."""
                gradient_sum, num_examples_sum = state
                with tf.GradientTape() as tape:
                    output = model.forward_pass(batch, training=True)
                if output.num_examples is None:
                    num_examples = tf.shape(output.predictions,
                                            out_type=tf.int64)[0]
                else:
                    num_examples = tf.cast(output.num_examples, tf.int64)
                # TODO(b/161529310): We flatten and convert to tuple, as tf.data
                # iterators would try to stack the tensors in list into a single tensor.
                gradients = tuple(
                    tf.nest.flatten(
                        tape.gradient(output.loss, model_weights.trainable)))
                gradient_sum = tf.nest.map_structure(
                    lambda g_sum, g: g_sum + g * tf.cast(
                        num_examples, g.dtype), gradient_sum, gradients)
                num_examples_sum += num_examples
                return gradient_sum, num_examples_sum

            def initial_state_for_full_gradient_reduce_fn():
                initial_gradient_sum = tf.nest.map_structure(
                    lambda spec: tf.zeros(spec.shape, spec.dtype),
                    tuple(tf.nest.flatten(weight_tensor_specs.trainable)))
                initial_num_examples_sum = tf.constant(0, tf.int64)
                return initial_gradient_sum, initial_num_examples_sum

            full_gradient, num_examples = dataset_reduce_fn(
                full_gradient_reduce_fn, data,
                initial_state_for_full_gradient_reduce_fn)
            # Compute the average gradient.
            full_gradient = tf.nest.map_structure(
                lambda g: tf.math.divide_no_nan(
                    g, tf.cast(num_examples, g.dtype)), full_gradient)

            # Resets the local model variables, including metrics states, as we are
            # not interested in metrics based on the full gradient evaluation, only
            # from the subsequent training.
            model.reset_metrics()

            def train_reduce_fn(state, batch):
                with tf.GradientTape() as tape:
                    output = model.forward_pass(batch, training=True)
                gradients = tape.gradient(output.loss, model_weights.trainable)
                # Mime Lite keeps optimizer state unchanged during local training.
                _, updated_weights = optimizer.next(global_optimizer_state,
                                                    model_weights.trainable,
                                                    gradients)
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      model_weights.trainable, updated_weights)
                return state

            # Performs local training, updating `tf.Variable`s in `model_weights`.
            dataset_reduce_fn(train_reduce_fn,
                              data,
                              initial_state_fn=lambda: tf.zeros(shape=[0]))

            client_weights_delta = tf.nest.map_structure(
                tf.subtract, initial_weights.trainable,
                model_weights.trainable)
            model_output = model.report_local_unfinalized_metrics()

            # TODO(b/122071074): Consider moving this functionality into aggregation.
            client_weights_delta, has_non_finite_delta = (
                tensor_utils.zero_all_if_any_non_finite(client_weights_delta))
            client_weight = _choose_client_weight(client_weighting,
                                                  has_non_finite_delta,
                                                  num_examples)
            return client_works.ClientResult(
                update=client_weights_delta,
                update_weight=client_weight), model_output, full_gradient