Beispiel #1
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Note that in addition to updating the server weights according to the client
    model weight deltas, we extract metrics (governed by the `monitor` attribute
    of the `client_lr_callback` and `server_lr_callback` attributes of the
    `server_state`) and use these to update the client learning rate callbacks.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation` before and during local
      client training.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_lr = tff.federated_broadcast(
            server_state.client_lr_callback.learning_rate)

        if dataset_preprocess_comp is not None:
            federated_dataset = tff.federated_map(dataset_preprocess_comp,
                                                  federated_dataset)
        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, client_model, client_lr))

        client_weight = client_outputs.client_weight
        aggregated_gradients = tff.federated_mean(
            client_outputs.accumulated_gradients, weight=client_weight)

        initial_aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.initial_model_output)
        if isinstance(initial_aggregated_outputs.type_signature,
                      tff.StructType):
            initial_aggregated_outputs = tff.federated_zip(
                initial_aggregated_outputs)

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)
        client_monitor_value = initial_aggregated_outputs[client_monitor]
        server_monitor_value = initial_aggregated_outputs[server_monitor]

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregated_gradients,
                               client_monitor_value, server_monitor_value))

        result = collections.OrderedDict(
            before_training=initial_aggregated_outputs,
            during_training=aggregated_outputs)

        return server_state, result
Beispiel #2
0
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        # TODO(b/131429028): The federated_zip should be automatic.
        from_server = tff.federated_zip(
            gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights))
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        if gan.dp_averaging_fn is None:
            # Not using differential privacy.
            new_dp_averaging_state = server_state.dp_averaging_state
            averaged_discriminator_weights_delta = tff.federated_mean(
                client_outputs.discriminator_weights_delta,
                weight=client_outputs.update_weight)
        else:
            # Using differential privacy. Note that the weight argument is set to None
            # here. This is because the DP aggregation code explicitly does not do
            # weighted aggregation. (If weighted aggregation is desired, differential
            # privacy needs to be turned off.)
            new_dp_averaging_state, averaged_discriminator_weights_delta = (
                gan.dp_averaging_fn(server_state.dp_averaging_state,
                                    client_outputs.discriminator_weights_delta,
                                    weight=None))

        # TODO(b/131085687): Perhaps reconsider the choice to also use
        # ClientOutput to hold the aggregated client output.
        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight),
            counters=tff.federated_sum(client_outputs.counters))

        # TODO(b/131839522): This federated_zip shouldn't be needed.
        aggregated_client_output = tff.federated_zip(aggregated_client_output)

        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_dp_averaging_state))
        return server_state
Beispiel #3
0
 def server_init():
     initial_model, server_optimizer_state = tff.federated_eval(
         server_init_tf, tff.SERVER)
     return tff.federated_zip(
         ServerState(model=initial_model,
                     optimizer_state=server_optimizer_state,
                     delta_aggregate_state=aggregation_process_init()))
Beispiel #4
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        weight_denom = client_outputs.weights_delta_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        server_state = tff.federated_apply(server_update_fn,
                                           (server_state, round_model_delta))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Beispiel #5
0
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        weight_denom = client_outputs.weights_delta_weight
        round_model_delta = tff.federated_mean(client_outputs.weights_delta,
                                               weight=weight_denom)

        round_grads_norm = tff.federated_mean(
            client_outputs.optimizer_output.flat_grads_norm_sum,
            weight=weight_denom)

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, round_grads_norm))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Beispiel #6
0
  def next_fn(state, deltas, weights):

    @tff.tf_computation(model_update_type)
    def clip_by_global_norm(update):
      clipped_update, global_norm = tf.clip_by_global_norm(
          tf.nest.flatten(update), tf.constant(clip_norm))
      was_clipped = tf.cond(
          tf.greater(global_norm, tf.constant(clip_norm)),
          lambda: tf.constant(1),
          lambda: tf.constant(0),
      )
      clipped_update = tf.nest.pack_sequence_as(update, clipped_update)
      return clipped_update, global_norm, was_clipped

    clipped_deltas, client_norms, client_was_clipped = tff.federated_map(
        clip_by_global_norm, deltas)

    return collections.OrderedDict(
        state=state,
        result=tff.federated_mean(clipped_deltas, weight=weights),
        measurements=tff.federated_zip(
            NormClippedAggregationMetrics(
                max_global_norm=tff.utils.federated_max(client_norms),
                num_clipped=tff.federated_sum(client_was_clipped),
            )))
Beispiel #7
0
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of TrieHH computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      An updated `ServerState`
    """
    discovered_prefixes = tff.federated_broadcast(
        server_state.discovered_prefixes)
    round_num = tff.federated_broadcast(server_state.round_num)

    client_outputs = tff.federated_map(
        client_update_fn,
        tff.federated_zip([federated_dataset, discovered_prefixes, round_num]))

    accumulated_votes = tff.federated_sum(client_outputs.client_votes)

    server_state = tff.federated_map(server_update_fn,
                                     (server_state, accumulated_votes))

    server_output = tff.federated_value([], tff.SERVER)

    return server_state, server_output
Beispiel #8
0
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
    client_model = tff.federated_broadcast(server_state.model)
    client_round_num = tff.federated_broadcast(server_state.round_num)

    client_outputs = tff.federated_map(
        client_update_fn,
        (federated_dataset, client_model, client_round_num))

    client_weight = client_outputs.client_weight
    model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=client_weight)

    server_state = tff.federated_map(server_update_fn,
                                     (server_state, model_delta))

    aggregated_outputs = dummy_model.federated_output_computation(
        client_outputs.model_output)
    if aggregated_outputs.type_signature.is_struct():
      aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs
Beispiel #9
0
 def comp(temperatures, threshold):
   return tff.federated_mean(
       tff.federated_map(
           count_over,
           tff.federated_zip(
               [temperatures,
                tff.federated_broadcast(threshold)])),
       tff.federated_map(count_total, temperatures))
Beispiel #10
0
 def fed_server_initial_state():
   state = tff.federated_eval(build_server_initial_state_comp(gan), tff.SERVER)
   server_initial_state = tff.federated_zip(
       gan_training_tf_fns.ServerState(
           state.generator_weights,
           state.discriminator_weights,
           state.counters,
           aggregation_state=gan.aggregation_process.initialize()))
   return server_initial_state
 def server_init_tff():
     """Returns a `reconstruction_utils.ServerState` placed at `tff.SERVER`."""
     tf_init_tuple = tff.federated_eval(server_init_tf, tff.SERVER)
     aggregation_process_init = aggregation_process.initialize()
     return tff.federated_zip(
         reconstruction_utils.ServerState(
             model=tf_init_tuple[0],
             optimizer_state=tf_init_tuple[1],
             round_num=tf_init_tuple[2],
             aggregator_state=aggregation_process_init))
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)
        client_optimizer_weights = tff.federated_broadcast(
            server_state.optimizer_state)
        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, client_model,
                               client_round_num, client_optimizer_weights))

        client_weight = client_outputs.client_weight
        model_delta = tff.federated_mean(client_outputs.weights_delta,
                                         weight=client_weight)

        implicit_mean_broad = tff.federated_broadcast(
            tff.federated_mean(client_outputs.weights, weight=client_weight))

        @tff.tf_computation(model_weights_type.trainable,
                            model_weights_type.trainable)
        def individual_drift_fn(client_model, implicit_mean):
            drifts = tf.nest.map_structure(lambda a, b: a - b, client_model,
                                           implicit_mean)
            return drifts

        drifts = tff.federated_map(
            individual_drift_fn, (client_outputs.weights, implicit_mean_broad))

        @tff.tf_computation(model_weights_type.trainable)
        def individual_drift_norm_fn(delta_from_mean):
            drift_norm = tf.reduce_sum(
                tf.nest.flatten(
                    tf.nest.map_structure(lambda a: tf.nn.l2_loss(a),
                                          delta_from_mean)))
            return drift_norm

        drift_norms = tff.federated_map(individual_drift_norm_fn, drifts)
        client_drift = tff.federated_mean(drift_norms, weight=client_weight)

        server_state = tff.federated_map(
            server_update_fn, (server_state, model_delta, client_drift))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)
        return server_state, aggregated_outputs
Beispiel #13
0
 def fed_server_initial_state():
   state = tff.federated_eval(build_server_initial_state_comp(gan), tff.SERVER)
   dp_averaging_state = (
       state.dp_averaging_state
       if gan.dp_averaging_fn is None else gan.dp_averaging_fn.initialize())
   server_initial_state = tff.federated_zip(
       gan_training_tf_fns.ServerState(
           state.generator_weights,
           state.discriminator_weights,
           state.counters,
           dp_averaging_state=dp_averaging_state))
   return server_initial_state
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        # Run computation on the clients.
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_num))

        # Aggregate model deltas.
        client_weight = client_outputs.client_weight
        if mask_zeros_in_client_updates:
            model_delta = federated_mean_masked(client_outputs.weights_delta,
                                                client_weight)
        else:
            model_delta = tff.federated_mean(client_outputs.weights_delta,
                                             client_weight)

        server_state = tff.federated_map(server_update_fn,
                                         (server_state, model_delta))

        # Aggregate model outputs that contain local metrics and various statistics.
        aggregated_outputs = placeholder_model.federated_output_computation(
            client_outputs.model_output)
        additional_outputs = tff.federated_mean(
            client_outputs.additional_output, weight=client_weight)

        @tff.tf_computation(aggregated_outputs.type_signature.member,
                            additional_outputs.type_signature.member)
        def _update_aggregated_outputs(aggregated_outputs, additional_outputs):
            aggregated_outputs.update(additional_outputs)
            return aggregated_outputs

        aggregated_outputs = tff.federated_map(
            _update_aggregated_outputs,
            (aggregated_outputs, additional_outputs))
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
    def run_one_round(server_state, federated_dataset, malicious_dataset,
                      malicious_clients):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
      malicious_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
        consisting of malicious datasets.
      malicious_clients: A federated `tf.bool` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """

        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, malicious_dataset,
                               malicious_clients, client_model))

        weight_denom = client_outputs.weights_delta_weight

        # If the aggregation process' next function takes three arguments it is
        # weighted, otherwise, unweighted. Unfortunately there is no better way
        # to determine this.
        if len(aggregation_process.next.type_signature.parameter) == 3:
            aggregate_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta,
                weight=weight_denom)
        else:
            aggregate_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta)
        new_delta_aggregate_state = aggregate_output.state
        round_model_delta = aggregate_output.result

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, new_delta_aggregate_state))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Beispiel #16
0
  def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
    client_model = tff.federated_broadcast(server_state.model)
    client_optimizer_state = tff.federated_broadcast(
        server_state.client_optimizer_state)
    client_round_num = tff.federated_broadcast(server_state.round_num)
    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, client_model,
                           client_optimizer_state, client_round_num))

    client_model_weight = client_outputs.client_weight.model_weight
    client_opt_weight = client_outputs.client_weight.optimizer_weight

    model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=client_model_weight)

    # We convert the optimizer state to a float type so that it can be used
    # with thing such as `tff.federated_mean`. This is only necessary because
    # `tf.keras.Optimizer` objects have a state with an integer indicating
    # the number of times it has been applied.
    client_optimizer_state_delta = tff.federated_map(
        _convert_opt_state_to_float, client_outputs.optimizer_state_delta)
    client_optimizer_state_delta = optimizer_aggregator(
        client_optimizer_state_delta, weight=client_opt_weight)
    # We conver the optimizer state back into one with an integer round number
    client_optimizer_state_delta = tff.federated_map(
        _convert_opt_state_to_int, client_optimizer_state_delta)

    server_state = tff.federated_map(
        server_update_fn,
        (server_state, model_delta, client_optimizer_state_delta))

    aggregated_outputs = placeholder_model.federated_output_computation(
        client_outputs.model_output)
    if aggregated_outputs.type_signature.is_struct():
      aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs
Beispiel #17
0
    def next_fn(state, deltas, weights=None):
        @tff.tf_computation(deltas.type_signature.member, tf.float32)
        def clip_by_global_norm(delta, clip_norm):
            clipped, global_norm = tf.clip_by_global_norm(
                tf.nest.flatten(delta), clip_norm)
            clipped_deltas = tf.nest.pack_sequence_as(delta, clipped)
            return clipped_deltas, global_norm

        client_clip_norm = tff.federated_broadcast(state.clip_norm)
        clipped_deltas, client_norms = tff.federated_map(
            clip_by_global_norm, (deltas, client_clip_norm))
        # clip_norm no-op update here but could be set using max_norm.
        next_state = tff.federated_zip(
            ClipNormAggregateState(
                clip_norm=state.clip_norm,
                max_norm=tff.utils.federated_max(client_norms)))
        return next_state, tff.federated_mean(clipped_deltas, weight=weights)
Beispiel #18
0
    def run_one_round(server_state, federated_dataset, malicious_dataset,
                      malicious_clients):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
      malicious_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
        consisting of malicious datasets.
      malicious_clients: A federated `tf.bool` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """

        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, malicious_dataset,
                               malicious_clients, client_model))

        weight_denom = client_outputs.weights_delta_weight

        if aggregation_process.is_weighted:
            aggregate_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta,
                weight=weight_denom)
        else:
            aggregate_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta)
        new_delta_aggregate_state = aggregate_output.state
        round_model_delta = aggregate_output.result

        server_state = tff.federated_map(
            server_update_fn,
            (server_state, round_model_delta, new_delta_aggregate_state))

        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        if isinstance(aggregated_outputs.type_signature, tff.StructType):
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
Beispiel #19
0
    def next_fn(state, deltas, weights=None):
        @tff.tf_computation(deltas.type_signature.member, tf.float32)
        def clip_by_global_norm(delta, clip_norm):
            # TODO(b/123092620): Replace anonymous_tuple with tf.nest.
            delta = anonymous_tuple.from_container(delta)
            clipped, global_norm = tf.clip_by_global_norm(
                anonymous_tuple.flatten(delta), clip_norm)
            return anonymous_tuple.pack_sequence_as(delta,
                                                    clipped), global_norm

        client_clip_norm = tff.federated_broadcast(state.clip_norm)
        clipped_deltas, client_norms = tff.federated_map(
            clip_by_global_norm, (deltas, client_clip_norm))
        # clip_norm no-op update here but could be set using max_norm.
        next_state = tff.federated_zip(
            ClipNormAggregateState(
                clip_norm=state.clip_norm,
                max_norm=tff.utils.federated_max(client_norms)))
        return next_state, tff.federated_mean(clipped_deltas, weight=weights)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)

        client_outputs = tff.federated_map(client_update_fn,
                                           (federated_dataset, client_model))

        client_weight = client_outputs.client_weight
        model_delta = tff.federated_mean(client_outputs.weights_delta,
                                         weight=client_weight)
        global_cor_states = tff.federated_mean(client_outputs.local_cor_states,
                                               weight=client_weight)
        if correction == 'joint':
            server_state = tff.federated_map(
                server_update_joint_cor_fn,
                (server_state, model_delta, global_cor_states))
        elif correction == 'local':
            server_state = tff.federated_map(server_update_fn,
                                             (server_state, model_delta))
        else:
            raise TypeError('Correction method must be local or joint.')

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs
        def next_fn(state, value):
            clip_range_lower, clip_range_upper = self._get_clip_range()

            # Modular clip values before aggregation.
            clipped_value = tff.federated_map(
                modular_clip_by_value_tff,
                (value, tff.federated_broadcast(clip_range_lower),
                 tff.federated_broadcast(clip_range_upper)))

            (agg_output_state, agg_output_result,
             agg_output_measurements) = inner_agg_next(state, clipped_value)

            # Clip the aggregate to the same range again (not considering summands).
            clipped_agg_output_result = tff.federated_map(
                modular_clip_by_value_tff,
                (agg_output_result, clip_range_lower, clip_range_upper))

            measurements = collections.OrderedDict(
                agg_process=agg_output_measurements)

            return tff.templates.MeasuredProcessOutput(
                state=agg_output_state,
                result=clipped_agg_output_result,
                measurements=tff.federated_zip(measurements))
Beispiel #22
0
 def foo(x):
   return tff.federated_zip(x)
def mean_over_threshold(temperatures, threshold):
    client_data = tff.federated_broadcast(threshold)
    client_data = tff.federated_zip([temperatures, client_data])
    result_map = tff.federated_map(count_over, client_data)
    count_map = tff.federated_map(count_total, temperatures)
    return tff.federated_mean(result_map, count_map)
    def run_one_round(server_state, federated_dataset, ids):
        """Orchestration logic for one round of computation.

    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.

    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
        client_model = tff.federated_broadcast(server_state.model)
        client_round_num = tff.federated_broadcast(server_state.round_num)

        client_outputs = tff.federated_map(
            client_update_fn,
            (federated_dataset, client_model, client_round_num, ids))

        client_weight = client_outputs.client_weight
        client_id = client_outputs.client_id

        #LOSS SELECTION:
        # losses_at_server = tff.federated_collect(client_outputs.model_output)
        # weights_at_server = tff.federated_collect(client_weight)
        @computations.tf_computation
        def zeros_fn():
            return tf.zeros(shape=[total_clients, 1], dtype=tf.float32)

        zero = zeros_fn()

        at_server_type = tff.TensorType(shape=[total_clients, 1],
                                        dtype=tf.float32)
        # list_type = tff.SequenceType( tff.TensorType(dtype=tf.float32))
        client_output_type = client_update_fn.type_signature.result

        @computations.tf_computation(at_server_type, client_output_type)
        def accumulate_weight(u, t):
            value = t.client_weight
            index = t.client_id
            new_u = tf.tensor_scatter_nd_update(u, index, value)
            return new_u

        @computations.tf_computation(at_server_type, client_output_type)
        def accumulate_loss(u, t):
            value = tf.reshape(tf.math.reduce_sum(t.model_output['loss']),
                               shape=[1, 1])
            index = t.client_id
            new_u = tf.tensor_scatter_nd_update(u, index, value)
            return new_u

        # output_at_server= tff.federated_collect(client_outputs)

        weights_at_server = tff.federated_reduce(client_outputs, zero,
                                                 accumulate_weight)
        losses_at_server = tff.federated_reduce(client_outputs, zero,
                                                accumulate_loss)
        #losses_at_server = tff.federated_aggregate(client_outputs.model_output, zero, accumulate, merge, report)

        selected_clients_weights = tff.federated_map(
            zero_small_loss_clients, (losses_at_server, weights_at_server,
                                      server_state.effective_num_clients))

        # selected_clients_weights_at_client = tff.federated_broadcast(selected_clients_weights)

        selected_clients_weights_broadcast = tff.federated_broadcast(
            selected_clients_weights)

        selected_clients_weights_at_client = tff.federated_map(
            select_weight_fn, (selected_clients_weights_broadcast, ids))

        aggregation_output = aggregation_process.next(
            server_state.delta_aggregate_state, client_outputs.weights_delta,
            selected_clients_weights_at_client)

        # model_delta = tff.federated_mean(
        #     client_outputs.weights_delta, weight=client_weight)

        server_state = tff.federated_map(
            server_update_fn, (server_state, aggregation_output.result))

        aggregated_outputs = dummy_model.federated_output_computation(
            client_outputs.model_output)
        if aggregated_outputs.type_signature.is_struct():
            aggregated_outputs = tff.federated_zip(aggregated_outputs)

        return server_state, aggregated_outputs