コード例 #1
0
ファイル: fed_avg_schedule.py プロジェクト: wlj0417/federated
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
    client_model = tff.federated_broadcast(server_state.model)
    client_round_num = tff.federated_broadcast(server_state.round_num)

    client_outputs = tff.federated_map(
        client_update_fn,
        (federated_dataset, client_model, client_round_num))

    client_weight = client_outputs.client_weight
    model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=client_weight)

    server_state = tff.federated_map(server_update_fn,
                                     (server_state, model_delta))

    aggregated_outputs = dummy_model.federated_output_computation(
        client_outputs.model_output)
    if aggregated_outputs.type_signature.is_struct():
      aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs
コード例 #2
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
        discovered_prefixes = tff.federated_broadcast(
            server_state.discovered_prefixes)
        round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, discovered_prefixes, round_num))

        accumulated_votes = tff.federated_sum(client_outputs.client_votes)

        accumulated_weights = tff.federated_sum(client_outputs.client_weight)

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, accumulated_votes, accumulated_weights))

        server_output = tff.federated_value([], tff.SERVER)

        return server_state, server_output
コード例 #3
0
def federated_train(model, learning_rate, data, classes):
    return tff.federated_mean(
        tff.federated_map(local_train, [
            tff.federated_broadcast(model),
            tff.federated_broadcast(learning_rate), data,
            tff.federated_broadcast(classes)
        ]))
コード例 #4
0
    def run_one_round(server_state, client_states, federated_dataset,
                      federated_dataset_single):
        from_server = FromServer(w=server_state.w,
                                 meta_w=server_state.meta_w,
                                 round_num=server_state.round_num)
        from_server = tff.federated_broadcast(from_server)
        control_output = tff.federated_map(
            control_computation,
            (federated_dataset_single, client_state, from_server))
        c = tff.federated_broadcast(
            tff.federated_mean(control_output.w_delta,
                               weight=control_output.client_weight))

        @tff.tf_computation(client_output_type.w_delta,
                            client_output_type.w_delta)
        def compute_control_input(c, c_i):
            correction = tf.nest.map_structure(lambda a, b: a - b, c, c_i)
            # if we are using SCAFFOLD then use the correction, otherwise
            # we just let the correction be zero.
            return tf.cond(
                tf.constant(control, dtype=tf.bool), lambda: correction,
                lambda: tf.nest.map_structure(tf.zeros_like, correction))

        # the collection of gradient corrections
        corrections = tff.federated_map(compute_control_input,
                                        (c, control_output.cs))
        client_outputs = tff.federated_map(
            client_computation,
            (federated_dataset, from_server, client_state, corrections))
        w_delta = tff.federated_mean(client_outputs.w_delta,
                                     weight=client_outputs.client_weight)
        server_state = tff.federated_map(server_computation,
                                         (server_state, w_delta))
        return (server_state, client_outputs.client_state)
コード例 #5
0
def federated_train(model, learning_rate, data):
    l = tff.federated_map(
        local_train,
        [tff.federated_broadcast(model),
         tff.federated_broadcast(learning_rate),
         data])
    return l
コード例 #6
0
def federated_train(model, lr, data):
    #返回的是训练后的模型
    return tff.federated_mean(
        tff.federated_map(local_train, [
            tff.federated_broadcast(model),
            tff.federated_broadcast(lr), data
        ]))
コード例 #7
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)
        client_optimizer_weights = tff.federated_broadcast(
            server_state.optimizer_state)
        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, client_model,
                               client_round_num, client_optimizer_weights))

        client_weight = client_outputs.client_weight
        model_delta = tff.federated_mean(client_outputs.weights_delta,
                                         weight=client_weight)

        implicit_mean_broad = tff.federated_broadcast(
            tff.federated_mean(client_outputs.weights, weight=client_weight))

        @tff.tf_computation(model_weights_type.trainable,
                            model_weights_type.trainable)
        def individual_drift_fn(client_model, implicit_mean):
            drifts = tf.nest.map_structure(lambda a, b: a - b, client_model,
                                           implicit_mean)
            return drifts

        drifts = tff.federated_map(
            individual_drift_fn, (client_outputs.weights, implicit_mean_broad))

        @tff.tf_computation(model_weights_type.trainable)
        def individual_drift_norm_fn(delta_from_mean):
            drift_norm = tf.reduce_sum(
                tf.nest.flatten(
                    tf.nest.map_structure(lambda a: tf.nn.l2_loss(a),
                                          delta_from_mean)))
            return drift_norm

        drift_norms = tff.federated_map(individual_drift_norm_fn, drifts)
        client_drift = tff.federated_mean(drift_norms, weight=client_weight)

        server_state = tff.federated_map(
            server_update_fn, (server_state, model_delta, client_drift))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)
        return server_state, aggregated_outputs
コード例 #8
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Note that in addition to updating the server weights according to the client
    model weight deltas, we extract metrics (governed by the `monitor` attribute
    of the `client_lr_callback` and `server_lr_callback` attributes of the
    `server_state`) and use these to update the client learning rate callbacks.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation` before and during local
      client training.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_lr = tff.federated_broadcast(
            server_state.client_lr_callback.learning_rate)

        if dataset_preprocess_comp is not None:
            federated_dataset = tff.federated_map(dataset_preprocess_comp,
                                                  federated_dataset)
        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, client_model, client_lr))

        client_weight = client_outputs.client_weight
        aggregated_gradients = tff.federated_mean(
            client_outputs.accumulated_gradients, weight=client_weight)

        initial_aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.initial_model_output)
        if isinstance(initial_aggregated_outputs.type_signature,
                      tff.StructType):
            initial_aggregated_outputs = tff.federated_zip(
                initial_aggregated_outputs)

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)
        client_monitor_value = initial_aggregated_outputs[client_monitor]
        server_monitor_value = initial_aggregated_outputs[server_monitor]

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregated_gradients,
                               client_monitor_value, server_monitor_value))

        result = collections.OrderedDict(
            before_training=initial_aggregated_outputs,
            during_training=aggregated_outputs)

        return server_state, result
コード例 #9
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        # Run computation on the clients.
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_num))

        # Aggregate model deltas.
        client_weight = client_outputs.client_weight
        if mask_zeros_in_client_updates:
            model_delta = federated_mean_masked(client_outputs.weights_delta,
                                                client_weight)
        else:
            model_delta = tff.federated_mean(client_outputs.weights_delta,
                                             client_weight)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, model_delta))

        # Aggregate model outputs that contain local metrics and various statistics.
        aggregated_outputs = placeholder_model.federated_output_computation(
            client_outputs.model_output)
        additional_outputs = tff.federated_mean(
            client_outputs.additional_output, weight=client_weight)

        @tff.tf_computation(aggregated_outputs.type_signature.member,
                            additional_outputs.type_signature.member)
        def _update_aggregated_outputs(aggregated_outputs, additional_outputs):
            aggregated_outputs.update(additional_outputs)
            return aggregated_outputs

        aggregated_outputs = tff.federated_map(
            _update_aggregated_outputs,
            (aggregated_outputs, additional_outputs))
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
コード例 #10
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState` containing the state of the training process
        up to the current round.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS` containing data to train the current round on.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        weight_denom = client_outputs.client_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, round_model_delta))
        aggregated_outputs = metrics_aggregation_computation(
            client_outputs.model_output)
        return server_state, aggregated_outputs
コード例 #11
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        # Model deltas are equally weighted in DP.
        round_model_delta = tff.federated_mean(client_outputs.weights_delta)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, round_model_delta))
        round_loss_metric = tff.federated_mean(client_outputs.model_output)

        return server_state, round_loss_metric
コード例 #12
0
ファイル: flars_fedavg.py プロジェクト: zhuchen03/federated
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        weight_denom = client_outputs.weights_delta_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        round_grads_norm = tff.federated_mean(
            client_outputs.optimizer_output.flat_grads_norm_sum,
            weight=weight_denom)

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, round_grads_norm))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
コード例 #13
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        weight_denom = client_outputs.weights_delta_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        server_state = tff.federated_apply(server_update_fn,
                                           (server_state, round_model_delta))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
コード例 #14
0
    def run_one_round(server_state, federated_dataset, client_states):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `stateful_fedavg_tf.ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.
      client_states: A federated `stateful_fedavg_tf.ClientState`.

    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_states, server_message_at_client))

        weight_denom = client_outputs.client_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)
        total_iters_count = tff.federated_sum(
            client_outputs.client_state.iters_count)
        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, total_iters_count))
        round_loss_metric = tff.federated_mean(client_outputs.model_output,
                                               weight=weight_denom)

        return server_state, round_loss_metric, client_outputs.client_state
コード例 #15
0
ファイル: __init__.py プロジェクト: Lando-L/ocd-detection
    def validate(weights, datasets, client_states):
        broadcast = tff.federated_broadcast(weights)
        outputs = tff.federated_map(validate_client_tf,
                                    (datasets, client_states, broadcast))
        metrics = model.federated_output_computation(outputs.metrics)

        return metrics
コード例 #16
0
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
    client_model = tff.federated_broadcast(server_state.model)
    client_optimizer_state = tff.federated_broadcast(
        server_state.client_optimizer_state)
    client_round_num = tff.federated_broadcast(server_state.round_num)
    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, client_model,
                           client_optimizer_state, client_round_num))

    client_model_weight = client_outputs.client_weight.model_weight
    client_opt_weight = client_outputs.client_weight.optimizer_weight

    model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=client_model_weight)

    # We convert the optimizer state to a float type so that it can be used
    # with thing such as `tff.federated_mean`. This is only necessary because
    # `tf.keras.Optimizer` objects have a state with an integer indicating
    # the number of times it has been applied.
    client_optimizer_state_delta = tff.federated_map(
        _convert_opt_state_to_float, client_outputs.optimizer_state_delta)
    client_optimizer_state_delta = optimizer_aggregator(
        client_optimizer_state_delta, weight=client_opt_weight)
    # We conver the optimizer state back into one with an integer round number
    client_optimizer_state_delta = tff.federated_map(
        _convert_opt_state_to_int, client_optimizer_state_delta)

    server_state = tff.federated_map(
        server_update_fn,
        (server_state, model_delta, client_optimizer_state_delta))

    aggregated_outputs = placeholder_model.federated_output_computation(
        client_outputs.model_output)
    if aggregated_outputs.type_signature.is_struct():
      aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs
コード例 #17
0
 def train_one_round(model, federated_data):
     locally_trained_models = tff.federated_map(
         train_on_one_client,
         collections.OrderedDict([('model', tff.federated_broadcast(model)),
                                  ('batches', federated_data)]))
     return tff.federated_aggregate(locally_trained_models,
                                    make_zero_model_and_count(), accumulate,
                                    merge, report)
コード例 #18
0
 def comp(temperatures, threshold):
   return tff.federated_mean(
       tff.federated_map(
           count_over,
           tff.federated_zip(
               [temperatures,
                tff.federated_broadcast(threshold)])),
       tff.federated_map(count_total, temperatures))
コード例 #19
0
 def robust_aggregation_fn(value, weight):
     aggregate = tff.federated_mean(value, weight=weight)
     for _ in range(num_communication_passes - 1):
         aggregate_at_client = tff.federated_broadcast(aggregate)
         updated_weight = tff.federated_map(
             update_weight_fn, (weight, aggregate_at_client, value))
         aggregate = tff.federated_mean(value, weight=updated_weight)
     return aggregate
コード例 #20
0
 def _robust_aggregation_fn(state, value, weight):
     aggregate = tff.federated_mean(value, weight=weight)
     for _ in range(num_communication_passes - 1):
         aggregate_at_client = tff.federated_broadcast(aggregate)
         updated_weight = tff.federated_map(
             update_weight_fn, (weight, aggregate_at_client, value))
         aggregate = tff.federated_mean(value, weight=updated_weight)
     no_metrics = tff.federated_value((), tff.SERVER)
     return tff.templates.MeasuredProcessOutput(state, aggregate,
                                                no_metrics)
コード例 #21
0
ファイル: __init__.py プロジェクト: Lando-L/ocd-detection
    def evaluate(weights, datasets, client_states):
        broadcast = tff.federated_broadcast(weights)
        outputs = tff.federated_map(evaluate_client_tf,
                                    (datasets, client_states, broadcast))

        confusion_matrix = tff.federated_sum(outputs.confusion_matrix)
        aggregated_metrics = model.federated_output_computation(
            outputs.metrics)
        collected_metrics = tff.federated_collect(outputs.metrics)

        return confusion_matrix, aggregated_metrics, collected_metrics
コード例 #22
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and aggregated metrics.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_number = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_number))

        if len(aggregation_process.next.type_signature.parameter) == 3:
            # Weighted aggregation.
            aggregation_output = aggregation_process.next(
                server_state.aggregator_state,
                client_outputs.weights_delta,
                weight=client_outputs.client_weight)
        else:
            # Unweighted aggregation.
            aggregation_output = aggregation_process.next(
                server_state.aggregator_state, client_outputs.weights_delta)

        round_model_delta = aggregation_output.result

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, aggregation_output.state))

        aggregated_model_outputs = federated_output_computation(
            client_outputs.model_output)

        # We drop the `measurements` portion of the aggregation_output here, as it
        # is not necessary for our experiments.

        return server_state, aggregated_model_outputs
コード例 #23
0
def next_fn(server_weights, federated_dataset):
    # Broadcast the server weights to the clients.
    server_weights_at_client = tff.federated_broadcast(server_weights)

    # Each client computes their updated weights.
    client_weights = client_update(federated_dataset, server_weights_at_client)

    # The server averages these updates.
    mean_client_weights = np.mean(client_weights)

    # The server updates its model.
    server_weights = server_update(mean_client_weights)

    return server_weights
コード例 #24
0
    def encoded_broadcast_fn(state, value):
        """Broadcast function, to be wrapped as federated_computation."""

        state_type = state.type_signature.member
        value_type = value.type_signature.member

        encode_fn, decode_fn = _build_encode_decode_tf_computations_for_broadcast(
            state_type, value_type, encoders)

        new_state, encoded_value = tff.federated_apply(encode_fn,
                                                       (state, value))
        client_encoded_value = tff.federated_broadcast(encoded_value)
        client_value = tff.federated_map(decode_fn, client_encoded_value)
        return new_state, client_value
コード例 #25
0
    def run_one_round(server_state, federated_dataset, malicious_dataset,
                      malicious_clients):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
      malicious_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
        consisting of malicious datasets.
      malicious_clients: A federated `tf.bool` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """

        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, malicious_dataset,
                               malicious_clients, client_model))

        weight_denom = client_outputs.weights_delta_weight

        # If the aggregation process' next function takes three arguments it is
        # weighted, otherwise, unweighted. Unfortunately there is no better way
        # to determine this.
        if len(aggregation_process.next.type_signature.parameter) == 3:
            aggregate_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta,
                weight=weight_denom)
        else:
            aggregate_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta)
        new_delta_aggregate_state = aggregate_output.state
        round_model_delta = aggregate_output.result

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, new_delta_aggregate_state))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
コード例 #26
0
ファイル: __init__.py プロジェクト: Lando-L/ocd-detection
    def next_tff(server_state, datasets, client_states):
        message = tff.federated_map(state_to_message_tf, server_state)
        broadcast = tff.federated_broadcast(message)

        outputs = tff.federated_map(update_client_tf,
                                    (datasets, client_states, broadcast))
        weights_delta = tff.federated_mean(outputs.weights_delta,
                                           weight=outputs.client_weight)

        metrics = model.federated_output_computation(outputs.metrics)

        next_state = tff.federated_map(update_server_tf,
                                       (server_state, weights_delta))

        return next_state, metrics, outputs.client_state
コード例 #27
0
        def run_gradient_computation_round(server_state, federated_dataset):
            """Orchestration logic for one round of gradient computation.
            Args:
              server_state: A `ServerState`.
              federated_dataset: A federated `tf.data.Dataset` with placement
                `tff.CLIENTS`.
            Returns:
            A tuple of updated `tf.Tensor` of clients initial probability and `ClientOutput`.
            """
            server_message = tff.federated_map(server_message_fn, server_state)
            server_message_at_client = tff.federated_broadcast(server_message)

            client_outputs = tff.federated_map(
                client_update_fn,
                (federated_dataset, server_message_at_client))

            update_norm_sum_weighted = tff.federated_sum(
                client_outputs.update_norm_weighted)
            norm_sum_clients_weighted = tff.federated_broadcast(
                update_norm_sum_weighted)

            prob_init = scale_on_clients(client_outputs.update_norm_weighted,
                                         norm_sum_clients_weighted)
            return prob_init, client_outputs
コード例 #28
0
        def next_fn(state, value):
            clip_range_lower, clip_range_upper = self._get_clip_range()

            # Modular clip values before aggregation.
            clipped_value = tff.federated_map(
                modular_clip_by_value_tff,
                (value, tff.federated_broadcast(clip_range_lower),
                 tff.federated_broadcast(clip_range_upper)))

            (agg_output_state, agg_output_result,
             agg_output_measurements) = inner_agg_next(state, clipped_value)

            # Clip the aggregate to the same range again (not considering summands).
            clipped_agg_output_result = tff.federated_map(
                modular_clip_by_value_tff,
                (agg_output_result, clip_range_lower, clip_range_upper))

            measurements = collections.OrderedDict(
                agg_process=agg_output_measurements)

            return tff.templates.MeasuredProcessOutput(
                state=agg_output_state,
                result=clipped_agg_output_result,
                measurements=tff.federated_zip(measurements))
コード例 #29
0
def next_fn(server_weights, federated_dataset):
    # 将服务器模型广播到客户端上
    server_weights_at_client = tff.federated_broadcast(server_weights)

    # 客户端计算更新过程,并更新参数
    client_weights, clients_loss = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_client))

    # 服务器平均所有客户端更新的模型参数
    mean_client_weights = tff.federated_mean(client_weights)

    # 服务器更新他的模型
    server_weights = tff.federated_map(server_update_fn, mean_client_weights)

    return server_weights, client_weights, clients_loss
コード例 #30
0
ファイル: tff_gans.py プロジェクト: uu0316/federated
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        # TODO(b/131429028): The federated_zip should be automatic.
        from_server = tff.federated_zip(
            gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights))
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        if gan.dp_averaging_fn is None:
            # Not using differential privacy.
            new_dp_averaging_state = server_state.dp_averaging_state
            averaged_discriminator_weights_delta = tff.federated_mean(
                client_outputs.discriminator_weights_delta,
                weight=client_outputs.update_weight)
        else:
            # Using differential privacy. Note that the weight argument is set to None
            # here. This is because the DP aggregation code explicitly does not do
            # weighted aggregation. (If weighted aggregation is desired, differential
            # privacy needs to be turned off.)
            new_dp_averaging_state, averaged_discriminator_weights_delta = (
                gan.dp_averaging_fn(server_state.dp_averaging_state,
                                    client_outputs.discriminator_weights_delta,
                                    weight=None))

        # TODO(b/131085687): Perhaps reconsider the choice to also use
        # ClientOutput to hold the aggregated client output.
        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight),
            counters=tff.federated_sum(client_outputs.counters))

        # TODO(b/131839522): This federated_zip shouldn't be needed.
        aggregated_client_output = tff.federated_zip(aggregated_client_output)

        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_dp_averaging_state))
        return server_state