Ejemplo n.º 1
0
    def count_clients_federated(client_data):
        @tf.function
        def client_ones_fn():
            return tf.ones(shape=[], dtype=tf.int32)

        client_ones = _bind_tf_function(client_data, client_ones_fn)
        return intrinsics.federated_sum(client_ones)
Ejemplo n.º 2
0
 def next_fn(strings, val):
     new_state_fn = computations.tf_computation()(
         lambda s: tf.concat([s, tf.constant(['abc'])], axis=0))
     return MeasuredProcessOutput(
         intrinsics.federated_map(new_state_fn, strings),
         intrinsics.federated_sum(val),
         intrinsics.federated_value(1, placements.SERVER))
Ejemplo n.º 3
0
    def count_clients_federated(client_data):
        client_ones = intrinsics.federated_value(1, placements.CLIENTS)

        client_ones = _bind_federated_value(
            client_data, computation_types.SequenceType(tf.string),
            client_ones)
        return intrinsics.federated_sum(client_ones)
Ejemplo n.º 4
0
 def next_comp(state, value, weight):
     return collections.OrderedDict(
         state=intrinsics.federated_map(_add_one, state),
         result=intrinsics.federated_mean(value, weight),
         measurements=intrinsics.federated_zip(
             collections.OrderedDict(num_clients=intrinsics.federated_sum(
                 intrinsics.federated_value(1, placements.CLIENTS)))))
Ejemplo n.º 5
0
 def encoded_mean_fn(state, values, weight):
   weighted_values = intrinsics.federated_map(multiply_fn, [values, weight])
   updated_state, summed_decoded_values = encoded_sum_fn(
       state, weighted_values)
   summed_weights = intrinsics.federated_sum(weight)
   decoded_values = intrinsics.federated_map(
       divide_fn, [summed_decoded_values, summed_weights])
   return updated_state, decoded_values
Ejemplo n.º 6
0
 def foo(temperatures, threshold):
     return intrinsics.federated_sum(
         intrinsics.federated_map(
             computations.tf_computation(
                 lambda x, y: tf.to_int32(tf.greater(x, y)),
                 [tf.float32, tf.float32]),
             [temperatures,
              intrinsics.federated_broadcast(threshold)]))
Ejemplo n.º 7
0
 def fed_output(local_outputs):
   # TODO(b/124070381): Remove need for using num_examples_float here.
   return {
       'num_examples':
           intrinsics.federated_sum(local_outputs.num_examples),
       'loss':
           intrinsics.federated_mean(
               local_outputs.loss, weight=local_outputs.num_examples_float),
   }
Ejemplo n.º 8
0
 def foo(temperatures, threshold):
     val = intrinsics.federated_sum(
         intrinsics.federated_map(
             computations.tf_computation(
                 lambda x, y: tf.cast(tf.greater(x, y), tf.int32)),
             [temperatures,
              intrinsics.federated_broadcast(threshold)]))
     self.assertIsInstance(val, value_base.Value)
     return val
Ejemplo n.º 9
0
    def next_fn(server_state, client_val):
        """`next` function for `tff.templates.IterativeProcess`."""
        server_update = intrinsics.federated_sum(client_val)
        server_output = intrinsics.federated_value((), placements.SERVER)
        state_at_clients = intrinsics.federated_broadcast(server_state)
        lambda_returning_sum = computation_returning_lambda()
        sum_fn = lambda_returning_sum(1)
        server_output = sum_fn(state_at_clients)

        return server_update, server_output
Ejemplo n.º 10
0
 def next_fn(server_state, client_data):
     """The `next` function for `computation_utils.IterativeProcess`."""
     s2 = intrinsics.federated_map(prepare, server_state)
     client_input = intrinsics.federated_broadcast(s2)
     c3 = intrinsics.federated_zip([client_data, client_input])
     client_updates, client_output = intrinsics.federated_map(work, c3)
     unsecure_update = intrinsics.federated_sum(client_updates)
     s6 = intrinsics.federated_zip([server_state, unsecure_update])
     new_server_state, server_output = intrinsics.federated_map(update, s6)
     return new_server_state, server_output, client_output
Ejemplo n.º 11
0
 def next_fn(server_state, client_data):
     """The `next` function for `computation_utils.IterativeProcess`."""
     del server_state  # Unused
     client_updates, client_output = intrinsics.federated_map(
         work, client_data)
     unsecure_update = intrinsics.federated_sum(client_updates[0])
     secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
     s5 = intrinsics.federated_zip([unsecure_update, secure_update])
     new_server_state, server_output = intrinsics.federated_map(update, s5)
     return new_server_state, server_output, client_output
Ejemplo n.º 12
0
 def encoded_mean_comp(state, values, weight):
   """Encoded mean federated_computation."""
   empty_metrics = intrinsics.federated_value((), placements.SERVER)
   weighted_values = intrinsics.federated_map(multiply_fn, (values, weight))
   updated_state, summed_decoded_values = encoded_sum_fn(
       state, weighted_values)
   summed_weights = intrinsics.federated_sum(weight)
   decoded_values = intrinsics.federated_map(
       divide_fn, (summed_decoded_values, summed_weights))
   return measured_process.MeasuredProcessOutput(
       state=updated_state, result=decoded_values, measurements=empty_metrics)
Ejemplo n.º 13
0
 def next_fn(server_state, client_data):
     """The `next` function for `computation_utils.IterativeProcess`."""
     client_input = intrinsics.federated_broadcast(server_state)
     c3 = intrinsics.federated_zip([client_data, client_input])
     client_updates = intrinsics.federated_map(work, c3)
     unsecure_update = intrinsics.federated_sum(client_updates[0])
     secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
     new_server_state = intrinsics.federated_zip(
         [unsecure_update, secure_update])
     server_output = intrinsics.federated_value([], placements.SERVER)
     return new_server_state, server_output
 def next_fn(server_state, client_data):
   """The `next` function for `tff.templates.IterativeProcess`."""
   # No call to `federated_map` with prepare.
   # No call to `federated_broadcast`.
   client_updates, client_output = intrinsics.federated_map(work, client_data)
   unsecure_update = intrinsics.federated_sum(client_updates[0])
   secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
   s6 = intrinsics.federated_zip(
       [server_state, [unsecure_update, secure_update]])
   new_server_state, server_output = intrinsics.federated_map(update, s6)
   return new_server_state, server_output, client_output
Ejemplo n.º 15
0
 def next_fn(state, value):
     state = intrinsics.federated_map(
         computations.tf_computation(lambda x: x + 1), state)
     result = intrinsics.federated_map(
         computations.tf_computation(
             lambda x: tf.nest.map_structure(lambda y: y + 1, x)),
         intrinsics.federated_sum(value))
     measurements = intrinsics.federated_value(MEASUREMENT_CONSTANT,
                                               placements.SERVER)
     return measured_process.MeasuredProcessOutput(
         state, result, measurements)
Ejemplo n.º 16
0
    def next_fn(server_state, client_val):
        """`next` function for `tff.utils.IterativeProcess`."""
        server_update = intrinsics.federated_zip(
            collections.OrderedDict(
                num_clients=count_clients_federated(client_val)))

        server_output = intrinsics.federated_value((), placements.SERVER)
        server_output = intrinsics.federated_sum(
            _bind_tf_function(intrinsics.federated_broadcast(server_state),
                              tf.timestamp))

        return server_update, server_output
Ejemplo n.º 17
0
 def next_fn(server_state, client_data):
     """The `next` function for `tff.templates.IterativeProcess`."""
     s2 = intrinsics.federated_map(prepare, server_state)
     client_input = intrinsics.federated_broadcast(s2)
     c3 = intrinsics.federated_zip([client_data, client_input])
     client_updates = intrinsics.federated_map(work, c3)
     unsecure_update = intrinsics.federated_sum(client_updates[0])
     secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
     s6 = intrinsics.federated_zip(
         [server_state, [unsecure_update, secure_update]])
     new_server_state, server_output = intrinsics.federated_map(update, s6)
     return new_server_state, server_output
 def next_fn(server_state, client_data):
   """The `next` function for `tff.templates.IterativeProcess`."""
   s2 = intrinsics.federated_map(prepare, server_state)
   client_input = intrinsics.federated_broadcast(s2)
   c3 = intrinsics.federated_zip([client_data, client_input])
   client_updates, client_output = intrinsics.federated_map(work, c3)
   unsecure_update = intrinsics.federated_sum(client_updates[0])
   secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
   new_server_state = intrinsics.federated_zip(
       [unsecure_update, secure_update])
   # No call to `federated_map` with an `update` function.
   server_output = intrinsics.federated_value([], placements.SERVER)
   return new_server_state, server_output, client_output
 def next_fn(server_state, client_data):
   """The `next` function for `tff.templates.IterativeProcess`."""
   del server_state  # Unused
   # No call to `federated_map` with prepare.
   # No call to `federated_broadcast`.
   client_updates = intrinsics.federated_map(work, client_data)
   unsecure_update = intrinsics.federated_sum(client_updates[0])
   secure_update = intrinsics.federated_secure_sum(client_updates[1], 8)
   new_server_state = intrinsics.federated_zip(
       [unsecure_update, secure_update])
   # No call to `federated_map` with an `update` function.
   server_output = intrinsics.federated_value([], placements.SERVER)
   return new_server_state, server_output
Ejemplo n.º 20
0
        def next_fn(state, value):
            value_sum_output = value_sum_process.next(state, value)
            count = intrinsics.federated_sum(
                intrinsics.federated_value(1, placements.CLIENTS))

            mean_value = intrinsics.federated_map(
                _div, (value_sum_output.result, count))
            state = value_sum_output.state
            measurements = intrinsics.federated_zip(
                collections.OrderedDict(
                    mean_value=value_sum_output.measurements))

            return measured_process.MeasuredProcessOutput(
                state, mean_value, measurements)
Ejemplo n.º 21
0
  def next_fn(global_state, value, weight):
    sample_params = intrinsics.federated_broadcast(
        intrinsics.federated_map(derive_sample_params, global_state))
    weighted_value, adj_weight, quantile_record, too_large = (
        intrinsics.federated_map(preprocess_value,
                                 (sample_params, value, weight)))

    value_sum = intrinsics.federated_sum(weighted_value)
    total_weight = intrinsics.federated_sum(adj_weight)
    quantile_sum = intrinsics.federated_sum(quantile_record)
    num_zeroed = intrinsics.federated_sum(too_large)

    mean_value = intrinsics.federated_map(divide_no_nan,
                                          (value_sum, total_weight))

    new_threshold, new_global_state = intrinsics.federated_map(
        next_quantile, (quantile_sum, global_state))

    measurements = intrinsics.federated_zip(
        AdaptiveZeroingMetrics(new_threshold, num_zeroed))

    return measured_process.MeasuredProcessOutput(
        state=new_global_state, result=mean_value, measurements=measurements)
Ejemplo n.º 22
0
    def next_fn(server_state, client_data):
      broadcast_state = intrinsics.federated_broadcast(server_state)

      @computations.tf_computation(tf.int32,
                                   computation_types.SequenceType(tf.float32))
      @tf.function
      def some_transform(x, y):
        del y  # Unused
        return x + 1

      client_update = intrinsics.federated_map(some_transform,
                                               (broadcast_state, client_data))
      aggregate_update = intrinsics.federated_sum(client_update)
      server_output = intrinsics.federated_value(1234, placements.SERVER)
      return aggregate_update, server_output
Ejemplo n.º 23
0
    def one_round_computation(server_state, federated_dataset):
        """Orchestration logic for one round of optimization.

    Args:
      server_state: a `tff.learning.framework.ServerState` named tuple.
      federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS.

    Returns:
      A tuple of updated `tff.learning.framework.ServerState` and the result of
      `tff.learning.Model.federated_output_computation`, both having
      `tff.SERVER` placement.
    """
        broadcast_output = broadcast_process.next(
            server_state.model_broadcast_state, server_state.model)
        client_outputs = intrinsics.federated_map(
            _compute_local_training_and_client_delta,
            (federated_dataset, broadcast_output.result))
        if len(aggregation_process.next.type_signature.parameter) == 3:
            aggregation_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta,
                client_outputs.weights_delta_weight)
        else:
            aggregation_output = aggregation_process.next(
                server_state.delta_aggregate_state,
                client_outputs.weights_delta)
        new_global_model, new_optimizer_state = intrinsics.federated_map(
            server_update, (server_state.model, aggregation_output.result,
                            server_state.optimizer_state))
        new_server_state = intrinsics.federated_zip(
            ServerState(new_global_model, new_optimizer_state,
                        aggregation_output.state, broadcast_output.state))
        aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
            client_outputs.model_output)
        optimizer_outputs = intrinsics.federated_sum(
            client_outputs.optimizer_output)
        measurements = intrinsics.federated_zip(
            collections.OrderedDict(
                broadcast=broadcast_output.measurements,
                aggregation=aggregation_output.measurements,
                train=aggregated_outputs,
                stat=optimizer_outputs))
        return new_server_state, measurements
Ejemplo n.º 24
0
 def _sum_securely(self, value, upper_bound, lower_bound):
     """Securely sums `value` placed at CLIENTS."""
     if self._config_mode == _Config.INT:
         value = intrinsics.federated_map(
             _client_shift,
             (value, intrinsics.federated_broadcast(upper_bound),
              intrinsics.federated_broadcast(lower_bound)))
         value = intrinsics.federated_secure_sum(value,
                                                 self._secagg_bitwidth)
         num_summands = intrinsics.federated_sum(_client_one())
         value = intrinsics.federated_map(
             _server_shift, (value, lower_bound, num_summands))
         return value
     elif self._config_mode == _Config.FLOAT:
         return federated_aggregations.secure_quantized_sum(
             value, lower_bound, upper_bound)
     else:
         raise ValueError(
             f'Unexpected internal config type: {self._config_mode}')
Ejemplo n.º 25
0
 def foo(x):
     return intrinsics.federated_sum(x)
Ejemplo n.º 26
0
 def foo(x):
     val = intrinsics.federated_sum(x)
     self.assertIsInstance(val, value_base.Value)
     return val
Ejemplo n.º 27
0
 def namedtuple_next_fn(state, client_values):
     metrics = intrinsics.federated_map(sum_sequence, client_values)
     metrics = intrinsics.federated_sum(metrics)
     return learning_process_output(state, metrics)
Ejemplo n.º 28
0
 def odict_next_fn(state, client_values):
     metrics = intrinsics.federated_map(sum_sequence, client_values)
     metrics = intrinsics.federated_sum(metrics)
     return collections.OrderedDict(state=state, metrics=metrics)
Ejemplo n.º 29
0
 def next_fn(state, client_values):
     metrics = intrinsics.federated_map(sum_sequence, client_values)
     metrics = intrinsics.federated_sum(metrics)
     return LearningProcessOutput(state, metrics)
Ejemplo n.º 30
0
 def next_fn(state, client_values, second_state):  # pylint: disable=unused-argument
     metrics = intrinsics.federated_map(sum_sequence, client_values)
     metrics = intrinsics.federated_sum(metrics)
     return LearningProcessOutput(state, metrics)