def test_bad_type_coercion_raises(self): tensor_type = tff.TensorType(shape=[None], dtype=tf.float32) @tff.tf_computation(tensor_type) def foo(x): # We will pass in a tensor which passes the TFF type check, but fails the # reshape. return tf.reshape(x, []) @tff.federated_computation(tff.type_at_clients(tensor_type)) def map_foo_at_clients(x): return tff.federated_map(foo, x) @tff.federated_computation(tff.type_at_server(tensor_type)) def map_foo_at_server(x): return tff.federated_map(foo, x) bad_tensor = tf.constant([1.] * 10, dtype=tf.float32) good_tensor = tf.constant([1.], dtype=tf.float32) # Ensure running this computation at both placements, or unplaced, still # raises. with self.assertRaises(Exception): foo(bad_tensor) with self.assertRaises(Exception): map_foo_at_server(bad_tensor) with self.assertRaises(Exception): map_foo_at_clients([bad_tensor] * 10) # We give the distributed runtime a chance to clean itself up, otherwise # workers may be getting SIGABRT while they are handling another exception, # causing the test infra to crash. Making a successful call ensures that # cleanup happens after failures have been handled. map_foo_at_clients([good_tensor] * 10)
def validator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) weights_type = tff.learning.framework.weights_type_from_model(model) validate_client_tf = tff.tf_computation( lambda dataset, state, weights: __validate_client( dataset, state, weights, coefficient_fn, model_fn, tf.function(client.validate)), (dataset_type, client_state_type, weights_type)) federated_weights_type = tff.type_at_server(weights_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def validate(weights, datasets, client_states): broadcast = tff.federated_broadcast(weights) outputs = tff.federated_map(validate_client_tf, (datasets, client_states, broadcast)) metrics = model.federated_output_computation(outputs.metrics) return metrics return tff.federated_computation( validate, (federated_weights_type, federated_dataset_type, federated_client_state_type))
def __attrs_post_init__(self): self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input) self.real_data_type = tensor_spec_for_batch(self.dummy_real_data) # Model-weights based types self._generator = self.generator_model_fn() _ = self._generator(self.dummy_gen_input) if not isinstance(self._generator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._generator))) self._discriminator = self.discriminator_model_fn() _ = self._discriminator(self.dummy_real_data) if not isinstance(self._discriminator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._discriminator))) def vars_to_type(var_struct): # TODO(b/131681951): read_value() shouldn't be needed return tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor(v.read_value()), var_struct) self.discriminator_weights_type = vars_to_type( self._discriminator.weights) self.generator_weights_type = vars_to_type(self._generator.weights) self.from_server_type = gan_training_tf_fns.FromServer( generator_weights=self.generator_weights_type, discriminator_weights=self.discriminator_weights_type, meta_gen=self.generator_weights_type, meta_disc=self.discriminator_weights_type) self.client_gen_input_type = tff.type_at_clients( tff.SequenceType(self.gen_input_type)) self.client_real_data_type = tff.type_at_clients( tff.SequenceType(self.real_data_type)) self.server_gen_input_type = tff.type_at_server( tff.SequenceType(self.gen_input_type)) if self.train_discriminator_dp_average_query is not None: self.aggregation_process = tff.aggregators.DifferentiallyPrivateFactory( query=self.train_discriminator_dp_average_query).create( value_type=tff.to_type(self.discriminator_weights_type)) else: self.aggregation_process = tff.aggregators.MeanFactory().create( value_type=tff.to_type(self.discriminator_weights_type), weight_type=tff.to_type(tf.float32))
def __attrs_post_init__(self): self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input) self.real_data_type = tensor_spec_for_batch(self.dummy_real_data) # Model-weights based types self._generator = self.generator_model_fn() _ = self._generator(self.dummy_gen_input) if not isinstance(self._generator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._generator))) self._discriminator = self.discriminator_model_fn() _ = self._discriminator(self.dummy_real_data) if not isinstance(self._discriminator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._discriminator))) def vars_to_type(var_struct): # TODO(b/131681951): read_value() shouldn't be needed return tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor(v.read_value()), var_struct) self.discriminator_weights_type = vars_to_type( self._discriminator.weights) self.generator_weights_type = vars_to_type(self._generator.weights) self.from_server_type = gan_training_tf_fns.FromServer( generator_weights=self.generator_weights_type, discriminator_weights=self.discriminator_weights_type) self.client_gen_input_type = tff.type_at_clients( tff.SequenceType(self.gen_input_type)) self.client_real_data_type = tff.type_at_clients( tff.SequenceType(self.real_data_type)) self.server_gen_input_type = tff.type_at_server( tff.SequenceType(self.gen_input_type)) # Right now, the logic in this library is effectively "if DP use stateful # aggregator, else don't use stateful aggregator". An alternative # formulation would be to always use a stateful aggregator, but when not # using DP default the aggregator to be a stateless mean, e.g., # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283. if self.train_discriminator_dp_average_query is not None: self.dp_averaging_fn = tff.utils.build_dp_aggregate_process( value_type=tff.to_type(self.discriminator_weights_type), query=self.train_discriminator_dp_average_query)
def iterator( model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN, client_optimizer_fn: OPTIMIZER_FN ): model = model_fn() client_state = client_state_fn() init_tf = tff.tf_computation( lambda: () ) server_state_type = init_tf.type_signature.result client_state_type = tff.framework.type_from_tensors(client_state) dataset_type = tff.SequenceType(model.input_spec) update_client_tf = tff.tf_computation( lambda dataset, state: __update_client( dataset, state, model_fn, client_optimizer_fn, tf.function(client.update) ), (dataset_type, client_state_type) ) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def init_tff(): return tff.federated_value(init_tf(), tff.SERVER) def next_tff(server_state, datasets, client_states): outputs = tff.federated_map(update_client_tf, (datasets, client_states)) metrics = model.federated_output_computation(outputs.metrics) return server_state, metrics, outputs.client_state return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation(init_tff), next_fn=tff.federated_computation( next_tff, (federated_server_state_type, federated_dataset_type, federated_client_state_type) ) )
def build_federated_reconstruction_process( model_fn: ModelFn, *, # Callers pass below args by name. loss_fn: LossFn, metrics_fn: Optional[MetricsFn] = None, server_optimizer_fn: OptimizerFn = functools.partial( tf.keras.optimizers.SGD, 1.0), client_optimizer_fn: OptimizerFn = functools.partial( tf.keras.optimizers.SGD, 0.1), reconstruction_optimizer_fn: OptimizerFn = functools.partial( tf.keras.optimizers.SGD, 0.1), dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None, evaluate_reconstruction: bool = False, jointly_train_variables: bool = False, client_weight_fn: Optional[ClientWeightFn] = None, aggregation_factory: Optional[ tff.aggregators.WeightedAggregationFactory] = None, ) -> tff.templates.IterativeProcess: """Builds the IterativeProcess for optimization using FedRecon. Returns a `tff.templates.IterativeProcess` for Federated Reconstruction. On the client, computation can be divided into two stages: (1) reconstruction of local variables and (2) training of global variables (possibly jointly with reconstructed local variables). Args: model_fn: A no-arg function that returns a `ReconstructionModel`. This method must *not* capture Tensorflow tensors or variables and use them. must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to compute local model updates during reconstruction and post-reconstruction and evaluate the model during training. The final loss metric is the example-weighted mean loss across batches and across clients. Depending on whether `evaluate_reconstruction` is True, the loss metric may or may not include reconstruction batches in the loss. metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s to evaluate the model. Metrics results are computed locally as described by the metric, and are aggregated across clients as in `federated_aggregate_keras_metric`. If None, no metrics are applied. Depending on whether evaluate_reconstruction is True, metrics may or may not be computed on reconstruction batches as well. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for applying updates to the global model on the server. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for local client training after reconstruction. reconstruction_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` used to reconstruct the local variables, with the global ones frozen, or the first stage described above. dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a client dataset and training round number (1-indexed) and producing two TF datasets. The first is iterated over during reconstruction, and the second is iterated over post-reconstruction. This can be used to preprocess datasets to e.g. iterate over them for multiple epochs or use disjoint data for reconstruction and post-reconstruction. If None, `reconstruction_utils.simple_dataset_split_fn` is used, which results in iterating over the original client data for both phases of training. See `reconstruction_utils.build_dataset_split_fn` for options. evaluate_reconstruction: If True, metrics (including loss) are computed on batches during reconstruction and post-reconstruction. If False, metrics are computed on batches only post-reconstruction, when global weights are being updated. Note that metrics are aggregated across batches as given by the metric (example-weighted mean for the loss). Setting this to True includes all local batches in metric calculations. Setting this to False brings the interpretation of these metrics closer to the interpretation of metrics in FedAvg. Note that this does not affect training at all: losses for individual batches are calculated and used to update variables regardless. jointly_train_variables: Whether to train local variables during the second stage described above. If True, global and local variables are trained jointly after reconstruction of local variables using the optimizer given by client_optimizer_fn. If False, only global variables are trained during the second stage with local variables frozen, similar to alternating minimization. client_weight_fn: Optional function that takes the local model's output, and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device during post-reconstruction phase. aggregation_factory: An optional instance of `tff.aggregators.WeightedAggregationFactory` determining the method of aggregation to perform. If unspecified, uses a default `tff.aggregators.MeanFactory` which computes a stateless weighted mean across clients. Returns: A `tff.templates.IterativeProcess`. """ with tf.Graph().as_default(): throwaway_model_for_metadata = model_fn() model_weights_type = tff.framework.type_from_tensors( reconstruction_utils.get_global_variables( throwaway_model_for_metadata)) aggregation_process = _instantiate_aggregation_process( aggregation_factory, model_weights_type, client_weight_fn) aggregator_state_type = ( aggregation_process.initialize.type_signature.result.member) server_init_tff = build_server_init_fn(model_fn, server_optimizer_fn, aggregation_process) server_state_type = server_init_tff.type_signature.result.member server_update_fn = build_server_update_fn( model_fn, server_optimizer_fn, server_state_type, server_state_type.model, aggregator_state_type=aggregator_state_type) tf_dataset_type = tff.SequenceType(throwaway_model_for_metadata.input_spec) if dataset_split_fn is None: dataset_split_fn = reconstruction_utils.simple_dataset_split_fn client_update_fn = build_client_update_fn( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, tf_dataset_type=tf_dataset_type, model_weights_type=server_state_type.model, client_optimizer_fn=client_optimizer_fn, reconstruction_optimizer_fn=reconstruction_optimizer_fn, dataset_split_fn=dataset_split_fn, evaluate_reconstruction=evaluate_reconstruction, jointly_train_variables=jointly_train_variables, client_weight_fn=client_weight_fn) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(tf_dataset_type) # Create placeholder metrics to produce a corresponding federated output # computation. metrics = [] if metrics_fn is not None: metrics.extend(metrics_fn()) metrics.append(keras_utils.MeanLossMetric(loss_fn())) federated_output_computation = ( keras_utils.federated_output_computation_from_metrics(metrics)) run_one_round_tff = build_run_one_round_fn( server_update_fn, client_update_fn, federated_output_computation, federated_server_state_type, federated_dataset_type, aggregation_process=aggregation_process, ) iterative_process = tff.templates.IterativeProcess( initialize_fn=server_init_tff, next_fn=run_one_round_tff) @tff.tf_computation(server_state_type) def get_model_weights(server_state): return server_state.model iterative_process.get_model_weights = get_model_weights return iterative_process
def build_triehh_process( possible_prefix_extensions: List[str], num_sub_rounds: int, max_num_prefixes: int, threshold: int, max_user_contribution: int, default_terminator: str = triehh_tf.DEFAULT_TERMINATOR): """Builds the TFF computations for heavy hitters discovery with TrieHH. TrieHH works by interactively keeping track of popular prefixes. In each round, the server broadcasts the popular prefixes it has discovered so far and the list of `possible_prefix_extensions` to a small fraction of selected clients. The select clients sample `max_user_contributions` words from their local datasets, and use them to vote on character extensions to the broadcasted popular prefixes. Client votes are accumulated across `num_sub_rounds` rounds, and then the top `max_num_prefixes` extensions get at least 'threshold' votes are used to extend the already discovered prefixes, and the extended prefixes are used in the next round. When an already discovered prefix is extended by `default_terminator` it is added to the list of discovered heavy hitters. Args: possible_prefix_extensions: A list containing all the possible extensions to learned prefixes. Each extensions must be a single character strings. This list should not contain the default_terminator. num_sub_rounds: The total number of sub rounds to be executed before decoding aggregated votes. Must be positive. max_num_prefixes: The maximum number of prefixes we can keep in the trie. Must be positive. threshold: The threshold for heavy hitters and discovered prefixes. Only those get at least `threshold` votes are discovered. Must be positive. max_user_contribution: The maximum number of examples a user can contribute. Must be positive. default_terminator: The end of sequence symbol. Returns: A `tff.templates.IterativeProcess`. Raises: ValueError: If possible_prefix_extensions contains default_terminator. """ if default_terminator in possible_prefix_extensions: raise ValueError( 'default_terminator should not appear in possible_prefix_extensions' ) # Append `default_terminator` to `possible_prefix_extensions` to make sure it # is the last item in the list. possible_prefix_extensions.append(default_terminator) @tff.tf_computation def server_init_tf(): return ServerState( discovered_heavy_hitters=tf.constant([], dtype=tf.string), heavy_hitters_counts=tf.constant([], dtype=tf.int32), discovered_prefixes=tf.constant([''], dtype=tf.string), round_num=tf.constant(0, dtype=tf.int32), accumulated_votes=tf.zeros( dtype=tf.int32, shape=[max_num_prefixes, len(possible_prefix_extensions)])) # We cannot use server_init_tf.type_signature.result because the # discovered_* fields need to have [None] shapes, since they will grow over # time. server_state_type = (tff.to_type( ServerState( discovered_heavy_hitters=tff.TensorType(dtype=tf.string, shape=[None]), heavy_hitters_counts=tff.TensorType(dtype=tf.int32, shape=[None]), discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]), round_num=tff.TensorType(dtype=tf.int32, shape=[]), accumulated_votes=tff.TensorType( dtype=tf.int32, shape=[None, len(possible_prefix_extensions)]), ))) sub_round_votes_type = tff.TensorType( dtype=tf.int32, shape=[max_num_prefixes, len(possible_prefix_extensions)]) @tff.tf_computation(server_state_type, sub_round_votes_type) def server_update_fn(server_state, sub_round_votes): return server_update(server_state, tf.constant(possible_prefix_extensions), sub_round_votes, num_sub_rounds=tf.constant(num_sub_rounds), max_num_prefixes=tf.constant(max_num_prefixes), threshold=tf.constant(threshold)) tf_dataset_type = tff.SequenceType(tf.string) discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None]) round_num_type = tff.TensorType(dtype=tf.int32, shape=[]) @tff.tf_computation(tf_dataset_type, discovered_prefixes_type, round_num_type) def client_update_fn(tf_dataset, discovered_prefixes, round_num): return client_update(tf_dataset, discovered_prefixes, tf.constant(possible_prefix_extensions), round_num, num_sub_rounds, max_num_prefixes, max_user_contribution, tf.constant(default_terminator, dtype=tf.string)) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(tf_dataset_type) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of TrieHH computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: An updated `ServerState` """ discovered_prefixes = tff.federated_broadcast( server_state.discovered_prefixes) round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, discovered_prefixes, round_num)) accumulated_votes = tff.federated_sum(client_outputs.client_votes) server_state = tff.federated_map(server_update_fn, (server_state, accumulated_votes)) server_output = tff.federated_value([], tff.SERVER) return server_state, server_output return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation( lambda: tff.federated_eval(server_init_tf, tff.SERVER)), next_fn=run_one_round)