def count_clients_federated(client_data): @tf.function def client_ones_fn(): return tf.ones(shape=[], dtype=tf.int32) client_ones = _bind_tf_function(client_data, client_ones_fn) return intrinsics.federated_sum(client_ones)
def next_fn(strings, val): new_state_fn = computations.tf_computation()( lambda s: tf.concat([s, tf.constant(['abc'])], axis=0)) return MeasuredProcessOutput( intrinsics.federated_map(new_state_fn, strings), intrinsics.federated_sum(val), intrinsics.federated_value(1, placements.SERVER))
def count_clients_federated(client_data): client_ones = intrinsics.federated_value(1, placements.CLIENTS) client_ones = _bind_federated_value( client_data, computation_types.SequenceType(tf.string), client_ones) return intrinsics.federated_sum(client_ones)
def next_comp(state, value, weight): return collections.OrderedDict( state=intrinsics.federated_map(_add_one, state), result=intrinsics.federated_mean(value, weight), measurements=intrinsics.federated_zip( collections.OrderedDict(num_clients=intrinsics.federated_sum( intrinsics.federated_value(1, placements.CLIENTS)))))
def encoded_mean_fn(state, values, weight): weighted_values = intrinsics.federated_map(multiply_fn, [values, weight]) updated_state, summed_decoded_values = encoded_sum_fn( state, weighted_values) summed_weights = intrinsics.federated_sum(weight) decoded_values = intrinsics.federated_map( divide_fn, [summed_decoded_values, summed_weights]) return updated_state, decoded_values
def foo(temperatures, threshold): return intrinsics.federated_sum( intrinsics.federated_map( computations.tf_computation( lambda x, y: tf.to_int32(tf.greater(x, y)), [tf.float32, tf.float32]), [temperatures, intrinsics.federated_broadcast(threshold)]))
def fed_output(local_outputs): # TODO(b/124070381): Remove need for using num_examples_float here. return { 'num_examples': intrinsics.federated_sum(local_outputs.num_examples), 'loss': intrinsics.federated_mean( local_outputs.loss, weight=local_outputs.num_examples_float), }
def foo(temperatures, threshold): val = intrinsics.federated_sum( intrinsics.federated_map( computations.tf_computation( lambda x, y: tf.cast(tf.greater(x, y), tf.int32)), [temperatures, intrinsics.federated_broadcast(threshold)])) self.assertIsInstance(val, value_base.Value) return val
def next_fn(server_state, client_val): """`next` function for `tff.templates.IterativeProcess`.""" server_update = intrinsics.federated_sum(client_val) server_output = intrinsics.federated_value((), placements.SERVER) state_at_clients = intrinsics.federated_broadcast(server_state) lambda_returning_sum = computation_returning_lambda() sum_fn = lambda_returning_sum(1) server_output = sum_fn(state_at_clients) return server_update, server_output
def next_fn(server_state, client_data): """The `next` function for `computation_utils.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) client_input = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates, client_output = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_sum(client_updates) s6 = intrinsics.federated_zip([server_state, unsecure_update]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output, client_output
def next_fn(server_state, client_data): """The `next` function for `computation_utils.IterativeProcess`.""" del server_state # Unused client_updates, client_output = intrinsics.federated_map( work, client_data) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum(client_updates[1], 8) s5 = intrinsics.federated_zip([unsecure_update, secure_update]) new_server_state, server_output = intrinsics.federated_map(update, s5) return new_server_state, server_output, client_output
def encoded_mean_comp(state, values, weight): """Encoded mean federated_computation.""" empty_metrics = intrinsics.federated_value((), placements.SERVER) weighted_values = intrinsics.federated_map(multiply_fn, (values, weight)) updated_state, summed_decoded_values = encoded_sum_fn( state, weighted_values) summed_weights = intrinsics.federated_sum(weight) decoded_values = intrinsics.federated_map( divide_fn, (summed_decoded_values, summed_weights)) return measured_process.MeasuredProcessOutput( state=updated_state, result=decoded_values, measurements=empty_metrics)
def next_fn(server_state, client_data): """The `next` function for `computation_utils.IterativeProcess`.""" client_input = intrinsics.federated_broadcast(server_state) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum(client_updates[1], 8) new_server_state = intrinsics.federated_zip( [unsecure_update, secure_update]) server_output = intrinsics.federated_value([], placements.SERVER) return new_server_state, server_output
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" # No call to `federated_map` with prepare. # No call to `federated_broadcast`. client_updates, client_output = intrinsics.federated_map(work, client_data) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum(client_updates[1], 8) s6 = intrinsics.federated_zip( [server_state, [unsecure_update, secure_update]]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output, client_output
def next_fn(state, value): state = intrinsics.federated_map( computations.tf_computation(lambda x: x + 1), state) result = intrinsics.federated_map( computations.tf_computation( lambda x: tf.nest.map_structure(lambda y: y + 1, x)), intrinsics.federated_sum(value)) measurements = intrinsics.federated_value(MEASUREMENT_CONSTANT, placements.SERVER) return measured_process.MeasuredProcessOutput( state, result, measurements)
def next_fn(server_state, client_val): """`next` function for `tff.utils.IterativeProcess`.""" server_update = intrinsics.federated_zip( collections.OrderedDict( num_clients=count_clients_federated(client_val))) server_output = intrinsics.federated_value((), placements.SERVER) server_output = intrinsics.federated_sum( _bind_tf_function(intrinsics.federated_broadcast(server_state), tf.timestamp)) return server_update, server_output
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) client_input = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum(client_updates[1], 8) s6 = intrinsics.federated_zip( [server_state, [unsecure_update, secure_update]]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) client_input = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates, client_output = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum(client_updates[1], 8) new_server_state = intrinsics.federated_zip( [unsecure_update, secure_update]) # No call to `federated_map` with an `update` function. server_output = intrinsics.federated_value([], placements.SERVER) return new_server_state, server_output, client_output
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" del server_state # Unused # No call to `federated_map` with prepare. # No call to `federated_broadcast`. client_updates = intrinsics.federated_map(work, client_data) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum(client_updates[1], 8) new_server_state = intrinsics.federated_zip( [unsecure_update, secure_update]) # No call to `federated_map` with an `update` function. server_output = intrinsics.federated_value([], placements.SERVER) return new_server_state, server_output
def next_fn(state, value): value_sum_output = value_sum_process.next(state, value) count = intrinsics.federated_sum( intrinsics.federated_value(1, placements.CLIENTS)) mean_value = intrinsics.federated_map( _div, (value_sum_output.result, count)) state = value_sum_output.state measurements = intrinsics.federated_zip( collections.OrderedDict( mean_value=value_sum_output.measurements)) return measured_process.MeasuredProcessOutput( state, mean_value, measurements)
def next_fn(global_state, value, weight): sample_params = intrinsics.federated_broadcast( intrinsics.federated_map(derive_sample_params, global_state)) weighted_value, adj_weight, quantile_record, too_large = ( intrinsics.federated_map(preprocess_value, (sample_params, value, weight))) value_sum = intrinsics.federated_sum(weighted_value) total_weight = intrinsics.federated_sum(adj_weight) quantile_sum = intrinsics.federated_sum(quantile_record) num_zeroed = intrinsics.federated_sum(too_large) mean_value = intrinsics.federated_map(divide_no_nan, (value_sum, total_weight)) new_threshold, new_global_state = intrinsics.federated_map( next_quantile, (quantile_sum, global_state)) measurements = intrinsics.federated_zip( AdaptiveZeroingMetrics(new_threshold, num_zeroed)) return measured_process.MeasuredProcessOutput( state=new_global_state, result=mean_value, measurements=measurements)
def next_fn(server_state, client_data): broadcast_state = intrinsics.federated_broadcast(server_state) @computations.tf_computation(tf.int32, computation_types.SequenceType(tf.float32)) @tf.function def some_transform(x, y): del y # Unused return x + 1 client_update = intrinsics.federated_map(some_transform, (broadcast_state, client_data)) aggregate_update = intrinsics.federated_sum(client_update) server_output = intrinsics.federated_value(1234, placements.SERVER) return aggregate_update, server_output
def one_round_computation(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`, both having `tff.SERVER` placement. """ broadcast_output = broadcast_process.next( server_state.model_broadcast_state, server_state.model) client_outputs = intrinsics.federated_map( _compute_local_training_and_client_delta, (federated_dataset, broadcast_output.result)) if len(aggregation_process.next.type_signature.parameter) == 3: aggregation_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, client_outputs.weights_delta_weight) else: aggregation_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta) new_global_model, new_optimizer_state = intrinsics.federated_map( server_update, (server_state.model, aggregation_output.result, server_state.optimizer_state)) new_server_state = intrinsics.federated_zip( ServerState(new_global_model, new_optimizer_state, aggregation_output.state, broadcast_output.state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) optimizer_outputs = intrinsics.federated_sum( client_outputs.optimizer_output) measurements = intrinsics.federated_zip( collections.OrderedDict( broadcast=broadcast_output.measurements, aggregation=aggregation_output.measurements, train=aggregated_outputs, stat=optimizer_outputs)) return new_server_state, measurements
def _sum_securely(self, value, upper_bound, lower_bound): """Securely sums `value` placed at CLIENTS.""" if self._config_mode == _Config.INT: value = intrinsics.federated_map( _client_shift, (value, intrinsics.federated_broadcast(upper_bound), intrinsics.federated_broadcast(lower_bound))) value = intrinsics.federated_secure_sum(value, self._secagg_bitwidth) num_summands = intrinsics.federated_sum(_client_one()) value = intrinsics.federated_map( _server_shift, (value, lower_bound, num_summands)) return value elif self._config_mode == _Config.FLOAT: return federated_aggregations.secure_quantized_sum( value, lower_bound, upper_bound) else: raise ValueError( f'Unexpected internal config type: {self._config_mode}')
def foo(x): return intrinsics.federated_sum(x)
def foo(x): val = intrinsics.federated_sum(x) self.assertIsInstance(val, value_base.Value) return val
def namedtuple_next_fn(state, client_values): metrics = intrinsics.federated_map(sum_sequence, client_values) metrics = intrinsics.federated_sum(metrics) return learning_process_output(state, metrics)
def odict_next_fn(state, client_values): metrics = intrinsics.federated_map(sum_sequence, client_values) metrics = intrinsics.federated_sum(metrics) return collections.OrderedDict(state=state, metrics=metrics)
def next_fn(state, client_values): metrics = intrinsics.federated_map(sum_sequence, client_values) metrics = intrinsics.federated_sum(metrics) return LearningProcessOutput(state, metrics)
def next_fn(state, client_values, second_state): # pylint: disable=unused-argument metrics = intrinsics.federated_map(sum_sequence, client_values) metrics = intrinsics.federated_sum(metrics) return LearningProcessOutput(state, metrics)