Beispiel #1
0
def _calculate_client_update_statistics_with_norm(client_norms,
                                                  client_weights):
    """Calculate client updates with client norms."""
    client_norms_squared = intrinsics.federated_map(_square_value,
                                                    client_norms)

    average_client_norm = intrinsics.federated_mean(client_norms,
                                                    client_weights)
    average_client_norm_squared = intrinsics.federated_mean(
        client_norms_squared, client_weights)

    # TODO(b/197972289): Add SecAgg compatibility to these measurements
    sum_of_client_weights = intrinsics.federated_sum(client_weights)
    client_weights_squared = intrinsics.federated_map(_square_value,
                                                      client_weights)
    sum_of_client_weights_squared = intrinsics.federated_sum(
        client_weights_squared)

    unbiased_std_dev = intrinsics.federated_map(
        _calculate_unbiased_std_dev,
        (average_client_norm, average_client_norm_squared,
         sum_of_client_weights, sum_of_client_weights_squared))

    return intrinsics.federated_zip(
        collections.OrderedDict(average_client_norm=average_client_norm,
                                std_dev_client_norm=unbiased_std_dev))
Beispiel #2
0
 def _train_one_round(model, federated_data):
     locally_trained_models = intrinsics.federated_map(
         _train_on_one_client,
         collections.OrderedDict([('model',
                                   intrinsics.federated_broadcast(model)),
                                  ('batches', federated_data)]))
     return intrinsics.federated_mean(locally_trained_models)
Beispiel #3
0
 def next_comp(state, value, weight):
     return measured_process.MeasuredProcessOutput(
         state=intrinsics.federated_map(_add_one, state),
         result=intrinsics.federated_mean(value, weight),
         measurements=intrinsics.federated_zip(
             collections.OrderedDict(num_clients=intrinsics.federated_sum(
                 intrinsics.federated_value(1, placements.CLIENTS)))))
Beispiel #4
0
 def comp(temperatures, threshold):
     return intrinsics.federated_mean(
         intrinsics.federated_map(
             count_over,
             intrinsics.federated_zip([
                 temperatures,
                 intrinsics.federated_broadcast(threshold)
             ])), intrinsics.federated_map(count_total, temperatures))
 def fed_output(local_outputs):
     # TODO(b/124070381): Remove need for using num_examples_float here.
     return collections.OrderedDict(
         loss=intrinsics.federated_mean(
             local_outputs.loss,
             weight=local_outputs.num_examples_float),
         num_examples=intrinsics.federated_sum(
             local_outputs.num_examples))
Beispiel #6
0
 def fed_output(local_outputs):
     # TODO(b/124070381): Remove need for using num_examples_float here.
     return {
         'num_examples':
         intrinsics.federated_sum(local_outputs.num_examples),
         'loss':
         intrinsics.federated_mean(
             local_outputs.loss,
             weight=local_outputs.num_examples_float),
     }
Beispiel #7
0
 def test_federated_mean_with_client_tuple_with_int32_weight(self):
     values = _mock_data_of_type(
         computation_types.at_clients(
             collections.OrderedDict(
                 x=tf.float64,
                 y=tf.float64,
             )))
     weights = _mock_data_of_type(computation_types.at_clients(tf.int32))
     val = intrinsics.federated_mean(values, weights)
     self.assert_value(val, '<x=float64,y=float64>@SERVER')
Beispiel #8
0
  def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None):
    clipping_norm_state, agg_state, clipped_count_state = state

    clipping_norm = clipping_norm_process.report(clipping_norm_state)

    clients_clipping_norm = intrinsics.federated_broadcast(clipping_norm)

    # TODO(b/163880757): Remove this when server-only metrics are supported.
    clipping_norm = intrinsics.federated_mean(clients_clipping_norm)

    clipped_value, global_norm, was_clipped = intrinsics.federated_map(
        clip_fn, (value, clients_clipping_norm))

    new_clipping_norm_state = clipping_norm_process.next(
        clipping_norm_state, global_norm)

    if weight is None:
      agg_output = inner_agg_process.next(agg_state, clipped_value)
    else:
      agg_output = inner_agg_process.next(agg_state, clipped_value, weight)

    clipped_count_output = clipped_count_agg_process.next(
        clipped_count_state, was_clipped)

    new_state = collections.OrderedDict([
        (prefix('ing_norm'), new_clipping_norm_state),
        ('inner_agg', agg_output.state),
        (prefix('ed_count_agg'), clipped_count_output.state)
    ])
    measurements = collections.OrderedDict([
        (prefix('ing'), agg_output.measurements),
        (prefix('ing_norm'), clipping_norm),
        (prefix('ed_count'), clipped_count_output.result)
    ])

    return measured_process.MeasuredProcessOutput(
        state=intrinsics.federated_zip(new_state),
        result=agg_output.result,
        measurements=intrinsics.federated_zip(measurements))
 def comp(x):
     return intrinsics.federated_mean(x)
Beispiel #10
0
 def test_federated_mean_with_string_weight_fails(self):
     values = _mock_data_of_type(computation_types.at_clients(tf.float32))
     weights = _mock_data_of_type(computation_types.at_clients(tf.string))
     with self.assertRaises(TypeError):
         intrinsics.federated_mean(values, weights)
Beispiel #11
0
 def test_federated_mean_with_client_int32_fails(self):
     x = _mock_data_of_type(computation_types.at_clients(tf.int32))
     with self.assertRaises(TypeError):
         intrinsics.federated_mean(x)
Beispiel #12
0
 def test_federated_mean_with_all_equal_client_float32_with_weight(self):
     federated_all_equal_float = computation_types.FederatedType(
         tf.float32, placements.CLIENTS, all_equal=True)
     x = _mock_data_of_type(federated_all_equal_float)
     val = intrinsics.federated_mean(x, x)
     self.assert_value(val, 'float32@SERVER')
Beispiel #13
0
 def test_federated_mean_with_client_float32_without_weight(self):
     x = _mock_data_of_type(computation_types.at_clients(tf.float32))
     val = intrinsics.federated_mean(x)
     self.assert_value(val, 'float32@SERVER')
Beispiel #14
0
 def stateless_mean(state, value, weight):
     empty_metrics = intrinsics.federated_value((), placements.SERVER)
     return measured_process.MeasuredProcessOutput(
         state=state,
         result=intrinsics.federated_mean(value, weight=weight),
         measurements=empty_metrics)