Example #1
0
 def next_fn(server_val, client_val):
   """Defines a series of federated computations compatible with CanonicalForm."""
   broadcast_val = intrinsics.federated_broadcast(server_val)
   values_on_clients = intrinsics.federated_zip((client_val, broadcast_val))
   result_on_clients = intrinsics.federated_map(add_two, values_on_clients)
   aggregated_result = intrinsics.federated_mean(result_on_clients)
   side_output = intrinsics.federated_value([1, 2, 3, 4, 5], placements.SERVER)
   return aggregated_result, side_output
Example #2
0
 def foo(temperatures, threshold):
     return intrinsics.federated_sum(
         intrinsics.federated_map(
             computations.tf_computation(
                 lambda x, y: tf.to_int32(tf.greater(x, y)),
                 [tf.float32, tf.float32]),
             [temperatures,
              intrinsics.federated_broadcast(threshold)]))
 def comp(temperatures, threshold):
     return intrinsics.federated_mean(
         intrinsics.federated_map(
             count_over,
             intrinsics.federated_zip([
                 temperatures,
                 intrinsics.federated_broadcast(threshold)
             ])), intrinsics.federated_map(count_total, temperatures))
Example #4
0
 def encoded_broadcast_fn(state, value):
     """Encoded broadcast federated_computation."""
     new_state, encoded_value = intrinsics.federated_map(
         encode_fn, (state, value))
     client_encoded_value = intrinsics.federated_broadcast(encoded_value)
     client_value = intrinsics.federated_map(decode_fn,
                                             client_encoded_value)
     return new_state, client_value
Example #5
0
 def next_comp(state, value):
     return measured_process.MeasuredProcessOutput(
         state=intrinsics.federated_map(_add_one, state),
         result=intrinsics.federated_broadcast(value),
         # Arbitrary metrics for testing.
         measurements=intrinsics.federated_map(
             computations.tf_computation(
                 lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0),
             value))
Example #6
0
 def encoded_broadcast_comp(state, value):
   """Encoded broadcast federated_computation."""
   empty_metrics = intrinsics.federated_value((), placements.SERVER)
   new_state, encoded_value = intrinsics.federated_map(encode_fn,
                                                       (state, value))
   client_encoded_value = intrinsics.federated_broadcast(encoded_value)
   client_value = intrinsics.federated_map(decode_fn, client_encoded_value)
   return measured_process.MeasuredProcessOutput(
       state=new_state, result=client_value, measurements=empty_metrics)
Example #7
0
 def comp(temperatures, threshold):
     client_data = [
         temperatures,
         intrinsics.federated_broadcast(threshold)
     ]
     result_map = intrinsics.federated_map(
         count_over, intrinsics.federated_zip(client_data))
     count_map = intrinsics.federated_map(count_total, temperatures)
     return intrinsics.federated_mean(result_map, count_map)
Example #8
0
 def foo(temperatures, threshold):
     val = intrinsics.federated_sum(
         intrinsics.federated_map(
             computations.tf_computation(
                 lambda x, y: tf.cast(tf.greater(x, y), tf.int32)),
             [temperatures,
              intrinsics.federated_broadcast(threshold)]))
     self.assertIsInstance(val, value_base.Value)
     return val
def broadcast_next_fn(state, value):
    @computations.tf_computation(tf.int32)
    def add_one(value):
        return value + 1

    return intrinsics.federated_zip(
        collections.OrderedDict([
            ('call_count', intrinsics.federated_map(add_one, state.call_count))
        ])), intrinsics.federated_broadcast(value)
Example #10
0
    def next_fn(server_state, client_val):
        """`next` function for `tff.templates.IterativeProcess`."""
        server_update = intrinsics.federated_sum(client_val)
        server_output = intrinsics.federated_value((), placements.SERVER)
        state_at_clients = intrinsics.federated_broadcast(server_state)
        lambda_returning_sum = computation_returning_lambda()
        sum_fn = lambda_returning_sum(1)
        server_output = sum_fn(state_at_clients)

        return server_update, server_output
Example #11
0
 def next_fn(server_state, client_data):
     """The `next` function for `computation_utils.IterativeProcess`."""
     s2 = intrinsics.federated_map(prepare, server_state)
     client_input = intrinsics.federated_broadcast(s2)
     c3 = intrinsics.federated_zip([client_data, client_input])
     client_updates, client_output = intrinsics.federated_map(work, c3)
     unsecure_update = intrinsics.federated_sum(client_updates)
     s6 = intrinsics.federated_zip([server_state, unsecure_update])
     new_server_state, server_output = intrinsics.federated_map(update, s6)
     return new_server_state, server_output, client_output
Example #12
0
 def _sum_securely(self, value, upper_bound, lower_bound):
     """Securely sums `value` placed at CLIENTS."""
     if self._config_mode == _Config.INT:
         value = intrinsics.federated_map(
             _client_shift,
             (value, intrinsics.federated_broadcast(upper_bound),
              intrinsics.federated_broadcast(lower_bound)))
         value = intrinsics.federated_secure_sum(value,
                                                 self._secagg_bitwidth)
         num_summands = intrinsics.federated_sum(_client_one())
         value = intrinsics.federated_map(
             _server_shift, (value, lower_bound, num_summands))
         return value
     elif self._config_mode == _Config.FLOAT:
         return federated_aggregations.secure_quantized_sum(
             value, lower_bound, upper_bound)
     else:
         raise ValueError(
             f'Unexpected internal config type: {self._config_mode}')
 def next_fn(server_state, client_data):
   """The `next` function for `tff.templates.IterativeProcess`."""
   s2 = intrinsics.federated_map(prepare, server_state)
   client_input = intrinsics.federated_broadcast(s2)
   c3 = intrinsics.federated_zip([client_data, client_input])
   client_updates, client_output = intrinsics.federated_map(work, c3)
   # No call to `federated_aggregate`.
   secure_update = intrinsics.federated_secure_sum(client_updates, 8)
   s6 = intrinsics.federated_zip([server_state, secure_update])
   new_server_state, server_output = intrinsics.federated_map(update, s6)
   return new_server_state, server_output, client_output
Example #14
0
 def next_fn(server_state, client_data):
     """The `next` function for `computation_utils.IterativeProcess`."""
     client_input = intrinsics.federated_broadcast(server_state)
     c3 = intrinsics.federated_zip([client_data, client_input])
     client_updates = intrinsics.federated_map(work, c3)
     unsecure_update = intrinsics.federated_sum(client_updates[0])
     secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
     new_server_state = intrinsics.federated_zip(
         [unsecure_update, secure_update])
     server_output = intrinsics.federated_value([], placements.SERVER)
     return new_server_state, server_output
 def next_fn(server_state, client_data):
   """The `next` function for `tff.templates.IterativeProcess`."""
   # No call to `federated_map` with a `prepare` function.
   client_input = intrinsics.federated_broadcast(server_state)
   c3 = intrinsics.federated_zip([client_data, client_input])
   client_updates, client_output = intrinsics.federated_map(work, c3)
   unsecure_update = intrinsics.federated_sum(client_updates[0])
   secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
   s6 = intrinsics.federated_zip(
       [server_state, [unsecure_update, secure_update]])
   new_server_state, server_output = intrinsics.federated_map(update, s6)
   return new_server_state, server_output, client_output
Example #16
0
    def next_fn(server_state, client_val):
        """`next` function for `tff.utils.IterativeProcess`."""
        server_update = intrinsics.federated_zip(
            collections.OrderedDict(
                num_clients=count_clients_federated(client_val)))

        server_output = intrinsics.federated_value((), placements.SERVER)
        server_output = intrinsics.federated_sum(
            _bind_tf_function(intrinsics.federated_broadcast(server_state),
                              tf.timestamp))

        return server_update, server_output
Example #17
0
  def next_fn(server_state, client_val):
    """`next` function for `computation_utils.IterativeProcess`."""
    server_update = intrinsics.federated_zip(
        collections.OrderedDict([('num_clients',
                                  count_clients_federated(client_val))]))

    server_output = intrinsics.federated_value((), placements.SERVER)
    server_output = _bind_federated_value(
        intrinsics.federated_broadcast(server_state), server_state_type,
        server_output)

    return server_update, server_output
Example #18
0
 def next_fn(server_state, client_data):
     """The `next` function for `computation_utils.IterativeProcess`."""
     s2 = intrinsics.federated_map(prepare, server_state)
     client_input = intrinsics.federated_broadcast(s2)
     c3 = intrinsics.federated_zip([client_data, client_input])
     _, client_output = intrinsics.federated_map(work, c3)
     federated_update = intrinsics.federated_value(1, placements.SERVER)
     secure_update = intrinsics.federated_value(1, placements.SERVER)
     s6 = intrinsics.federated_zip(
         [server_state, [federated_update, secure_update]])
     new_server_state, server_output = intrinsics.federated_map(update, s6)
     return new_server_state, server_output, client_output
Example #19
0
 def _compute_measurements(self, upper_bound, lower_bound, value_max,
                           value_min):
     """Creates measurements to be reported. All values are summed securely."""
     is_max_clipped = intrinsics.federated_map(
         computations.tf_computation(
             lambda bound, value: tf.cast(bound < value, COUNT_TF_TYPE)),
         (intrinsics.federated_broadcast(upper_bound), value_max))
     max_clipped_count = intrinsics.federated_secure_sum(is_max_clipped,
                                                         bitwidth=1)
     is_min_clipped = intrinsics.federated_map(
         computations.tf_computation(
             lambda bound, value: tf.cast(bound > value, COUNT_TF_TYPE)),
         (intrinsics.federated_broadcast(lower_bound), value_min))
     min_clipped_count = intrinsics.federated_secure_sum(is_min_clipped,
                                                         bitwidth=1)
     measurements = collections.OrderedDict(
         upper_bound_clipped_count=max_clipped_count,
         lower_bound_clipped_count=min_clipped_count,
         upper_bound_threshold=upper_bound,
         lower_bound_threshold=lower_bound)
     return intrinsics.federated_zip(measurements)
Example #20
0
        def next_fn(state, value, weight):
            (clipping_norm_state, agg_state, clipped_count_state,
             zeroed_count_state) = state

            clipping_norm = self._clipping_norm_process.report(
                clipping_norm_state)
            zeroing_norm = intrinsics.federated_map(self._zeroing_norm_fn,
                                                    clipping_norm)

            (zeroed_and_clipped, global_norm,
             was_clipped, was_zeroed) = intrinsics.federated_map(
                 clip_and_zero,
                 (value, intrinsics.federated_broadcast(clipping_norm),
                  intrinsics.federated_broadcast(zeroing_norm)))

            new_clipping_norm_state = self._clipping_norm_process.next(
                clipping_norm_state, global_norm)
            agg_output = inner_agg_process.next(agg_state, zeroed_and_clipped,
                                                weight)
            clipped_count_output = clipped_count_agg_process.next(
                clipped_count_state, was_clipped)
            zeroed_count_output = zeroed_count_agg_process.next(
                zeroed_count_state, was_zeroed)

            new_state = collections.OrderedDict(
                clipping_norm=new_clipping_norm_state,
                inner_agg=agg_output.state,
                clipped_count_agg=clipped_count_output.state,
                zeroed_count_agg=zeroed_count_output.state)
            measurements = collections.OrderedDict(
                agg_process=agg_output.measurements,
                clipping_norm=clipping_norm,
                zeroing_norm=zeroing_norm,
                clipped_count=clipped_count_output.result,
                zeroed_count=zeroed_count_output.result)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(new_state),
                result=agg_output.result,
                measurements=intrinsics.federated_zip(measurements))
    def comp(temperatures, threshold):

      @computations.tf_computation(
          computation_types.SequenceType(tf.float32), tf.float32)
      def count(ds, t):
        return ds.reduce(
            np.int32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.int32))

      return intrinsics.federated_map(
          count,
          intrinsics.federated_zip(
              [temperatures,
               intrinsics.federated_broadcast(threshold)]))
 def next_fn(server_state, client_data):
   """The `next` function for `tff.templates.IterativeProcess`."""
   s2 = intrinsics.federated_map(prepare, server_state)
   client_input = intrinsics.federated_broadcast(s2)
   c3 = intrinsics.federated_zip([client_data, client_input])
   client_updates, client_output = intrinsics.federated_map(work, c3)
   unsecure_update = intrinsics.federated_sum(client_updates[0])
   secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
   new_server_state = intrinsics.federated_zip(
       [unsecure_update, secure_update])
   # No call to `federated_map` with an `update` function.
   server_output = intrinsics.federated_value([], placements.SERVER)
   return new_server_state, server_output, client_output
Example #23
0
    def personalization_eval(server_model_weights, federated_client_input):
        """TFF orchestration logic."""
        client_init_weights = intrinsics.federated_broadcast(
            server_model_weights)
        client_final_metrics = intrinsics.federated_map(
            _client_computation, (client_init_weights, federated_client_input))

        # WARNING: Collecting information from clients can be risky. Users have to
        # make sure that it is proper to collect those metrics from clients.
        # TODO(b/147889283): Add a link to the TFF doc once it exists.
        results = federated_aggregations.federated_sample(
            client_final_metrics, max_num_clients)
        return results
 def next_fn(server_state, client_data):
     """The `next` function for `tff.utils.IterativeProcess`."""
     s2 = intrinsics.federated_map(prepare, server_state)
     client_input = intrinsics.federated_broadcast(s2)
     c3 = intrinsics.federated_zip([client_data, client_input])
     # No client output.
     client_updates = intrinsics.federated_map(work, c3)
     unsecure_update = intrinsics.federated_sum(client_updates[0])
     secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
     s6 = intrinsics.federated_zip(
         [server_state, [unsecure_update, secure_update]])
     new_server_state, server_output = intrinsics.federated_map(update, s6)
     return new_server_state, server_output
Example #25
0
 def test_fails_stateful_broadcast_and_process(self):
   model_weights_type = model_utils.weights_type_from_model(
       model_examples.LinearRegression)
   with self.assertRaises(optimizer_utils.DisjointArgumentError):
     optimizer_utils.build_model_delta_optimizer_process(
         model_fn=model_examples.LinearRegression,
         model_to_client_delta_fn=DummyClientDeltaFn,
         server_optimizer_fn=tf.keras.optimizers.SGD,
         stateful_model_broadcast_fn=computation_utils.StatefulBroadcastFn(
             initialize_fn=lambda: (),
             next_fn=lambda state, weights:  # pylint: disable=g-long-lambda
             (state, intrinsics.federated_broadcast(weights))),
         broadcast_process=optimizer_utils.build_stateless_broadcaster(
             model_weights_type=model_weights_type))
        def next_fn(state, value, weight):
            clipping_norm_state, agg_state = state

            clipping_norm = self._clipping_norm_process.report(
                clipping_norm_state)
            zeroing_norm = intrinsics.federated_map(self._zeroing_norm_fn,
                                                    clipping_norm)

            zeroed_and_clipped, global_norm = intrinsics.federated_map(
                clip_and_zero,
                (value, intrinsics.federated_broadcast(clipping_norm),
                 intrinsics.federated_broadcast(zeroing_norm)))

            agg_output = inner_agg_process.next(agg_state, zeroed_and_clipped,
                                                weight)
            new_clipping_norm_state = self._clipping_norm_process.next(
                clipping_norm_state, global_norm)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(
                    (new_clipping_norm_state, agg_output.state)),
                result=agg_output.result,
                measurements=agg_output.measurements)
Example #27
0
 def next_fn(server_state, client_data):
     """The `next` function for `tff.templates.IterativeProcess`."""
     s2 = intrinsics.federated_map(prepare, server_state)
     unused_client_input, to_broadcast = broadcast_and_return_arg_and_result(
         s2)
     client_input = intrinsics.federated_broadcast(to_broadcast)
     c3 = intrinsics.federated_zip([client_data, client_input])
     client_updates = intrinsics.federated_map(work, c3)
     unsecure_update = intrinsics.federated_sum(client_updates[0])
     secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
     s6 = intrinsics.federated_zip(
         [server_state, [unsecure_update, secure_update]])
     new_server_state, server_output = intrinsics.federated_map(update, s6)
     return new_server_state, server_output
Example #28
0
  def next_fn(state, value):
    encode_params, decode_before_sum_params, decode_after_sum_params = (
        intrinsics.federated_map(get_params_fn, state))
    encode_params = intrinsics.federated_broadcast(encode_params)
    decode_before_sum_params = intrinsics.federated_broadcast(
        decode_before_sum_params)

    encoded_values = intrinsics.federated_map(
        encode_fn, [value, encode_params, decode_before_sum_params])

    aggregated_values = intrinsics.federated_aggregate(encoded_values,
                                                       zero_fn(), accumulate_fn,
                                                       merge_fn, report_fn)

    decoded_values = intrinsics.federated_map(
        decode_after_sum_fn,
        [aggregated_values.values, decode_after_sum_params])

    updated_state = intrinsics.federated_map(
        update_state_fn, [state, aggregated_values.state_update_tensors])

    empty_metrics = intrinsics.federated_value((), placements.SERVER)
    return measured_process.MeasuredProcessOutput(
        state=updated_state, result=decoded_values, measurements=empty_metrics)
Example #29
0
    def next_fn(server_state, client_data):
      broadcast_state = intrinsics.federated_broadcast(server_state)

      @computations.tf_computation(tf.int32,
                                   computation_types.SequenceType(tf.float32))
      @tf.function
      def some_transform(x, y):
        del y  # Unused
        return x + 1

      client_update = intrinsics.federated_map(some_transform,
                                               (broadcast_state, client_data))
      aggregate_update = intrinsics.federated_sum(client_update)
      server_output = intrinsics.federated_value(1234, placements.SERVER)
      return aggregate_update, server_output
Example #30
0
 def server_eval(server_model_weights, federated_dataset):
     if broadcast_process is not None:
         # TODO(b/179091838): Confirm that the process has no state.
         # TODO(b/179091838): Zip the measurements from the broadcast_process with
         # the result of `model.federated_output_computation` below to avoid
         # dropping these metrics.
         broadcast_output = broadcast_process.next(
             broadcast_process.initialize(), server_model_weights)
         client_outputs = intrinsics.federated_map(
             client_eval, (broadcast_output.result, federated_dataset))
     else:
         client_outputs = intrinsics.federated_map(client_eval, [
             intrinsics.federated_broadcast(server_model_weights),
             federated_dataset
         ])
     return model.federated_output_computation(client_outputs.local_outputs)