def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
    client_model = tff.federated_broadcast(server_state.model)
    client_round_num = tff.federated_broadcast(server_state.round_num)

    client_outputs = tff.federated_map(
        client_update_fn,
        (federated_dataset, client_model, client_round_num))

    client_weight = client_outputs.client_weight
    model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=client_weight)

    server_state = tff.federated_map(server_update_fn,
                                     (server_state, model_delta))

    aggregated_outputs = dummy_model.federated_output_computation(
        client_outputs.model_output)
    if aggregated_outputs.type_signature.is_struct():
      aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs
Exemple #2
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        # Model deltas are equally weighted in DP.
        round_model_delta = tff.federated_mean(client_outputs.weights_delta)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, round_model_delta))
        round_loss_metric = tff.federated_mean(client_outputs.model_output)

        return server_state, round_loss_metric
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState` containing the state of the training process
        up to the current round.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS` containing data to train the current round on.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        weight_denom = client_outputs.client_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, round_model_delta))
        aggregated_outputs = metrics_aggregation_computation(
            client_outputs.model_output)
        return server_state, aggregated_outputs
Exemple #4
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        weight_denom = client_outputs.weights_delta_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        round_grads_norm = tff.federated_mean(
            client_outputs.optimizer_output.flat_grads_norm_sum,
            weight=weight_denom)

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, round_grads_norm))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Exemple #5
0
    def run_one_round(server_state, federated_dataset, client_states):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `stateful_fedavg_tf.ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.
      client_states: A federated `stateful_fedavg_tf.ClientState`.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_states, server_message_at_client))

        weight_denom = client_outputs.client_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)
        total_iters_count = tff.federated_sum(
            client_outputs.client_state.iters_count)
        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, total_iters_count))
        round_loss_metric = tff.federated_mean(client_outputs.model_output,
                                               weight=weight_denom)

        return server_state, round_loss_metric, client_outputs.client_state
Exemple #6
0
def train(
    server_state: int, client_data: tf.data.Dataset
) -> Tuple[int, collections.OrderedDict[str, Any]]:
    """Computes the sum of all the integers on the clients.

  Computes the sum of all the integers on the clients, updates the server state,
  and returns the updated server state and the following metrics:

  * `sum_client_data.METRICS_TOTAL_SUM`: The sum of all the client_data on the
    clients.

  Args:
    server_state: The server state.
    client_data: The data on the clients.

  Returns:
    A tuple of the updated server state and the train metrics.
  """
    client_sums = tff.federated_map(_sum_dataset, client_data)
    total_sum = tff.federated_sum(client_sums)
    updated_state = tff.federated_map(_sum_integers, (server_state, total_sum))
    metrics = collections.OrderedDict([
        (METRICS_TOTAL_SUM, total_sum),
    ])
    return updated_state, metrics
Exemple #7
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
        discovered_prefixes = tff.federated_broadcast(
            server_state.discovered_prefixes)
        round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, discovered_prefixes, round_num))

        accumulated_votes = tff.federated_sum(client_outputs.client_votes)

        accumulated_weights = tff.federated_sum(client_outputs.client_weight)

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, accumulated_votes, accumulated_weights))

        server_output = tff.federated_value([], tff.SERVER)

        return server_state, server_output
    def run_one_round(server_state, client_states, federated_dataset,
                      federated_dataset_single):
        from_server = FromServer(w=server_state.w,
                                 meta_w=server_state.meta_w,
                                 round_num=server_state.round_num)
        from_server = tff.federated_broadcast(from_server)
        control_output = tff.federated_map(
            control_computation,
            (federated_dataset_single, client_state, from_server))
        c = tff.federated_broadcast(
            tff.federated_mean(control_output.w_delta,
                               weight=control_output.client_weight))

        @tff.tf_computation(client_output_type.w_delta,
                            client_output_type.w_delta)
        def compute_control_input(c, c_i):
            correction = tf.nest.map_structure(lambda a, b: a - b, c, c_i)
            # if we are using SCAFFOLD then use the correction, otherwise
            # we just let the correction be zero.
            return tf.cond(
                tf.constant(control, dtype=tf.bool), lambda: correction,
                lambda: tf.nest.map_structure(tf.zeros_like, correction))

        # the collection of gradient corrections
        corrections = tff.federated_map(compute_control_input,
                                        (c, control_output.cs))
        client_outputs = tff.federated_map(
            client_computation,
            (federated_dataset, from_server, client_state, corrections))
        w_delta = tff.federated_mean(client_outputs.w_delta,
                                     weight=client_outputs.client_weight)
        server_state = tff.federated_map(server_computation,
                                         (server_state, w_delta))
        return (server_state, client_outputs.client_state)
Exemple #9
0
 def comp(temperatures, threshold):
   return tff.federated_mean(
       tff.federated_map(
           count_over,
           tff.federated_zip(
               [temperatures,
                tff.federated_broadcast(threshold)])),
       tff.federated_map(count_total, temperatures))
Exemple #10
0
 def next_fn(state, value, weight):
     weighted_values = tff.federated_map(_mul, (value, weight))
     summed_value = tff.federated_sum(weighted_values)
     normalized_value = tff.federated_map(_div, (summed_value, state))
     measurements = tff.federated_value((), tff.SERVER)
     return tff.templates.MeasuredProcessOutput(
         state=state,
         result=normalized_value,
         measurements=measurements)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)
        client_optimizer_weights = tff.federated_broadcast(
            server_state.optimizer_state)
        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, client_model,
                               client_round_num, client_optimizer_weights))

        client_weight = client_outputs.client_weight
        model_delta = tff.federated_mean(client_outputs.weights_delta,
                                         weight=client_weight)

        implicit_mean_broad = tff.federated_broadcast(
            tff.federated_mean(client_outputs.weights, weight=client_weight))

        @tff.tf_computation(model_weights_type.trainable,
                            model_weights_type.trainable)
        def individual_drift_fn(client_model, implicit_mean):
            drifts = tf.nest.map_structure(lambda a, b: a - b, client_model,
                                           implicit_mean)
            return drifts

        drifts = tff.federated_map(
            individual_drift_fn, (client_outputs.weights, implicit_mean_broad))

        @tff.tf_computation(model_weights_type.trainable)
        def individual_drift_norm_fn(delta_from_mean):
            drift_norm = tf.reduce_sum(
                tf.nest.flatten(
                    tf.nest.map_structure(lambda a: tf.nn.l2_loss(a),
                                          delta_from_mean)))
            return drift_norm

        drift_norms = tff.federated_map(individual_drift_norm_fn, drifts)
        client_drift = tff.federated_mean(drift_norms, weight=client_weight)

        server_state = tff.federated_map(
            server_update_fn, (server_state, model_delta, client_drift))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)
        return server_state, aggregated_outputs
Exemple #12
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Note that in addition to updating the server weights according to the client
    model weight deltas, we extract metrics (governed by the `monitor` attribute
    of the `client_lr_callback` and `server_lr_callback` attributes of the
    `server_state`) and use these to update the client learning rate callbacks.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation` before and during local
      client training.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_lr = tff.federated_broadcast(
            server_state.client_lr_callback.learning_rate)

        if dataset_preprocess_comp is not None:
            federated_dataset = tff.federated_map(dataset_preprocess_comp,
                                                  federated_dataset)
        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, client_model, client_lr))

        client_weight = client_outputs.client_weight
        aggregated_gradients = tff.federated_mean(
            client_outputs.accumulated_gradients, weight=client_weight)

        initial_aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.initial_model_output)
        if isinstance(initial_aggregated_outputs.type_signature,
                      tff.StructType):
            initial_aggregated_outputs = tff.federated_zip(
                initial_aggregated_outputs)

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)
        client_monitor_value = initial_aggregated_outputs[client_monitor]
        server_monitor_value = initial_aggregated_outputs[server_monitor]

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregated_gradients,
                               client_monitor_value, server_monitor_value))

        result = collections.OrderedDict(
            before_training=initial_aggregated_outputs,
            during_training=aggregated_outputs)

        return server_state, result
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        # Run computation on the clients.
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_num))

        # Aggregate model deltas.
        client_weight = client_outputs.client_weight
        if mask_zeros_in_client_updates:
            model_delta = federated_mean_masked(client_outputs.weights_delta,
                                                client_weight)
        else:
            model_delta = tff.federated_mean(client_outputs.weights_delta,
                                             client_weight)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, model_delta))

        # Aggregate model outputs that contain local metrics and various statistics.
        aggregated_outputs = placeholder_model.federated_output_computation(
            client_outputs.model_output)
        additional_outputs = tff.federated_mean(
            client_outputs.additional_output, weight=client_weight)

        @tff.tf_computation(aggregated_outputs.type_signature.member,
                            additional_outputs.type_signature.member)
        def _update_aggregated_outputs(aggregated_outputs, additional_outputs):
            aggregated_outputs.update(additional_outputs)
            return aggregated_outputs

        aggregated_outputs = tff.federated_map(
            _update_aggregated_outputs,
            (aggregated_outputs, additional_outputs))
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
    def run_one_round(server_state, federated_dataset, malicious_dataset,
                      malicious_clients):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
      malicious_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
        consisting of malicious datasets.
      malicious_clients: A federated `tf.bool` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """

        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, malicious_dataset,
                               malicious_clients, client_model))

        weight_denom = client_outputs.weights_delta_weight

        # If the aggregation process' next function takes three arguments it is
        # weighted, otherwise, unweighted. Unfortunately there is no better way
        # to determine this.
        if len(aggregation_process.next.type_signature.parameter) == 3:
            aggregate_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta,
                weight=weight_denom)
        else:
            aggregate_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta)
        new_delta_aggregate_state = aggregate_output.state
        round_model_delta = aggregate_output.result

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, new_delta_aggregate_state))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Exemple #15
0
    def next_tff(server_state, datasets, client_states):
        message = tff.federated_map(state_to_message_tf, server_state)
        broadcast = tff.federated_broadcast(message)

        outputs = tff.federated_map(update_client_tf,
                                    (datasets, client_states, broadcast))
        weights_delta = tff.federated_mean(outputs.weights_delta,
                                           weight=outputs.client_weight)

        metrics = model.federated_output_computation(outputs.metrics)

        next_state = tff.federated_map(update_server_tf,
                                       (server_state, weights_delta))

        return next_state, metrics, outputs.client_state
Exemple #16
0
def next_fn(server_weights, federated_dataset):
    # 将服务器模型广播到客户端上
    server_weights_at_client = tff.federated_broadcast(server_weights)

    # 客户端计算更新过程,并更新参数
    client_weights, clients_loss = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_client))

    # 服务器平均所有客户端更新的模型参数
    mean_client_weights = tff.federated_mean(client_weights)

    # 服务器更新他的模型
    server_weights = tff.federated_map(server_update_fn, mean_client_weights)

    return server_weights, client_weights, clients_loss
Exemple #17
0
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        # TODO(b/131429028): The federated_zip should be automatic.
        from_server = tff.federated_zip(
            gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights))
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        if gan.dp_averaging_fn is None:
            # Not using differential privacy.
            new_dp_averaging_state = server_state.dp_averaging_state
            averaged_discriminator_weights_delta = tff.federated_mean(
                client_outputs.discriminator_weights_delta,
                weight=client_outputs.update_weight)
        else:
            # Using differential privacy. Note that the weight argument is set to None
            # here. This is because the DP aggregation code explicitly does not do
            # weighted aggregation. (If weighted aggregation is desired, differential
            # privacy needs to be turned off.)
            new_dp_averaging_state, averaged_discriminator_weights_delta = (
                gan.dp_averaging_fn(server_state.dp_averaging_state,
                                    client_outputs.discriminator_weights_delta,
                                    weight=None))

        # TODO(b/131085687): Perhaps reconsider the choice to also use
        # ClientOutput to hold the aggregated client output.
        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight),
            counters=tff.federated_sum(client_outputs.counters))

        # TODO(b/131839522): This federated_zip shouldn't be needed.
        aggregated_client_output = tff.federated_zip(aggregated_client_output)

        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_dp_averaging_state))
        return server_state
def federated_train(model, learning_rate, data, classes):
    return tff.federated_mean(
        tff.federated_map(local_train, [
            tff.federated_broadcast(model),
            tff.federated_broadcast(learning_rate), data,
            tff.federated_broadcast(classes)
        ]))
Exemple #19
0
  def next_fn(state, deltas, weights):

    @tff.tf_computation(model_update_type)
    def clip_by_global_norm(update):
      clipped_update, global_norm = tf.clip_by_global_norm(
          tf.nest.flatten(update), tf.constant(clip_norm))
      was_clipped = tf.cond(
          tf.greater(global_norm, tf.constant(clip_norm)),
          lambda: tf.constant(1),
          lambda: tf.constant(0),
      )
      clipped_update = tf.nest.pack_sequence_as(update, clipped_update)
      return clipped_update, global_norm, was_clipped

    clipped_deltas, client_norms, client_was_clipped = tff.federated_map(
        clip_by_global_norm, deltas)

    return collections.OrderedDict(
        state=state,
        result=tff.federated_mean(clipped_deltas, weight=weights),
        measurements=tff.federated_zip(
            NormClippedAggregationMetrics(
                max_global_norm=tff.utils.federated_max(client_norms),
                num_clipped=tff.federated_sum(client_was_clipped),
            )))
Exemple #20
0
def evaluation(
        server_state: int,
        client_data: tf.data.Dataset) -> collections.OrderedDict[str, Any]:
    """Computes the sum of all the integers on the clients.

  Computes the sum of all the integers on the clients and returns the following
  metrics:

  * `sum_client_data.METRICS_TOTAL_SUM`: The sum of all the client_data on the
    clients.

  Args:
    server_state: The server state.
    client_data: The data on the clients.

  Returns:
    The evaluation metrics.
  """
    del server_state  # Unused.
    client_sums = tff.federated_map(_sum_dataset, client_data)
    total_sum = tff.federated_sum(client_sums)
    metrics = collections.OrderedDict([
        (METRICS_TOTAL_SUM, total_sum),
    ])
    return metrics
Exemple #21
0
def federated_train(model, learning_rate, data):
    l = tff.federated_map(
        local_train,
        [tff.federated_broadcast(model),
         tff.federated_broadcast(learning_rate),
         data])
    return l
def federated_train(model, lr, data):
    #返回的是训练后的模型
    return tff.federated_mean(
        tff.federated_map(local_train, [
            tff.federated_broadcast(model),
            tff.federated_broadcast(lr), data
        ]))
Exemple #23
0
def next_fn(server_weights, federated_dataset):
    print('sw', dir(server_weights))
    # Broadcast the server weights to the clients.
    server_weights_at_client = tff.federated_broadcast(server_weights)
    print(server_weights_at_client)
    # Each client computes their updated weights.
    client_weights = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_client))
    print('weights', client_weights)
    # The server averages these updates.
    mean_client_weights = tff.federated_mean(client_weights)

    # The server updates its model.
    server_weights = tff.federated_map(server_update_fn, mean_client_weights)

    return server_weights
Exemple #24
0
    def validate(weights, datasets, client_states):
        broadcast = tff.federated_broadcast(weights)
        outputs = tff.federated_map(validate_client_tf,
                                    (datasets, client_states, broadcast))
        metrics = model.federated_output_computation(outputs.metrics)

        return metrics
  def foo(x):

    @tff.tf_computation(element_type)
    def local_sum(nums):
      return tf.math.reduce_sum(nums)

    return tff.federated_sum(tff.federated_map(local_sum, x))
Exemple #26
0
    def aggregate_and_clip(global_state, value, weight=None):
        """Compute the weighted mean of values@CLIENTS and clip it and return an aggregated value @SERVER."""

        round_model_delta = tff.federated_mean(value, weight)
        value_type = value.type_signature.member

        @tff.tf_computation(value_type._asdict())
        @tf.function
        def clip_by_norm(gradient, norm=norm_bound):
            """Clip the gradient by a certain l_2 norm."""

            delta_norm = tf.linalg.global_norm(tf.nest.flatten(gradient))

            if delta_norm < tf.cast(norm, tf.float32):
                return gradient
            else:
                delta_mul_factor = tf.math.divide_no_nan(
                    tf.cast(norm, tf.float32), delta_norm)
                nested_mul_factor = collections.OrderedDict([
                    (key, delta_mul_factor) for key in gradient.keys()
                ])
                return tf.nest.map_structure(tf.multiply, nested_mul_factor,
                                             gradient)

        return global_state, tff.federated_map(clip_by_norm, round_model_delta)
Exemple #27
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        weight_denom = client_outputs.weights_delta_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        server_state = tff.federated_apply(server_update_fn,
                                           (server_state, round_model_delta))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        from_server = gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights)
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        if gan.dp_averaging_fn is None:
            # Not using differential privacy.
            new_dp_averaging_state = server_state.dp_averaging_state
            averaged_discriminator_weights_delta = tff.federated_mean(
                client_outputs.discriminator_weights_delta,
                weight=client_outputs.update_weight)
        else:
            # Using differential privacy. Note that the weight argument is set to
            # a constant 1.0 here, however the underlying AggregationProcess ignores
            # the parameter and performs no weighting.
            ignored_weight = tff.federated_value(1.0, tff.CLIENTS)
            aggregation_output = gan.dp_averaging_fn.next(
                server_state.dp_averaging_state,
                client_outputs.discriminator_weights_delta,
                weight=ignored_weight)
            new_dp_averaging_state = aggregation_output.state
            averaged_discriminator_weights_delta = aggregation_output.result

        # TODO(b/131085687): Perhaps reconsider the choice to also use
        # ClientOutput to hold the aggregated client output.
        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight),
            counters=tff.federated_sum(client_outputs.counters))

        server_computation = build_server_computation(
            gan, server_state.type_signature.member, client_output_type)
        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_dp_averaging_state))
        return server_state
Exemple #29
0
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
    client_model = tff.federated_broadcast(server_state.model)
    client_optimizer_state = tff.federated_broadcast(
        server_state.client_optimizer_state)
    client_round_num = tff.federated_broadcast(server_state.round_num)
    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, client_model,
                           client_optimizer_state, client_round_num))

    client_model_weight = client_outputs.client_weight.model_weight
    client_opt_weight = client_outputs.client_weight.optimizer_weight

    model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=client_model_weight)

    # We convert the optimizer state to a float type so that it can be used
    # with thing such as `tff.federated_mean`. This is only necessary because
    # `tf.keras.Optimizer` objects have a state with an integer indicating
    # the number of times it has been applied.
    client_optimizer_state_delta = tff.federated_map(
        _convert_opt_state_to_float, client_outputs.optimizer_state_delta)
    client_optimizer_state_delta = optimizer_aggregator(
        client_optimizer_state_delta, weight=client_opt_weight)
    # We conver the optimizer state back into one with an integer round number
    client_optimizer_state_delta = tff.federated_map(
        _convert_opt_state_to_int, client_optimizer_state_delta)

    server_state = tff.federated_map(
        server_update_fn,
        (server_state, model_delta, client_optimizer_state_delta))

    aggregated_outputs = placeholder_model.federated_output_computation(
        client_outputs.model_output)
    if aggregated_outputs.type_signature.is_struct():
      aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs
 def robust_aggregation_fn(value, weight):
     aggregate = tff.federated_mean(value, weight=weight)
     for _ in range(num_communication_passes - 1):
         aggregate_at_client = tff.federated_broadcast(aggregate)
         updated_weight = tff.federated_map(
             update_weight_fn, (weight, aggregate_at_client, value))
         aggregate = tff.federated_mean(value, weight=updated_weight)
     return aggregate