Beispiel #1
0
    def next_fn(state, client_data):
        # Compose processes.
        distributor_output = model_weights_distributor.next(
            state.distributor, state.global_model_weights)
        client_work_output = client_work.next(state.client_work,
                                              distributor_output.result,
                                              client_data)
        aggregator_output = model_update_aggregator.next(
            state.aggregator, client_work_output.result.update,
            client_work_output.result.update_weight)
        finalizer_output = model_finalizer.next(state.finalizer,
                                                state.global_model_weights,
                                                aggregator_output.result)

        # Form the learning process output.
        new_global_model_weights = finalizer_output.result
        new_state = intrinsics.federated_zip(
            LearningAlgorithmState(new_global_model_weights,
                                   distributor_output.state,
                                   client_work_output.state,
                                   aggregator_output.state,
                                   finalizer_output.state))
        metrics = intrinsics.federated_zip(
            collections.OrderedDict(
                distributor=distributor_output.measurements,
                client_work=client_work_output.measurements,
                aggregator=aggregator_output.measurements,
                finalizer=finalizer_output.measurements))

        return learning_process.LearningProcessOutput(new_state, metrics)
Beispiel #2
0
    def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None):
        clipping_norm_state, agg_state, clipped_count_state = state

        clipping_norm = clipping_norm_process.report(clipping_norm_state)

        clipped_value, global_norm, was_clipped = intrinsics.federated_map(
            clip_fn, (value, intrinsics.federated_broadcast(clipping_norm)))

        new_clipping_norm_state = clipping_norm_process.next(
            clipping_norm_state, global_norm)

        if weight is None:
            agg_output = inner_agg_process.next(agg_state, clipped_value)
        else:
            agg_output = inner_agg_process.next(agg_state, clipped_value,
                                                weight)

        clipped_count_output = clipped_count_agg_process.next(
            clipped_count_state, was_clipped)

        new_state = collections.OrderedDict([
            (prefix('ing_norm'), new_clipping_norm_state),
            ('inner_agg', agg_output.state),
            (prefix('ed_count_agg'), clipped_count_output.state)
        ])
        measurements = collections.OrderedDict([
            (prefix('ing'), agg_output.measurements),
            (prefix('ing_norm'), clipping_norm),
            (prefix('ed_count'), clipped_count_output.result)
        ])

        return measured_process.MeasuredProcessOutput(
            state=intrinsics.federated_zip(new_state),
            result=agg_output.result,
            measurements=intrinsics.federated_zip(measurements))
Beispiel #3
0
 def next_fn(state, dataset):
     sketch_state, value_tensor_state = state
     sketch, value_tensor = intrinsics.federated_map(
         encode_iblt, dataset)
     sketch_output = inner_aggregator_sketch.next(sketch_state, sketch)
     value_tensor_output = inner_aggregator_value_tensor.next(
         value_tensor_state, value_tensor)
     summed_sketch = sketch_output.result
     summed_value_tensor = value_tensor_output.result
     (output_strings, string_values,
      num_not_decoded) = intrinsics.federated_map(
          decode_iblt, (summed_sketch, summed_value_tensor))
     result = intrinsics.federated_zip(
         ServerOutput(output_strings=output_strings,
                      string_values=string_values,
                      num_not_decoded=num_not_decoded))
     updated_state = intrinsics.federated_zip(
         (sketch_output.state, value_tensor_output.state))
     updated_measurements = intrinsics.federated_zip(
         collections.OrderedDict(
             num_not_decoded=num_not_decoded,
             sketch=sketch_output.measurements,
             value_tensor=value_tensor_output.measurements))
     return measured_process.MeasuredProcessOutput(
         updated_state, result, updated_measurements)
Beispiel #4
0
        def next_fn(state, value):
            server_scale_factor = state['scale_factor']
            client_scale_factor = intrinsics.federated_broadcast(
                server_scale_factor)
            server_prior_norm_bound = state['prior_norm_bound']
            prior_norm_bound = intrinsics.federated_broadcast(
                server_prior_norm_bound)

            discretized_value = intrinsics.federated_map(
                discretize_fn, (value, client_scale_factor, prior_norm_bound))

            inner_state = state['inner_agg_process']
            inner_agg_output = inner_agg_process.next(inner_state,
                                                      discretized_value)

            undiscretized_agg_value = intrinsics.federated_map(
                undiscretize_fn,
                (inner_agg_output.result, server_scale_factor))

            new_state = collections.OrderedDict(
                scale_factor=server_scale_factor,
                prior_norm_bound=server_prior_norm_bound,
                inner_agg_process=inner_agg_output.state)
            measurements = collections.OrderedDict(
                discretize=inner_agg_output.measurements)

            return measured_process.MeasuredProcessOutput(
                state=intrinsics.federated_zip(new_state),
                result=undiscretized_agg_value,
                measurements=intrinsics.federated_zip(measurements))
Beispiel #5
0
    def test_federated_zip_named_n_tuple(self, n, element_type):
        fed_type = computation_types.at_clients(element_type)
        initial_tuple_type = computation_types.to_type([fed_type] * n)
        initial_tuple = _mock_data_of_type(initial_tuple_type)

        naming_fn = str
        named_result = intrinsics.federated_zip(
            collections.OrderedDict(
                (naming_fn(i), initial_tuple[i]) for i in range(n)))
        self.assertIsInstance(named_result, value_impl.Value)
        expected = computation_types.at_clients(
            collections.OrderedDict(
                (naming_fn(i), element_type) for i in range(n)))
        type_test_utils.assert_types_identical(named_result.type_signature,
                                               expected)

        naming_fn = lambda i: str(i) if i % 2 == 0 else None
        mixed_result = intrinsics.federated_zip(
            structure.Struct(
                (naming_fn(i), initial_tuple[i]) for i in range(n)))
        self.assertIsInstance(mixed_result, value_impl.Value)
        expected = computation_types.at_clients(
            computation_types.StructType([(naming_fn(i), element_type)
                                          for i in range(n)]))
        type_test_utils.assert_types_identical(mixed_result.type_signature,
                                               expected)
Beispiel #6
0
        def next_fn(state, value, weight):
            # Client computation.
            weighted_value = intrinsics.federated_map(_mul, (value, weight))

            # Inner aggregations.
            value_output = value_sum_process.next(state['value_sum_process'],
                                                  weighted_value)
            weight_output = weight_sum_process.next(
                state['weight_sum_process'], weight)

            # Server computation.
            weighted_mean_value = intrinsics.federated_map(
                _div_no_nan if self._no_nan_division else _div,
                (value_output.result, weight_output.result))

            # Output preparation.
            state = collections.OrderedDict(
                value_sum_process=value_output.state,
                weight_sum_process=weight_output.state)
            measurements = collections.OrderedDict(
                mean_value=value_output.measurements,
                mean_weight=weight_output.measurements)
            return measured_process.MeasuredProcessOutput(
                intrinsics.federated_zip(state), weighted_mean_value,
                intrinsics.federated_zip(measurements))
Beispiel #7
0
    def next_fn(state, value):
      server_step_size = state['step_size']
      client_step_size = intrinsics.federated_broadcast(server_step_size)

      discretized_value = intrinsics.federated_map(discretize_fn,
                                                   (value, client_step_size))

      inner_state = state['inner_agg_process']
      inner_agg_output = inner_agg_process.next(inner_state, discretized_value)

      undiscretized_agg_value = intrinsics.federated_map(
          undiscretize_fn, (inner_agg_output.result, server_step_size))

      new_state = collections.OrderedDict(
          step_size=server_step_size, inner_agg_process=inner_agg_output.state)
      measurements = collections.OrderedDict(
          deterministic_discretization=inner_agg_output.measurements)

      if self._distortion_aggregation_factory is not None:
        distortions = intrinsics.federated_map(distortion_measurement_fn,
                                               (value, client_step_size))
        aggregate_distortion = distortion_aggregation_process.next(
            distortion_aggregation_process.initialize(), distortions).result
        measurements['distortion'] = aggregate_distortion

      return measured_process.MeasuredProcessOutput(
          state=intrinsics.federated_zip(new_state),
          result=undiscretized_agg_value,
          measurements=intrinsics.federated_zip(measurements))
    def next_fn(state,
                unfinalized_metrics) -> measured_process.MeasuredProcessOutput:
      inner_summation_state, unfinalized_metrics_accumulators = state

      inner_summation_output = inner_summation_process.next(
          inner_summation_state, unfinalized_metrics)
      summed_unfinalized_metrics = inner_summation_output.result
      inner_summation_state = inner_summation_output.state

      @tensorflow_computation.tf_computation(local_unfinalized_metrics_type,
                                             local_unfinalized_metrics_type)
      def add_unfinalized_metrics(unfinalized_metrics,
                                  summed_unfinalized_metrics):
        return tf.nest.map_structure(tf.add, unfinalized_metrics,
                                     summed_unfinalized_metrics)

      unfinalized_metrics_accumulators = intrinsics.federated_map(
          add_unfinalized_metrics,
          (unfinalized_metrics_accumulators, summed_unfinalized_metrics))

      finalizer_computation = _build_finalizer_computation(
          metric_finalizers, local_unfinalized_metrics_type)

      current_round_metrics = intrinsics.federated_map(
          finalizer_computation, summed_unfinalized_metrics)
      total_rounds_metrics = intrinsics.federated_map(
          finalizer_computation, unfinalized_metrics_accumulators)

      return measured_process.MeasuredProcessOutput(
          state=intrinsics.federated_zip(
              (inner_summation_state, unfinalized_metrics_accumulators)),
          result=intrinsics.federated_zip(
              (current_round_metrics, total_rounds_metrics)),
          measurements=inner_summation_output.measurements)
 def next_fn(server_state, client_data):
     """The `next` function for `tff.templates.IterativeProcess`."""
     s2 = intrinsics.federated_map(prepare, server_state)
     client_input = intrinsics.federated_broadcast(s2)
     c3 = intrinsics.federated_zip([client_data, client_input])
     client_updates = intrinsics.federated_map(work, c3)
     unsecure_update = intrinsics.federated_sum(client_updates)
     # No call to `federated_secure_sum_bitwidth`.
     s6 = intrinsics.federated_zip([server_state, unsecure_update])
     new_server_state, server_output = intrinsics.federated_map(update, s6)
     return new_server_state, server_output
 def concatenation_next(state, values):
     new_states = collections.OrderedDict()
     results = collections.OrderedDict()
     measurements = collections.OrderedDict()
     for name, process in measured_processes.items():
         output = process.next(state[name], values[name])
         new_states[name] = output.state
         results[name] = output.result
         measurements[name] = output.measurements
     return MeasuredProcessOutput(
         state=intrinsics.federated_zip(new_states),
         result=results,
         measurements=intrinsics.federated_zip(measurements))
 def next_fn(server_state, client_data):
   """The `next` function for `tff.templates.IterativeProcess`."""
   s2 = intrinsics.federated_map(prepare, server_state)
   unused_client_input, to_broadcast = broadcast_and_return_arg_and_result(s2)
   client_input = intrinsics.federated_broadcast(to_broadcast)
   c3 = intrinsics.federated_zip([client_data, client_input])
   client_updates = intrinsics.federated_map(work, c3)
   unsecure_update = intrinsics.federated_sum(client_updates[0])
   secure_update = intrinsics.federated_secure_sum_bitwidth(
       client_updates[1], 8)
   s6 = intrinsics.federated_zip(
       [server_state, [unsecure_update, secure_update]])
   new_server_state, server_output = intrinsics.federated_map(update, s6)
   return new_server_state, server_output
Beispiel #12
0
 def next_fn(server_state, client_data):
     """The `next` function for `tff.templates.IterativeProcess`."""
     s2 = intrinsics.federated_map(prepare, server_state)
     client_input = intrinsics.federated_broadcast(s2)
     c3 = intrinsics.federated_zip([client_data, client_input])
     client_updates = intrinsics.federated_map(work, c3)
     unsecure_update = intrinsics.federated_sum(client_updates[0])
     secure_update = intrinsics.federated_secure_sum_bitwidth(
         client_updates[1], 8)
     new_server_state = intrinsics.federated_zip(
         [unsecure_update, secure_update])
     # No call to `federated_map` with an `update` function.
     server_output = intrinsics.federated_value([], placements.SERVER)
     return new_server_state, server_output
Beispiel #13
0
def _calculate_client_update_statistics_with_norm(client_norms,
                                                  client_weights):
    """Calculate client updates with client norms."""
    client_norms_squared = intrinsics.federated_map(_square_value,
                                                    client_norms)

    average_client_norm = intrinsics.federated_mean(client_norms,
                                                    client_weights)
    average_client_norm_squared = intrinsics.federated_mean(
        client_norms_squared, client_weights)

    # TODO(b/197972289): Add SecAgg compatibility to these measurements
    sum_of_client_weights = intrinsics.federated_sum(client_weights)
    client_weights_squared = intrinsics.federated_map(_square_value,
                                                      client_weights)
    sum_of_client_weights_squared = intrinsics.federated_sum(
        client_weights_squared)

    unbiased_std_dev = intrinsics.federated_map(
        _calculate_unbiased_std_dev,
        (average_client_norm, average_client_norm_squared,
         sum_of_client_weights, sum_of_client_weights_squared))

    return intrinsics.federated_zip(
        collections.OrderedDict(average_client_norm=average_client_norm,
                                std_dev_client_norm=unbiased_std_dev))
Beispiel #14
0
 def test_federated_zip_with_names_client_non_all_equal_int_and_bool(self):
     x = _mock_data_of_type(
         computation_types.at_clients(tf.int32, all_equal=False))
     y = _mock_data_of_type(
         computation_types.at_clients(tf.bool, all_equal=True))
     val = intrinsics.federated_zip(collections.OrderedDict(x=x, y=y))
     self.assert_value(val, '{<x=int32,y=bool>}@CLIENTS')
Beispiel #15
0
 def init_fn_impl(inner_agg_process):
   state = collections.OrderedDict([
       (prefix('ing_norm'), clipping_norm_process.initialize()),
       ('inner_agg', inner_agg_process.initialize()),
       (prefix('ed_count_agg'), clipped_count_agg_process.initialize())
   ])
   return intrinsics.federated_zip(state)
Beispiel #16
0
 def test_federated_zip_with_client_non_all_equal_int_and_bool(self):
     x = _mock_data_of_type(
         computation_types.at_clients(tf.int32, all_equal=False))
     y = _mock_data_of_type(
         computation_types.at_clients(tf.bool, all_equal=True))
     val = intrinsics.federated_zip([x, y])
     self.assert_value(val, '{<int32,bool>}@CLIENTS')
Beispiel #17
0
 def next_fn(state, weights, data):
     return MeasuredProcessOutput(
         state,
         intrinsics.federated_zip(
             client_works.ClientResult(
                 federated_add(weights.trainable, data), client_one())),
         server_zero())
 def next_fn(state, weights, update):
     return MeasuredProcessOutput(
         state,
         intrinsics.federated_zip(
             model_utils.ModelWeights(
                 federated_add(weights['trainable'], update), ())),
         server_zero())
Beispiel #19
0
 def test_federated_zip_with_single_named_bool_clients(self):
     x = _mock_data_of_type(
         computation_types.StructType([
             ('a', computation_types.at_clients(tf.bool))
         ]))
     val = intrinsics.federated_zip(x)
     self.assert_value(val, '{<a=bool>}@CLIENTS')
Beispiel #20
0
    def one_round_computation(examples):
        """The TFF computation to compute the aggregated IBLT sketch."""
        if secure_sum_bitwidth is not None:
            # Use federated secure modular sum for IBLT sketches, because IBLT
            # sketches are decoded by taking modulo over the field size.
            sketch_sum_fn = secure_modular_sum
            count_sum_fn = secure_sum
        else:
            sketch_sum_fn = intrinsics.federated_sum
            count_sum_fn = intrinsics.federated_sum
        round_timestamp = intrinsics.federated_eval(
            tensorflow_computation.tf_computation(
                lambda: tf.cast(tf.timestamp(), tf.int64)), placements.SERVER)
        clients = count_sum_fn(
            intrinsics.federated_value(1, placements.CLIENTS))
        sketch, count_tensor = intrinsics.federated_map(
            compute_sketch, examples)
        sketch = sketch_sum_fn(sketch)
        count_tensor = count_sum_fn(count_tensor)

        (heavy_hitters, heavy_hitters_unique_counts, heavy_hitters_counts,
         num_not_decoded) = intrinsics.federated_map(decode_heavy_hitters,
                                                     (sketch, count_tensor))
        server_output = intrinsics.federated_zip(
            ServerOutput(
                clients=clients,
                heavy_hitters=heavy_hitters,
                heavy_hitters_unique_counts=heavy_hitters_unique_counts,
                heavy_hitters_counts=heavy_hitters_counts,
                num_not_decoded=num_not_decoded,
                round_timestamp=round_timestamp))
        return server_output
Beispiel #21
0
 def next_comp(state, value, weight):
     return measured_process.MeasuredProcessOutput(
         state=intrinsics.federated_map(_add_one, state),
         result=intrinsics.federated_mean(value, weight),
         measurements=intrinsics.federated_zip(
             collections.OrderedDict(num_clients=intrinsics.federated_sum(
                 intrinsics.federated_value(1, placements.CLIENTS)))))
Beispiel #22
0
 def test_federated_zip_with_single_named_bool_server(self):
     x = _mock_data_of_type(
         computation_types.StructType([
             ('a', computation_types.at_server(tf.bool))
         ]))
     val = intrinsics.federated_zip(x)
     self.assert_value(val, '<a=bool>@SERVER')
Beispiel #23
0
 def test_federated_zip_with_names_client_all_equal_int_and_bool(self):
     x = _mock_data_of_type(
         computation_types.at_clients(tf.int32, all_equal=True))
     y = _mock_data_of_type(
         computation_types.at_clients(tf.bool, all_equal=True))
     val = intrinsics.federated_zip({'x': x, 'y': y})
     self.assert_value(val, '{<x=int32,y=bool>}@CLIENTS')
Beispiel #24
0
 def test_federated_zip_n_tuple(self, n):
     fed_type = computation_types.at_clients(tf.int32)
     x = _mock_data_of_type([fed_type] * n)
     val = intrinsics.federated_zip(x)
     self.assertIsInstance(val, value_impl.Value)
     expected = computation_types.at_clients([tf.int32] * n)
     self.assert_types_identical(val.type_signature, expected)
        def next_fn(state, value):
            clip_lower = intrinsics.federated_value(self._clip_range_lower,
                                                    placements.SERVER)
            clip_upper = intrinsics.federated_value(self._clip_range_upper,
                                                    placements.SERVER)

            # Modular clip values before aggregation.
            clipped_value = intrinsics.federated_map(
                modular_clip_by_value_fn,
                (value, intrinsics.federated_broadcast(clip_lower),
                 intrinsics.federated_broadcast(clip_upper)))

            inner_agg_output = inner_agg_next(state, clipped_value)

            # Clip the aggregate to the same range again (not considering summands).
            clipped_agg_output_result = intrinsics.federated_map(
                modular_clip_by_value_fn,
                (inner_agg_output.result, clip_lower, clip_upper))

            measurements = collections.OrderedDict(
                modclip=inner_agg_output.measurements)

            if self._estimate_stddev:
                estimate = intrinsics.federated_map(
                    estimator_fn,
                    (clipped_agg_output_result, clip_lower, clip_upper))
                measurements['estimated_stddev'] = estimate

            return measured_process.MeasuredProcessOutput(
                state=inner_agg_output.state,
                result=clipped_agg_output_result,
                measurements=intrinsics.federated_zip(measurements))
Beispiel #26
0
 def init_fn():
     specs = weight_tensor_specs.trainable
     optimizer_state = intrinsics.federated_eval(
         tensorflow_computation.tf_computation(
             lambda: optimizer.initialize(specs)), placements.SERVER)
     aggregator_state = full_gradient_aggregator.initialize()
     return intrinsics.federated_zip((optimizer_state, aggregator_state))
Beispiel #27
0
 def next_fn(state, weights, data):
     reduced_data = intrinsics.federated_map(tf_data_sum, data)
     bad_client_result = intrinsics.federated_zip(
         collections.OrderedDict(update=federated_add(
             weights.trainable, reduced_data),
                                 update_weight=client_one()))
     return MeasuredProcessOutput(state, bad_client_result,
                                  server_zero())
Beispiel #28
0
 def next_fn(state, weights, updates):
     new_weights = intrinsics.federated_map(
         tensorflow_computation.tf_computation(lambda x, y: x + y),
         (weights.trainable, updates))
     new_weights = intrinsics.federated_zip(
         model_utils.ModelWeights(new_weights, ()))
     return measured_process.MeasuredProcessOutput(state, new_weights,
                                                   empty_at_server())
Beispiel #29
0
 def init_fn():
     state = collections.OrderedDict(
         scale_factor=intrinsics.federated_value(
             self._scale_factor, placements.SERVER),
         prior_norm_bound=intrinsics.federated_value(
             self._prior_norm_bound, placements.SERVER),
         inner_agg_process=inner_agg_process.initialize())
     return intrinsics.federated_zip(state)
Beispiel #30
0
 def comp(temperatures, threshold):
     return intrinsics.federated_mean(
         intrinsics.federated_map(
             count_over,
             intrinsics.federated_zip([
                 temperatures,
                 intrinsics.federated_broadcast(threshold)
             ])), intrinsics.federated_map(count_total, temperatures))