Exemplo n.º 1
0
    def server_update(global_model, mean_model_delta, optimizer_state):
        """Updates the global model with the mean model update from clients."""
        with tf.init_scope():
            # Create a structure of variables that the server optimizer can update.
            model_variables = tf.nest.map_structure(
                lambda t: tf.Variable(initial_value=tf.zeros(t.shape, t.dtype)
                                      ), global_model)
            optimizer = keras_optimizer.build_or_verify_tff_optimizer(
                server_optimizer_fn,
                model_variables.trainable,
                disjoint_init_and_next=True)

        # Set the variables to the current global model, the optimizer will
        # update these variables.
        tf.nest.map_structure(lambda a, b: a.assign(b), model_variables,
                              global_model)
        # We might have a NaN value e.g. if all of the clients processed had no
        # data, so the denominator in the federated_mean is zero. If we see any
        # NaNs, zero out the whole update.
        # TODO(b/124538167): We should increment a server counter to
        # track the fact a non-finite weights_delta was encountered.
        finite_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite(
            mean_model_delta)
        # Update the global model variables with the delta as a pseudo-gradient.
        negative_weights_delta = tf.nest.map_structure(lambda w: -1.0 * w,
                                                       finite_weights_delta)
        optimizer_state, updated_weights = optimizer.next(
            optimizer_state, model_variables.trainable, negative_weights_delta)
        # Keras optimizers mutate model variables in with the `next` step above, so
        # we skip calling the assignment for those optimizers.
        if not isinstance(optimizer, keras_optimizer.KerasOptimizer):
            tf.nest.map_structure(lambda a, b: a.assign(b),
                                  model_variables.trainable, updated_weights)
        return model_variables, optimizer_state
Exemplo n.º 2
0
 def test_build_tff_optimizer_keras(self, specs, disjoint_init_and_next):
     optimizer_fn = lambda: tf.keras.optimizers.SGD(0.1)
     variables = tf.nest.map_structure(
         lambda s: tf.Variable(tf.ones(s.shape, s.dtype)), specs)
     optimizer = keras_optimizer.build_or_verify_tff_optimizer(
         optimizer_fn, variables, disjoint_init_and_next)
     self.assertIsInstance(optimizer, optimizer_base.Optimizer)
Exemplo n.º 3
0
 def init_fn():
     tensor_specs = type_conversions.type_to_tf_tensor_specs(
         model_weights_type.trainable)
     model_variables = tf.nest.map_structure(
         lambda s: tf.Variable(initial_value=tf.zeros(s.shape, s.dtype)),
         tensor_specs)
     optimizer = keras_optimizer.build_or_verify_tff_optimizer(
         optimizer_fn, model_variables, disjoint_init_and_next=True)
     return optimizer.initialize(tensor_specs)
Exemplo n.º 4
0
 def server_init_tf():
     """Initialize the TensorFlow-only portions of the server state."""
     model_weights = reconstruction_utils.get_global_variables(model_fn())
     optimizer = keras_optimizer.build_or_verify_tff_optimizer(
         server_optimizer_fn,
         model_weights.trainable,
         disjoint_init_and_next=True)
     trainable_tensor_specs = tf.nest.map_structure(
         lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights.trainable)
     optimizer_state = optimizer.initialize(trainable_tensor_specs)
     return model_weights, optimizer_state
Exemplo n.º 5
0
    def server_update(server_state, weights_delta, aggregator_state,
                      broadcaster_state):
        """Updates the `server_state` based on `weights_delta`.

    Args:
      server_state: A `tff.learning.framework.ServerState`, the state to be
        updated.
      weights_delta: The model delta in global trainable variables from clients.
      aggregator_state: The state of the aggregator after performing
        aggregation.
      broadcaster_state: The state of the broadcaster after broadcasting.

    Returns:
      The updated `tff.learning.framework.ServerState`.
    """
        with tf.init_scope():
            model = model_fn()
        global_model_weights = reconstruction_utils.get_global_variables(model)
        optimizer = keras_optimizer.build_or_verify_tff_optimizer(
            server_optimizer_fn,
            global_model_weights.trainable,
            disjoint_init_and_next=True)
        optimizer_state = server_state.optimizer_state

        # Initialize the model with the current state.
        tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights,
                              server_state.model)

        weights_delta, has_non_finite_weight = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))

        # We ignore the update if the weights_delta is non finite.
        if tf.equal(has_non_finite_weight, 0):
            negative_weights_delta = tf.nest.map_structure(
                lambda w: -1.0 * w, weights_delta)
            optimizer_state, updated_weights = optimizer.next(
                optimizer_state, global_model_weights.trainable,
                negative_weights_delta)
            if not isinstance(optimizer, keras_optimizer.KerasOptimizer):
                # Keras optimizer mutates model variables within the `next` step.
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      global_model_weights.trainable,
                                      updated_weights)

        # Create a new state based on the updated model.
        return structure.update_struct(
            server_state,
            model=global_model_weights,
            optimizer_state=optimizer_state,
            model_broadcast_state=broadcaster_state,
            delta_aggregate_state=aggregator_state,
        )
Exemplo n.º 6
0
 def model_and_optimizer_init_fn(
 ) -> Tuple[model_utils.ModelWeights, List[tf.Variable]]:
     """Returns initial model weights and state of the global optimizer."""
     model_variables = model_utils.ModelWeights.from_model(model_fn())
     optimizer = keras_optimizer.build_or_verify_tff_optimizer(
         server_optimizer_fn,
         model_variables.trainable,
         disjoint_init_and_next=True)
     trainable_tensor_specs = tf.nest.map_structure(
         lambda v: tf.TensorSpec(v.shape, v.dtype),
         model_variables.trainable)
     optimizer_state = optimizer.initialize(trainable_tensor_specs)
     return model_variables, optimizer_state
Exemplo n.º 7
0
    def server_init() -> Tuple[model_utils.ModelWeights, List[tf.Variable]]:
        """Returns initial `tff.learning.framework.ServerState`.

    Returns:
      A `tuple` of `tff.learning.framework.ModelWeights` and a `list` of
      `tf.Variable`s for the global optimizer state.
    """
        model_variables = model_utils.ModelWeights.from_model(model_fn())
        optimizer = keras_optimizer.build_or_verify_tff_optimizer(
            server_optimizer_fn,
            model_variables.trainable,
            disjoint_init_and_next=True)
        trainable_tensor_specs = tf.nest.map_structure(
            lambda v: tf.TensorSpec(v.shape, v.dtype),
            model_variables.trainable)
        optimizer_state = optimizer.initialize(trainable_tensor_specs)
        return model_variables, optimizer_state
Exemplo n.º 8
0
    def next_fn(optimizer_state, trainable_weights, update):
        with tf.init_scope():
            # Create a structure of variables that the server optimizer can update.
            trainable_variables = tf.nest.map_structure(
                lambda t: tf.Variable(initial_value=tf.zeros(t.shape, t.dtype)
                                      ), trainable_weights)
            optimizer = keras_optimizer.build_or_verify_tff_optimizer(
                optimizer_fn, trainable_variables, disjoint_init_and_next=True)

        tf.nest.map_structure(lambda a, b: a.assign(b), trainable_variables,
                              trainable_weights)
        optimizer_state, updated_weights = optimizer.next(
            optimizer_state, trainable_variables, update)
        # Keras optimizers mutate model variables in with the `next` step above, so
        # we skip calling the assignment for those optimizers.
        if not isinstance(optimizer, keras_optimizer.KerasOptimizer):
            tf.nest.map_structure(lambda a, b: a.assign(b),
                                  trainable_variables, updated_weights)
        return optimizer_state, trainable_variables
    def __init__(
            self,
            model: model_lib.Model,
            optimizer: Union[optimizer_base.Optimizer,
                             Callable[[], tf.keras.optimizers.Optimizer]],
            client_weighting: client_weight_lib.
        ClientWeightType = client_weight_lib.ClientWeighting.NUM_EXAMPLES,
            use_experimental_simulation_loop: bool = False):
        """Creates the client computation for Federated Averaging.

    Note: All variable creation required for the client computation (e.g. model
    variable creation) must occur in during construction, and not during
    `__call__`.

    Args:
      model: A `tff.learning.Model` instance.
      optimizer: A `optimizer_base.Optimizer` instance, or a no-arg callable
        that returns a `tf.keras.Optimizer` instance..
      client_weighting: A value of `tff.learning.ClientWeighting` that
        specifies a built-in weighting method, or a callable that takes the
        output of `model.report_local_outputs` and returns a tensor that
        provides the weight in the federated average of model deltas.
      use_experimental_simulation_loop: Controls the reduce loop function for
        input dataset. An experimental reduce loop is used for simulation.
    """
        py_typecheck.check_type(model, model_lib.Model)
        self._model = model
        self._optimizer = keras_optimizer.build_or_verify_tff_optimizer(
            optimizer,
            model_utils.ModelWeights.from_model(self._model).trainable,
            disjoint_init_and_next=False)
        client_weight_lib.check_is_client_weighting_or_callable(
            client_weighting)
        self._client_weighting = client_weighting
        self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
            use_experimental_simulation_loop)
Exemplo n.º 10
0
def _build_one_round_computation(
    *,
    model_fn: _ModelConstructor,
    server_optimizer_fn: _OptimizerConstructor,
    model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]],
                                       ClientDeltaFn],
    broadcast_process: measured_process.MeasuredProcess,
    aggregation_process: measured_process.MeasuredProcess,
) -> computation_base.Computation:
    """Builds the `next` computation for a model delta averaging process.

  Args:
    model_fn: A no-argument callable that constructs and returns a
      `tff.learning.Model`. *Must* construct and return a new model when called.
      Returning captured models from other scopes will raise errors.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-argument
      callable that constructs and returns a `tf.keras.optimizers.Optimizer`.
      The callable *must* construct and return a new Keras optimizer when
      called. Returning captured optimizers from other scopes will raise errors.
    model_to_client_delta_fn: A callable that takes a single no-arg callable
      that returns `tff.learning.Model` as an argument and returns a
      `ClientDeltaFn` which performs the local training loop and model delta
      computation.
    broadcast_process: A `tff.templates.MeasuredProcess` to broadcast the global
      model to the clients.
    aggregation_process: A `tff.templates.MeasuredProcess` to aggregate client
      model deltas.

  Returns:
    A `tff.Computation` that initializes the process. The computation takes
    a tuple of `(ServerState@SERVER, tf.data.Dataset@CLIENTS)` argument, and
    returns a tuple of `(ServerState@SERVER, metrics@SERVER)`.
  """
    # TODO(b/124477628): would be nice not to have the construct a throwaway model
    # here just to get the types. After fully moving to TF2.0 and eager-mode, we
    # should re-evaluate what happens here.
    # TODO(b/144382142): Keras name uniquification is probably the main reason we
    # still need this.
    with tf.Graph().as_default():
        whimsy_model_for_metadata = model_fn()
        model_weights = model_utils.ModelWeights.from_model(
            whimsy_model_for_metadata)
        model_weights_type = type_conversions.type_from_tensors(model_weights)

        optimizer = keras_optimizer.build_or_verify_tff_optimizer(
            server_optimizer_fn,
            model_weights.trainable,
            disjoint_init_and_next=True)
        trainable_tensor_specs = tf.nest.map_structure(
            lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights.trainable)
        optimizer_state_type = type_conversions.type_from_tensors(
            optimizer.initialize(trainable_tensor_specs))

    @computations.tf_computation(model_weights_type,
                                 model_weights_type.trainable,
                                 optimizer_state_type)
    @tf.function
    def server_update(global_model, mean_model_delta, optimizer_state):
        """Updates the global model with the mean model update from clients."""
        with tf.init_scope():
            # Create a structure of variables that the server optimizer can update.
            model_variables = tf.nest.map_structure(
                lambda t: tf.Variable(initial_value=tf.zeros(t.shape, t.dtype)
                                      ), global_model)
            optimizer = keras_optimizer.build_or_verify_tff_optimizer(
                server_optimizer_fn,
                model_variables.trainable,
                disjoint_init_and_next=True)

        # Set the variables to the current global model, the optimizer will
        # update these variables.
        tf.nest.map_structure(lambda a, b: a.assign(b), model_variables,
                              global_model)
        # We might have a NaN value e.g. if all of the clients processed had no
        # data, so the denominator in the federated_mean is zero. If we see any
        # NaNs, zero out the whole update.
        # TODO(b/124538167): We should increment a server counter to
        # track the fact a non-finite weights_delta was encountered.
        finite_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite(
            mean_model_delta)
        # Update the global model variables with the delta as a pseudo-gradient.
        negative_weights_delta = tf.nest.map_structure(lambda w: -1.0 * w,
                                                       finite_weights_delta)
        optimizer_state, updated_weights = optimizer.next(
            optimizer_state, model_variables.trainable, negative_weights_delta)
        # Keras optimizers mutate model variables in with the `next` step above, so
        # we skip calling the assignment for those optimizers.
        if not isinstance(optimizer, keras_optimizer.KerasOptimizer):
            tf.nest.map_structure(lambda a, b: a.assign(b),
                                  model_variables.trainable, updated_weights)
        return model_variables, optimizer_state

    dataset_type = computation_types.SequenceType(
        whimsy_model_for_metadata.input_spec)

    @computations.tf_computation(dataset_type, model_weights_type)
    @tf.function
    def _compute_local_training_and_client_delta(dataset,
                                                 initial_model_weights):
        """Performs client local model optimization.

    Args:
      dataset: A `tf.data.Dataset` that provides training examples.
      initial_model_weights: A `model_utils.ModelWeights` containing the
        starting weights.

    Returns:
      A `ClientOutput` structure.
    """
        with tf.init_scope():
            client_delta_fn = model_to_client_delta_fn(model_fn)
        client_output = client_delta_fn(dataset, initial_model_weights)
        return client_output

    broadcast_state = broadcast_process.initialize.type_signature.result.member
    aggregation_state = aggregation_process.initialize.type_signature.result.member

    server_state_type = ServerState(model=model_weights_type,
                                    optimizer_state=optimizer_state_type,
                                    delta_aggregate_state=aggregation_state,
                                    model_broadcast_state=broadcast_state)

    @computations.federated_computation(
        computation_types.FederatedType(server_state_type, placements.SERVER),
        computation_types.FederatedType(dataset_type, placements.CLIENTS))
    def one_round_computation(server_state, federated_dataset):
        """Orchestration logic for one round of optimization.

    Args:
      server_state: A `tff.learning.framework.ServerState` named tuple.
      federated_dataset: A federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
      `tff.learning.Model.federated_output_computation`, both having
      `tff.SERVER` placement.
    """
        broadcast_output = broadcast_process.next(
            server_state.model_broadcast_state, server_state.model)
        client_outputs = intrinsics.federated_map(
            _compute_local_training_and_client_delta,
            (federated_dataset, broadcast_output.result))
        if aggregation_process.is_weighted:
            aggregation_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta,
                client_outputs.weights_delta_weight)
        else:
            aggregation_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta)
        new_global_model, new_optimizer_state = intrinsics.federated_map(
            server_update, (server_state.model, aggregation_output.result,
                            server_state.optimizer_state))
        new_server_state = intrinsics.federated_zip(
            ServerState(new_global_model, new_optimizer_state,
                        aggregation_output.state, broadcast_output.state))
        aggregated_outputs = whimsy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        optimizer_outputs = intrinsics.federated_sum(
            client_outputs.optimizer_output)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(
                broadcast=broadcast_output.measurements,
                aggregation=aggregation_output.measurements,
                train=aggregated_outputs,
                stat=optimizer_outputs))
        return new_server_state, measurements

    return one_round_computation
Exemplo n.º 11
0
 def test_build_tff_optimizer_arg_callable(self, disjoint_init_and_next):
     with self.assertRaises(TypeError):
         keras_optimizer.build_or_verify_tff_optimizer(
             optimizer_fn=lambda x: x,
             trainable_weights=None,
             disjoint_init_and_next=disjoint_init_and_next)
Exemplo n.º 12
0
 def test_build_tff_optimizer_raise(self, disjoint_init_and_next):
     with self.assertRaisesRegex(TypeError,
                                 '`optimizer_fn` must be a callable or '):
         keras_optimizer.build_or_verify_tff_optimizer(
             None, None, disjoint_init_and_next)
Exemplo n.º 13
0
 def test_build_tff_optimizer_tff(self):
     optimizer = sgdm.build_sgdm()
     optimizer2 = keras_optimizer.build_or_verify_tff_optimizer(optimizer)
     self.assertIs(optimizer, optimizer2)
Exemplo n.º 14
0
    def client_update(dataset, initial_model_weights):
        """Performs client local model optimization.

    Args:
      dataset: A `tf.data.Dataset` that provides training examples.
      initial_model_weights: A `tff.learning.ModelWeights` containing the
        starting global trainable and non-trainable weights.

    Returns:
      A `ClientOutput`.
    """
        with tf.init_scope():
            model = model_fn()

            metrics = []
            if metrics_fn is not None:
                metrics.extend(metrics_fn())
            # To be used to calculate example-weighted mean across batches and
            # clients.
            metrics.append(keras_utils.MeanLossMetric(loss_fn()))
            # To be used to calculate batch loss for model updates.
            client_loss = loss_fn()

        global_model_weights = reconstruction_utils.get_global_variables(model)
        local_model_weights = reconstruction_utils.get_local_variables(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights,
                              initial_model_weights)
        client_optimizer = keras_optimizer.build_or_verify_tff_optimizer(
            client_optimizer_fn,
            global_model_weights.trainable,
            disjoint_init_and_next=False)
        reconstruction_optimizer = keras_optimizer.build_or_verify_tff_optimizer(
            reconstruction_optimizer_fn,
            local_model_weights.trainable,
            disjoint_init_and_next=False)

        @tf.function
        def reconstruction_reduce_fn(state, batch):
            """Runs reconstruction training on local client batch."""
            num_examples_sum, optimizer_state = state
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)
                batch_loss = client_loss(y_true=output.labels,
                                         y_pred=output.predictions)

            gradients = tape.gradient(batch_loss,
                                      local_model_weights.trainable)
            optimizer_state, updated_weights = reconstruction_optimizer.next(
                optimizer_state, local_model_weights.trainable, gradients)
            if not isinstance(reconstruction_optimizer,
                              keras_optimizer.KerasOptimizer):
                # Keras optimizer mutates model variables within the `next` step.
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      local_model_weights.trainable,
                                      updated_weights)

            return num_examples_sum + output.num_examples

        @tf.function
        def train_reduce_fn(state, batch):
            """Runs one step of client optimizer on local client batch."""
            num_examples_sum, optimizer_state = state
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)
                batch_loss = client_loss(y_true=output.labels,
                                         y_pred=output.predictions)

            gradients = tape.gradient(batch_loss,
                                      global_model_weights.trainable)
            optimizer_state, updated_weights = client_optimizer.next(
                optimizer_state, global_model_weights.trainable, gradients)
            if not isinstance(client_optimizer,
                              keras_optimizer.KerasOptimizer):
                # Keras optimizer mutates model variables within the `next` step.
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      global_model_weights.trainable,
                                      updated_weights)

            # Update each metric.
            for metric in metrics:
                metric.update_state(y_true=output.labels,
                                    y_pred=output.predictions)

            return num_examples_sum + output.num_examples

        recon_dataset, post_recon_dataset = dataset_split_fn(dataset)

        # If needed, do reconstruction, training the local variables while keeping
        # the global ones frozen.
        if local_model_weights.trainable:
            # Ignore output number of examples used in reconstruction, since this
            # isn't included in `client_weight`.
            def initial_state_reconstruction_reduce():
                trainable_tensor_specs = tf.nest.map_structure(
                    lambda v: tf.TensorSpec(v.shape, v.dtype),
                    local_model_weights.trainable)
                return tf.constant(0), reconstruction_optimizer.initialize(
                    trainable_tensor_specs)

            recon_dataset.reduce(
                initial_state=initial_state_reconstruction_reduce(),
                reduce_func=reconstruction_reduce_fn)

        # Train the global variables, keeping local variables frozen.
        def initial_state_train_reduce():
            trainable_tensor_specs = tf.nest.map_structure(
                lambda v: tf.TensorSpec(v.shape, v.dtype),
                global_model_weights.trainable)
            return tf.constant(0), client_optimizer.initialize(
                trainable_tensor_specs)

        num_examples_sum, _ = post_recon_dataset.reduce(
            initial_state=initial_state_train_reduce(),
            reduce_func=train_reduce_fn)

        weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                              global_model_weights.trainable,
                                              initial_model_weights.trainable)

        # We ignore the update if the weights_delta is non finite.
        weights_delta, has_non_finite_weight = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))

        model_local_outputs = keras_utils.read_metric_variables(metrics)

        if has_non_finite_weight > 0:
            client_weight = tf.constant(0.0, dtype=tf.float32)
        elif client_weighting is client_weight_lib.ClientWeighting.NUM_EXAMPLES:
            client_weight = tf.cast(num_examples_sum, dtype=tf.float32)
        elif client_weighting is client_weight_lib.ClientWeighting.UNIFORM:
            client_weight = tf.constant(1.0, dtype=tf.float32)
        else:
            client_weight = client_weighting(model_local_outputs)

        return ClientOutput(weights_delta, client_weight, model_local_outputs)