def _calculate_client_update_statistics_with_norm(client_norms, client_weights): """Calculate client updates with client norms.""" client_norms_squared = intrinsics.federated_map(_square_value, client_norms) average_client_norm = intrinsics.federated_mean(client_norms, client_weights) average_client_norm_squared = intrinsics.federated_mean( client_norms_squared, client_weights) # TODO(b/197972289): Add SecAgg compatibility to these measurements sum_of_client_weights = intrinsics.federated_sum(client_weights) client_weights_squared = intrinsics.federated_map(_square_value, client_weights) sum_of_client_weights_squared = intrinsics.federated_sum( client_weights_squared) unbiased_std_dev = intrinsics.federated_map( _calculate_unbiased_std_dev, (average_client_norm, average_client_norm_squared, sum_of_client_weights, sum_of_client_weights_squared)) return intrinsics.federated_zip( collections.OrderedDict(average_client_norm=average_client_norm, std_dev_client_norm=unbiased_std_dev))
def _train_one_round(model, federated_data): locally_trained_models = intrinsics.federated_map( _train_on_one_client, collections.OrderedDict([('model', intrinsics.federated_broadcast(model)), ('batches', federated_data)])) return intrinsics.federated_mean(locally_trained_models)
def next_comp(state, value, weight): return measured_process.MeasuredProcessOutput( state=intrinsics.federated_map(_add_one, state), result=intrinsics.federated_mean(value, weight), measurements=intrinsics.federated_zip( collections.OrderedDict(num_clients=intrinsics.federated_sum( intrinsics.federated_value(1, placements.CLIENTS)))))
def comp(temperatures, threshold): return intrinsics.federated_mean( intrinsics.federated_map( count_over, intrinsics.federated_zip([ temperatures, intrinsics.federated_broadcast(threshold) ])), intrinsics.federated_map(count_total, temperatures))
def fed_output(local_outputs): # TODO(b/124070381): Remove need for using num_examples_float here. return collections.OrderedDict( loss=intrinsics.federated_mean( local_outputs.loss, weight=local_outputs.num_examples_float), num_examples=intrinsics.federated_sum( local_outputs.num_examples))
def fed_output(local_outputs): # TODO(b/124070381): Remove need for using num_examples_float here. return { 'num_examples': intrinsics.federated_sum(local_outputs.num_examples), 'loss': intrinsics.federated_mean( local_outputs.loss, weight=local_outputs.num_examples_float), }
def test_federated_mean_with_client_tuple_with_int32_weight(self): values = _mock_data_of_type( computation_types.at_clients( collections.OrderedDict( x=tf.float64, y=tf.float64, ))) weights = _mock_data_of_type(computation_types.at_clients(tf.int32)) val = intrinsics.federated_mean(values, weights) self.assert_value(val, '<x=float64,y=float64>@SERVER')
def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None): clipping_norm_state, agg_state, clipped_count_state = state clipping_norm = clipping_norm_process.report(clipping_norm_state) clients_clipping_norm = intrinsics.federated_broadcast(clipping_norm) # TODO(b/163880757): Remove this when server-only metrics are supported. clipping_norm = intrinsics.federated_mean(clients_clipping_norm) clipped_value, global_norm, was_clipped = intrinsics.federated_map( clip_fn, (value, clients_clipping_norm)) new_clipping_norm_state = clipping_norm_process.next( clipping_norm_state, global_norm) if weight is None: agg_output = inner_agg_process.next(agg_state, clipped_value) else: agg_output = inner_agg_process.next(agg_state, clipped_value, weight) clipped_count_output = clipped_count_agg_process.next( clipped_count_state, was_clipped) new_state = collections.OrderedDict([ (prefix('ing_norm'), new_clipping_norm_state), ('inner_agg', agg_output.state), (prefix('ed_count_agg'), clipped_count_output.state) ]) measurements = collections.OrderedDict([ (prefix('ing'), agg_output.measurements), (prefix('ing_norm'), clipping_norm), (prefix('ed_count'), clipped_count_output.result) ]) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=agg_output.result, measurements=intrinsics.federated_zip(measurements))
def comp(x): return intrinsics.federated_mean(x)
def test_federated_mean_with_string_weight_fails(self): values = _mock_data_of_type(computation_types.at_clients(tf.float32)) weights = _mock_data_of_type(computation_types.at_clients(tf.string)) with self.assertRaises(TypeError): intrinsics.federated_mean(values, weights)
def test_federated_mean_with_client_int32_fails(self): x = _mock_data_of_type(computation_types.at_clients(tf.int32)) with self.assertRaises(TypeError): intrinsics.federated_mean(x)
def test_federated_mean_with_all_equal_client_float32_with_weight(self): federated_all_equal_float = computation_types.FederatedType( tf.float32, placements.CLIENTS, all_equal=True) x = _mock_data_of_type(federated_all_equal_float) val = intrinsics.federated_mean(x, x) self.assert_value(val, 'float32@SERVER')
def test_federated_mean_with_client_float32_without_weight(self): x = _mock_data_of_type(computation_types.at_clients(tf.float32)) val = intrinsics.federated_mean(x) self.assert_value(val, 'float32@SERVER')
def stateless_mean(state, value, weight): empty_metrics = intrinsics.federated_value((), placements.SERVER) return measured_process.MeasuredProcessOutput( state=state, result=intrinsics.federated_mean(value, weight=weight), measurements=empty_metrics)