def broadcast_next_fn(state, value):
    @computations.tf_computation(tf.int32)
    def add_one(value):
        return value + 1

    return intrinsics.federated_zip(
        collections.OrderedDict([
            ('call_count', intrinsics.federated_map(add_one, state.call_count))
        ])), intrinsics.federated_broadcast(value)
Example #2
0
 def initialize_computation():
     """Orchestration logic for server model initialization."""
     initial_global_model, initial_global_optimizer_state = intrinsics.federated_eval(
         server_init, placements.SERVER)
     return intrinsics.federated_zip(
         ServerState(model=initial_global_model,
                     optimizer_state=initial_global_optimizer_state,
                     delta_aggregate_state=aggregation_process.initialize(),
                     model_broadcast_state=broadcast_process.initialize()))
Example #3
0
 def comp(temperatures, threshold):
     client_data = [
         temperatures,
         intrinsics.federated_broadcast(threshold)
     ]
     result_map = intrinsics.federated_map(
         count_over, intrinsics.federated_zip(client_data))
     count_map = intrinsics.federated_map(count_total, temperatures)
     return intrinsics.federated_mean(result_map, count_map)
 def next_computation(arg):
   """The logic of a single MapReduce processing round."""
   s1 = arg[0]
   c1 = arg[1]
   s2 = intrinsics.federated_map(cf.prepare, s1)
   c2 = intrinsics.federated_broadcast(s2)
   c3 = intrinsics.federated_zip([c1, c2])
   c4 = intrinsics.federated_map(cf.work, c3)
   c5 = c4[0]
   c6 = c4[1]
   s3 = intrinsics.federated_aggregate(c5, cf.zero(), cf.accumulate, cf.merge,
                                       cf.report)
   s4 = intrinsics.federated_secure_sum(c6, cf.bitwidth())
   s5 = intrinsics.federated_zip([s3, s4])
   s6 = intrinsics.federated_zip([s1, s5])
   s7 = intrinsics.federated_map(cf.update, s6)
   s8 = s7[0]
   s9 = s7[1]
   return s8, s9
 def comp():
     return intrinsics.federated_zip(
         collections.OrderedDict([
             ('A',
              intrinsics.federated_value(10,
                                         placement_literals.CLIENTS)),
             ('B',
              intrinsics.federated_value(20,
                                         placement_literals.CLIENTS)),
         ]))
Example #6
0
    def one_round_computation(server_state, federated_dataset):
        """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
      `tff.learning.Model.federated_output_computation`, both having
      `tff.SERVER` placement.
    """
        broadcast_output = broadcast_process.next(
            server_state.model_broadcast_state, server_state.model)
        client_outputs = intrinsics.federated_map(
            _compute_local_training_and_client_delta,
            (federated_dataset, broadcast_output.result))
        if len(aggregation_process.next.type_signature.parameter) == 3:
            aggregation_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta,
                client_outputs.weights_delta_weight)
        else:
            aggregation_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta)
        new_global_model, new_optimizer_state = intrinsics.federated_map(
            server_update, (server_state.model, aggregation_output.result,
                            server_state.optimizer_state))
        new_server_state = intrinsics.federated_zip(
            ServerState(new_global_model, new_optimizer_state,
                        aggregation_output.state, broadcast_output.state))
        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        optimizer_outputs = intrinsics.federated_sum(
            client_outputs.optimizer_output)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(
                broadcast=broadcast_output.measurements,
                aggregation=aggregation_output.measurements,
                train=aggregated_outputs,
                stat=optimizer_outputs))
        return new_server_state, measurements
Example #7
0
 def next_fn(server_state, client_data):
     """The `next` function for `computation_utils.IterativeProcess`."""
     del server_state  # Unused
     client_updates, client_output = intrinsics.federated_map(
         work, client_data)
     unsecure_update = intrinsics.federated_sum(client_updates[0])
     secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
     s5 = intrinsics.federated_zip([unsecure_update, secure_update])
     new_server_state, server_output = intrinsics.federated_map(update, s5)
     return new_server_state, server_output, client_output
Example #8
0
 def next_fn(server_val, client_val):
     """Defines a series of federated computations compatible with CanonicalForm."""
     broadcast_val = intrinsics.federated_broadcast(server_val)
     values_on_clients = intrinsics.federated_zip(
         (client_val, broadcast_val))
     result_on_clients = intrinsics.federated_map(add_two,
                                                  values_on_clients)
     aggregated_result = intrinsics.federated_mean(result_on_clients)
     side_output = intrinsics.federated_value([1, 2, 3, 4, 5],
                                              placements.SERVER)
     return aggregated_result, side_output
 def next_fn(server_state, client_data):
   """The `next` function for `tff.templates.IterativeProcess`."""
   # No call to `federated_map` with prepare.
   # No call to `federated_broadcast`.
   client_updates, client_output = intrinsics.federated_map(work, client_data)
   unsecure_update = intrinsics.federated_sum(client_updates[0])
   secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
   s6 = intrinsics.federated_zip(
       [server_state, [unsecure_update, secure_update]])
   new_server_state, server_output = intrinsics.federated_map(update, s6)
   return new_server_state, server_output, client_output
Example #10
0
 def next_fn(server_state, client_data):
     """The `next` function for `tff.templates.IterativeProcess`."""
     del client_data
     # No call to `federated_aggregate`.
     unsecure_update = intrinsics.federated_value(1, placements.SERVER)
     # No call to `federated_secure_sum`.
     secure_update = intrinsics.federated_value(1, placements.SERVER)
     s6 = intrinsics.federated_zip(
         [server_state, [unsecure_update, secure_update]])
     new_server_state, server_output = intrinsics.federated_map(update, s6)
     return new_server_state, server_output
Example #11
0
 def initialize_computation():
     model = model_fn()
     initial_global_model, initial_global_optimizer_state = intrinsics.federated_eval(
         server_init_tf, tff.SERVER)
     return intrinsics.federated_zip(
         ServerState(
             model=initial_global_model,
             optimizer_state=initial_global_optimizer_state,
             round_num=tff.federated_value(0.0, tff.SERVER),
             aggregation_state=aggregation_process.initialize(),
         ))
Example #12
0
  def next_fn(server_state, client_val):
    """`next` function for `computation_utils.IterativeProcess`."""
    server_update = intrinsics.federated_zip(
        collections.OrderedDict([('num_clients',
                                  count_clients_federated(client_val))]))

    server_output = intrinsics.federated_value((), placements.SERVER)
    server_output = _bind_federated_value(
        intrinsics.federated_broadcast(server_state), server_state_type,
        server_output)

    return server_update, server_output
Example #13
0
    def next_fn(server_state, client_val):
        """`next` function for `tff.utils.IterativeProcess`."""
        server_update = intrinsics.federated_zip(
            collections.OrderedDict(
                num_clients=count_clients_federated(client_val)))

        server_output = intrinsics.federated_value((), placements.SERVER)
        server_output = intrinsics.federated_sum(
            _bind_tf_function(intrinsics.federated_broadcast(server_state),
                              tf.timestamp))

        return server_update, server_output
Example #14
0
        def next_fn(state, value, weight):
            (clipping_norm_state, agg_state, clipped_count_state,
             zeroed_count_state) = state

            clipping_norm = self._clipping_norm_process.report(
                clipping_norm_state)
            zeroing_norm = intrinsics.federated_map(self._zeroing_norm_fn,
                                                    clipping_norm)

            (zeroed_and_clipped, global_norm,
             was_clipped, was_zeroed) = intrinsics.federated_map(
                 clip_and_zero,
                 (value, intrinsics.federated_broadcast(clipping_norm),
                  intrinsics.federated_broadcast(zeroing_norm)))

            new_clipping_norm_state = self._clipping_norm_process.next(
                clipping_norm_state, global_norm)
            agg_output = inner_agg_process.next(agg_state, zeroed_and_clipped,
                                                weight)
            clipped_count_output = clipped_count_agg_process.next(
                clipped_count_state, was_clipped)
            zeroed_count_output = zeroed_count_agg_process.next(
                zeroed_count_state, was_zeroed)

            new_state = collections.OrderedDict(
                clipping_norm=new_clipping_norm_state,
                inner_agg=agg_output.state,
                clipped_count_agg=clipped_count_output.state,
                zeroed_count_agg=zeroed_count_output.state)
            measurements = collections.OrderedDict(
                agg_process=agg_output.measurements,
                clipping_norm=clipping_norm,
                zeroing_norm=zeroing_norm,
                clipped_count=clipped_count_output.result,
                zeroed_count=zeroed_count_output.result)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(new_state),
                result=agg_output.result,
                measurements=intrinsics.federated_zip(measurements))
 def initialize_computation():
     model = model_fn()
     initial_global_model, initial_global_optimizer_state = intrinsics.federated_eval(
         server_init_tf, placements.SERVER)
     return intrinsics.federated_zip(
         ServerState(
             model=initial_global_model,
             optimizer_state=initial_global_optimizer_state,
             round_num=tff.federated_value(0.0, tff.SERVER),
             effective_num_clients=intrinsics.federated_eval(
                 get_effective_num_clients, placements.SERVER),
             delta_aggregate_state=aggregation_process.initialize(),
         ))
    def comp(temperatures, threshold):

      @computations.tf_computation(
          computation_types.SequenceType(tf.float32), tf.float32)
      def count(ds, t):
        return ds.reduce(
            np.int32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.int32))

      return intrinsics.federated_map(
          count,
          intrinsics.federated_zip(
              [temperatures,
               intrinsics.federated_broadcast(threshold)]))
 def next_fn(server_state, client_data):
   """The `next` function for `tff.templates.IterativeProcess`."""
   del server_state  # Unused
   # No call to `federated_map` with prepare.
   # No call to `federated_broadcast`.
   client_updates = intrinsics.federated_map(work, client_data)
   unsecure_update = intrinsics.federated_sum(client_updates[0])
   secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
   new_server_state = intrinsics.federated_zip(
       [unsecure_update, secure_update])
   # No call to `federated_map` with an `update` function.
   server_output = intrinsics.federated_value([], placements.SERVER)
   return new_server_state, server_output
Example #18
0
        def next_fn(state, value):
            query_state, agg_state = state

            params = intrinsics.federated_broadcast(
                intrinsics.federated_map(derive_sample_params, query_state))
            record = intrinsics.federated_map(get_query_record,
                                              (params, value))

            (new_agg_state, agg_result,
             agg_measurements) = record_agg_process.next(agg_state, record)

            result, new_query_state = intrinsics.federated_map(
                get_noised_result, (agg_result, query_state))

            query_metrics = intrinsics.federated_map(derive_metrics,
                                                     new_query_state)

            new_state = (new_query_state, new_agg_state)
            measurements = collections.OrderedDict(
                dp_query_metrics=query_metrics, dp=agg_measurements)
            return measured_process.MeasuredProcessOutput(
                intrinsics.federated_zip(new_state), result,
                intrinsics.federated_zip(measurements))
Example #19
0
        def next_fn(state, value):
            value_sum_output = value_sum_process.next(state, value)
            count = intrinsics.federated_sum(
                intrinsics.federated_value(1, placements.CLIENTS))

            mean_value = intrinsics.federated_map(
                _div, (value_sum_output.result, count))
            state = value_sum_output.state
            measurements = intrinsics.federated_zip(
                collections.OrderedDict(
                    mean_value=value_sum_output.measurements))

            return measured_process.MeasuredProcessOutput(
                state, mean_value, measurements)
Example #20
0
        def next_fn(state, value, weight):
            # Client computation.
            weighted_value = intrinsics.federated_map(_mul, (value, weight))

            # Inner aggregations.
            value_output = value_sum_process.next(state['value_sum_process'],
                                                  weighted_value)
            weight_output = weight_sum_process.next(
                state['weight_sum_process'], weight)

            # Server computation.
            weighted_mean_value = intrinsics.federated_map(
                _div, (value_output.result, weight_output.result))

            # Output preparation.
            state = collections.OrderedDict(
                value_sum_process=value_output.state,
                weight_sum_process=weight_output.state)
            measurements = collections.OrderedDict(
                value_sum_process=value_output.measurements,
                weight_sum_process=weight_output.measurements)
            return measured_process.MeasuredProcessOutput(
                intrinsics.federated_zip(state), weighted_mean_value,
                intrinsics.federated_zip(measurements))
        def next_fn(state, value, weight):
            zeroing_norm_state, agg_state = state

            zeroing_norm = self._zeroing_norm_process.report(
                zeroing_norm_state)

            zeroed, norm = intrinsics.federated_map(
                zero, (value, intrinsics.federated_broadcast(zeroing_norm)))

            agg_output = inner_agg_process.next(agg_state, zeroed, weight)
            new_zeroing_norm_state = self._zeroing_norm_process.next(
                zeroing_norm_state, norm)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(
                    (new_zeroing_norm_state, agg_output.state)),
                result=agg_output.result,
                measurements=agg_output.measurements)
        def next_fn(state, value, weight):
            clipping_norm_state, agg_state = state

            clipping_norm = self._clipping_norm_process.report(
                clipping_norm_state)

            clipped_value, global_norm = intrinsics.federated_map(
                clip, (value, intrinsics.federated_broadcast(clipping_norm)))

            agg_output = inner_agg_process.next(agg_state, clipped_value,
                                                weight)
            new_clipping_norm_state = self._clipping_norm_process.next(
                clipping_norm_state, global_norm)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(
                    (new_clipping_norm_state, agg_output.state)),
                result=agg_output.result,
                measurements=agg_output.measurements)
Example #23
0
    def next_fn(state, value):
      quantile_query_state, agg_state = state

      params = intrinsics.federated_broadcast(
          intrinsics.federated_map(derive_sample_params, quantile_query_state))
      quantile_record = intrinsics.federated_map(get_quantile_record,
                                                 (params, value))

      (new_agg_state, agg_result,
       agg_measurements) = quantile_agg_process.next(agg_state, quantile_record)

      # We expect the quantile record aggregation process to be something simple
      # like basic sum, so we won't surface its measurements.
      del agg_measurements

      _, new_quantile_query_state = intrinsics.federated_map(
          get_noised_result, (agg_result, quantile_query_state))

      return intrinsics.federated_zip((new_quantile_query_state, new_agg_state))
Example #24
0
 def _compute_measurements(self, upper_bound, lower_bound, value_max,
                           value_min):
     """Creates measurements to be reported. All values are summed securely."""
     is_max_clipped = intrinsics.federated_map(
         computations.tf_computation(
             lambda bound, value: tf.cast(bound < value, COUNT_TF_TYPE)),
         (intrinsics.federated_broadcast(upper_bound), value_max))
     max_clipped_count = intrinsics.federated_secure_sum(is_max_clipped,
                                                         bitwidth=1)
     is_min_clipped = intrinsics.federated_map(
         computations.tf_computation(
             lambda bound, value: tf.cast(bound > value, COUNT_TF_TYPE)),
         (intrinsics.federated_broadcast(lower_bound), value_min))
     min_clipped_count = intrinsics.federated_secure_sum(is_min_clipped,
                                                         bitwidth=1)
     measurements = collections.OrderedDict(
         upper_bound_clipped_count=max_clipped_count,
         lower_bound_clipped_count=min_clipped_count,
         upper_bound_threshold=upper_bound,
         lower_bound_threshold=lower_bound)
     return intrinsics.federated_zip(measurements)
Example #25
0
  def next_fn(global_state, value, weight):
    sample_params = intrinsics.federated_broadcast(
        intrinsics.federated_map(derive_sample_params, global_state))
    weighted_value, adj_weight, quantile_record, too_large = (
        intrinsics.federated_map(preprocess_value,
                                 (sample_params, value, weight)))

    value_sum = intrinsics.federated_sum(weighted_value)
    total_weight = intrinsics.federated_sum(adj_weight)
    quantile_sum = intrinsics.federated_sum(quantile_record)
    num_zeroed = intrinsics.federated_sum(too_large)

    mean_value = intrinsics.federated_map(divide_no_nan,
                                          (value_sum, total_weight))

    new_threshold, new_global_state = intrinsics.federated_map(
        next_quantile, (quantile_sum, global_state))

    measurements = intrinsics.federated_zip(
        AdaptiveZeroingMetrics(new_threshold, num_zeroed))

    return measured_process.MeasuredProcessOutput(
        state=new_global_state, result=mean_value, measurements=measurements)
Example #26
0
 def baz(x):
     val = intrinsics.federated_zip(x)
     self.assertIsInstance(val, value_base.Value)
     return val
Example #27
0
 def bar(x):
     arg = structure.Struct(_make_test_tuple(x, k) for k in range(n))
     val = intrinsics.federated_zip(arg)
     self.assertIsInstance(val, value_base.Value)
     return val
Example #28
0
 def foo(x):
     arg = {str(k): x[k] for k in range(n)}
     val = intrinsics.federated_zip(arg)
     self.assertIsInstance(val, value_base.Value)
     return val
Example #29
0
 def _(arg):
     return intrinsics.federated_zip(arg)
Example #30
0
 def foo(arg):
     val = intrinsics.federated_zip(arg)
     self.assertIsInstance(val, value_base.Value)
     return val