def next_comp(state, value, weight): return measured_process.MeasuredProcessOutput( state=intrinsics.federated_map(_add_one, state), result=intrinsics.federated_mean(value, weight), measurements=intrinsics.federated_zip( collections.OrderedDict(num_clients=intrinsics.federated_sum( intrinsics.federated_value(1, placements.CLIENTS)))))
def _train_one_round(model, federated_data): locally_trained_models = intrinsics.federated_map( _train_on_one_client, collections.OrderedDict([('model', intrinsics.federated_broadcast(model)), ('batches', federated_data)])) return intrinsics.federated_mean(locally_trained_models)
def next_fn(server_val, client_val): """Defines a series of federated computations compatible with CanonicalForm.""" broadcast_val = intrinsics.federated_broadcast(server_val) values_on_clients = intrinsics.federated_zip((client_val, broadcast_val)) result_on_clients = intrinsics.federated_map(add_two, values_on_clients) aggregated_result = intrinsics.federated_mean(result_on_clients) side_output = intrinsics.federated_value([1, 2, 3, 4, 5], placements.SERVER) return aggregated_result, side_output
def comp(temperatures, threshold): return intrinsics.federated_mean( intrinsics.federated_map( count_over, intrinsics.federated_zip( [temperatures, intrinsics.federated_broadcast(threshold)])), intrinsics.federated_map(count_total, temperatures))
def comp(temperatures, threshold): client_data = [ temperatures, intrinsics.federated_broadcast(threshold) ] result_map = intrinsics.federated_map( count_over, intrinsics.federated_zip(client_data)) count_map = intrinsics.federated_map(count_total, temperatures) return intrinsics.federated_mean(result_map, count_map)
def fed_output(local_outputs): # TODO(b/124070381): Remove need for using num_examples_float here. return { 'num_examples': intrinsics.federated_sum(local_outputs.num_examples), 'loss': intrinsics.federated_mean( local_outputs.loss, weight=local_outputs.num_examples_float), }
def agg_next_fn(state, value, weight): @computations.tf_computation(tf.int32) def add_one(value): return value + 1 return intrinsics.federated_zip( collections.OrderedDict([ ('call_count', intrinsics.federated_map(add_one, state.call_count)) ])), intrinsics.federated_mean(value, weight)
def test_fails_stateful_aggregate_and_process(self): model_weights_type = model_utils.weights_type_from_model( model_examples.LinearRegression) with self.assertRaises(optimizer_utils.DisjointArgumentError): optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, stateful_delta_aggregate_fn=computation_utils.StatefulAggregateFn( initialize_fn=lambda: (), next_fn=lambda state, value, weight=None: # pylint: disable=g-long-lambda (state, intrinsics.federated_mean(value, weight))), aggregation_process=optimizer_utils.build_stateless_mean( model_delta_type=model_weights_type.trainable))
def foo(x): return intrinsics.federated_mean(x)
def foo(v, w): return intrinsics.federated_mean(v, w)
def stateless_mean(state, value, weight): empty_metrics = intrinsics.federated_value((), placements.SERVER) return collections.OrderedDict(state=state, result=intrinsics.federated_mean( value, weight=weight), measurements=empty_metrics)
def simple_weighted_mean(): one_at_clients = intrinsics.federated_value(1.0, placement_literals.CLIENTS) return intrinsics.federated_mean(one_at_clients, weight=one_at_clients)
def trivial_mean(): empty_at_clients = intrinsics.federated_value((), placement_literals.CLIENTS) return intrinsics.federated_mean(empty_at_clients)
def stateless_mean(state, value, weight): empty_metrics = intrinsics.federated_value((), placements.SERVER) return measured_process.MeasuredProcessOutput( state=state, result=intrinsics.federated_mean(value, weight=weight), measurements=empty_metrics)
def simple_mean(): one_at_clients = intrinsics.federated_value(1.0, placements.CLIENTS) return intrinsics.federated_mean(one_at_clients)
def foo(x, y): val = intrinsics.federated_mean(x, y) self.assertIsInstance(val, value_base.Value) return val
def comp(x): return intrinsics.federated_mean(x)
def _(x, y): return intrinsics.federated_mean(x, y)
def _state_incrementing_mean_next(server_state, client_value, weight=None): new_state = intrinsics.federated_map(_add_one, server_state) return (new_state, intrinsics.federated_mean(client_value, weight=weight))