Esempio n. 1
0
    def next_fn(state, value):
        encode_params, decode_before_sum_params, decode_after_sum_params = (
            intrinsics.federated_map(get_params_fn, state))
        encode_params = intrinsics.federated_broadcast(encode_params)
        decode_before_sum_params = intrinsics.federated_broadcast(
            decode_before_sum_params)

        encoded_values = intrinsics.federated_map(
            encode_fn, [value, encode_params, decode_before_sum_params])

        aggregated_values = intrinsics.federated_aggregate(
            encoded_values, zero_fn(), accumulate_fn, merge_fn, report_fn)

        decoded_values = intrinsics.federated_map(
            decode_after_sum_fn,
            [aggregated_values.values, decode_after_sum_params])

        updated_state = intrinsics.federated_map(
            update_state_fn, [state, aggregated_values.state_update_tensors])

        empty_metrics = intrinsics.federated_value((), placements.SERVER)
        return measured_process.MeasuredProcessOutput(
            state=updated_state,
            result=decoded_values,
            measurements=empty_metrics)
Esempio n. 2
0
        def next_fn(state, value):
            server_scale_factor = state['scale_factor']
            client_scale_factor = intrinsics.federated_broadcast(
                server_scale_factor)
            server_prior_norm_bound = state['prior_norm_bound']
            prior_norm_bound = intrinsics.federated_broadcast(
                server_prior_norm_bound)

            discretized_value = intrinsics.federated_map(
                discretize_fn, (value, client_scale_factor, prior_norm_bound))

            inner_state = state['inner_agg_process']
            inner_agg_output = inner_agg_process.next(inner_state,
                                                      discretized_value)

            undiscretized_agg_value = intrinsics.federated_map(
                undiscretize_fn,
                (inner_agg_output.result, server_scale_factor))

            new_state = collections.OrderedDict(
                scale_factor=server_scale_factor,
                prior_norm_bound=server_prior_norm_bound,
                inner_agg_process=inner_agg_output.state)
            measurements = collections.OrderedDict(
                discretize=inner_agg_output.measurements)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(new_state),
                result=undiscretized_agg_value,
                measurements=intrinsics.federated_zip(measurements))
Esempio n. 3
0
        def next_fn(state, value):
            clip_lower = intrinsics.federated_value(self._clip_range_lower,
                                                    placements.SERVER)
            clip_upper = intrinsics.federated_value(self._clip_range_upper,
                                                    placements.SERVER)

            # Modular clip values before aggregation.
            clipped_value = intrinsics.federated_map(
                modular_clip_by_value_fn,
                (value, intrinsics.federated_broadcast(clip_lower),
                 intrinsics.federated_broadcast(clip_upper)))

            inner_agg_output = inner_agg_next(state, clipped_value)

            # Clip the aggregate to the same range again (not considering summands).
            clipped_agg_output_result = intrinsics.federated_map(
                modular_clip_by_value_fn,
                (inner_agg_output.result, clip_lower, clip_upper))

            measurements = collections.OrderedDict(
                modclip=inner_agg_output.measurements)

            if self._estimate_stddev:
                estimate = intrinsics.federated_map(
                    estimator_fn,
                    (clipped_agg_output_result, clip_lower, clip_upper))
                measurements['estimated_stddev'] = estimate

            return measured_process.MeasuredProcessOutput(
                state=inner_agg_output.state,
                result=clipped_agg_output_result,
                measurements=intrinsics.federated_zip(measurements))
Esempio n. 4
0
 def _sum_securely(self, value, upper_bound, lower_bound):
   """Securely sums `value` placed at CLIENTS."""
   if self._config_mode == _Config.INT:
     value = intrinsics.federated_map(
         _client_shift, (value, intrinsics.federated_broadcast(upper_bound),
                         intrinsics.federated_broadcast(lower_bound)))
     value = intrinsics.federated_secure_sum(value, self._secagg_bitwidth)
     num_summands = intrinsics.federated_sum(_client_one())
     value = intrinsics.federated_map(_server_shift,
                                      (value, lower_bound, num_summands))
     return value
   elif self._config_mode == _Config.FLOAT:
     return primitives.secure_quantized_sum(value, lower_bound, upper_bound)
   else:
     raise ValueError(f'Unexpected internal config type: {self._config_mode}')
Esempio n. 5
0
    def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None):
        clipping_norm_state, agg_state, clipped_count_state = state

        clipping_norm = clipping_norm_process.report(clipping_norm_state)

        clipped_value, global_norm, was_clipped = intrinsics.federated_map(
            clip_fn, (value, intrinsics.federated_broadcast(clipping_norm)))

        new_clipping_norm_state = clipping_norm_process.next(
            clipping_norm_state, global_norm)

        if weight is None:
            agg_output = inner_agg_process.next(agg_state, clipped_value)
        else:
            agg_output = inner_agg_process.next(agg_state, clipped_value,
                                                weight)

        clipped_count_output = clipped_count_agg_process.next(
            clipped_count_state, was_clipped)

        new_state = collections.OrderedDict([
            (prefix('ing_norm'), new_clipping_norm_state),
            ('inner_agg', agg_output.state),
            (prefix('ed_count_agg'), clipped_count_output.state)
        ])
        measurements = collections.OrderedDict([
            (prefix('ing'), agg_output.measurements),
            (prefix('ing_norm'), clipping_norm),
            (prefix('ed_count'), clipped_count_output.result)
        ])

        return measured_process.MeasuredProcessOutput(
            state=intrinsics.federated_zip(new_state),
            result=agg_output.result,
            measurements=intrinsics.federated_zip(measurements))
Esempio n. 6
0
 def add_server_number_plus_one(server_number, client_numbers):
     one = intrinsics.federated_value(1, placements.SERVER)
     server_context = intrinsics.federated_map(add,
                                               (one, server_number))
     client_context = intrinsics.federated_broadcast(server_context)
     return intrinsics.federated_map(add,
                                     (client_context, client_numbers))
Esempio n. 7
0
 def _train_one_round(model, federated_data):
     locally_trained_models = intrinsics.federated_map(
         _train_on_one_client,
         collections.OrderedDict([('model',
                                   intrinsics.federated_broadcast(model)),
                                  ('batches', federated_data)]))
     return intrinsics.federated_mean(locally_trained_models)
Esempio n. 8
0
 def stateful_broadcast(state, value):
     test_metrics = intrinsics.federated_value(
         3.0, placements.SERVER)
     return measured_process_lib.MeasuredProcessOutput(
         state=state,
         result=intrinsics.federated_broadcast(value),
         measurements=test_metrics)
Esempio n. 9
0
    def next_fn(state, value):
      server_step_size = state['step_size']
      client_step_size = intrinsics.federated_broadcast(server_step_size)

      discretized_value = intrinsics.federated_map(discretize_fn,
                                                   (value, client_step_size))

      inner_state = state['inner_agg_process']
      inner_agg_output = inner_agg_process.next(inner_state, discretized_value)

      undiscretized_agg_value = intrinsics.federated_map(
          undiscretize_fn, (inner_agg_output.result, server_step_size))

      new_state = collections.OrderedDict(
          step_size=server_step_size, inner_agg_process=inner_agg_output.state)
      measurements = collections.OrderedDict(
          deterministic_discretization=inner_agg_output.measurements)

      if self._distortion_aggregation_factory is not None:
        distortions = intrinsics.federated_map(distortion_measurement_fn,
                                               (value, client_step_size))
        aggregate_distortion = distortion_aggregation_process.next(
            distortion_aggregation_process.initialize(), distortions).result
        measurements['distortion'] = aggregate_distortion

      return measured_process.MeasuredProcessOutput(
          state=intrinsics.federated_zip(new_state),
          result=undiscretized_agg_value,
          measurements=intrinsics.federated_zip(measurements))
Esempio n. 10
0
 def comp(temperatures, threshold):
     return intrinsics.federated_mean(
         intrinsics.federated_map(
             count_over,
             intrinsics.federated_zip([
                 temperatures,
                 intrinsics.federated_broadcast(threshold)
             ])), intrinsics.federated_map(count_total, temperatures))
Esempio n. 11
0
 def encoded_broadcast_comp(state, value):
   """Encoded broadcast federated_computation."""
   empty_metrics = intrinsics.federated_value((), placements.SERVER)
   new_state, encoded_value = intrinsics.federated_map(encode_fn,
                                                       (state, value))
   client_encoded_value = intrinsics.federated_broadcast(encoded_value)
   client_value = intrinsics.federated_map(decode_fn, client_encoded_value)
   return measured_process.MeasuredProcessOutput(
       state=new_state, result=client_value, measurements=empty_metrics)
Esempio n. 12
0
 def next_comp(state, value):
     return measured_process.MeasuredProcessOutput(
         state=intrinsics.federated_map(_add_one, state),
         result=intrinsics.federated_broadcast(value),
         # Arbitrary metrics for testing.
         measurements=intrinsics.federated_map(
             tensorflow_computation.tf_computation(
                 lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0),
             value))
Esempio n. 13
0
 def computation(arg):
     server_data, client_data = arg
     context_at_server = intrinsics.federated_map(bf.compute_server_context,
                                                  server_data)
     context_at_clients = intrinsics.federated_broadcast(context_at_server)
     client_processing_arg = intrinsics.federated_zip(
         (context_at_clients, client_data))
     return intrinsics.federated_map(bf.client_processing,
                                     client_processing_arg)
Esempio n. 14
0
    def next_fn(server_state, client_val):
        """`next` function for `tff.templates.IterativeProcess`."""
        server_update = intrinsics.federated_sum(client_val)
        server_output = intrinsics.federated_value((), placements.SERVER)
        state_at_clients = intrinsics.federated_broadcast(server_state)
        lambda_returning_sum = computation_returning_lambda()
        sum_fn = lambda_returning_sum(1)
        server_output = sum_fn(state_at_clients)

        return server_update, server_output
 def next_fn(state, weights, client_data):
   round_num_at_clients = intrinsics.federated_broadcast(state)
   client_result, model_outputs = intrinsics.federated_map(
       client_update_computation, (weights, client_data, round_num_at_clients))
   updated_state = intrinsics.federated_map(add_one, state)
   train_metrics = metrics_aggregation_fn(model_outputs)
   measurements = intrinsics.federated_zip(
       collections.OrderedDict(train=train_metrics))
   return measured_process.MeasuredProcessOutput(updated_state, client_result,
                                                 measurements)
 def aggregation_comp(server_arg, client_arg):
     client_sums = intrinsics.federated_map(compute_tuple_sum, client_arg)
     summed_client_value = intrinsics.federated_sum(client_sums)
     broadcast_sum = intrinsics.federated_broadcast(summed_client_value)
     # Adding a function call here requires normalization into CDF before
     # checking the aggregation-dependence condition.
     client_tuple = package_args_as_tuple(client_sums, broadcast_sum)
     summed_client_value = intrinsics.federated_sum(client_tuple[0])
     return intrinsics.federated_map(compute_sum,
                                     (server_arg, summed_client_value))
Esempio n. 17
0
 def next_comp(server_state, client_data):
   del server_state, client_data
   client_val = intrinsics.federated_value(0, placements.CLIENTS)
   server_agg = intrinsics.federated_sum(client_val)
   # This broadcast is dependent on the result of the above aggregation,
   # which is not supported by MapReduce form.
   broadcasted = intrinsics.federated_broadcast(server_agg)
   server_agg_again = intrinsics.federated_sum(broadcasted)
   # `next` must return two values.
   return server_agg_again, intrinsics.federated_value((), placements.SERVER)
Esempio n. 18
0
 def next_fn(server_state, client_data):
     """The `next` function for `tff.templates.IterativeProcess`."""
     s2 = intrinsics.federated_map(prepare, server_state)
     client_input = intrinsics.federated_broadcast(s2)
     c3 = intrinsics.federated_zip([client_data, client_input])
     client_updates = intrinsics.federated_map(work, c3)
     unsecure_update = intrinsics.federated_sum(client_updates)
     # No call to `federated_secure_sum_bitwidth`.
     s6 = intrinsics.federated_zip([server_state, unsecure_update])
     new_server_state, server_output = intrinsics.federated_map(update, s6)
     return new_server_state, server_output
Esempio n. 19
0
 def _compute_measurements(self, upper_bound, lower_bound, value_max,
                           value_min):
     """Creates measurements to be reported. All values are summed securely."""
     is_max_clipped = intrinsics.federated_map(
         computations.tf_computation(
             lambda bound, value: tf.cast(bound < value, COUNT_TF_TYPE)),
         (intrinsics.federated_broadcast(upper_bound), value_max))
     max_clipped_count = intrinsics.federated_secure_sum(is_max_clipped,
                                                         bitwidth=1)
     is_min_clipped = intrinsics.federated_map(
         computations.tf_computation(
             lambda bound, value: tf.cast(bound > value, COUNT_TF_TYPE)),
         (intrinsics.federated_broadcast(lower_bound), value_min))
     min_clipped_count = intrinsics.federated_secure_sum(is_min_clipped,
                                                         bitwidth=1)
     measurements = collections.OrderedDict(
         secure_upper_clipped_count=max_clipped_count,
         secure_lower_clipped_count=min_clipped_count,
         secure_upper_threshold=upper_bound,
         secure_lower_threshold=lower_bound)
     return intrinsics.federated_zip(measurements)
Esempio n. 20
0
    def next_fn(server_state, client_val):
        """`next` function for `tff.templates.IterativeProcess`."""
        server_update = intrinsics.federated_zip(
            collections.OrderedDict(
                num_clients=count_clients_federated(client_val)))

        server_output = intrinsics.federated_value((), placements.SERVER)
        server_output = intrinsics.federated_sum(
            _bind_tf_function(intrinsics.federated_broadcast(server_state),
                              tf.timestamp))

        return server_update, server_output
Esempio n. 21
0
 def next_fn(server_state, client_data):
   """The `next` function for `tff.templates.IterativeProcess`."""
   s2 = intrinsics.federated_map(prepare, server_state)
   unused_client_input, to_broadcast = broadcast_and_return_arg_and_result(s2)
   client_input = intrinsics.federated_broadcast(to_broadcast)
   c3 = intrinsics.federated_zip([client_data, client_input])
   client_updates = intrinsics.federated_map(work, c3)
   unsecure_update = intrinsics.federated_sum(client_updates[0])
   secure_update = intrinsics.federated_secure_sum_bitwidth(
       client_updates[1], 8)
   s6 = intrinsics.federated_zip(
       [server_state, [unsecure_update, secure_update]])
   new_server_state, server_output = intrinsics.federated_map(update, s6)
   return new_server_state, server_output
Esempio n. 22
0
 def next_fn(server_state, client_data):
     """The `next` function for `tff.templates.IterativeProcess`."""
     s2 = intrinsics.federated_map(prepare, server_state)
     client_input = intrinsics.federated_broadcast(s2)
     c3 = intrinsics.federated_zip([client_data, client_input])
     client_updates = intrinsics.federated_map(work, c3)
     unsecure_update = intrinsics.federated_sum(client_updates[0])
     secure_update = intrinsics.federated_secure_sum_bitwidth(
         client_updates[1], 8)
     new_server_state = intrinsics.federated_zip(
         [unsecure_update, secure_update])
     # No call to `federated_map` with an `update` function.
     server_output = intrinsics.federated_value([], placements.SERVER)
     return new_server_state, server_output
Esempio n. 23
0
  def personalization_eval(server_model_weights, federated_client_input):
    """TFF orchestration logic."""
    client_init_weights = intrinsics.federated_broadcast(server_model_weights)
    client_final_metrics = intrinsics.federated_map(
        _client_computation, (client_init_weights, federated_client_input))

    # WARNING: Collecting information from clients can be risky. Users have to
    # make sure that it is proper to collect those metrics from clients.
    # TODO(b/147889283): Add a link to the TFF doc once it exists.
    sampling_output = aggregation_process.next(
        aggregation_process.initialize(),  # No state.
        client_final_metrics)
    # In the future we may want to output `sampling_output.measurements` also
    # but currently it is empty.
    return sampling_output.result
Esempio n. 24
0
 def server_eval(server_model_weights, federated_dataset):
     if broadcast_process is not None:
         # TODO(b/179091838): Confirm that the process has no state.
         # TODO(b/179091838): Zip the measurements from the broadcast_process with
         # the result of `model.federated_output_computation` below to avoid
         # dropping these metrics.
         broadcast_output = broadcast_process.next(
             broadcast_process.initialize(), server_model_weights)
         client_outputs = intrinsics.federated_map(
             client_eval, (broadcast_output.result, federated_dataset))
     else:
         client_outputs = intrinsics.federated_map(client_eval, [
             intrinsics.federated_broadcast(server_model_weights),
             federated_dataset
         ])
     return model.federated_output_computation(client_outputs.local_outputs)
Esempio n. 25
0
        def next_fn(state, value):
            quantile_query_state, agg_state = state

            params = intrinsics.federated_broadcast(
                intrinsics.federated_map(derive_sample_params,
                                         quantile_query_state))
            quantile_record = intrinsics.federated_map(get_quantile_record,
                                                       (params, value))

            quantile_agg_output = quantile_agg_process.next(
                agg_state, quantile_record)

            _, new_quantile_query_state, _ = intrinsics.federated_map(
                get_noised_result,
                (quantile_agg_output.result, quantile_query_state))

            return intrinsics.federated_zip(
                (new_quantile_query_state, quantile_agg_output.state))
Esempio n. 26
0
    def next_fn(state, value):
      quantile_query_state, agg_state = state

      params = intrinsics.federated_broadcast(
          intrinsics.federated_map(derive_sample_params, quantile_query_state))
      quantile_record = intrinsics.federated_map(get_quantile_record,
                                                 (params, value))

      (new_agg_state, agg_result,
       agg_measurements) = quantile_agg_process.next(agg_state, quantile_record)

      # We expect the quantile record aggregation process to be something simple
      # like basic sum, so we won't surface its measurements.
      del agg_measurements

      _, new_quantile_query_state = intrinsics.federated_map(
          get_noised_result, (agg_result, quantile_query_state))

      return intrinsics.federated_zip((new_quantile_query_state, new_agg_state))
Esempio n. 27
0
 def next_computation(arg):
     """The logic of a single MapReduce processing round."""
     s1 = arg[0]
     c1 = arg[1]
     s2 = intrinsics.federated_map(mrf.prepare, s1)
     c2 = intrinsics.federated_broadcast(s2)
     c3 = intrinsics.federated_zip([c1, c2])
     c4 = intrinsics.federated_map(mrf.work, c3)
     c5 = c4[0]
     c6 = c4[1]
     s3 = intrinsics.federated_aggregate(c5, mrf.zero(), mrf.accumulate,
                                         mrf.merge, mrf.report)
     s4 = intrinsics.federated_secure_sum_bitwidth(c6, mrf.bitwidth())
     s5 = intrinsics.federated_zip([s3, s4])
     s6 = intrinsics.federated_zip([s1, s5])
     s7 = intrinsics.federated_map(mrf.update, s6)
     s8 = s7[0]
     s9 = s7[1]
     return s8, s9
Esempio n. 28
0
  def next_fn_impl(state, value):
    inner_state, my_state = state
    client_my_state = intrinsics.federated_broadcast(my_state)
    projected_value = intrinsics.federated_map(client_transform,
                                               (value, client_my_state))

    inner_agg_output = inner_agg_process.next(inner_state, projected_value)

    aggregate_value = intrinsics.federated_map(
        server_transform, (inner_agg_output.result, my_state))

    new_state = (inner_agg_output.state,
                 intrinsics.federated_map(update_my_state, my_state))
    measurements = collections.OrderedDict([(name,
                                             inner_agg_output.measurements)])

    return measured_process.MeasuredProcessOutput(
        state=intrinsics.federated_zip(new_state),
        result=aggregate_value,
        measurements=intrinsics.federated_zip(measurements))
Esempio n. 29
0
 def server_eval(server_model_weights, federated_dataset):
     client_eval = build_local_evaluation(model_fn, model_weights_type,
                                          batch_type,
                                          use_experimental_simulation_loop)
     if broadcast_process is not None:
         # TODO(b/179091838): Zip the measurements from the broadcast_process with
         # the result of `model_metrics` below to avoid dropping these metrics.
         broadcast_output = broadcast_process.next(
             broadcast_process.initialize(), server_model_weights)
         client_outputs = intrinsics.federated_map(
             client_eval, (broadcast_output.result, federated_dataset))
     else:
         client_outputs = intrinsics.federated_map(client_eval, [
             intrinsics.federated_broadcast(server_model_weights),
             federated_dataset
         ])
     model_metrics = metrics_aggregation_computation(
         client_outputs.local_outputs)
     return intrinsics.federated_zip(
         collections.OrderedDict(eval=model_metrics))
Esempio n. 30
0
 def server_eval(server_model_weights, federated_dataset):
     if broadcast_process is not None:
         # TODO(b/179091838): Zip the measurements from the broadcast_process with
         # the result of `model.federated_output_computation` below to avoid
         # dropping these metrics.
         broadcast_output = broadcast_process.next(
             broadcast_process.initialize(), server_model_weights)
         client_outputs = intrinsics.federated_map(
             client_eval, (broadcast_output.result, federated_dataset))
     else:
         client_outputs = intrinsics.federated_map(client_eval, [
             intrinsics.federated_broadcast(server_model_weights),
             federated_dataset
         ])
     model_metrics = model.federated_output_computation(
         client_outputs.local_outputs)
     statistics = collections.OrderedDict(
         num_examples=intrinsics.federated_sum(client_outputs.num_examples))
     return intrinsics.federated_zip(
         collections.OrderedDict(eval=model_metrics, stat=statistics))