import tensorflow as tf import tensorflow_federated as tff tf.enable_resource_variables() @tff.federated_computation def hello_word(): return "Hello, World!" print(hello_word()) # %% federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS) print(str(federated_float_on_clients.member)) print(str(federated_float_on_clients.placement)) print(str(federated_float_on_clients)) print(federated_float_on_clients.all_equal) print(tff.FederatedType(tf.float32, tff.CLIENTS, all_equal=True)) # %% simple_regression_model_type = ( tff.NamedTupleType([('a', tf.float32), ('b', tf.float32)]) ) print(str(simple_regression_model_type)) print(str(tff.FederatedType(simple_regression_model_type, tff.CLIENTS, all_equal=True))) # %%
model = model_fn() client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) loss = tf.Variable(0.0, trainable=False, dtype=tf.float32) return client_update(model, tf_dataset, server_weights, client_optimizer, loss) # 将服务器更新代码转为tff代码 @tff.tf_computation(model_weights_type) def server_update_fn(mean_client_weights): model = model_fn() return server_update(model, mean_client_weights) # 将数据集结构和模型参数结构转为联邦结构 federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER) federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) # 联邦学习过程 @tff.federated_computation(federated_server_type, federated_dataset_type) def next_fn(server_weights, federated_dataset): # 将服务器模型广播到客户端上 server_weights_at_client = tff.federated_broadcast(server_weights) # 客户端计算更新过程,并更新参数 client_weights, clients_loss = tff.federated_map( client_update_fn, (federated_dataset, server_weights_at_client)) # 服务器平均所有客户端更新的模型参数 mean_client_weights = tff.federated_mean(client_weights)
def build_run_one_round_fn_attacked(server_update_fn, client_update_fn, stateful_delta_aggregate_fn, dummy_model_for_metadata, federated_server_state_type, federated_dataset_type): """Builds a `tff.federated_computation` for a round of training. Args: server_update_fn: A function for updates in the server. client_update_fn: A function for updates in the clients. stateful_delta_aggregate_fn: A 'tff.computation'that takes in model deltas placed@CLIENTS to an aggregated model delta placed@SERVER. dummy_model_for_metadata: A dummy `tff.learning.Model`. federated_server_state_type: type_signature of federated server state. federated_dataset_type: type_signature of federated dataset. Returns: A `tff.federated_computation` for a round of training. """ federated_bool_type = tff.FederatedType(tf.bool, tff.CLIENTS) @tff.federated_computation(federated_server_state_type, federated_dataset_type, federated_dataset_type, federated_bool_type) def run_one_round(server_state, federated_dataset, malicious_dataset, malicious_clients): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. malicious_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. consisting of malicious datasets. malicious_clients: A federated `tf.bool` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, malicious_dataset, malicious_clients, client_model)) weight_denom = client_outputs.weights_delta_weight new_delta_aggregate_state, round_model_delta = stateful_delta_aggregate_fn( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) server_state = tff.federated_map( server_update_fn, (server_state, round_model_delta, new_delta_aggregate_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs return run_one_round
def build_fixed_clip_norm_mean_process( *, clip_norm: float, model_update_type: Union[tff.NamedTupleType, tff.TensorType], ) -> tff.templates.MeasuredProcess: """Returns process that clips the client deltas before averaging. The returned `MeasuredProcess` has a next function with the TFF type signature: ``` (<()@SERVER, {model_update_type}@CLIENTS> -> <state=()@SERVER, result=model_update_type@SERVER, measurements=NormClippedAggregationMetrics@SERVER>) ``` Args: clip_norm: the clip norm to apply to the global norm of the model update. See https://www.tensorflow.org/api_docs/python/tf/clip_by_global_norm for details. model_update_type: a `tff.Type` describing the shape and type of the value that will be clipped and averaged. Returns: A `tff.templates.MeasuredProcess` with the type signature detailed above. """ @tff.federated_computation def initialize_fn(): return tff.federated_value((), tff.SERVER) @tff.federated_computation( tff.FederatedType((), tff.SERVER), tff.FederatedType(model_update_type, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS)) def next_fn(state, deltas, weights): @tff.tf_computation(model_update_type) def clip_by_global_norm(update): clipped_update, global_norm = tf.clip_by_global_norm( tf.nest.flatten(update), tf.constant(clip_norm)) was_clipped = tf.cond( tf.greater(global_norm, tf.constant(clip_norm)), lambda: tf.constant(1), lambda: tf.constant(0), ) clipped_update = tf.nest.pack_sequence_as(update, clipped_update) return clipped_update, global_norm, was_clipped clipped_deltas, client_norms, client_was_clipped = tff.federated_map( clip_by_global_norm, deltas) return collections.OrderedDict( state=state, result=tff.federated_mean(clipped_deltas, weight=weights), measurements=tff.federated_zip( NormClippedAggregationMetrics( max_global_norm=tff.utils.federated_max(client_norms), num_clipped=tff.federated_sum(client_was_clipped), ))) return tff.templates.MeasuredProcess( initialize_fn=initialize_fn, next_fn=next_fn)
def build_federated_averaging_process( model_fn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1)): """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `simple_fedavg_tf.KerasModelWrapper`. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for server update. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for client update. Returns: A `tff.utils.IterativeProcess`. """ dummy_model = model_fn() @tff.tf_computation def server_init_tf(): model = model_fn() server_optimizer = server_optimizer_fn() _initialize_optimizer_vars(model, server_optimizer) return ServerState(model_weights=model.weights, optimizer_state=server_optimizer.variables(), round_num=0) server_state_type = server_init_tf.type_signature.result model_weights_type = server_state_type.model_weights @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_fn(server_state, model_delta): model = model_fn() server_optimizer = server_optimizer_fn() _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta) @tff.tf_computation(server_state_type) def server_message_fn(server_state): return build_server_broadcast_message(server_state) server_message_type = server_message_fn.type_signature.result tf_dataset_type = tff.SequenceType(dummy_model.input_spec) @tff.tf_computation(tf_dataset_type, server_message_type) def client_update_fn(tf_dataset, server_message): model = model_fn() client_optimizer = client_optimizer_fn() return client_update(model, tf_dataset, server_message, client_optimizer) federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER) federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ server_message = tff.federated_map(server_message_fn, server_state) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, server_message_at_client)) weight_denom = client_outputs.client_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta)) round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom) return server_state, round_loss_metric @tff.federated_computation def server_init_tff(): """Orchestration logic for server model initialization.""" return tff.federated_value(server_init_tf(), tff.SERVER) return tff.utils.IterativeProcess(initialize_fn=server_init_tff, next_fn=run_one_round)
#apply the gradient using client optimizer client_optimizer.apply_gradients(grads_and_vars) return client_weights @tff.tf_computation(tf_dataset_type, model_weights_type) def client_update_fn(tf_dataset, server_weights): tff_model = wrap_model_with_tff(keras_model(), input_spec) client_optimizer = tf.keras.optimizers.Adam() return client_update(tff_model, tf_dataset, server_weights, client_optimizer) federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER) federated_dataset_data = tff.FederatedType(tf_dataset_type, tff.CLIENTS) @tff.federated_computation(federated_server_type, federated_dataset_data) def next_fn(server_weights, federated_dataset): # Send server weights to clients server_weights_to_clients = tff.federated_broadcast(server_weights) # Each client computes their updated weights client_weights = tff.federated_map( client_update_fn, (federated_dataset, server_weights_to_clients)) # Client mean mean_client_weights = tff.federated_mean(client_weights)
def create_trainer(batch_size, step_size): """Constructs a trainer for the given batch size. Args: batch_size: The size of a single data batch. step_size: The step size to use during training. Returns: An instance of `Trainer`. """ batch_type = tff.to_type( collections.OrderedDict([ ('pixels', tff.TensorType(np.float32, (batch_size, 784))), ('labels', tff.TensorType(np.int32, (batch_size, ))) ])) model_type = tff.to_type( collections.OrderedDict([('weights', tff.TensorType(np.float32, (784, 10))), ('bias', tff.TensorType(np.float32, (10, )))])) @tff.experimental.jax_computation def create_zero_model(): weights = jax.numpy.zeros((784, 10), dtype=np.float32) bias = jax.numpy.zeros((10, ), dtype=np.float32) return collections.OrderedDict([('weights', weights), ('bias', bias)]) def generate_random_batches(num_batches): for _ in range(num_batches): pixels = np.random.uniform(low=0.0, high=1.0, size=(batch_size, 784)).astype(np.float32) labels = np.random.randint(low=0, high=9, size=(batch_size, ), dtype=np.int32) yield collections.OrderedDict([('pixels', pixels), ('labels', labels)]) def _loss_fn(model, batch): y = jax.nn.softmax( jax.numpy.add(jax.numpy.matmul(batch['pixels'], model['weights']), model['bias'])) targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10) return -jax.numpy.mean( jax.numpy.sum(targets * jax.numpy.log(y), axis=1)) @tff.experimental.jax_computation(model_type, batch_type) def train_on_one_batch(model, batch): grads = jax.api.grad(_loss_fn)(model, batch) return collections.OrderedDict([(k, model[k] - step_size * grads[k]) for k in ['weights', 'bias']]) @tff.federated_computation(model_type, tff.SequenceType(batch_type)) def train_on_one_client(model, batches): return tff.sequence_reduce(batches, model, train_on_one_batch) local_training_process = tff.templates.IterativeProcess( initialize_fn=create_zero_model, next_fn=train_on_one_client) # TODO(b/175888145): Switch to a simple tff.federated_mean after finding a # way to reduce reliance on the auto-generated TF bits in the executor stack # for the GENERIC_PLUS and similar intrinsics. @tff.experimental.jax_computation def create_zero_count(): return np.int32(0) @tff.experimental.jax_computation def create_one_count(): return np.int32(1) @tff.experimental.jax_computation(model_type, model_type) def combine_two_models(x, y): return collections.OrderedDict([ ('weights', jax.numpy.add(x['weights'], y['weights'])), ('bias', jax.numpy.add(x['bias'], y['bias'])) ]) @tff.experimental.jax_computation(model_type, np.int32) def divide_model_by_count(model, count): multiplier = 1.0 / count.astype(np.float32) return collections.OrderedDict([ ('weights', jax.numpy.multiply(model['weights'], multiplier)), ('bias', jax.numpy.multiply(model['bias'], multiplier)) ]) @tff.experimental.jax_computation(np.int32, np.int32) def combine_two_counts(x, y): return jax.numpy.add(x, y) @tff.federated_computation def make_zero_model_and_count(): return collections.OrderedDict([('model', create_zero_model()), ('count', create_zero_count())]) model_and_count_type = make_zero_model_and_count.type_signature.result @tff.federated_computation(model_and_count_type, model_type) def accumulate(arg): # TODO(b/175888145): Diagnose the newly emergent problem with tuple arg # handling that gets in the way by forcing named elements here at input # (i.e., we can't just declare `def accumulate(accumulator, model)` for # reasons that yet need to be understood). accumulator = arg[0] model = arg[1] return collections.OrderedDict([ ('model', combine_two_models(accumulator['model'], model)), ('count', combine_two_counts(accumulator['count'], create_one_count())) ]) @tff.federated_computation(model_and_count_type, model_and_count_type) def merge(arg): x = arg[0] y = arg[1] return collections.OrderedDict([ ('model', combine_two_models(x['model'], y['model'])), ('count', combine_two_counts(x['count'], y['count'])) ]) @tff.federated_computation(model_and_count_type) def report(x): return divide_model_by_count(x['model'], x['count']) @tff.federated_computation def create_zero_model_on_server(): return tff.federated_eval(create_zero_model, tff.SERVER) @tff.federated_computation(tff.FederatedType(model_type, tff.SERVER), tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS)) def train_one_round(model, federated_data): locally_trained_models = tff.federated_map( train_on_one_client, collections.OrderedDict([('model', tff.federated_broadcast(model)), ('batches', federated_data)])) return tff.federated_aggregate(locally_trained_models, make_zero_model_and_count(), accumulate, merge, report) federated_averaging_process = tff.templates.IterativeProcess( initialize_fn=create_zero_model_on_server, next_fn=train_one_round) compute_loss_on_one_batch = tff.experimental.jax_computation( _loss_fn, model_type, batch_type) return Trainer(create_initial_model=create_zero_model, generate_random_batches=generate_random_batches, train_on_one_batch=train_on_one_batch, train_on_one_client=train_on_one_client, local_training_process=local_training_process, train_one_round=train_one_round, federated_averaging_process=federated_averaging_process, compute_loss_on_one_batch=compute_loss_on_one_batch)
def build_fed_avg_process( model_fn: ModelBuilder, client_optimizer_fn: OptimizerBuilder, client_lr: Union[float, LRScheduleFn] = 0.1, server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD, server_lr: Union[float, LRScheduleFn] = 1.0, client_weight_fn: Optional[ClientWeightFn] = None, ) -> tff.templates.IterativeProcess: """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A function that accepts a `learning_rate` keyword argument and returns a `tf.keras.optimizers.Optimizer` instance. client_lr: A scalar learning rate or a function that accepts a float `round_num` argument and returns a learning rate. server_optimizer_fn: A function that accepts a `learning_rate` argument and returns a `tf.keras.optimizers.Optimizer` instance. server_lr: A scalar learning rate or a function that accepts a float `round_num` argument and returns a learning rate. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ client_lr_schedule = client_lr if not callable(client_lr_schedule): client_lr_schedule = lambda round_num: client_lr server_lr_schedule = server_lr if not callable(server_lr_schedule): server_lr_schedule = lambda round_num: server_lr dummy_model = model_fn() server_init_tf = build_server_init_fn( model_fn, # Initialize with the learning rate for round zero. lambda: server_optimizer_fn(server_lr_schedule(0))) server_state_type = server_init_tf.type_signature.result model_weights_type = server_state_type.model round_num_type = server_state_type.round_num tf_dataset_type = tff.SequenceType(dummy_model.input_spec) model_input_type = tff.SequenceType(dummy_model.input_spec) @tff.tf_computation(model_input_type, model_weights_type, round_num_type) def client_update_fn(tf_dataset, initial_model_weights, round_num): client_lr = client_lr_schedule(round_num) client_optimizer = client_optimizer_fn(client_lr) client_update = create_client_update_fn() return client_update(model_fn(), tf_dataset, initial_model_weights, client_optimizer, client_weight_fn) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_fn(server_state, model_delta): model = model_fn() server_lr = server_lr_schedule(server_state.round_num) server_optimizer = server_optimizer_fn(server_lr) # We initialize the server optimizer variables to avoid creating them # within the scope of the tf.function server_update. _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta) @tff.federated_computation( tff.FederatedType(server_state_type, tff.SERVER), tff.FederatedType(tf_dataset_type, tff.CLIENTS)) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num)) client_weight = client_outputs.client_weight model_delta = tff.federated_mean( client_outputs.weights_delta, weight=client_weight) server_state = tff.federated_map(server_update_fn, (server_state, model_delta)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs @tff.federated_computation def initialize_fn(): return tff.federated_value(server_init_tf(), tff.SERVER) return tff.templates.IterativeProcess( initialize_fn=initialize_fn, next_fn=run_one_round)
def build_fed_avg_process( model_fn: ModelBuilder, client_optimizer_fn: OptimizerBuilder, client_lr: Union[float, LRScheduleFn] = 0.1, server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD, server_lr: Union[float, LRScheduleFn] = 1.0, client_weight_fn: Optional[ClientWeightFn] = None, dataset_preprocess_comp: Optional[tff.Computation] = None, ) -> FederatedAveragingProcessAdapter: """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A function that accepts a `learning_rate` keyword argument and returns a `tf.keras.optimizers.Optimizer` instance. client_lr: A scalar learning rate or a function that accepts a float `round_num` argument and returns a learning rate. server_optimizer_fn: A function that accepts a `learning_rate` argument and returns a `tf.keras.optimizers.Optimizer` instance. server_lr: A scalar learning rate or a function that accepts a float `round_num` argument and returns a learning rate. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. dataset_preprocess_comp: Optional `tff.Computation` that sets up a data pipeline on the clients. The computation must take a squence of values and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If `None`, no dataset preprocessing is applied. Returns: A `FederatedAveragingProcessAdapter`. """ client_lr_schedule = client_lr if not callable(client_lr_schedule): client_lr_schedule = lambda round_num: client_lr server_lr_schedule = server_lr if not callable(server_lr_schedule): server_lr_schedule = lambda round_num: server_lr dummy_model = model_fn() server_init_tf = build_server_init_fn( model_fn, # Initialize with the learning rate for round zero. lambda: server_optimizer_fn(server_lr_schedule(0))) server_state_type = server_init_tf.type_signature.result model_weights_type = server_state_type.model round_num_type = server_state_type.round_num if dataset_preprocess_comp is not None: tf_dataset_type = dataset_preprocess_comp.type_signature.parameter model_input_type = tff.SequenceType(dummy_model.input_spec) preprocessed_dataset_type = dataset_preprocess_comp.type_signature.result if not model_input_type.is_assignable_from(preprocessed_dataset_type): raise TypeError( 'Supplied `dataset_preprocess_comp` does not yield ' 'batches that are compatible with the model constructed ' 'by `model_fn`. Model expects type {m}, but dataset ' 'yields type {d}.'.format(m=model_input_type, d=preprocessed_dataset_type)) else: tf_dataset_type = tff.SequenceType(dummy_model.input_spec) model_input_type = tff.SequenceType(dummy_model.input_spec) @tff.tf_computation(model_input_type, model_weights_type, round_num_type) def client_update_fn(tf_dataset, initial_model_weights, round_num): client_lr = client_lr_schedule(round_num) client_optimizer = client_optimizer_fn(client_lr) client_update = create_client_update_fn() return client_update(model_fn(), tf_dataset, initial_model_weights, client_optimizer, client_weight_fn) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_fn(server_state, model_delta): model = model_fn() server_lr = server_lr_schedule(server_state.round_num) server_optimizer = server_optimizer_fn(server_lr) # We initialize the server optimizer variables to avoid creating them # within the scope of the tf.function server_update. _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta) @tff.federated_computation(tff.FederatedType(server_state_type, tff.SERVER), tff.FederatedType(tf_dataset_type, tff.CLIENTS)) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) if dataset_preprocess_comp is not None: federated_dataset = tff.federated_map(dataset_preprocess_comp, federated_dataset) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num)) client_weight = client_outputs.client_weight model_delta = tff.federated_mean(client_outputs.weights_delta, weight=client_weight) server_state = tff.federated_map(server_update_fn, (server_state, model_delta)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_tuple(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs @tff.federated_computation def initialize_fn(): return tff.federated_value(server_init_tf(), tff.SERVER) tff_iterative_process = tff.templates.IterativeProcess( initialize_fn=initialize_fn, next_fn=run_one_round) return FederatedAveragingProcessAdapter(tff_iterative_process)
import collections import time import numpy as np import grpc import sys import absl import tensorflow as tf import tensorflow_federated as tff import nest_asyncio nest_asyncio.apply() @tff.tf_computation(tf.int64) @tf.function def add_one(n): tf.print("Hello: ", n, output_stream=absl.logging.info) return tf.add(n, 1) @tff.federated_computation(tff.FederatedType(tf.int64, tff.CLIENTS)) def add_one_on_clients(federated_n): return tff.federated_map(add_one, federated_n) print(add_one_on_clients([1]))
def __attrs_post_init__(self): self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input) self.real_data_type = tensor_spec_for_batch(self.dummy_real_data) # Model-weights based types self._generator = self.generator_model_fn() _ = self._generator(self.dummy_gen_input) py_typecheck.check_type(self._generator, tf.keras.models.Model) self._discriminator = self.discriminator_model_fn() _ = self._discriminator(self.dummy_real_data) py_typecheck.check_type(self._discriminator, tf.keras.models.Model) self._state_gen_opt = self.state_gen_optimizer_fn(1) self._state_disc_opt = self.state_disc_optimizer_fn(1) gan_training_tf_fns.initialize_optimizer_vars(self._generator, self._state_gen_opt) gan_training_tf_fns.initialize_optimizer_vars(self._discriminator, self._state_disc_opt) self._counters = collections.OrderedDict({ 'num_discriminator_train_examples': tf.constant(0), 'num_generator_train_examples': tf.constant(0), 'num_rounds': tf.constant(0), }) def vars_to_type(var_struct): # TODO(b/131681951): read_value() shouldn't be needed return tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor( tf.cast(v.read_value(), tf.float32)), var_struct) def vars_to_type_counter(var_struct): # TODO(b/131681951): read_value() shouldn't be needed return tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor(v), var_struct) self.discriminator_weights_type = vars_to_type( gan_training_tf_fns._weights(self._discriminator)) self.generator_weights_type = vars_to_type( gan_training_tf_fns._weights(self._generator)) self.state_gen_opt_weights_type = vars_to_type( self._state_gen_opt.variables()) self.state_disc_opt_weights_type = vars_to_type( self._state_disc_opt.variables()) self.counters_type = vars_to_type_counter(self._counters) self.from_server_type = gan_training_tf_fns.FromServer( generator_weights=self.generator_weights_type, discriminator_weights=self.discriminator_weights_type, state_gen_optimizer_weights=self.state_gen_opt_weights_type, state_disc_optimizer_weights=self.state_disc_opt_weights_type, counters=self.counters_type) self.client_gen_input_type = tff.FederatedType( tff.SequenceType(self.gen_input_type), tff.CLIENTS) self.client_real_data_type = tff.FederatedType( tff.SequenceType(self.real_data_type), tff.CLIENTS) self.server_gen_input_type = tff.FederatedType( tff.SequenceType(self.gen_input_type), tff.SERVER) # Right now, the logic in this library is effectively "if DP use stateful # aggregator, else don't use stateful aggregator". An alternative # formulation would be to always use a stateful aggregator, but when not # using DP default the aggregator to be a stateless mean, e.g., # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283. # This change will be easier to make if the tff.StatefulAggregateFn is # modified to have a property that gives the type of the aggregation state # (i.e., what we're storing in self.dp_averaging_state_type). if self.train_discriminator_dp_average_query is not None: self.dp_averaging_fn, self.dp_averaging_state_type = ( tff.utils.build_dp_aggregate( query=self.train_discriminator_dp_average_query, value_type_fn=lambda value: self.discriminator_weights_type ))
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Simple temperature sensor example in TFF.""" import numpy as np import tensorflow as tf import tensorflow_federated as tff @tff.tf_computation(tff.SequenceType(tf.float32), tf.float32) def count_over(ds, t): return ds.reduce(np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32)) @tff.tf_computation(tff.SequenceType(tf.float32)) def count_total(ds): return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0) @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.float32), tff.CLIENTS), tff.FederatedType(tf.float32, tff.SERVER)) def mean_over_threshold(temperatures, threshold): client_data = tff.federated_broadcast(threshold) client_data = tff.federated_zip([temperatures, client_data]) result_map = tff.federated_map(count_over, client_data) count_map = tff.federated_map(count_total, temperatures) return tff.federated_mean(result_map, count_map)
def build_jax_federated_averaging_process(batch_type, model_type, loss_fn, step_size): """Constructs an iterative process that implements simple federated averaging. Args: batch_type: An instance of `tff.Type` that represents the type of a single batch of data to use for training. This type should be constructed with standard Python containers (such as `collections.OrderedDict`) of the sort that are expected as parameters to `loss_fn`. model_type: An instance of `tff.Type` that represents the type of the model. Similarly to `batch_size`, this type should be constructed with standard Python containers (such as `collections.OrderedDict`) of the sort that are expected as parameters to `loss_fn`. loss_fn: A loss function for the model. Must be a Python function that takes two parameters, one of them being the model, and the other being a single batch of data (with types matching `batch_type` and `model_type`). step_size: The step size to use during training (an `np.float32`). Returns: An instance of `tff.templates.IterativeProcess` that implements federated training in JAX. """ batch_type = tff.to_type(batch_type) model_type = tff.to_type(model_type) # py_typecheck.check_type(batch_type, computation_types.Type) # py_typecheck.check_type(model_type, computation_types.Type) # py_typecheck.check_callable(loss_fn) # py_typecheck.check_type(step_size, np.float) def _tensor_zeros(tensor_type): return jax.numpy.zeros( tensor_type.shape.dims, dtype=tensor_type.dtype.as_numpy_dtype) @tff.jax_computation def _create_zero_model(): model_zeros = tff.structure.map_structure(_tensor_zeros, model_type) return tff.types.type_to_py_container(model_zeros, model_type) @tff.federated_computation def _create_zero_model_on_server(): return tff.federated_eval(_create_zero_model, tff.SERVER) def _apply_update(model_param, param_delta): return model_param - step_size * param_delta @tff.jax_computation(model_type, batch_type) def _train_on_one_batch(model, batch): params = tff.structure.flatten( tff.structure.from_container(model, recursive=True)) grads = tff.structure.flatten( tff.structure.from_container(jax.grad(loss_fn)(model, batch))) updated_params = [_apply_update(x, y) for (x, y) in zip(params, grads)] trained_model = tff.structure.pack_sequence_as(model_type, updated_params) return tff.types.type_to_py_container(trained_model, model_type) local_dataset_type = tff.SequenceType(batch_type) @tff.federated_computation(model_type, local_dataset_type) def _train_on_one_client(model, batches): return tff.sequence_reduce(batches, model, _train_on_one_batch) @tff.federated_computation( tff.FederatedType(model_type, tff.SERVER), tff.FederatedType(local_dataset_type, tff.CLIENTS)) def _train_one_round(model, federated_data): locally_trained_models = tff.federated_map( _train_on_one_client, collections.OrderedDict([('model', tff.federated_broadcast(model)), ('batches', federated_data)])) return tff.federated_mean(locally_trained_models) return tff.templates.IterativeProcess( initialize_fn=_create_zero_model_on_server, next_fn=_train_one_round)
def build_federated_averaging_process( model_fn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1)): """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `tff.learning.TrainableModel`. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for server update. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for client update. Returns: A `tff.utils.IterativeProcess`. """ dummy_model = model_fn( ) # TODO(b/144510813): try remove dependency on dummy model @tff.tf_computation def server_init_tf(): model = model_fn() server_optimizer = server_optimizer_fn() _initialize_optimizer_vars(model, server_optimizer) return ServerState(model=model.weights, optimizer_state=server_optimizer.variables()) server_state_type = server_init_tf.type_signature.result model_weights_type = server_state_type.model @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_fn(server_state, model_delta): model = model_fn() server_optimizer = server_optimizer_fn() _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta) tf_dataset_type = tff.SequenceType(dummy_model.input_spec) @tff.tf_computation(tf_dataset_type, model_weights_type) def client_update_fn(tf_dataset, initial_model_weights): model = model_fn() client_optimizer = client_optimizer_fn() return client_update(model, tf_dataset, initial_model_weights, client_optimizer) federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER) federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map(client_update_fn, (federated_dataset, client_model)) weight_denom = client_outputs.client_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs return tff.utils.IterativeProcess(initialize_fn=tff.federated_computation( lambda: tff.federated_value(server_init_tf(), tff.SERVER)), next_fn=run_one_round)
def build_fed_avg_process( model_fn: ModelBuilder, client_optimizer_fn: OptimizerBuilder, client_lr: Union[float, LRScheduleFn] = 0.1, server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD, server_lr: Union[float, LRScheduleFn] = 1.0, aggregation_process: Optional[tff.templates.MeasuredProcess] = None, ) -> tff.templates.IterativeProcess: """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A function that accepts a `learning_rate` keyword argument and returns a `tf.keras.optimizers.Optimizer` instance. client_lr: A scalar learning rate or a function that accepts a float `round_num` argument and returns a learning rate. server_optimizer_fn: A function that accepts a `learning_rate` argument and returns a `tf.keras.optimizers.Optimizer` instance. server_lr: A scalar learning rate or a function that accepts a float `round_num` argument and returns a learning rate. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ client_lr_schedule = client_lr if not callable(client_lr_schedule): client_lr_schedule = lambda round_num: client_lr server_lr_schedule = server_lr if not callable(server_lr_schedule): server_lr_schedule = lambda round_num: server_lr with tf.Graph().as_default(): dummy_model = model_fn() model_weights_type = model_utils.weights_type_from_model(dummy_model) dummy_optimizer = server_optimizer_fn() _initialize_optimizer_vars(dummy_model, dummy_optimizer) optimizer_variable_type = tff.framework.type_from_tensors( dummy_optimizer.variables()) initialize_computation = build_server_init_fn( model_fn=model_fn, # Initialize with the learning rate for round zero. server_optimizer_fn=lambda: server_optimizer_fn(server_lr_schedule(0)), aggregation_process=aggregation_process) #model_weights_type = tff.framework.type_from_tensors(_get_weights(dummy_model).trainable) round_num_type = tf.float32 tf_dataset_type = tff.SequenceType(dummy_model.input_spec) model_input_type = tff.SequenceType(dummy_model.input_spec) client_weight_type = tf.float32 aggregation_state_type = aggregation_process.initialize.type_signature.result.member server_state_type = ServerState( model=model_weights_type, optimizer_state=optimizer_variable_type, round_num=round_num_type, aggregation_state=aggregation_state_type, ) @tff.tf_computation(model_input_type, model_weights_type, round_num_type) def client_update_fn(tf_dataset, initial_model_weights, round_num): client_lr = client_lr_schedule(round_num) client_optimizer = client_optimizer_fn(client_lr) client_update = create_client_update_fn() return client_update(model_fn(), tf_dataset, initial_model_weights, client_optimizer) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_fn(server_state, model_delta): model = model_fn() server_lr = server_lr_schedule(server_state.round_num) server_optimizer = server_optimizer_fn(server_lr) # We initialize the server optimizer variables to avoid creating them # within the scope of the tf.function server_update. _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta) # @tff.tf_computation(tf.float32, tf.float32) # def local_mul(weight, participated): # return tf.math.multiply(weight, participated) @tff.federated_computation(tff.FederatedType(server_state_type, tff.SERVER), tff.FederatedType(tf_dataset_type, tff.CLIENTS), tff.FederatedType(client_weight_type, tff.CLIENTS)) def run_one_round(server_state, federated_dataset, client_weight): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num)) #client_weight = client_outputs.client_weight # model_delta = tff.federated_mean( # client_outputs.weights_delta, weight=client_weight) participant_client_weight = tff.federated_map( tff.tf_computation(lambda x, y: x * y), (client_weight, client_outputs.client_weight)) aggregation_output = aggregation_process.next( server_state.aggregation_state, client_outputs.weights_delta, participant_client_weight) server_state = tff.federated_map( server_update_fn, (server_state, aggregation_output.result)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs # @tff.federated_computation # def initialize_fn(): # return tff.federated_value(server_init_tf(), tff.SERVER) return tff.templates.IterativeProcess(initialize_fn=initialize_computation, next_fn=run_one_round)
def test_build_with_preprocess_funtion(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.FederatedType( tff.SequenceType(test_dataset.element_spec), tff.CLIENTS) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return collections.OrderedDict(x=[float(x) * 1.0], y=[float(x) * 3.0 + 1.0]) return ds.map(to_batch).repeat().batch(2).take(3) client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, dataset_preprocess_comp=preprocess_dataset) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))
async def _encrypt_values_on_singleton(self, val, sender, receiver): ### # we can safely assume sender has cardinality=1 when receiver is CLIENTS ### # Case 1: receiver=CLIENTS # plaintext: Fed(Tensor, sender, all_equal=True) # pk_receiver: Fed(Tuple(Tensor), sender, all_equal=True) # sk_sender: Fed(Tensor, sender, all_equal=True) # Returns: # encrypted_values: Tuple(Fed(Tensor, sender, all_equal=True)) ### ### Check proper key placement sk_sender = self.key_references.get_secret_key(sender) pk_receiver = self.key_references.get_public_key(receiver) type_analysis.check_federated_type(sk_sender.type_signature, placement=sender) assert sk_sender.type_signature.placement is sender assert pk_receiver.type_signature.placement is sender ### Check placement cardinalities rcv_children = self.strategy._get_child_executors(receiver) snd_children = self.strategy._get_child_executors(sender) py_typecheck.check_len(snd_children, 1) snd_child = snd_children[0] ### Check value cardinalities type_analysis.check_federated_type(val.type_signature, placement=sender) py_typecheck.check_len(val.internal_representation, 1) py_typecheck.check_type(pk_receiver.type_signature.member, tff.StructType) py_typecheck.check_len(pk_receiver.internal_representation, len(rcv_children)) py_typecheck.check_len(sk_sender.internal_representation, 1) ### Materialize encryptor function definition & type spec input_type = val.type_signature.member self._input_type_cache = input_type pk_rcv_type = pk_receiver.type_signature.member sk_snd_type = sk_sender.type_signature.member pk_element_type = pk_rcv_type[0] encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type) encryptor_proto, encryptor_type = utils.materialize_computation_from_cache( sodium_comp.make_encryptor, self._encryptor_cache, encryptor_arg_spec) ### Prepare encryption arguments v = val.internal_representation[0] sk = sk_sender.internal_representation[0] ### Encrypt values and return them encryptor_fn = await snd_child.create_value(encryptor_proto, encryptor_type) encryptor_args = await asyncio.gather(*[ snd_child.create_struct([v, this_pk, sk]) for this_pk in pk_receiver.internal_representation ]) encrypted_values = await asyncio.gather(*[ snd_child.create_call(encryptor_fn, arg) for arg in encryptor_args ]) encrypted_value_types = [encryptor_type.result] * len(encrypted_values) return federated_resolving_strategy.FederatedResolvingStrategyValue( structure.from_container(encrypted_values), tff.StructType([ tff.FederatedType(evt, sender, all_equal=False) for evt in encrypted_value_types ]))
def build_fed_avg_process(model_fn, client_lr_callback, client_callback_update_fn, server_lr_callback, server_callback_update_fn, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, client_weight_fn=None): """Builds the TFF computations for FedAvg with learning rate decay. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_lr_callback: A `ReduceLROnPlateau` callback. client_callback_update_fn: A function that updates the client callback. server_lr_callback: A `ReduceLROnPlateau` callback. server_callback_update_fn: A function that updates the server callback. client_optimizer_fn: A function that accepts a `learning_rate` keyword argument and returns a `tf.keras.optimizers.Optimizer` instance. server_optimizer_fn: A function that accepts a `learning_rate` argument and returns a `tf.keras.optimizers.Optimizer` instance. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ dummy_model = model_fn() client_monitor = client_lr_callback.monitor server_monitor = server_lr_callback.monitor server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn, client_lr_callback, server_lr_callback) server_state_type = server_init_tf.type_signature.result model_weights_type = server_state_type.model tf_dataset_type = tff.SequenceType(dummy_model.input_spec) client_lr_type = server_state_type.client_lr_callback.learning_rate client_monitor_value_type = server_state_type.client_lr_callback.best server_monitor_value_type = server_state_type.server_lr_callback.best @tff.tf_computation(tf_dataset_type, model_weights_type, client_lr_type) def client_update_fn(tf_dataset, initial_model_weights, client_lr): client_optimizer = client_optimizer_fn(learning_rate=client_lr) initial_model_output = get_client_output(model_fn(), tf_dataset, initial_model_weights) client_state = client_update(model_fn(), tf_dataset, initial_model_weights, client_optimizer, client_weight_fn) return tff.utils.update_state( client_state, initial_model_output=initial_model_output) @tff.tf_computation(server_state_type, model_weights_type.trainable, client_monitor_value_type, server_monitor_value_type) def server_update_fn(server_state, model_delta, client_monitor_value, server_monitor_value): model = model_fn() server_lr = server_state.server_lr_callback.learning_rate server_optimizer = server_optimizer_fn(learning_rate=server_lr) # We initialize the server optimizer variables to avoid creating them # within the scope of the tf.function server_update. _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta, client_monitor_value, client_callback_update_fn, server_monitor_value, server_callback_update_fn) @tff.federated_computation(tff.FederatedType(server_state_type, tff.SERVER), tff.FederatedType(tf_dataset_type, tff.CLIENTS)) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Note that in addition to updating the server weights according to the client model weight deltas, we extract metrics (governed by the `monitor` attribute of the `client_lr_callback` and `server_lr_callback` attributes of the `server_state`) and use these to update the client learning rate callbacks. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation` before and during local client training. """ client_model = tff.federated_broadcast(server_state.model) client_lr = tff.federated_broadcast( server_state.client_lr_callback.learning_rate) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_lr)) client_weight = client_outputs.client_weight aggregated_gradients = tff.federated_mean( client_outputs.accumulated_gradients, weight=client_weight) initial_aggregated_outputs = dummy_model.federated_output_computation( client_outputs.initial_model_output) if isinstance(initial_aggregated_outputs.type_signature, tff.StructType): initial_aggregated_outputs = tff.federated_zip( initial_aggregated_outputs) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) client_monitor_value = initial_aggregated_outputs[client_monitor] server_monitor_value = initial_aggregated_outputs[server_monitor] server_state = tff.federated_map( server_update_fn, (server_state, aggregated_gradients, client_monitor_value, server_monitor_value)) result = collections.OrderedDict( before_training=initial_aggregated_outputs, during_training=aggregated_outputs) return server_state, result @tff.federated_computation def initialize_fn(): return tff.federated_value(server_init_tf(), tff.SERVER) return tff.templates.IterativeProcess(initialize_fn=initialize_fn, next_fn=run_one_round)
def __init__( self, model_fn, m, n, j_max, importance_sampling, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1), ): """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `simple_fedavg_tf.KerasModelWrapper`. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for server update. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for client update. Returns: A `tff.templates.IterativeProcess`. """ dummy_model = model_fn() @tff.tf_computation def server_init_tf(): model = model_fn() server_optimizer = server_optimizer_fn() _initialize_optimizer_vars(model, server_optimizer) return ServerState(model_weights=model.weights, optimizer_state=server_optimizer.variables(), round_num=0) server_state_type = server_init_tf.type_signature.result model_weights_type = server_state_type.model_weights @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_fn(server_state, model_delta): model = model_fn() server_optimizer = server_optimizer_fn() _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta) @tff.tf_computation(server_state_type) def server_message_fn(server_state): return build_server_broadcast_message(server_state) server_message_type = server_message_fn.type_signature.result tf_dataset_type = tff.SequenceType(dummy_model.input_spec) @tff.tf_computation(tf_dataset_type, server_message_type) def client_update_fn(tf_dataset, server_message): model = model_fn() client_optimizer = client_optimizer_fn() return client_update(model, tf_dataset, server_message, client_optimizer) federated_server_state_type = tff.FederatedType( server_state_type, tff.SERVER) federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) @tff.tf_computation( tf.float32, tf.float32, ) def scale(update_norm, sum_update_norms): if importance_sampling: return tf.minimum( 1., tf.divide(tf.multiply(update_norm, m), sum_update_norms)) else: return tf.divide(m, n) @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS, True)) def scale_on_clients(update_norm, sum_update_norms): return tff.federated_map(scale, (update_norm, sum_update_norms)) @tff.tf_computation(tf.float32) def create_prob_message(prob): def f1(): return tf.stack([prob, 1.]) def f2(): return tf.constant([0., 0.]) prob_message = tf.cond(tf.less(prob, 1), f1, f2) return prob_message @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS)) def create_prob_message_on_clients(prob): return tff.federated_map(create_prob_message, prob) @tff.tf_computation(tff.TensorType(tf.float32, (2, ))) def compute_rescaling(prob_aggreg): rescaling_factor = (m - n + prob_aggreg[1]) / prob_aggreg[0] return rescaling_factor @tff.federated_computation( tff.FederatedType(tff.TensorType(tf.float32, (2, )), tff.SERVER)) def compute_rescaling_on_master(prob_aggreg): return tff.federated_map(compute_rescaling, prob_aggreg) @tff.tf_computation(tf.float32, tf.float32) def rescale_prob(prob, rescaling_factor): return tf.minimum(1., tf.multiply(prob, rescaling_factor)) @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS, True)) def rescale_prob_on_clients(rob, rescaling_factor): return tff.federated_map(rescale_prob, (rob, rescaling_factor)) @tff.tf_computation(tf.float32) def compute_weights_is_fn(prob): def f1(): return 1. / prob def f2(): return 0. weight = tf.cond(tf.less(tf.random.uniform(()), prob), f1, f2) return weight @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS)) def compute_weights_is(prob): return tff.federated_map(compute_weights_is_fn, prob) @tff.federated_computation( tff.FederatedType(model_weights_type.trainable, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS)) def compute_round_model_delta(weights_delta, weights_denom): return tff.federated_mean(weights_delta, weight=weights_denom) @tff.federated_computation(federated_server_state_type, tff.FederatedType( model_weights_type.trainable, tff.SERVER)) def update_server_state(server_state, round_model_delta): return tff.federated_map(server_update_fn, (server_state, round_model_delta)) @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS)) def compute_loss_metric(model_output, weight_denom): return tff.federated_mean(model_output, weight=weight_denom) @tff.tf_computation(model_weights_type.trainable, tf.float32) def rescale_and_remove_fn(weights_delta, weights_is): return [ tf.math.scalar_mul(weights_is, weights_layer_delta) for weights_layer_delta in weights_delta ] @tff.federated_computation( tff.FederatedType(model_weights_type.trainable, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS)) def rescale_and_remove(weights_delta, weights_is): return tff.federated_map(rescale_and_remove_fn, (weights_delta, weights_is)) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_gradient_computation_round(server_state, federated_dataset): """Orchestration logic for one round of gradient computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `tf.Tensor` of clients initial probability and `ClientOutput`. """ server_message = tff.federated_map(server_message_fn, server_state) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, server_message_at_client)) update_norm_sum_weighted = tff.federated_sum( client_outputs.update_norm_weighted) norm_sum_clients_weighted = tff.federated_broadcast( update_norm_sum_weighted) prob_init = scale_on_clients(client_outputs.update_norm_weighted, norm_sum_clients_weighted) return prob_init, client_outputs @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS)) def run_one_inner_loop_weights_computation(prob): """Orchestration logic for one round of computation. Args: prob: Probability of each client to communicate update. Returns: A tuple of updated `Probabilities` and `tf.float32` of rescaling factor. """ prob_message = create_prob_message_on_clients(prob) prob_aggreg = tff.federated_sum(prob_message) rescaling_factor_master = compute_rescaling_on_master(prob_aggreg) rescaling_factor_clients = tff.federated_broadcast( rescaling_factor_master) prob = rescale_prob_on_clients(prob, rescaling_factor_clients) return prob, rescaling_factor_master @tff.federated_computation def server_init_tff(): """Orchestration logic for server model initialization.""" return tff.federated_value(server_init_tf(), tff.SERVER) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ prob, client_outputs = run_gradient_computation_round( server_state, federated_dataset) if importance_sampling: for j in range(j_max): prob, rescaling_factor = run_one_inner_loop_weights_computation( prob) if rescaling_factor <= 1: break weight_denom = [ client_output.client_weight for client_output in client_outputs ] weights_delta = [ client_output.weights_delta for client_output in client_outputs ] # rescale weights based on sampling procedure weights_is = compute_weights_is(prob) weights_delta = rescale_and_remove(weights_delta, weights_is) round_model_delta = compute_round_model_delta( weights_delta, weight_denom) server_state = update_server_state(server_state, round_model_delta) model_output = [ client_output.model_output for client_output in client_outputs ] round_loss_metric = compute_loss_metric(model_output, weight_denom) prob_numpy = [] for p in prob: prob_numpy.append(p.numpy()) return server_state, round_loss_metric, prob_numpy self.next = run_one_round self.initialize = server_init_tff
def build_gan_training_process(gan: GanFnsAndTypes): """Constructs a `tff.Computation` for GAN training. Args: gan: A `GanFnsAndTypes` object. Returns: A `tff.utils.IterativeProcess` for GAN training. """ # Generally, it is easiest to get the types correct by building # all of the needed tf_computations first, since this ensures we only # have non-federated types. server_initial_state = build_server_initial_state_comp(gan) server_state_type = server_initial_state.type_signature.result client_computation = build_client_computation(gan) client_output_type = client_computation.type_signature.result server_computation = build_server_computation(gan, server_state_type, client_output_type) @tff.federated_computation def fed_server_initial_state(): return tff.federated_value(server_initial_state(), tff.SERVER) @tff.federated_computation(tff.FederatedType(server_state_type, tff.SERVER), gan.server_gen_input_type, gan.client_gen_input_type, gan.client_real_data_type) def run_one_round(server_state, server_gen_inputs, client_gen_inputs, client_real_data): """The `tff.Computation` to be returned.""" # TODO(b/131429028): The federated_zip should be automatic. from_server = tff.federated_zip( gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights)) client_input = tff.federated_broadcast(from_server) client_outputs = tff.federated_map( client_computation, (client_gen_inputs, client_real_data, client_input)) if gan.dp_averaging_fn is None: # Not using differential privacy. new_dp_averaging_state = server_state.dp_averaging_state averaged_discriminator_weights_delta = tff.federated_mean( client_outputs.discriminator_weights_delta, weight=client_outputs.update_weight) else: # Using differential privacy. Note that the weight argument is set to None # here. This is because the DP aggregation code explicitly does not do # weighted aggregation. (If weighted aggregation is desired, differential # privacy needs to be turned off.) new_dp_averaging_state, averaged_discriminator_weights_delta = ( gan.dp_averaging_fn(server_state.dp_averaging_state, client_outputs.discriminator_weights_delta, weight=None)) # TODO(b/131085687): Perhaps reconsider the choice to also use # ClientOutput to hold the aggregated client output. aggregated_client_output = gan_training_tf_fns.ClientOutput( discriminator_weights_delta=averaged_discriminator_weights_delta, # We don't actually need the aggregated update_weight, but # this keeps the types of the non-aggregated and aggregated # client_output the same, which is convenient. And I can # imagine wanting this. update_weight=tff.federated_sum(client_outputs.update_weight), counters=tff.federated_sum(client_outputs.counters)) # TODO(b/131839522): This federated_zip shouldn't be needed. aggregated_client_output = tff.federated_zip(aggregated_client_output) server_state = tff.federated_map( server_computation, (server_state, server_gen_inputs, aggregated_client_output, new_dp_averaging_state)) return server_state return tff.utils.IterativeProcess(fed_server_initial_state, run_one_round)
def build_fed_avg_process( model_fn: ModelBuilder, client_lr_callback: callbacks.ReduceLROnPlateau, server_lr_callback: callbacks.ReduceLROnPlateau, client_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD, server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD, client_weight_fn: Optional[ClientWeightFn] = None, dataset_preprocess_comp: Optional[tff.Computation] = None, ) -> tff.templates.IterativeProcess: """Builds the TFF computations for FedAvg with learning rate decay. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_lr_callback: A `ReduceLROnPlateau` callback. server_lr_callback: A `ReduceLROnPlateau` callback. client_optimizer_fn: A function that accepts a `learning_rate` keyword argument and returns a `tf.keras.optimizers.Optimizer` instance. server_optimizer_fn: A function that accepts a `learning_rate` argument and returns a `tf.keras.optimizers.Optimizer` instance. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. dataset_preprocess_comp: Optional `tff.Computation` that sets up a data pipeline on the clients. The computation must take a squence of values and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If `None`, no dataset preprocessing is applied. Returns: A `tff.templates.IterativeProcess`. """ dummy_model = model_fn() client_monitor = client_lr_callback.monitor server_monitor = server_lr_callback.monitor server_init_tf = build_server_init_fn(model_fn, server_optimizer_fn, client_lr_callback, server_lr_callback) server_state_type = server_init_tf.type_signature.result model_weights_type = server_state_type.model if dataset_preprocess_comp is not None: tf_dataset_type = dataset_preprocess_comp.type_signature.parameter model_input_type = tff.SequenceType(dummy_model.input_spec) preprocessed_dataset_type = dataset_preprocess_comp.type_signature.result if not model_input_type.is_assignable_from(preprocessed_dataset_type): raise TypeError( 'Supplied `dataset_preprocess_comp` does not yield ' 'batches that are compatible with the model constructed ' 'by `model_fn`. Model expects type {m}, but dataset ' 'yields type {d}.'.format(m=model_input_type, d=preprocessed_dataset_type)) else: tf_dataset_type = tff.SequenceType(dummy_model.input_spec) model_input_type = tff.SequenceType(dummy_model.input_spec) client_lr_type = server_state_type.client_lr_callback.learning_rate client_monitor_value_type = server_state_type.client_lr_callback.best server_monitor_value_type = server_state_type.server_lr_callback.best @tff.tf_computation(model_input_type, model_weights_type, client_lr_type) def client_update_fn(tf_dataset, initial_model_weights, client_lr): client_optimizer = client_optimizer_fn(client_lr) initial_model_output = get_client_output(model_fn(), tf_dataset, initial_model_weights) client_state = client_update(model_fn(), tf_dataset, initial_model_weights, client_optimizer, client_weight_fn) return tff.utils.update_state( client_state, initial_model_output=initial_model_output) @tff.tf_computation(server_state_type, model_weights_type.trainable, client_monitor_value_type, server_monitor_value_type) def server_update_fn(server_state, model_delta, client_monitor_value, server_monitor_value): model = model_fn() server_lr = server_state.server_lr_callback.learning_rate server_optimizer = server_optimizer_fn(server_lr) # We initialize the server optimizer variables to avoid creating them # within the scope of the tf.function server_update. _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta, client_monitor_value, server_monitor_value) @tff.federated_computation(tff.FederatedType(server_state_type, tff.SERVER), tff.FederatedType(tf_dataset_type, tff.CLIENTS)) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Note that in addition to updating the server weights according to the client model weight deltas, we extract metrics (governed by the `monitor` attribute of the `client_lr_callback` and `server_lr_callback` attributes of the `server_state`) and use these to update the client learning rate callbacks. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation` before and during local client training. """ client_model = tff.federated_broadcast(server_state.model) client_lr = tff.federated_broadcast( server_state.client_lr_callback.learning_rate) if dataset_preprocess_comp is not None: federated_dataset = tff.federated_map(dataset_preprocess_comp, federated_dataset) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_lr)) client_weight = client_outputs.client_weight aggregated_gradients = tff.federated_mean( client_outputs.accumulated_gradients, weight=client_weight) initial_aggregated_outputs = dummy_model.federated_output_computation( client_outputs.initial_model_output) if isinstance(initial_aggregated_outputs.type_signature, tff.StructType): initial_aggregated_outputs = tff.federated_zip( initial_aggregated_outputs) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) client_monitor_value = initial_aggregated_outputs[client_monitor] server_monitor_value = initial_aggregated_outputs[server_monitor] server_state = tff.federated_map( server_update_fn, (server_state, aggregated_gradients, client_monitor_value, server_monitor_value)) result = collections.OrderedDict( before_training=initial_aggregated_outputs, during_training=aggregated_outputs) return server_state, result @tff.federated_computation def initialize_fn(): return tff.federated_value(server_init_tf(), tff.SERVER) return tff.templates.IterativeProcess(initialize_fn=initialize_fn, next_fn=run_one_round)
import numpy as np import tensorflow as tf import tensorflow_federated as tff from federated_aggregations import paillier NUM_CLIENTS = 5 paillier_factory = paillier.local_paillier_executor_factory(NUM_CLIENTS) paillier_context = tff.framework.ExecutionContext(paillier_factory) tff.framework.set_default_context(paillier_context) @tff.federated_computation( tff.FederatedType(tff.TensorType(tf.int32, [2]), tff.CLIENTS), tff.TensorType(tf.int32)) def secure_paillier_addition(x, bitwidth): return tff.federated_secure_sum(x, bitwidth) base = np.array([1, 2], np.int32) x = [base + i for i in range(NUM_CLIENTS)] result = secure_paillier_addition(x, 32) print(result)
else: false_positive_rate[threshold] = 0.0 false_discovery_rate[threshold] = 0.0 harmonic_mean_fpr_fdr[threshold] = 0.0 # The leaked_words in the next round must be a subset of this round. leaked_words_candidates = leaked_words bisect_upper_bound = below_threshold_index return false_positive_rate, false_discovery_rate, harmonic_mean_fpr_fdr @tff.tf_computation(tff.SequenceType(tf.string)) def compute_lossless_result_per_user(dataset): # Do not have limit on each client's contribution in this case. k_words = get_top_elements(dataset, tf.constant(tf.int32.max)) return k_words @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.string), tff.CLIENTS)) def compute_lossless_results_federated(datasets): words = tff.federated_map(compute_lossless_result_per_user, datasets) return words def compute_lossless_results(datasets): all_words = tf.concat(compute_lossless_results_federated(datasets), axis=0) word, _, count = tf.unique_with_counts(all_words) return dict(zip(word.numpy(), count.numpy()))
return batch_train(model, batch, learning_rate) l = tff.sequence_reduce(all_batches, initial_model, batch_fn) return l @tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE) def local_eval(model, all_batches): # return tff.sequence_sum( tff.sequence_map( tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE), all_batches)) SERVER_MODEL_TYPE = tff.FederatedType(MODEL_TYPE, tff.SERVER, all_equal=True) CLIENT_DATA_TYPE = tff.FederatedType(LOCAL_DATA_TYPE, tff.CLIENTS) @tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE) def federated_eval(model, data): return tff.federated_mean( tff.federated_map(local_eval, [tff.federated_broadcast(model), data])) SERVER_FLOAT_TYPE = tff.FederatedType(tf.float32, tff.SERVER, all_equal=True) @tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE, CLIENT_DATA_TYPE) def federated_train(model, learning_rate, data):
def build_fed_avg_process( total_clients: int, effective_num_clients: int, model_fn: ModelBuilder, client_optimizer_fn: OptimizerBuilder, client_lr: Union[float, LRScheduleFn] = 0.1, server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD, server_lr: Union[float, LRScheduleFn] = 1.0, client_weight_fn: Optional[ClientWeightFn] = None, aggregation_process: Optional[measured_process.MeasuredProcess] = None, ) -> tff.templates.IterativeProcess: """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A function that accepts a `learning_rate` keyword argument and returns a `tf.keras.optimizers.Optimizer` instance. client_lr: A scalar learning rate or a function that accepts a float `round_num` argument and returns a learning rate. server_optimizer_fn: A function that accepts a `learning_rate` argument and returns a `tf.keras.optimizers.Optimizer` instance. server_lr: A scalar learning rate or a function that accepts a float `round_num` argument and returns a learning rate. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ client_lr_schedule = client_lr if not callable(client_lr_schedule): client_lr_schedule = lambda round_num: client_lr server_lr_schedule = server_lr if not callable(server_lr_schedule): server_lr_schedule = lambda round_num: server_lr with tf.Graph().as_default(): dummy_model = model_fn() model_weights_type = model_utils.weights_type_from_model(dummy_model) dummy_optimizer = server_optimizer_fn() _initialize_optimizer_vars(dummy_model, dummy_optimizer) optimizer_variable_type = type_conversions.type_from_tensors( dummy_optimizer.variables()) if aggregation_process is None: aggregation_process = build_stateless_mean( model_delta_type=model_weights_type.trainable) if not _is_valid_aggregation_process(aggregation_process): raise ProcessTypeError( 'aggregation_process type signature does not conform to expected ' 'signature (<state@S, input@C> -> <state@S, result@S, measurements@S>).' ' Got: {t}'.format(t=aggregation_process.next.type_signature)) initialize_computation = build_server_init_fn( model_fn=model_fn, effective_num_clients=effective_num_clients, # Initialize with the learning rate for round zero. server_optimizer_fn=lambda: server_optimizer_fn(server_lr_schedule(0)), aggregation_process=aggregation_process) # server_state_type = initialize_computation.type_signature.result # model_weights_type = server_state_type.model round_num_type = tf.float32 tf_dataset_type = tff.SequenceType(dummy_model.input_spec) model_input_type = tff.SequenceType(dummy_model.input_spec) client_losses_at_server_type = tff.TensorType(dtype=tf.float32, shape=[total_clients, 1]) clients_weights_at_server_type = tff.TensorType(dtype=tf.float32, shape=[total_clients, 1]) aggregation_state = aggregation_process.initialize.type_signature.result.member server_state_type = ServerState( model=model_weights_type, optimizer_state=optimizer_variable_type, round_num=round_num_type, effective_num_clients=tf.int32, delta_aggregate_state=aggregation_state, ) # @computations.tf_computation(clients_weights_type) # def get_zero_weights_all_clients(weights): # return tf.zeros_like(weights, dtype=tf.float32) ###################################################### # def federated_output(local_outputs): # return federated_aggregate_keras_metric(self.get_metrics(), local_outputs) # federated_output_computation = computations.federated_computation( # federated_output, federated_local_outputs_type) single_id_type = tff.TensorType(dtype=tf.int32, shape=[1, 1]) @tff.tf_computation(model_input_type, model_weights_type, round_num_type, single_id_type) def client_update_fn(tf_dataset, initial_model_weights, round_num, client_id): client_lr = client_lr_schedule(round_num) client_optimizer = client_optimizer_fn(client_lr) client_update = create_client_update_fn() return client_update(model_fn(), tf_dataset, initial_model_weights, client_optimizer, client_id, client_weight_fn) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_fn(server_state, model_delta): model = model_fn() server_lr = server_lr_schedule(server_state.round_num) server_optimizer = server_optimizer_fn(server_lr) # We initialize the server optimizer variables to avoid creating them # within the scope of the tf.function server_update. _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta) id_type = tff.TensorType(shape=[1, 1], dtype=tf.int32) @tff.tf_computation(clients_weights_at_server_type, id_type) def select_weight_fn(clients_weights, local_id): return select_weight(clients_weights, local_id) @tff.tf_computation(client_losses_at_server_type, clients_weights_at_server_type, tf.int32) def zero_small_loss_clients(losses_at_server, weights_at_server, effective_num_clients): """Receives losses and returns participating clients. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ return redefine_client_weight(losses_at_server, weights_at_server, effective_num_clients) # @tff.tf_computation(client_losses_type) # def dataset_to_tensor_fn(dataset): # return dataset_to_tensor(dataset) @tff.federated_computation(tff.FederatedType(server_state_type, tff.SERVER), tff.FederatedType(tf_dataset_type, tff.CLIENTS), tff.FederatedType(id_type, tff.CLIENTS)) def run_one_round(server_state, federated_dataset, ids): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num, ids)) client_weight = client_outputs.client_weight client_id = client_outputs.client_id #LOSS SELECTION: # losses_at_server = tff.federated_collect(client_outputs.model_output) # weights_at_server = tff.federated_collect(client_weight) @computations.tf_computation def zeros_fn(): return tf.zeros(shape=[total_clients, 1], dtype=tf.float32) zero = zeros_fn() at_server_type = tff.TensorType(shape=[total_clients, 1], dtype=tf.float32) # list_type = tff.SequenceType( tff.TensorType(dtype=tf.float32)) client_output_type = client_update_fn.type_signature.result @computations.tf_computation(at_server_type, client_output_type) def accumulate_weight(u, t): value = t.client_weight index = t.client_id new_u = tf.tensor_scatter_nd_update(u, index, value) return new_u @computations.tf_computation(at_server_type, client_output_type) def accumulate_loss(u, t): value = tf.reshape(tf.math.reduce_sum(t.model_output['loss']), shape=[1, 1]) index = t.client_id new_u = tf.tensor_scatter_nd_update(u, index, value) return new_u # output_at_server= tff.federated_collect(client_outputs) weights_at_server = tff.federated_reduce(client_outputs, zero, accumulate_weight) losses_at_server = tff.federated_reduce(client_outputs, zero, accumulate_loss) #losses_at_server = tff.federated_aggregate(client_outputs.model_output, zero, accumulate, merge, report) selected_clients_weights = tff.federated_map( zero_small_loss_clients, (losses_at_server, weights_at_server, server_state.effective_num_clients)) # selected_clients_weights_at_client = tff.federated_broadcast(selected_clients_weights) selected_clients_weights_broadcast = tff.federated_broadcast( selected_clients_weights) selected_clients_weights_at_client = tff.federated_map( select_weight_fn, (selected_clients_weights_broadcast, ids)) aggregation_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, selected_clients_weights_at_client) # model_delta = tff.federated_mean( # client_outputs.weights_delta, weight=client_weight) server_state = tff.federated_map( server_update_fn, (server_state, aggregation_output.result)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs # @tff.federated_computation # def initialize_fn(): # return tff.federated_value(server_init_tf(), tff.SERVER) return tff.templates.IterativeProcess(initialize_fn=initialize_computation, next_fn=run_one_round)
def build_triehh_process(possible_prefix_extensions: List[str], num_sub_rounds: int, max_num_heavy_hitters: int, max_user_contribution: int, default_terminator: str = '$'): """Builds the TFF computations for heavy hitters discovery with TrieHH. TrieHH works by interactively keeping track of popular prefixes. In each round, the server broadcasts the popular prefixes it has discovered so far and the list of `possible_prefix_extensions` to a small fraction of selected clients. The select clients sample `max_user_contributions` words from their local datasets, and use them to vote on character extensions to the broadcasted popular prefixes. Client votes are accumulated across `num_sub_rounds` rounds, and then the top `max_num_heavy_hitters` extensions are used to extend the already discovered prefixes, and the extended prefixes are used in the next round. When an already discovered prefix is extended by `default_terminator` it is added to the list of discovered heavy hitters. Args: possible_prefix_extensions: A list containing all the possible extensions to learned prefixes. Each extensions must be a single character strings. num_sub_rounds: The total number of sub rounds to be executed before decoding aggregated votes. Must be positive. max_num_heavy_hitters: The maximum number of discoverable heavy hitters. Must be positive. max_user_contribution: The maximum number of examples a user can contribute. Must be positive. default_terminator: The end of sequence symbol. Returns: A `tff.utils.IterativeProcess`. """ @tff.tf_computation def server_init_tf(): return ServerState( discovered_heavy_hitters=tf.constant([], dtype=tf.string), discovered_prefixes=tf.constant([''], dtype=tf.string), possible_prefix_extensions=tf.constant( possible_prefix_extensions, dtype=tf.string), round_num=tf.constant(0, dtype=tf.int32), accumulated_votes=tf.zeros( dtype=tf.int32, shape=[max_num_heavy_hitters, len(possible_prefix_extensions)])) # We cannot use server_init_tf.type_signature.result because the # discovered_* fields need to have [None] shapes, since they will grow over # time. server_state_type = ( tff.to_type( ServerState( discovered_heavy_hitters=tff.TensorType( dtype=tf.string, shape=[None]), discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]), possible_prefix_extensions=tff.TensorType( dtype=tf.string, shape=[len(possible_prefix_extensions)]), round_num=tff.TensorType(dtype=tf.int32, shape=[]), accumulated_votes=tff.TensorType( dtype=tf.int32, shape=[None, len(possible_prefix_extensions)]), ))) sub_round_votes_type = tff.TensorType( dtype=tf.int32, shape=[max_num_heavy_hitters, len(possible_prefix_extensions)]) @tff.tf_computation(server_state_type, sub_round_votes_type) @tf.function def server_update_fn(server_state, sub_round_votes): server_state = server_update( server_state, sub_round_votes, num_sub_rounds=tf.constant(num_sub_rounds), max_num_heavy_hitters=tf.constant(max_num_heavy_hitters), default_terminator=tf.constant(default_terminator, dtype=tf.string)) return server_state tf_dataset_type = tff.SequenceType(tf.string) discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None]) round_num_type = tff.TensorType(dtype=tf.int32, shape=[]) @tff.tf_computation(tf_dataset_type, discovered_prefixes_type, round_num_type) @tf.function def client_update_fn(tf_dataset, discovered_prefixes, round_num): result = client_update(tf_dataset, discovered_prefixes, tf.constant(possible_prefix_extensions), round_num, num_sub_rounds, max_num_heavy_hitters, max_user_contribution) return result federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER) federated_dataset_type = tff.FederatedType( tf_dataset_type, tff.CLIENTS, all_equal=False) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of TrieHH computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: An updated `ServerState` """ discovered_prefixes = tff.federated_broadcast( server_state.discovered_prefixes) round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, tff.federated_zip([federated_dataset, discovered_prefixes, round_num])) accumulated_votes = tff.federated_sum(client_outputs.client_votes) server_state = tff.federated_map(server_update_fn, (server_state, accumulated_votes)) server_output = tff.federated_value([], tff.SERVER) return server_state, server_output return tff.utils.IterativeProcess( initialize_fn=tff.federated_computation( lambda: tff.federated_value(server_init_tf(), tff.SERVER)), next_fn=run_one_round)
def build_triehh_process( possible_prefix_extensions: List[str], num_sub_rounds: int, max_num_prefixes: int, threshold: int, max_user_contribution: int, default_terminator: str = triehh_tf.DEFAULT_TERMINATOR): """Builds the TFF computations for heavy hitters discovery with TrieHH. TrieHH works by interactively keeping track of popular prefixes. In each round, the server broadcasts the popular prefixes it has discovered so far and the list of `possible_prefix_extensions` to a small fraction of selected clients. The select clients sample `max_user_contributions` words from their local datasets, and use them to vote on character extensions to the broadcasted popular prefixes. Client votes are accumulated across `num_sub_rounds` rounds, and then the top `max_num_prefixes` extensions get at least 'threshold' votes are used to extend the already discovered prefixes, and the extended prefixes are used in the next round. When an already discovered prefix is extended by `default_terminator` it is added to the list of discovered heavy hitters. Args: possible_prefix_extensions: A list containing all the possible extensions to learned prefixes. Each extensions must be a single character strings. This list should not contain the default_terminator. num_sub_rounds: The total number of sub rounds to be executed before decoding aggregated votes. Must be positive. max_num_prefixes: The maximum number of prefixes we can keep in the trie. Must be positive. threshold: The threshold for heavy hitters and discovered prefixes. Only those get at least `threshold` votes are discovered. Must be positive. max_user_contribution: The maximum number of examples a user can contribute. Must be positive. default_terminator: The end of sequence symbol. Returns: A `tff.templates.IterativeProcess`. Raises: ValueError: If possible_prefix_extensions contains default_terminator. """ if default_terminator in possible_prefix_extensions: raise ValueError( 'default_terminator should not appear in possible_prefix_extensions') # Append `default_terminator` to `possible_prefix_extensions` to make sure it # is the last item in the list. possible_prefix_extensions.append(default_terminator) @tff.tf_computation def server_init_tf(): return ServerState( discovered_heavy_hitters=tf.constant([], dtype=tf.string), heavy_hitter_frequencies=tf.constant([], dtype=tf.float64), discovered_prefixes=tf.constant([''], dtype=tf.string), round_num=tf.constant(0, dtype=tf.int32), accumulated_votes=tf.zeros( dtype=tf.int32, shape=[max_num_prefixes, len(possible_prefix_extensions)]), accumulated_weights=tf.constant(0, dtype=tf.int32)) # We cannot use server_init_tf.type_signature.result because the # discovered_* fields need to have [None] shapes, since they will grow over # time. server_state_type = ( tff.to_type( ServerState( discovered_heavy_hitters=tff.TensorType( dtype=tf.string, shape=[None]), heavy_hitter_frequencies=tff.TensorType( dtype=tf.float64, shape=[None]), discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]), round_num=tff.TensorType(dtype=tf.int32, shape=[]), accumulated_votes=tff.TensorType( dtype=tf.int32, shape=[None, len(possible_prefix_extensions)]), accumulated_weights=tff.TensorType(dtype=tf.int32, shape=[]), ))) sub_round_votes_type = tff.TensorType( dtype=tf.int32, shape=[max_num_prefixes, len(possible_prefix_extensions)]) sub_round_weight_type = tff.TensorType(dtype=tf.int32, shape=[]) @tff.tf_computation(server_state_type, sub_round_votes_type, sub_round_weight_type) def server_update_fn(server_state, sub_round_votes, sub_round_weight): return server_update( server_state, tf.constant(possible_prefix_extensions), sub_round_votes, sub_round_weight, num_sub_rounds=tf.constant(num_sub_rounds), max_num_prefixes=tf.constant(max_num_prefixes), threshold=tf.constant(threshold)) tf_dataset_type = tff.SequenceType(tf.string) discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None]) round_num_type = tff.TensorType(dtype=tf.int32, shape=[]) @tff.tf_computation(tf_dataset_type, discovered_prefixes_type, round_num_type) def client_update_fn(tf_dataset, discovered_prefixes, round_num): return client_update(tf_dataset, discovered_prefixes, tf.constant(possible_prefix_extensions), round_num, num_sub_rounds, max_num_prefixes, max_user_contribution) federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER) federated_dataset_type = tff.FederatedType( tf_dataset_type, tff.CLIENTS, all_equal=False) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of TrieHH computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: An updated `ServerState` """ discovered_prefixes = tff.federated_broadcast( server_state.discovered_prefixes) round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, discovered_prefixes, round_num)) accumulated_votes = tff.federated_sum(client_outputs.client_votes) accumulated_weights = tff.federated_sum(client_outputs.client_weight) server_state = tff.federated_map( server_update_fn, (server_state, accumulated_votes, accumulated_weights)) server_output = tff.federated_value([], tff.SERVER) return server_state, server_output return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation( lambda: tff.federated_eval(server_init_tf, tff.SERVER)), next_fn=run_one_round)