コード例 #1
0
    def __call__(self, dataset, initial_weights):
        # TODO(b/123898430): The control dependencies below have been inserted as a
        # temporary workaround. These control dependencies need to be removed, and
        # defuns and datasets supported together fully.
        model = self._model

        # TODO(b/113112108): Remove this temporary workaround and restore check for
        # `tf.data.Dataset` after subclassing the currently used custom data set
        # representation from it.
        if 'Dataset' not in str(type(dataset)):
            raise TypeError('Expected a data set, found {}.'.format(
                py_typecheck.type_string(type(dataset))))

        # TODO(b/120801384): We should initialize model.local_variables here.
        # Or, we may just need a convention that TFF initializes all variables
        # before invoking the TF function.

        # We must assign to a variable here in order to use control_dependencies.
        dummy_weights = nest.map_structure(tf.assign, model.weights,
                                           initial_weights)

        with tf.control_dependencies(list(dummy_weights.trainable.values())):

            def reduce_fn(dummy_state, batch):
                """Runs `tff.learning.Model.train_on_batch` on local client batch."""
                output = model.train_on_batch(batch)
                tf.assign_add(self._num_examples,
                              tf.shape(output.predictions)[0])
                return dummy_state

            # TODO(b/124477598): Remove dummy_output when b/121400757 fixed.
            dummy_output = dataset.reduce(initial_state=tf.constant(0.0),
                                          reduce_func=reduce_fn)

        with tf.control_dependencies([dummy_output]):
            weights_delta = nest.map_structure(tf.subtract,
                                               model.weights.trainable,
                                               initial_weights.trainable)
            aggregated_outputs = model.report_local_outputs()
            weights_delta_weight = self._client_weight_fn(aggregated_outputs)  # pylint:disable=not-callable

            # TODO(b/122071074): Consider moving this functionality into
            # tff.federated_average?
            weights_delta, has_non_finite_delta = (
                tensor_utils.zero_all_if_any_non_finite(weights_delta))
            weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0),
                                           lambda: weights_delta_weight,
                                           lambda: tf.constant(0))

            return optimizer_utils.ClientOutput(
                weights_delta, weights_delta_weight, aggregated_outputs,
                tensor_utils.to_odict({
                    'num_examples':
                    self._num_examples.value(),
                    'has_non_finite_delta':
                    has_non_finite_delta,
                    'workaround for b/121400757':
                    dummy_output,
                }))
コード例 #2
0
 def __call__(self, dataset, initial_weights):
     """Dummy client delta which simply returns 1.0 for all parameters."""
     client_weight = tf.constant(1.0)
     return optimizer_utils.ClientOutput(
         tf.nest.map_structure(tf.ones_like, initial_weights.trainable),
         weights_delta_weight=client_weight,
         model_output=self._model.report_local_outputs(),
         optimizer_output={'client_weight': client_weight})
コード例 #3
0
    def __call__(self, dataset, initial_weights):
        # TODO(b/113112108): Remove this temporary workaround and restore check for
        # `tf.data.Dataset` after subclassing the currently used custom data set
        # representation from it.
        if 'Dataset' not in str(type(dataset)):
            raise TypeError('Expected a data set, found {}.'.format(
                py_typecheck.type_string(type(dataset))))

        model = self._model
        dummy_weights = nest.map_structure(tf.assign, model.weights,
                                           initial_weights)

        def reduce_fn(accumulated_grads, batch):
            """Runs forward_pass on batch."""
            with tf.contrib.eager.GradientTape() as tape:
                output = model.forward_pass(batch)

            with tf.control_dependencies(list(output)):
                flat_vars = nest.flatten(model.weights.trainable)
                grads = nest.pack_sequence_as(
                    accumulated_grads, tape.gradient(output.loss, flat_vars))

                if self._batch_weight_fn is not None:
                    batch_weight = self._batch_weight_fn(batch)
                else:
                    batch_weight = tf.cast(
                        tf.shape(output.predictions)[0], tf.float32)

            tf.assign_add(self._batch_weight_sum, batch_weight)
            return nest.map_structure(
                lambda accumulator, grad: accumulator + batch_weight * grad,
                accumulated_grads, grads)

        with tf.control_dependencies(list(dummy_weights.trainable.values())):
            self._grad_sum_vars = dataset.reduce(
                initial_state=self._grad_sum_vars, reduce_func=reduce_fn)

        with tf.control_dependencies(
            [tf.identity(v) for v in self._grad_sum_vars.values()]):
            # For SGD, the delta is just the negative of the average gradient:
            weights_delta = nest.map_structure(
                lambda gradient: -1.0 * gradient / self._batch_weight_sum,
                self._grad_sum_vars)
            weights_delta, has_non_finite_delta = (
                tensor_utils.zero_all_if_any_non_finite(weights_delta))
            weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0),
                                           lambda: self._batch_weight_sum,
                                           lambda: tf.constant(0.0))

            return optimizer_utils.ClientOutput(
                weights_delta, weights_delta_weight,
                model.report_local_outputs(),
                tensor_utils.to_odict({
                    'client_weight':
                    weights_delta_weight,
                    'has_non_finite_delta':
                    has_non_finite_delta,
                }))
コード例 #4
0
 def __call__(self, dataset, initial_weights):
     trainable_weights_delta = nest.map_structure(lambda x: -1.0 * x,
                                                  initial_weights.trainable)
     client_weight = tf.constant(3.0)
     return optimizer_utils.ClientOutput(
         trainable_weights_delta,
         weights_delta_weight=client_weight,
         model_output=self._model.report_local_outputs(),
         optimizer_output={'client_weight': client_weight})
コード例 #5
0
    def __call__(self, dataset, initial_weights):
        # N.B. When not in eager mode, this code must be wrapped as a defun
        # as it uses program-order semantics to avoid adding many explicit
        # control dependencies.
        model = self._model
        py_typecheck.check_type(dataset, tf.data.Dataset)

        nest.map_structure(tf.assign, model.weights, initial_weights)

        @tf.contrib.eager.function(autograph=False)
        def reduce_fn(dummy_state, batch):
            """Runs forward_pass on batch."""
            with tf.contrib.eager.GradientTape() as tape:
                output = model.forward_pass(batch)

            flat_vars = nest.flatten(model.weights.trainable)
            grads = nest.pack_sequence_as(
                self._grad_sum_vars, tape.gradient(output.loss, flat_vars))

            if self._batch_weight_fn is not None:
                batch_weight = self._batch_weight_fn(batch)
            else:
                batch_weight = tf.cast(
                    tf.shape(output.predictions)[0], tf.float32)

            tf.assign_add(self._batch_weight_sum, batch_weight)
            nest.map_structure(
                lambda v, g:  # pylint:disable=g-long-lambda
                tf.assign_add(v, batch_weight * g),
                self._grad_sum_vars,
                grads)

            return dummy_state

        # TODO(b/121400757): Remove dummy_output when bug fixed.
        dummy_output = dataset.reduce(initial_state=tf.constant(0.0),
                                      reduce_func=reduce_fn)

        # For SGD, the delta is just the negative of the average gradient:
        # TODO(b/109733734): Might be better to send the weighted grad sums
        # and the denominator separately?
        weights_delta = nest.map_structure(
            lambda g: -1.0 * g / self._batch_weight_sum, self._grad_sum_vars)
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0),
                                       lambda: self._batch_weight_sum,
                                       lambda: tf.constant(0.0))

        return optimizer_utils.ClientOutput(
            weights_delta, weights_delta_weight, model.report_local_outputs(),
            tensor_utils.to_odict({
                'client_weight': weights_delta_weight,
                'has_non_finite_delta': has_non_finite_delta,
                'workaround for b/121400757': dummy_output,
            }))
コード例 #6
0
    def __call__(self, dataset, initial_weights):
        model = self._model
        optimizer = self._optimizer
        tf.nest.map_structure(lambda a, b: a.assign(b), model.weights,
                              initial_weights)

        def reduce_fn(num_examples_sum, batch):
            """Train `tff.learning.Model` on local client batch."""
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model.weights.trainable)
            optimizer.apply_gradients(zip(gradients, model.weights.trainable))

            if output.num_examples is None:
                return num_examples_sum + tf.shape(output.predictions,
                                                   out_type=tf.int64)[0]
            else:
                return num_examples_sum + tf.cast(output.num_examples,
                                                  tf.int64)

        num_examples_sum = self._dataset_reduce_fn(
            reduce_fn,
            dataset,
            initial_state_fn=lambda: tf.zeros(shape=[], dtype=tf.int64))

        weights_delta = tf.nest.map_structure(tf.subtract,
                                              model.weights.trainable,
                                              initial_weights.trainable)
        model_output = model.report_local_outputs()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        # Zero out the weight if there are any non-finite values.
        if has_non_finite_delta > 0:
            # TODO(b/176171842): Zeroing has no effect with unweighted aggregation.
            weights_delta_weight = tf.constant(0.0)
        elif self._client_weighting is ClientWeighting.NUM_EXAMPLES:
            weights_delta_weight = tf.cast(num_examples_sum, tf.float32)
        elif self._client_weighting is ClientWeighting.UNIFORM:
            weights_delta_weight = tf.constant(1.0)
        else:
            weights_delta_weight = self._client_weighting(model_output)
        # TODO(b/176245976): TFF `ClientOutput` structure names are confusing.
        optimizer_output = collections.OrderedDict(
            num_examples=num_examples_sum)
        return optimizer_utils.ClientOutput(weights_delta,
                                            weights_delta_weight, model_output,
                                            optimizer_output)
コード例 #7
0
    def __call__(self, dataset, initial_weights):
        # TODO(b/113112108): Remove this temporary workaround and restore check for
        # `tf.data.Dataset` after subclassing the currently used custom data set
        # representation from it.
        if 'Dataset' not in str(type(dataset)):
            raise TypeError('Expected a data set, found {}.'.format(
                py_typecheck.type_string(type(dataset))))

        model = self._model
        tf.nest.map_structure(lambda a, b: a.assign(b), model.weights,
                              initial_weights)

        @tf.function
        def reduce_fn(num_examples_sum, batch):
            """Runs `tff.learning.Model.train_on_batch` on local client batch."""
            output = model.train_on_batch(batch)
            if output.num_examples is None:
                return num_examples_sum + tf.shape(output.predictions)[0]
            else:
                return num_examples_sum + output.num_examples

        num_examples_sum = dataset.reduce(initial_state=tf.constant(0),
                                          reduce_func=reduce_fn)

        weights_delta = tf.nest.map_structure(tf.subtract,
                                              model.weights.trainable,
                                              initial_weights.trainable)
        aggregated_outputs = model.report_local_outputs()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        if self._client_weight_fn is None:
            weights_delta_weight = tf.cast(num_examples_sum, tf.float32)
        else:
            weights_delta_weight = self._client_weight_fn(aggregated_outputs)
        # Zero out the weight if there are any non-finite values.
        if has_non_finite_delta > 0:
            weights_delta_weight = tf.constant(0.0)

        return optimizer_utils.ClientOutput(
            weights_delta, weights_delta_weight, aggregated_outputs,
            tensor_utils.to_odict({
                'num_examples': num_examples_sum,
                'has_non_finite_delta': has_non_finite_delta,
            }))
コード例 #8
0
    def __call__(self, dataset, initial_weights):
        model = self._model
        optimizer = self._optimizer
        tf.nest.map_structure(lambda a, b: a.assign(b), model.weights,
                              initial_weights)

        @tf.function
        def reduce_fn(num_examples_sum, batch):
            """Runs `tff.learning.Model.train_on_batch` on local client batch."""
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model.weights.trainable)
            optimizer.apply_gradients(zip(gradients, model.weights.trainable))

            if output.num_examples is None:
                return num_examples_sum + tf.shape(output.predictions)[0]
            else:
                return num_examples_sum + output.num_examples

        num_examples_sum = dataset.reduce(initial_state=tf.constant(0),
                                          reduce_func=reduce_fn)

        weights_delta = tf.nest.map_structure(tf.subtract,
                                              model.weights.trainable,
                                              initial_weights.trainable)
        aggregated_outputs = model.report_local_outputs()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        # Zero out the weight if there are any non-finite values.
        if has_non_finite_delta > 0:
            weights_delta_weight = tf.constant(0.0)
        elif self._client_weight_fn is None:
            weights_delta_weight = tf.cast(num_examples_sum, tf.float32)
        else:
            weights_delta_weight = self._client_weight_fn(aggregated_outputs)

        return optimizer_utils.ClientOutput(
            weights_delta, weights_delta_weight, aggregated_outputs,
            collections.OrderedDict(
                num_examples=num_examples_sum,
                has_non_finite_delta=has_non_finite_delta,
            ))
コード例 #9
0
    def __call__(self, dataset, initial_weights):
        # Iterate over the dataset to get new metric values.
        def reduce_fn(dummy, batch):
            self._model.train_on_batch(batch)
            return dummy

        dataset.reduce(tf.constant(0.0), reduce_fn)

        # Create some fake weight deltas to send back.
        trainable_weights_delta = tf.nest.map_structure(
            lambda x: -tf.ones_like(x), initial_weights.trainable)
        client_weight = tf.constant(1.0)
        return optimizer_utils.ClientOutput(
            trainable_weights_delta,
            weights_delta_weight=client_weight,
            model_output=self._model.report_local_outputs(),
            optimizer_output={'client_weight': client_weight})
コード例 #10
0
    def __call__(self, dataset, initial_weights):
        del initial_weights
        model = self._model

        @tf.function
        def reduce_fn_num_examples(num_examples_sum, batch):
            """Count number of examples."""
            num_examples_in_batch = tf.shape(batch['x'])[0]
            return num_examples_sum + num_examples_in_batch

        @tf.function
        def reduce_fn_dataset_mean(sum_vector, batch):
            """Sum all the examples in the local dataset."""
            sum_batch = tf.reshape(tf.reduce_sum(batch['x'], [0]), (-1, 1))
            return sum_vector + sum_batch

        num_examples_sum = dataset.reduce(initial_state=tf.constant(0),
                                          reduce_func=reduce_fn_num_examples)

        example_vector_sum = dataset.reduce(initial_state=tf.zeros((DIM, 1)),
                                            reduce_func=reduce_fn_dataset_mean)

        # create an ordered dictionary with the same type as model.trainable
        # containing a mean of all the examples in the local dataset
        # Note: this works for a linear model only (as in the example above)
        key = list(model.weights.trainable.keys())[0]
        weights_delta = collections.OrderedDict(
            {key: example_vector_sum / tf.cast(num_examples_sum, tf.float32)})

        aggregated_outputs = model.report_local_outputs()
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))

        weights_delta_weight = tf.cast(num_examples_sum, tf.float32)

        return optimizer_utils.ClientOutput(
            weights_delta, weights_delta_weight, aggregated_outputs,
            collections.OrderedDict([
                ('num_examples', num_examples_sum),
                ('has_non_finite_delta', has_non_finite_delta),
            ]))
コード例 #11
0
  def __call__(self, dataset, initial_weights):
    # Iterate over the dataset to get new metric values.
    def reduce_fn(dummy, batch):
      with tf.GradientTape() as tape:
        output = self._model.forward_pass(batch)
      gradients = tape.gradient(output.loss, self._model.trainable_variables)
      self._optimizer.apply_gradients(
          zip(gradients, self._model.trainable_variables))
      return dummy

    dataset.reduce(tf.constant(0.0), reduce_fn)

    # Create some fake weight deltas to send back.
    trainable_weights_delta = tf.nest.map_structure(lambda x: -tf.ones_like(x),
                                                    initial_weights.trainable)
    client_weight = tf.constant(1.0)
    return optimizer_utils.ClientOutput(
        trainable_weights_delta,
        weights_delta_weight=client_weight,
        model_output=self._model.report_local_outputs(),
        optimizer_output=collections.OrderedDict([('client_weight',
                                                   client_weight)]))
コード例 #12
0
    def __call__(self, dataset, initial_weights):
        # Iterate over the dataset to get new metric values.
        def reduce_fn(whimsy, batch):
            with tf.GradientTape() as tape:
                output = self._model.forward_pass(batch)
            gradients = tape.gradient(output.loss,
                                      self._model.trainable_variables)
            self._optimizer.apply_gradients(
                zip(gradients, self._model.trainable_variables))
            return whimsy

        dataset.reduce(tf.constant(0.0), reduce_fn)

        # Create some fake weight deltas to send back.
        trainable_weights_delta = tf.nest.map_structure(
            lambda x: -tf.ones_like(x), initial_weights.trainable)
        client_weight = tf.constant(1.0)
        model_output = self._model.report_local_unfinalized_metrics()
        return optimizer_utils.ClientOutput(
            weights_delta=trainable_weights_delta,
            weights_delta_weight=client_weight,
            model_output=model_output,
            # We avoid using a `None` as it is unsupported in graph serialization
            optimizer_output=())
コード例 #13
0
    def __call__(self, dataset, initial_weights):
        model = self._model

        # TODO(b/113112108): Remove this temporary workaround and restore check for
        # `tf.data.Dataset` after subclassing the currently used custom data set
        # representation from it.
        if 'Dataset' not in str(type(dataset)):
            raise TypeError('Expected a data set, found {}.'.format(
                py_typecheck.type_string(type(dataset))))

        tf.nest.map_structure(lambda a, b: a.assign(b), model.weights,
                              initial_weights)
        flat_trainable_weights = tuple(tf.nest.flatten(
            model.weights.trainable))

        @tf.function
        def reduce_fn(state, batch):
            """Runs forward_pass on batch and sums the weighted gradients."""
            flat_accumulated_grads, batch_weight_sum = state

            with tf.GradientTape() as tape:
                output = model.forward_pass(batch)
            flat_grads = tape.gradient(output.loss, flat_trainable_weights)

            if self._batch_weight_fn is not None:
                batch_weight = self._batch_weight_fn(batch)
            else:
                batch_weight = tf.cast(
                    tf.shape(output.predictions)[0], tf.float32)

            flat_accumulated_grads = tuple(
                accumulator + batch_weight * grad for accumulator, grad in zip(
                    flat_accumulated_grads, flat_grads))

            # The TF team is aware of an optimization in the reduce state to avoid
            # doubling the number of required variables here (e.g. keeping two copies
            # of all gradients). If you're looking to optimize memory usage this might
            # be a place to look.
            return (flat_accumulated_grads, batch_weight_sum + batch_weight)

        def _zero_initial_state():
            """Create a tuple of (tuple of gradient accumulators, batch weight sum)."""
            return (tuple(tf.zeros_like(w)
                          for w in flat_trainable_weights), tf.constant(0.0))

        flat_grad_sums, batch_weight_sum = self._dataset_reduce_fn(
            reduce_fn=reduce_fn,
            dataset=dataset,
            initial_state_fn=_zero_initial_state)
        grad_sums = tf.nest.pack_sequence_as(model.weights.trainable,
                                             flat_grad_sums)

        # For SGD, the delta is just the negative of the average gradient:
        weights_delta = tf.nest.map_structure(
            lambda gradient: -1.0 * gradient / batch_weight_sum, grad_sums)
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        if has_non_finite_delta > 0:
            weights_delta_weight = tf.constant(0.0)
        else:
            weights_delta_weight = batch_weight_sum
        return optimizer_utils.ClientOutput(
            weights_delta, weights_delta_weight, model.report_local_outputs(),
            tensor_utils.to_odict({
                'client_weight': weights_delta_weight,
                'has_non_finite_delta': has_non_finite_delta,
            }))
コード例 #14
0
    def __call__(self, dataset, initial_weights):
        model = self._model
        optimizer = self._optimizer
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                              initial_weights)

        def reduce_fn(state, batch):
            """Train `tff.learning.Model` on local client batch."""
            num_examples_sum, optimizer_state = state

            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model_weights.trainable)
            optimizer_state, updated_weights = optimizer.next(
                optimizer_state, model_weights.trainable, gradients)
            if not isinstance(optimizer, keras_optimizer.KerasOptimizer):
                # Keras optimizer mutates model variables within the `next` step.
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      model_weights.trainable, updated_weights)

            if output.num_examples is None:
                num_examples_sum += tf.shape(output.predictions,
                                             out_type=tf.int64)[0]
            else:
                num_examples_sum += tf.cast(output.num_examples, tf.int64)

            return num_examples_sum, optimizer_state

        def initial_state_for_reduce_fn():
            trainable_tensor_specs = tf.nest.map_structure(
                lambda v: tf.TensorSpec(v.shape, v.dtype),
                model_weights.trainable)
            return tf.zeros(
                shape=[],
                dtype=tf.int64), optimizer.initialize(trainable_tensor_specs)

        num_examples_sum, _ = self._dataset_reduce_fn(
            reduce_fn, dataset, initial_state_fn=initial_state_for_reduce_fn)

        weights_delta = tf.nest.map_structure(tf.subtract,
                                              model_weights.trainable,
                                              initial_weights.trainable)
        model_output = model.report_local_outputs()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        # Zero out the weight if there are any non-finite values.
        if has_non_finite_delta > 0:
            # TODO(b/176171842): Zeroing has no effect with unweighted aggregation.
            weights_delta_weight = tf.constant(0.0)
        elif self._client_weighting is client_weight_lib.ClientWeighting.NUM_EXAMPLES:
            weights_delta_weight = tf.cast(num_examples_sum, tf.float32)
        elif self._client_weighting is client_weight_lib.ClientWeighting.UNIFORM:
            weights_delta_weight = tf.constant(1.0)
        else:
            weights_delta_weight = self._client_weighting(model_output)
        # TODO(b/176245976): TFF `ClientOutput` structure names are confusing.
        optimizer_output = collections.OrderedDict(
            num_examples=num_examples_sum)
        return optimizer_utils.ClientOutput(weights_delta,
                                            weights_delta_weight, model_output,
                                            optimizer_output)