Ejemplo n.º 1
0
    def run_one_round(server_state, federated_dataset, client_states):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `stateful_fedavg_tf.ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.
      client_states: A federated `stateful_fedavg_tf.ClientState`.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_states, server_message_at_client))

        weight_denom = client_outputs.client_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)
        total_iters_count = tff.federated_sum(
            client_outputs.client_state.iters_count)
        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, total_iters_count))
        round_loss_metric = tff.federated_mean(client_outputs.model_output,
                                               weight=weight_denom)

        return server_state, round_loss_metric, client_outputs.client_state
Ejemplo n.º 2
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        weight_denom = client_outputs.weights_delta_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        round_grads_norm = tff.federated_mean(
            client_outputs.optimizer_output.flat_grads_norm_sum,
            weight=weight_denom)

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, round_grads_norm))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Ejemplo n.º 3
0
    def run_one_round(server_state, client_states, federated_dataset,
                      federated_dataset_single):
        from_server = FromServer(w=server_state.w,
                                 meta_w=server_state.meta_w,
                                 round_num=server_state.round_num)
        from_server = tff.federated_broadcast(from_server)
        control_output = tff.federated_map(
            control_computation,
            (federated_dataset_single, client_state, from_server))
        c = tff.federated_broadcast(
            tff.federated_mean(control_output.w_delta,
                               weight=control_output.client_weight))

        @tff.tf_computation(client_output_type.w_delta,
                            client_output_type.w_delta)
        def compute_control_input(c, c_i):
            correction = tf.nest.map_structure(lambda a, b: a - b, c, c_i)
            # if we are using SCAFFOLD then use the correction, otherwise
            # we just let the correction be zero.
            return tf.cond(
                tf.constant(control, dtype=tf.bool), lambda: correction,
                lambda: tf.nest.map_structure(tf.zeros_like, correction))

        # the collection of gradient corrections
        corrections = tff.federated_map(compute_control_input,
                                        (c, control_output.cs))
        client_outputs = tff.federated_map(
            client_computation,
            (federated_dataset, from_server, client_state, corrections))
        w_delta = tff.federated_mean(client_outputs.w_delta,
                                     weight=client_outputs.client_weight)
        server_state = tff.federated_map(server_computation,
                                         (server_state, w_delta))
        return (server_state, client_outputs.client_state)
Ejemplo n.º 4
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        # Model deltas are equally weighted in DP.
        round_model_delta = tff.federated_mean(client_outputs.weights_delta)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, round_model_delta))
        round_loss_metric = tff.federated_mean(client_outputs.model_output)

        return server_state, round_loss_metric
 def robust_aggregation_fn(value, weight):
     aggregate = tff.federated_mean(value, weight=weight)
     for _ in range(num_communication_passes - 1):
         aggregate_at_client = tff.federated_broadcast(aggregate)
         updated_weight = tff.federated_map(
             update_weight_fn, (weight, aggregate_at_client, value))
         aggregate = tff.federated_mean(value, weight=updated_weight)
     return aggregate
 def _robust_aggregation_fn(state, value, weight):
     aggregate = tff.federated_mean(value, weight=weight)
     for _ in range(num_communication_passes - 1):
         aggregate_at_client = tff.federated_broadcast(aggregate)
         updated_weight = tff.federated_map(
             update_weight_fn, (weight, aggregate_at_client, value))
         aggregate = tff.federated_mean(value, weight=updated_weight)
     no_metrics = tff.federated_value((), tff.SERVER)
     return tff.templates.MeasuredProcessOutput(state, aggregate,
                                                no_metrics)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)
        client_optimizer_weights = tff.federated_broadcast(
            server_state.optimizer_state)
        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, client_model,
                               client_round_num, client_optimizer_weights))

        client_weight = client_outputs.client_weight
        model_delta = tff.federated_mean(client_outputs.weights_delta,
                                         weight=client_weight)

        implicit_mean_broad = tff.federated_broadcast(
            tff.federated_mean(client_outputs.weights, weight=client_weight))

        @tff.tf_computation(model_weights_type.trainable,
                            model_weights_type.trainable)
        def individual_drift_fn(client_model, implicit_mean):
            drifts = tf.nest.map_structure(lambda a, b: a - b, client_model,
                                           implicit_mean)
            return drifts

        drifts = tff.federated_map(
            individual_drift_fn, (client_outputs.weights, implicit_mean_broad))

        @tff.tf_computation(model_weights_type.trainable)
        def individual_drift_norm_fn(delta_from_mean):
            drift_norm = tf.reduce_sum(
                tf.nest.flatten(
                    tf.nest.map_structure(lambda a: tf.nn.l2_loss(a),
                                          delta_from_mean)))
            return drift_norm

        drift_norms = tff.federated_map(individual_drift_norm_fn, drifts)
        client_drift = tff.federated_mean(drift_norms, weight=client_weight)

        server_state = tff.federated_map(
            server_update_fn, (server_state, model_delta, client_drift))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)
        return server_state, aggregated_outputs
Ejemplo n.º 8
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        # Run computation on the clients.
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_num))

        # Aggregate model deltas.
        client_weight = client_outputs.client_weight
        if mask_zeros_in_client_updates:
            model_delta = federated_mean_masked(client_outputs.weights_delta,
                                                client_weight)
        else:
            model_delta = tff.federated_mean(client_outputs.weights_delta,
                                             client_weight)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, model_delta))

        # Aggregate model outputs that contain local metrics and various statistics.
        aggregated_outputs = placeholder_model.federated_output_computation(
            client_outputs.model_output)
        additional_outputs = tff.federated_mean(
            client_outputs.additional_output, weight=client_weight)

        @tff.tf_computation(aggregated_outputs.type_signature.member,
                            additional_outputs.type_signature.member)
        def _update_aggregated_outputs(aggregated_outputs, additional_outputs):
            aggregated_outputs.update(additional_outputs)
            return aggregated_outputs

        aggregated_outputs = tff.federated_map(
            _update_aggregated_outputs,
            (aggregated_outputs, additional_outputs))
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Ejemplo n.º 9
0
def federated_train(model, learning_rate, data, classes):
    return tff.federated_mean(
        tff.federated_map(local_train, [
            tff.federated_broadcast(model),
            tff.federated_broadcast(learning_rate), data,
            tff.federated_broadcast(classes)
        ]))
Ejemplo n.º 10
0
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
    client_model = tff.federated_broadcast(server_state.model)
    client_round_num = tff.federated_broadcast(server_state.round_num)

    client_outputs = tff.federated_map(
        client_update_fn,
        (federated_dataset, client_model, client_round_num))

    client_weight = client_outputs.client_weight
    model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=client_weight)

    server_state = tff.federated_map(server_update_fn,
                                     (server_state, model_delta))

    aggregated_outputs = dummy_model.federated_output_computation(
        client_outputs.model_output)
    if aggregated_outputs.type_signature.is_struct():
      aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs
Ejemplo n.º 11
0
  def next_fn(state, deltas, weights):

    @tff.tf_computation(model_update_type)
    def clip_by_global_norm(update):
      clipped_update, global_norm = tf.clip_by_global_norm(
          tf.nest.flatten(update), tf.constant(clip_norm))
      was_clipped = tf.cond(
          tf.greater(global_norm, tf.constant(clip_norm)),
          lambda: tf.constant(1),
          lambda: tf.constant(0),
      )
      clipped_update = tf.nest.pack_sequence_as(update, clipped_update)
      return clipped_update, global_norm, was_clipped

    clipped_deltas, client_norms, client_was_clipped = tff.federated_map(
        clip_by_global_norm, deltas)

    return collections.OrderedDict(
        state=state,
        result=tff.federated_mean(clipped_deltas, weight=weights),
        measurements=tff.federated_zip(
            NormClippedAggregationMetrics(
                max_global_norm=tff.utils.federated_max(client_norms),
                num_clipped=tff.federated_sum(client_was_clipped),
            )))
Ejemplo n.º 12
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        weight_denom = client_outputs.weights_delta_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        server_state = tff.federated_apply(server_update_fn,
                                           (server_state, round_model_delta))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Ejemplo n.º 13
0
    def aggregate_and_clip(global_state, value, weight=None):
        """Compute the weighted mean of values@CLIENTS and clip it and return an aggregated value @SERVER."""

        round_model_delta = tff.federated_mean(value, weight)
        value_type = value.type_signature.member

        @tff.tf_computation(value_type._asdict())
        @tf.function
        def clip_by_norm(gradient, norm=norm_bound):
            """Clip the gradient by a certain l_2 norm."""

            delta_norm = tf.linalg.global_norm(tf.nest.flatten(gradient))

            if delta_norm < tf.cast(norm, tf.float32):
                return gradient
            else:
                delta_mul_factor = tf.math.divide_no_nan(
                    tf.cast(norm, tf.float32), delta_norm)
                nested_mul_factor = collections.OrderedDict([
                    (key, delta_mul_factor) for key in gradient.keys()
                ])
                return tf.nest.map_structure(tf.multiply, nested_mul_factor,
                                             gradient)

        return global_state, tff.federated_map(clip_by_norm, round_model_delta)
def federated_train(model, lr, data):
    #返回的是训练后的模型
    return tff.federated_mean(
        tff.federated_map(local_train, [
            tff.federated_broadcast(model),
            tff.federated_broadcast(lr), data
        ]))
Ejemplo n.º 15
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState` containing the state of the training process
        up to the current round.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS` containing data to train the current round on.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        weight_denom = client_outputs.client_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, round_model_delta))
        aggregated_outputs = metrics_aggregation_computation(
            client_outputs.model_output)
        return server_state, aggregated_outputs
Ejemplo n.º 16
0
 def comp(temperatures, threshold):
   return tff.federated_mean(
       tff.federated_map(
           count_over,
           tff.federated_zip(
               [temperatures,
                tff.federated_broadcast(threshold)])),
       tff.federated_map(count_total, temperatures))
Ejemplo n.º 17
0
def next_fn(server_weights, federated_dataset):
    # Send server weights to clients
    server_weights_to_clients = tff.federated_broadcast(server_weights)

    # Each client computes their updated weights
    client_weights = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_to_clients))

    # Client mean
    mean_client_weights = tff.federated_mean(client_weights)

    # Server averages all the client weights
    mean_client_weights = tff.federated_mean(client_weights)

    # The server updates it model
    server_weights = tff.federated_map(server_update_fn, mean_client_weights)

    return (server_weights, client_weights)
Ejemplo n.º 18
0
    def federated_evaluate(model_weights, federated_dataset):
        client_model = tff.federated_broadcast(model_weights)
        client_metrics = tff.federated_map(compute_client_metrics,
                                           (client_model, federated_dataset))
        # Extract the number of examples in order to compute client weights
        num_examples = client_metrics.num_examples
        uniform_weighted_metrics = tff.federated_mean(client_metrics,
                                                      weight=None)
        example_weighted_metrics = tff.federated_mean(client_metrics,
                                                      weight=num_examples)
        # Aggregate the metrics in a single nested dictionary
        aggregate_metrics = collections.OrderedDict()
        aggregate_metrics[AggregationMethods.EXAMPLE_WEIGHTED.
                          value] = example_weighted_metrics
        aggregate_metrics[AggregationMethods.UNIFORM_WEIGHTED.
                          value] = uniform_weighted_metrics

        return aggregate_metrics
    def run_one_round(server_state, client_states):
        """Orchestration logic for one round of federated training computation."""
        # performing the federated averaging of the clients' weights
        mean_client_weights = tff.federated_mean(client_states)
        print(str(mean_client_weights))
        ## SERVER UPDATING STEP
        server_state = tff.federated_apply(server_update_fn,
                                           (server_state, mean_client_weights))

        # returning the new server state
        return server_state
Ejemplo n.º 20
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Note that in addition to updating the server weights according to the client
    model weight deltas, we extract metrics (governed by the `monitor` attribute
    of the `client_lr_callback` and `server_lr_callback` attributes of the
    `server_state`) and use these to update the client learning rate callbacks.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation` before and during local
      client training.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_lr = tff.federated_broadcast(
            server_state.client_lr_callback.learning_rate)

        if dataset_preprocess_comp is not None:
            federated_dataset = tff.federated_map(dataset_preprocess_comp,
                                                  federated_dataset)
        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, client_model, client_lr))

        client_weight = client_outputs.client_weight
        aggregated_gradients = tff.federated_mean(
            client_outputs.accumulated_gradients, weight=client_weight)

        initial_aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.initial_model_output)
        if isinstance(initial_aggregated_outputs.type_signature,
                      tff.StructType):
            initial_aggregated_outputs = tff.federated_zip(
                initial_aggregated_outputs)

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)
        client_monitor_value = initial_aggregated_outputs[client_monitor]
        server_monitor_value = initial_aggregated_outputs[server_monitor]

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregated_gradients,
                               client_monitor_value, server_monitor_value))

        result = collections.OrderedDict(
            before_training=initial_aggregated_outputs,
            during_training=aggregated_outputs)

        return server_state, result
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        client_weight = client_outputs.client_weight
        model_delta = tff.federated_mean(client_outputs.weights_delta,
                                         weight=client_weight)
        global_cor_states = tff.federated_mean(client_outputs.local_cor_states,
                                               weight=client_weight)
        if correction == 'joint':
            server_state = tff.federated_map(
                server_update_joint_cor_fn,
                (server_state, model_delta, global_cor_states))
        elif correction == 'local':
            server_state = tff.federated_map(server_update_fn,
                                             (server_state, model_delta))
        else:
            raise TypeError('Correction method must be local or joint.')

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Ejemplo n.º 22
0
def aggregate_metrics_across_clients(metrics):
    global metrics_name
    output = collections.OrderedDict()

    for metric in metrics_name:
        if metric == 'num_examples':
            output[metric] = tff.federated_sum(getattr(metrics, metric))
            output['per_client/' + metric] = tff.federated_collect(
                getattr(metrics, metric))
        else:
            output[metric] = tff.federated_mean(getattr(metrics, metric),
                                                metrics.num_examples)
            output['per_client/' + metric] = tff.federated_collect(
                getattr(metrics, metric))
    return output
Ejemplo n.º 23
0
    def run_one_round(server_state, client_states, federated_dataset):
        """Orchestration logic for one round of computation.

        Args:
        server_state: A `ServerState`.
        federated_dataset: A federated `tf.data.Dataset` with placement
            `tff.CLIENTS`.

        Returns:
        A tuple of updated `ServerState` and `tf.Tensor` of average loss.
        """
        # Prepare server_message to be sent to the clients,
        # based on the server_state from previous round
        server_message = tff.federated_map(server_message_fn, server_state)

        # Update the clients with the new server_message and dataset
        client_outputs, new_client_state = tff.federated_map(
            client_update_fn,
            (
                federated_dataset,
                tff.federated_broadcast(server_message),
                client_states,
            )
        )

        round_model_delta = tff.federated_mean(
            client_outputs.weights_delta, weight=client_outputs.client_weight)

        # Update server state given the current round's completion
        server_state = tff.federated_map(
            server_update_fn, (server_state, round_model_delta))

        round_loss_metric = tff.federated_mean(
            client_outputs.model_output, weight=client_outputs.client_weight)

        return server_state, new_client_state, round_loss_metric
Ejemplo n.º 24
0
def next_fn(server_weights, federated_dataset):
    # 将服务器模型广播到客户端上
    server_weights_at_client = tff.federated_broadcast(server_weights)

    # 客户端计算更新过程,并更新参数
    client_weights, clients_loss = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_client))

    # 服务器平均所有客户端更新的模型参数
    mean_client_weights = tff.federated_mean(client_weights)

    # 服务器更新他的模型
    server_weights = tff.federated_map(server_update_fn, mean_client_weights)

    return server_weights, client_weights, clients_loss
Ejemplo n.º 25
0
    def next_tff(server_state, datasets, client_states):
        message = tff.federated_map(state_to_message_tf, server_state)
        broadcast = tff.federated_broadcast(message)

        outputs = tff.federated_map(update_client_tf,
                                    (datasets, client_states, broadcast))
        weights_delta = tff.federated_mean(outputs.weights_delta,
                                           weight=outputs.client_weight)

        metrics = model.federated_output_computation(outputs.metrics)

        next_state = tff.federated_map(update_server_tf,
                                       (server_state, weights_delta))

        return next_state, metrics, outputs.client_state
Ejemplo n.º 26
0
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        # TODO(b/131429028): The federated_zip should be automatic.
        from_server = tff.federated_zip(
            gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights))
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        if gan.dp_averaging_fn is None:
            # Not using differential privacy.
            new_dp_averaging_state = server_state.dp_averaging_state
            averaged_discriminator_weights_delta = tff.federated_mean(
                client_outputs.discriminator_weights_delta,
                weight=client_outputs.update_weight)
        else:
            # Using differential privacy. Note that the weight argument is set to None
            # here. This is because the DP aggregation code explicitly does not do
            # weighted aggregation. (If weighted aggregation is desired, differential
            # privacy needs to be turned off.)
            new_dp_averaging_state, averaged_discriminator_weights_delta = (
                gan.dp_averaging_fn(server_state.dp_averaging_state,
                                    client_outputs.discriminator_weights_delta,
                                    weight=None))

        # TODO(b/131085687): Perhaps reconsider the choice to also use
        # ClientOutput to hold the aggregated client output.
        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight),
            counters=tff.federated_sum(client_outputs.counters))

        # TODO(b/131839522): This federated_zip shouldn't be needed.
        aggregated_client_output = tff.federated_zip(aggregated_client_output)

        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_dp_averaging_state))
        return server_state
Ejemplo n.º 27
0
def next_fn(server_weights, federated_dataset):
    print('sw', dir(server_weights))
    # Broadcast the server weights to the clients.
    server_weights_at_client = tff.federated_broadcast(server_weights)
    print(server_weights_at_client)
    # Each client computes their updated weights.
    client_weights = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_client))
    print('weights', client_weights)
    # The server averages these updates.
    mean_client_weights = tff.federated_mean(client_weights)

    # The server updates its model.
    server_weights = tff.federated_map(server_update_fn, mean_client_weights)

    return server_weights
Ejemplo n.º 28
0
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
    client_model = tff.federated_broadcast(server_state.model)
    client_optimizer_state = tff.federated_broadcast(
        server_state.client_optimizer_state)
    client_round_num = tff.federated_broadcast(server_state.round_num)
    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, client_model,
                           client_optimizer_state, client_round_num))

    client_model_weight = client_outputs.client_weight.model_weight
    client_opt_weight = client_outputs.client_weight.optimizer_weight

    model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=client_model_weight)

    # We convert the optimizer state to a float type so that it can be used
    # with thing such as `tff.federated_mean`. This is only necessary because
    # `tf.keras.Optimizer` objects have a state with an integer indicating
    # the number of times it has been applied.
    client_optimizer_state_delta = tff.federated_map(
        _convert_opt_state_to_float, client_outputs.optimizer_state_delta)
    client_optimizer_state_delta = optimizer_aggregator(
        client_optimizer_state_delta, weight=client_opt_weight)
    # We conver the optimizer state back into one with an integer round number
    client_optimizer_state_delta = tff.federated_map(
        _convert_opt_state_to_int, client_optimizer_state_delta)

    server_state = tff.federated_map(
        server_update_fn,
        (server_state, model_delta, client_optimizer_state_delta))

    aggregated_outputs = placeholder_model.federated_output_computation(
        client_outputs.model_output)
    if aggregated_outputs.type_signature.is_struct():
      aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs
Ejemplo n.º 29
0
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        from_server = gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights)
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        if gan.dp_averaging_fn is None:
            # Not using differential privacy.
            new_dp_averaging_state = server_state.dp_averaging_state
            averaged_discriminator_weights_delta = tff.federated_mean(
                client_outputs.discriminator_weights_delta,
                weight=client_outputs.update_weight)
        else:
            # Using differential privacy. Note that the weight argument is set to
            # a constant 1.0 here, however the underlying AggregationProcess ignores
            # the parameter and performs no weighting.
            ignored_weight = tff.federated_value(1.0, tff.CLIENTS)
            aggregation_output = gan.dp_averaging_fn.next(
                server_state.dp_averaging_state,
                client_outputs.discriminator_weights_delta,
                weight=ignored_weight)
            new_dp_averaging_state = aggregation_output.state
            averaged_discriminator_weights_delta = aggregation_output.result

        # TODO(b/131085687): Perhaps reconsider the choice to also use
        # ClientOutput to hold the aggregated client output.
        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight),
            counters=tff.federated_sum(client_outputs.counters))

        server_computation = build_server_computation(
            gan, server_state.type_signature.member, client_output_type)
        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_dp_averaging_state))
        return server_state
Ejemplo n.º 30
0
    def next_fn(state, deltas, weights=None):
        @tff.tf_computation(deltas.type_signature.member, tf.float32)
        def clip_by_global_norm(delta, clip_norm):
            clipped, global_norm = tf.clip_by_global_norm(
                tf.nest.flatten(delta), clip_norm)
            clipped_deltas = tf.nest.pack_sequence_as(delta, clipped)
            return clipped_deltas, global_norm

        client_clip_norm = tff.federated_broadcast(state.clip_norm)
        clipped_deltas, client_norms = tff.federated_map(
            clip_by_global_norm, (deltas, client_clip_norm))
        # clip_norm no-op update here but could be set using max_norm.
        next_state = tff.federated_zip(
            ClipNormAggregateState(
                clip_norm=state.clip_norm,
                max_norm=tff.utils.federated_max(client_norms)))
        return next_state, tff.federated_mean(clipped_deltas, weight=weights)