def build_control_computation(gan: GanFnsAndTypes, disc_optimizer_fn: OptimizerBuilder, gen_optimizer_fn: OptimizerBuilder, tau: float): """Returns a `tff.tf_computation` for the `client_computation`. This is a thin wrapper around `gan_training_tf_fns.client_computation`. Args: gan: A `GanFnsAndTypes` object. Returns: A `tff.tf_computation.` """ @tff.tf_computation(tff.SequenceType(gan.gen_input_type), tff.SequenceType(gan.real_data_type), gan.from_server_type) def control_computation(gen_inputs, real_data, from_server): """Returns the client_output.""" generator = gan.generator_model_fn() discriminator = gan.discriminator_model_fn() zero_gen = tf.nest.map_structure(tf.zeros_like, generator.weights) zero_disc = tf.nest.map_structure(tf.zeros_like, discriminator.weights) return gan_training_tf_fns.client_control( gen_inputs_ds=gen_inputs, real_data_ds=real_data, from_server=from_server, generator=generator, discriminator=discriminator, disc_optimizer=disc_optimizer_fn(), gen_optimizer=gen_optimizer_fn(), zero_gen=gan.generator_model_fn(), zero_disc=gan.discriminator_model_fn(), tau=tau) return control_computation
def build_client_computation(gan: GanFnsAndTypes): """Returns a `tff.tf_computation` for the `client_computation`. This is a thin wrapper around `gan_training_tf_fns.client_computation`. Args: gan: A `GanFnsAndTypes` object. Returns: A `tff.tf_computation.` """ @tff.tf_computation(tff.SequenceType(gan.gen_input_type), tff.SequenceType(gan.real_data_type), gan.from_server_type) def client_computation(gen_inputs, real_data, from_server): """Returns the client_output.""" return gan_training_tf_fns.client_computation( gen_inputs_ds=gen_inputs, real_data_ds=real_data, from_server=from_server, generator=gan.generator_model_fn(), discriminator=gan.discriminator_model_fn(), train_discriminator_fn=gan.train_discriminator_fn) return client_computation
def _temperature_sensor_example_next_fn(): @tff.tf_computation( tff.SequenceType(tf.float32), tf.float32) def count_over(ds, t): return ds.reduce( np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32)) @tff.tf_computation(tff.SequenceType(tf.float32)) def count_total(ds): return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0) @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.float32), tff.CLIENTS), tff.FederatedType(tf.float32, tff.SERVER)) def comp(temperatures, threshold): return tff.federated_mean( tff.federated_map( count_over, tff.federated_zip( [temperatures, tff.federated_broadcast(threshold)])), tff.federated_map(count_total, temperatures)) return comp
def build_client_computation(gan: GanFnsAndTypes, disc_optimizer_fn: OptimizerBuilder, gen_optimizer_fn: OptimizerBuilder, tau: float): """Returns a `tff.tf_computation` for the `client_computation`. This is a thin wrapper around `gan_training_tf_fns.client_computation`. Args: gan: A `GanFnsAndTypes` object. Returns: A `tff.tf_computation.` """ @tff.tf_computation(tff.SequenceType(gan.gen_input_type), tff.SequenceType(gan.real_data_type), gan.from_server_type, gan.from_server_type.generator_weights, gan.from_server_type.discriminator_weights) def client_computation(gen_inputs, real_data, from_server, control_input_gen, control_input_disc): """Returns the client_output.""" return gan_training_tf_fns.client_computation( gen_inputs_ds=gen_inputs, real_data_ds=real_data, from_server=from_server, generator=gan.generator_model_fn(), discriminator=gan.discriminator_model_fn(), gen_optimizer=gen_optimizer_fn(), disc_optimizer=disc_optimizer_fn(), control_input_gen=control_input_gen, control_input_disc=control_input_disc, tau=tau) return client_computation
def test_build_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.FederatedType( tff.SequenceType(test_dataset.element_spec), tff.CLIENTS) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return _Batch( tf.fill(dims=(784,), value=float(x) * 2.0), tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0)) return ds.map(to_batch).batch(2) iterproc = _build_simple_fed_pa_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) iterproc = tff.simulation.compose_dataset_computation_with_iterative_process( preprocess_dataset, iterproc) with tf.Graph().as_default(): test_model_for_types = _uncompiled_model_builder() server_state_type = tff.FederatedType( fed_pa_schedule.ServerState( model=tff.framework.type_from_tensors( tff.learning.ModelWeights( test_model_for_types.trainable_variables, test_model_for_types.non_trainable_variables)), optimizer_state=(tf.int64,), round_num=tf.float32), tff.SERVER) metrics_type = tff.FederatedType( tff.StructType([('loss', tf.float32), ('model_delta_zeros_percent', tf.float32), ('model_delta_correction_l2_norm', tf.float32)]), tff.SERVER) expected_parameter_type = collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type, ) expected_result_type = (server_state_type, metrics_type) expected_type = tff.FunctionType( parameter=expected_parameter_type, result=expected_result_type) self.assertTrue( iterproc.next.type_signature.is_equivalent_to(expected_type), msg='{s}\n!={t}'.format( s=iterproc.next.type_signature, t=expected_type))
def __attrs_post_init__(self): self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input) self.real_data_type = tensor_spec_for_batch(self.dummy_real_data) # Model-weights based types self._generator = self.generator_model_fn() _ = self._generator(self.dummy_gen_input) if not isinstance(self._generator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._generator))) self._discriminator = self.discriminator_model_fn() _ = self._discriminator(self.dummy_real_data) if not isinstance(self._discriminator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._discriminator))) def vars_to_type(var_struct): # TODO(b/131681951): read_value() shouldn't be needed return tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor(v.read_value()), var_struct) self.discriminator_weights_type = vars_to_type( self._discriminator.weights) self.generator_weights_type = vars_to_type(self._generator.weights) self.from_server_type = gan_training_tf_fns.FromServer( generator_weights=self.generator_weights_type, discriminator_weights=self.discriminator_weights_type, meta_gen=self.generator_weights_type, meta_disc=self.discriminator_weights_type) self.client_gen_input_type = tff.type_at_clients( tff.SequenceType(self.gen_input_type)) self.client_real_data_type = tff.type_at_clients( tff.SequenceType(self.real_data_type)) self.server_gen_input_type = tff.type_at_server( tff.SequenceType(self.gen_input_type)) if self.train_discriminator_dp_average_query is not None: self.aggregation_process = tff.aggregators.DifferentiallyPrivateFactory( query=self.train_discriminator_dp_average_query).create( value_type=tff.to_type(self.discriminator_weights_type)) else: self.aggregation_process = tff.aggregators.MeanFactory().create( value_type=tff.to_type(self.discriminator_weights_type), weight_type=tff.to_type(tf.float32))
def __attrs_post_init__(self): self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input) self.real_data_type = tensor_spec_for_batch(self.dummy_real_data) # Model-weights based types self._generator = self.generator_model_fn() _ = self._generator(self.dummy_gen_input) if not isinstance(self._generator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._generator))) self._discriminator = self.discriminator_model_fn() _ = self._discriminator(self.dummy_real_data) if not isinstance(self._discriminator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._discriminator))) def vars_to_type(var_struct): # TODO(b/131681951): read_value() shouldn't be needed return tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor(v.read_value()), var_struct) self.discriminator_weights_type = vars_to_type( self._discriminator.weights) self.generator_weights_type = vars_to_type(self._generator.weights) self.from_server_type = gan_training_tf_fns.FromServer( generator_weights=self.generator_weights_type, discriminator_weights=self.discriminator_weights_type) self.client_gen_input_type = tff.type_at_clients( tff.SequenceType(self.gen_input_type)) self.client_real_data_type = tff.type_at_clients( tff.SequenceType(self.real_data_type)) self.server_gen_input_type = tff.type_at_server( tff.SequenceType(self.gen_input_type)) # Right now, the logic in this library is effectively "if DP use stateful # aggregator, else don't use stateful aggregator". An alternative # formulation would be to always use a stateful aggregator, but when not # using DP default the aggregator to be a stateless mean, e.g., # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283. if self.train_discriminator_dp_average_query is not None: self.dp_averaging_fn = tff.utils.build_dp_aggregate_process( value_type=tff.to_type(self.discriminator_weights_type), query=self.train_discriminator_dp_average_query)
def build_client_computation(gan: GanFnsAndTypes): """Returns a `tff.tf_computation` for the `client_computation`. This is a thin wrapper around `gan_training_tf_fns.client_computation`. Args: gan: A `GanFnsAndTypes` object. Returns: A `tff.tf_computation.` """ @tff.tf_computation(tff.SequenceType(gan.gen_input_type), tff.SequenceType(gan.real_data_type), gan.from_server_type) def client_computation(gen_inputs, real_data, from_server): """Returns the client_output.""" steps = from_server.counters['num_rounds'] scheduler = tf.keras.optimizers.schedules.PiecewiseConstantDecay( [1000], [0.001, 0.0005]) generator = gan.generator_model_fn() state_gen_optimizer = gan.state_gen_optimizer_fn( scheduler.__call__(steps)) gan_training_tf_fns.initialize_optimizer_vars(generator, state_gen_optimizer) discriminator = gan.discriminator_model_fn() state_disc_optimizer = gan.state_disc_optimizer_fn(0.0002) gan_training_tf_fns.initialize_optimizer_vars(discriminator, state_disc_optimizer) if gan.disc_status == 'fedadam': return gan_training_tf_fns.client_computation_fedadam( gen_inputs_ds=gen_inputs, real_data_ds=real_data, from_server=from_server, generator=generator, discriminator=discriminator, state_gen_optimizer=state_gen_optimizer, state_disc_optimizer=state_disc_optimizer) else: return gan_training_tf_fns.client_computation( gen_inputs_ds=gen_inputs, real_data_ds=real_data, from_server=from_server, generator=generator, discriminator=discriminator, state_gen_optimizer=state_gen_optimizer, state_disc_optimizer=state_disc_optimizer) return client_computation
def __attrs_post_init__(self): self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input) self.real_data_type = tensor_spec_for_batch(self.dummy_real_data) # Model-weights based types self._generator = self.generator_model_fn() _ = self._generator(self.dummy_gen_input) py_typecheck.check_type(self._generator, tf.keras.models.Model) self._discriminator = self.discriminator_model_fn() _ = self._discriminator(self.dummy_real_data) py_typecheck.check_type(self._discriminator, tf.keras.models.Model) def vars_to_type(var_struct): # TODO(b/131681951): read_value() shouldn't be needed return tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor(v.read_value()), var_struct) self.discriminator_weights_type = vars_to_type( self._discriminator.weights) self.generator_weights_type = vars_to_type(self._generator.weights) self.from_server_type = gan_training_tf_fns.FromServer( generator_weights=self.generator_weights_type, discriminator_weights=self.discriminator_weights_type) self.client_gen_input_type = tff.FederatedType( tff.SequenceType(self.gen_input_type), tff.CLIENTS) self.client_real_data_type = tff.FederatedType( tff.SequenceType(self.real_data_type), tff.CLIENTS) self.server_gen_input_type = tff.FederatedType( tff.SequenceType(self.gen_input_type), tff.SERVER) # Right now, the logic in this library is effectively "if DP use stateful # aggregator, else don't use stateful aggregator". An alternative # formulation would be to always use a stateful aggregator, but when not # using DP default the aggregator to be a stateless mean, e.g., # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283. # This change will be easier to make if the tff.StatefulAggregateFn is # modified to have a property that gives the type of the aggregation state # (i.e., what we're storing in self.dp_averaging_state_type). if self.train_discriminator_dp_average_query is not None: self.dp_averaging_fn, self.dp_averaging_state_type = ( tff.utils.build_dp_aggregate( query=self.train_discriminator_dp_average_query, value_type_fn=lambda value: self. discriminator_weights_type, from_tff_result_fn=lambda record: list(record))) # pylint: disable=unnecessary-lambda
def validator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) weights_type = tff.learning.framework.weights_type_from_model(model) validate_client_tf = tff.tf_computation( lambda dataset, state, weights: __validate_client( dataset, state, weights, coefficient_fn, model_fn, tf.function(client.validate)), (dataset_type, client_state_type, weights_type)) federated_weights_type = tff.type_at_server(weights_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def validate(weights, datasets, client_states): broadcast = tff.federated_broadcast(weights) outputs = tff.federated_map(validate_client_tf, (datasets, client_states, broadcast)) metrics = model.federated_output_computation(outputs.metrics) return metrics return tff.federated_computation( validate, (federated_weights_type, federated_dataset_type, federated_client_state_type))
def validator( model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN ): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) validate_client_tf = tff.tf_computation( lambda dataset, state: __validate_client( dataset, state, model_fn, tf.function(client.validate) ), (dataset_type, client_state_type) ) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def validate(datasets, client_states): outputs = tff.federated_map(validate_client_tf, (datasets, client_states)) metrics = model.federated_output_computation(outputs.metrics) return metrics return tff.federated_computation( validate, (federated_dataset_type, federated_client_state_type) )
def test_executes_dataset_concat_aggregation(self): tensor_spec = tf.TensorSpec(shape=[2], dtype=tf.float32) @tff.tf_computation def create_empty_ds(): empty_tensor = tf.zeros(shape=[0] + tensor_spec.shape, dtype=tensor_spec.dtype) return tf.data.Dataset.from_tensor_slices(empty_tensor) @tff.tf_computation def concat_datasets(ds1, ds2): return ds1.concatenate(ds2) @tff.tf_computation def identity(ds): return ds @tff.federated_computation( tff.type_at_clients(tff.SequenceType(tensor_spec))) def do_a_federated_aggregate(client_ds): return tff.federated_aggregate(value=client_ds, zero=create_empty_ds(), accumulate=concat_datasets, merge=concat_datasets, report=identity) input_data = tf.data.Dataset.from_tensor_slices([[0.1, 0.2]]) ds = do_a_federated_aggregate([input_data]) self.assertIsInstance(ds, tf.data.Dataset)
def iterator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN, server_optimizer_fn: OPTIMIZER_FN, client_optimizer_fn: OPTIMIZER_FN): model = model_fn() client_state = client_state_fn() init_tf = tff.tf_computation( lambda: __initialize_server(model_fn, server_optimizer_fn)) server_state_type = init_tf.type_signature.result client_state_type = tff.framework.type_from_tensors(client_state) update_server_tf = tff.tf_computation( lambda state, weights_delta: __update_server( state, weights_delta, model_fn, server_optimizer_fn, tf.function(server.update)), (server_state_type, server_state_type.model.trainable)) state_to_message_tf = tff.tf_computation( lambda state: __state_to_message(state, tf.function(server.to_message) ), server_state_type) dataset_type = tff.SequenceType(model.input_spec) server_message_type = state_to_message_tf.type_signature.result update_client_tf = tff.tf_computation( lambda dataset, state, message: __update_client( dataset, state, message, coefficient_fn, model_fn, client_optimizer_fn, tf.function(client.update)), (dataset_type, client_state_type, server_message_type)) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def init_tff(): return tff.federated_value(init_tf(), tff.SERVER) def next_tff(server_state, datasets, client_states): message = tff.federated_map(state_to_message_tf, server_state) broadcast = tff.federated_broadcast(message) outputs = tff.federated_map(update_client_tf, (datasets, client_states, broadcast)) weights_delta = tff.federated_mean(outputs.weights_delta, weight=outputs.client_weight) metrics = model.federated_output_computation(outputs.metrics) next_state = tff.federated_map(update_server_tf, (server_state, weights_delta)) return next_state, metrics, outputs.client_state return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation(init_tff), next_fn=tff.federated_computation( next_tff, (federated_server_state_type, federated_dataset_type, federated_client_state_type)))
def evaluator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN, client_state_fn: CLIENT_STATE_FN): model = model_fn() client_state = client_state_fn() dataset_type = tff.SequenceType(model.input_spec) client_state_type = tff.framework.type_from_tensors(client_state) weights_type = tff.framework.type_from_tensors( tff.learning.ModelWeights.from_model(model)) evaluate_client_tf = tff.tf_computation( lambda dataset, state, weights: __evaluate_client( dataset, state, weights, coefficient_fn, model_fn, tf.function(client.evaluate)), (dataset_type, client_state_type, weights_type)) federated_weights_type = tff.type_at_server(weights_type) federated_dataset_type = tff.type_at_clients(dataset_type) federated_client_state_type = tff.type_at_clients(client_state_type) def evaluate(weights, datasets, client_states): broadcast = tff.federated_broadcast(weights) outputs = tff.federated_map(evaluate_client_tf, (datasets, client_states, broadcast)) confusion_matrix = tff.federated_sum(outputs.confusion_matrix) aggregated_metrics = model.federated_output_computation( outputs.metrics) collected_metrics = tff.federated_collect(outputs.metrics) return confusion_matrix, aggregated_metrics, collected_metrics return tff.federated_computation( evaluate, (federated_weights_type, federated_dataset_type, federated_client_state_type))
def test_eval_fn_has_correct_type_signature(self): metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] eval_fn = evaluation.build_centralized_evaluation( tff_model_fn, metrics_builder) actual_type = eval_fn.type_signature model_type = tff.FederatedType( tff.learning.ModelWeights( trainable=( tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1]), ), non_trainable=(), ), tff.SERVER) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.SERVER) metrics_type = tff.FederatedType( collections.OrderedDict( mean_squared_error=tff.TensorType(tf.float32), num_examples=tff.TensorType(tf.float32)), tff.SERVER) expected_type = tff.FunctionType(parameter=collections.OrderedDict( model_weights=model_type, centralized_dataset=dataset_type), result=metrics_type) actual_type.check_assignable_from(expected_type)
def test_execute_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(1) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_example(x): del x # Unused. return _Batch(x=np.ones([784], dtype=np.float32), y=np.ones([1], dtype=np.int64)) return ds.map(to_example).batch(1) iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) iterproc = tff.simulation.compose_dataset_computation_with_iterative_process( preprocess_dataset, iterproc) _, train_outputs, _ = self._run_rounds(iterproc, [test_dataset], 6) self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss']) train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2][ 'loss'] train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5][ 'loss'] self.assertLess(train_gap_second_half, train_gap_first_half)
def build_server_computation(gan: GanFnsAndTypes, server_state_type: tff.Type, client_output_type: tff.Type): """Returns a `tff.tf_computation` for the `server_computation`. This is a thin wrapper around `gan_training_tf_fns.server_computation`. Args: gan: A `GanFnsAndTypes` object. server_state_type: The `tff.Type` of the ServerState. client_output_type: The `tff.Type` of the ClientOutput. Returns: A `tff.tf_computation.` """ @tff.tf_computation(server_state_type, tff.SequenceType(gan.gen_input_type), client_output_type, gan.dp_averaging_state_type) def server_computation(server_state, gen_inputs, client_output, new_dp_averaging_state): """The wrapped server_computation.""" return gan_training_tf_fns.server_computation( server_state=server_state, gen_inputs_ds=gen_inputs, client_output=client_output, generator=gan.generator_model_fn(), discriminator=gan.discriminator_model_fn(), server_disc_update_optimizer=gan.server_disc_update_optimizer_fn(), train_generator_fn=gan.train_generator_fn, new_dp_averaging_state=new_dp_averaging_state) return server_computation
def get_federated_tokenize_fn(dataset_name, dataset_element_type_structure): """Get a federated tokenizer function.""" @tff.tf_computation(tff.SequenceType(dataset_element_type_structure)) def tokenize_dataset(dataset): """The TF computation to tokenize a dataset.""" dataset = tokenize(dataset, dataset_name) return dataset @tff.federated_computation( tff.type_at_clients(tff.SequenceType(dataset_element_type_structure))) def tokenize_datasets(datasets): """The TFF computation to compute tokenized datasets.""" tokenized_datasets = tff.federated_map(tokenize_dataset, datasets) return tokenized_datasets return tokenize_datasets
def test_iterative_process_type_signature(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))
def test_executes_passthru_dataset(self): @tff.tf_computation(tff.SequenceType(tf.int64)) def passthru_dataset(ds): return ds input_data = tf.data.Dataset.range(10) ds = passthru_dataset(input_data) self.assertIsInstance(ds, tf.data.Dataset)
def test_build_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.type_at_clients( tff.SequenceType(test_dataset.element_spec)) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def create_batch(x): return collections.OrderedDict( x=[tf.cast(x, dtype=tf.float32)], y=[2.0]) return ds.map(create_batch).batch(2) iterproc = fed_avg_schedule.build_fed_avg_process( model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, client_lr=0.01, server_optimizer_fn=tf.keras.optimizers.SGD) iterproc = tff.simulation.compose_dataset_computation_with_iterative_process( preprocess_dataset, iterproc) with tf.Graph().as_default(): test_model_for_types = model_builder() server_state_type = tff.FederatedType( fed_avg_schedule.ServerState(model=tff.framework.type_from_tensors( tff.learning.ModelWeights( test_model_for_types.trainable_variables, test_model_for_types.non_trainable_variables)), optimizer_state=(tf.int64, ), round_num=tf.float32), tff.SERVER) metrics_type = test_model_for_types.federated_output_computation.type_signature.result expected_parameter_type = collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type, ) expected_result_type = (server_state_type, metrics_type) expected_type = tff.FunctionType(parameter=expected_parameter_type, result=expected_result_type) self.assertTrue( iterproc.next.type_signature.is_equivalent_to(expected_type), msg='{s}\n!={t}'.format(s=iterproc.next.type_signature, t=expected_type))
def create_preprocess_fn( num_epochs: int, batch_size: int, shuffle_buffer_size: int = MAX_CLIENT_DATASET_SIZE, emnist_task: str = 'digit_recognition', num_parallel_calls: tf.Tensor = tf.data.experimental.AUTOTUNE ) -> tff.Computation: """Creates a preprocessing function for EMNIST client datasets. The preprocessing shuffles, repeats, batches, and then reshapes, using the `shuffle`, `repeat`, `batch`, and `map` attributes of a `tf.data.Dataset`, in that order. Args: num_epochs: An integer representing the number of epochs to repeat the client datasets. batch_size: An integer representing the batch size on clients. shuffle_buffer_size: An integer representing the shuffle buffer size on clients. If set to a number <= 1, no shuffling occurs. emnist_task: A string indicating the EMNIST task being performed. Must be one of 'digit_recognition' or 'autoencoder'. If the former, then elements are mapped to tuples of the form (pixels, label), if the latter then elements are mapped to tuples of the form (pixels, pixels). num_parallel_calls: An integer representing the number of parallel calls used when performing `tf.data.Dataset.map`. Returns: A `tff.Computation` performing the preprocessing discussed above. """ if num_epochs < 1: raise ValueError('num_epochs must be a positive integer.') if shuffle_buffer_size <= 1: shuffle_buffer_size = 1 if emnist_task == 'digit_recognition': mapping_fn = _reshape_for_digit_recognition elif emnist_task == 'autoencoder': mapping_fn = _reshape_for_autoencoder else: raise ValueError('emnist_task must be one of "digit_recognition" or ' '"autoencoder".') # Features are intentionally sorted lexicographically by key for consistency # across datasets. feature_dtypes = collections.OrderedDict(label=tff.TensorType(tf.int32), pixels=tff.TensorType(tf.float32, shape=(28, 28))) @tff.tf_computation(tff.SequenceType(feature_dtypes)) def preprocess_fn(dataset): return dataset.shuffle(shuffle_buffer_size).repeat(num_epochs).batch( batch_size, drop_remainder=False).map(mapping_fn, num_parallel_calls=num_parallel_calls) return preprocess_fn
def test_build_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.FederatedType( tff.SequenceType(test_dataset.element_spec), tff.CLIENTS) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return _Batch( tf.fill(dims=(784,), value=float(x) * 2.0), tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0)) return ds.map(to_batch).batch(2) iterproc_adapter = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, dataset_preprocess_comp=preprocess_dataset) with tf.Graph().as_default(): test_model_for_types = _uncompiled_model_builder() iterproc = iterproc_adapter._iterative_process server_state_type = tff.FederatedType( fed_avg_schedule.ServerState( model=tff.framework.type_from_tensors( fed_avg_schedule.ModelWeights( test_model_for_types.trainable_variables, test_model_for_types.non_trainable_variables)), optimizer_state=(tf.int64,), round_num=tf.float32), tff.SERVER) metrics_type = test_model_for_types.federated_output_computation.type_signature.result expected_type = tff.FunctionType( parameter=(server_state_type, client_datasets_type), result=(server_state_type, metrics_type)) self.assertEqual( iterproc.next.type_signature, expected_type, msg='{s}\n!={t}'.format( s=iterproc.next.type_signature, t=expected_type))
def build_server_computation(gan: GanFnsAndTypes, server_state_type: tff.Type, client_output_type: tff.Type): """Returns a `tff.tf_computation` for the `server_computation`. This is a thin wrapper around `gan_training_tf_fns.server_computation`. Args: gan: A `GanFnsAndTypes` object. server_state_type: The `tff.Type` of the ServerState. client_output_type: The `tff.Type` of the ClientOutput. Returns: A `tff.tf_computation.` """ @tff.tf_computation(server_state_type, tff.SequenceType(gan.gen_input_type), client_output_type, gan.dp_averaging_state_type) def server_computation(server_state, gen_inputs, client_output, new_dp_averaging_state): """The wrapped server_computation.""" # initialize the optimizers beforehand so you don't create them within the tf.function steps = server_state.counters['num_rounds'] scheduler = tf.keras.optimizers.schedules.PiecewiseConstantDecay( [1000], [0.001, 0.0005]) state_gen_optimizer = gan.state_gen_optimizer_fn( scheduler.__call__(steps)) generator = gan.generator_model_fn() gan_training_tf_fns.initialize_optimizer_vars(generator, state_gen_optimizer) discriminator = gan.discriminator_model_fn(0.0002) state_disc_optimizer = gan.state_disc_optimizer_fn(steps) gan_training_tf_fns.initialize_optimizer_vars(discriminator, state_disc_optimizer) if gan.disc_status == 'fedadam': return gan_training_tf_fns.server_computation_fedadam( server_state=server_state, gen_inputs_ds=gen_inputs, client_output=client_output, generator=generator, discriminator=discriminator, state_disc_optimizer=state_disc_optimizer, state_gen_optimizer=state_gen_optimizer, new_dp_averaging_state=new_dp_averaging_state) else: return gan_training_tf_fns.server_computation( server_state=server_state, gen_inputs_ds=gen_inputs, client_output=client_output, generator=generator, discriminator=discriminator, state_disc_optimizer=state_disc_optimizer, state_gen_optimizer=state_gen_optimizer, new_dp_averaging_state=new_dp_averaging_state) return server_computation
def create_preprocess_fn( num_epochs: int, batch_size: int, shuffle_buffer_size: int = 50, sequence_length: int = SEQUENCE_LENGTH, num_parallel_calls: int = tf.data.experimental.AUTOTUNE ) -> tff.Computation: """Creates a preprocessing function for Shakespeare client datasets. This function maps a dataset of string snippets to a dataset of input/output character ID sequences. This is done by first repeating the dataset and shuffling (according to `num_epochs` and `shuffle_buffer_size`), mapping the the string sequences to tokens, and packing them into input/output sequences of length `sequence_length`. Args: num_epochs: An integer representing the number of epochs to repeat the client datasets. batch_size: An integer representing the batch size on clients. shuffle_buffer_size: An integer representing the shuffle buffer size on clients. If set to a number <= 1, no shuffling occurs. sequence_length: the length of each example in the batch. num_parallel_calls: An integer representing the number of parallel calls used when performing `tf.data.Dataset.map`. Returns: A `tff.Computation` performing the preprocessing described above. """ if num_epochs < 1: raise ValueError('num_epochs must be a positive integer.') if sequence_length < 1: raise ValueError('sequence_length must be a positive integer.') if shuffle_buffer_size <= 1: shuffle_buffer_size = 1 feature_dtypes = collections.OrderedDict(snippets=tf.string, ) @tff.tf_computation(tff.SequenceType(feature_dtypes)) def preprocess_fn(dataset): to_tokens = _build_tokenize_fn(split_length=sequence_length + 1) return ( dataset.shuffle(shuffle_buffer_size).repeat(num_epochs) # Convert snippets to int64 tokens and pad. .map(to_tokens, num_parallel_calls=num_parallel_calls) # Separate into individual tokens .unbatch() # Join into sequences of the desired length. The previous call of # map(to_ids,...) ensures that the collection of tokens has length # divisible by sequence_length + 1, so no batch dropping is expected. .batch(sequence_length + 1, drop_remainder=True) # Batch sequences together for mini-batching purposes. .batch(batch_size) # Convert batches into training examples. .map(_split_target, num_parallel_calls=num_parallel_calls)) return preprocess_fn
def _create_tff_parallel_clients_with_dataset_reduce(): @tf.function def reduce_fn(x, y): return x + y @tf.function def dataset_reduce_fn(ds, initial_val): return ds.reduce(initial_val, reduce_fn) @tff.tf_computation(tff.SequenceType(tf.int64)) def dataset_reduce_fn_wrapper(ds): initial_val = tf.Variable(np.int64(1.0)) return dataset_reduce_fn(ds, initial_val) @tff.federated_computation(tff.at_clients(tff.SequenceType(tf.int64))) def parallel_client_run(client_datasets): return tff.federated_map(dataset_reduce_fn_wrapper, client_datasets) return parallel_client_run
def build_federated_evaluation( model_fn: Callable[[], tff.learning.Model], metrics_builder: Callable[[], Sequence[tf.keras.metrics.Metric]] ) -> tff.federated_computation: """Builds a federated evaluation `tff.federated_computation`. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. metrics_builder: A no-arg function that returns a sequence of `tf.keras.metrics.Metric` objects. These metrics must have a callable `update_state` accepting `y_true` and `y_pred` arguments, corresponding to the true and predicted label, respectively. Returns: A `tff.federated_computation` that accepts model weights and federated data, and returns the evaluation metrics, aggregated in both uniform- and example-weighted manners. """ # Wrap model construction in a graph to avoid polluting the global context # with variables created for this model. with tf.Graph().as_default(): placeholder_model = model_fn() model_weights_type = tff.learning.framework.weights_type_from_model( placeholder_model) model_input_type = tff.SequenceType(placeholder_model.input_spec) @tff.tf_computation(model_weights_type, model_input_type) def compute_client_metrics(model_weights, federated_dataset): model = model_fn() metrics = metrics_builder() return compute_metrics(model, model_weights, metrics, federated_dataset) @tff.federated_computation(tff.type_at_server(model_weights_type), tff.type_at_clients(model_input_type)) def federated_evaluate(model_weights, federated_dataset): client_model = tff.federated_broadcast(model_weights) client_metrics = tff.federated_map(compute_client_metrics, (client_model, federated_dataset)) # Extract the number of examples in order to compute client weights num_examples = client_metrics.num_examples uniform_weighted_metrics = tff.federated_mean(client_metrics, weight=None) example_weighted_metrics = tff.federated_mean(client_metrics, weight=num_examples) # Aggregate the metrics in a single nested dictionary aggregate_metrics = collections.OrderedDict() aggregate_metrics[AggregationMethods.EXAMPLE_WEIGHTED. value] = example_weighted_metrics aggregate_metrics[AggregationMethods.UNIFORM_WEIGHTED. value] = uniform_weighted_metrics return aggregate_metrics return federated_evaluate
def build_federated_averaging_process_attacked( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1), server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), aggregation_process=None, client_update_tf=ClientExplicitBoosting(boost_factor=1.0)): """Builds the TFF computations for optimization using federated averaging with potentially malicious clients. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`, use during local client training. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`, use to apply updates to the global model. aggregation_process: A 'tff.templates.MeasuredProcess' that aggregates model deltas placed@CLIENTS to an aggregated model delta placed@SERVER. client_update_tf: a 'tf.function' computes the ClientOutput. Returns: A `tff.templates.IterativeProcess`. """ with tf.Graph().as_default(): dummy_model_for_metadata = model_fn() weights_type = tff.learning.framework.weights_type_from_model( dummy_model_for_metadata) if aggregation_process is None: aggregation_process = tff.learning.framework.build_stateless_mean( model_delta_type=weights_type.trainable) server_init = build_server_init_fn(model_fn, server_optimizer_fn, aggregation_process.initialize) server_state_type = server_init.type_signature.result.member server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn, server_state_type, server_state_type.model) tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn, client_update_tf, tf_dataset_type, server_state_type.model) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(tf_dataset_type) run_one_round_tff = build_run_one_round_fn_attacked( server_update_fn, client_update_fn, aggregation_process, dummy_model_for_metadata, federated_server_state_type, federated_dataset_type) return tff.templates.IterativeProcess(initialize_fn=server_init, next_fn=run_one_round_tff)
def create_preprocess_fn( num_epochs: int, batch_size: int, shuffle_buffer_size: int = NUM_EXAMPLES_PER_CLIENT, crop_shape: Tuple[int, int, int] = CIFAR_SHAPE, distort_image=False, num_parallel_calls: int = tf.data.experimental.AUTOTUNE) -> tff.Computation: """Creates a preprocessing function for CIFAR-100 client datasets. Args: num_epochs: An integer representing the number of epochs to repeat the client datasets. batch_size: An integer representing the batch size on clients. shuffle_buffer_size: An integer representing the shuffle buffer size on clients. If set to a number <= 1, no shuffling occurs. crop_shape: A tuple (crop_height, crop_width, num_channels) specifying the desired crop shape for pre-processing. This tuple cannot have elements exceeding (32, 32, 3), element-wise. The element in the last index should be set to 3 to maintain the RGB image structure of the elements. distort_image: A boolean indicating whether to perform preprocessing that includes image distortion, including random crops and flips. num_parallel_calls: An integer representing the number of parallel calls used when performing `tf.data.Dataset.map`. Returns: A `tff.Computation` performing the preprocessing described above. """ if num_epochs < 1: raise ValueError('num_epochs must be a positive integer.') if shuffle_buffer_size <= 1: shuffle_buffer_size = 1 # Features are intentionally sorted lexicographically by key for consistency # across datasets. feature_dtypes = collections.OrderedDict( coarse_label=tff.TensorType(tf.int64), image=tff.TensorType(tf.uint8, shape=(32, 32, 3)), label=tff.TensorType(tf.int64)) image_map_fn = build_image_map(crop_shape, distort_image) @tff.tf_computation(tff.SequenceType(feature_dtypes)) def preprocess_fn(dataset): return ( dataset.shuffle(shuffle_buffer_size).repeat(num_epochs) # We map before batching to ensure that the cropping occurs # at an image level (eg. we do not perform the same crop on # every image within a batch) .map(image_map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)) return preprocess_fn
def build_federated_averaging_process_attacked( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1), server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), stateful_delta_aggregate_fn=build_stateless_mean(), client_update_tf=ClientExplicitBoosting(boost_factor=1.0)): """Builds the TFF computations for optimization using federated averaging with potentially malicious clients. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`, use during local client training. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`, use to apply updates to the global model. stateful_delta_aggregate_fn: A 'tff.computation' that aggregates model deltas placed@CLIENTS to an aggregated model delta placed@SERVER. client_update_tf: a 'tf.function' computes the ClientOutput. Returns: A `tff.templates.IterativeProcess`. """ dummy_model_for_metadata = model_fn() server_init_tf = build_server_init_fn( model_fn, server_optimizer_fn, stateful_delta_aggregate_fn.initialize()) server_state_type = server_init_tf.type_signature.result server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn, server_state_type, server_state_type.model) tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn, client_update_tf, tf_dataset_type, server_state_type.model) federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER) federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) run_one_round_tff = build_run_one_round_fn_attacked( server_update_fn, client_update_fn, stateful_delta_aggregate_fn, dummy_model_for_metadata, federated_server_state_type, federated_dataset_type) return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation( lambda: tff.federated_value(server_init_tf(), tff.SERVER)), next_fn=run_one_round_tff)