def __attrs_post_init__(self): self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input) self.real_data_type = tensor_spec_for_batch(self.dummy_real_data) # Model-weights based types self._generator = self.generator_model_fn() _ = self._generator(self.dummy_gen_input) if not isinstance(self._generator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._generator))) self._discriminator = self.discriminator_model_fn() _ = self._discriminator(self.dummy_real_data) if not isinstance(self._discriminator, tf.keras.models.Model): raise TypeError( 'Expected `tf.keras.models.Model`, found {}.'.format( type(self._discriminator))) def vars_to_type(var_struct): # TODO(b/131681951): read_value() shouldn't be needed return tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor(v.read_value()), var_struct) self.discriminator_weights_type = vars_to_type( self._discriminator.weights) self.generator_weights_type = vars_to_type(self._generator.weights) self.from_server_type = gan_training_tf_fns.FromServer( generator_weights=self.generator_weights_type, discriminator_weights=self.discriminator_weights_type, meta_gen=self.generator_weights_type, meta_disc=self.discriminator_weights_type) self.client_gen_input_type = tff.type_at_clients( tff.SequenceType(self.gen_input_type)) self.client_real_data_type = tff.type_at_clients( tff.SequenceType(self.real_data_type)) self.server_gen_input_type = tff.type_at_server( tff.SequenceType(self.gen_input_type)) if self.train_discriminator_dp_average_query is not None: self.aggregation_process = tff.aggregators.DifferentiallyPrivateFactory( query=self.train_discriminator_dp_average_query).create( value_type=tff.to_type(self.discriminator_weights_type)) else: self.aggregation_process = tff.aggregators.MeanFactory().create( value_type=tff.to_type(self.discriminator_weights_type), weight_type=tff.to_type(tf.float32))
def build_jax_federated_averaging_process(batch_type, model_type, loss_fn, step_size): """Constructs an iterative process that implements simple federated averaging. Args: batch_type: An instance of `tff.Type` that represents the type of a single batch of data to use for training. This type should be constructed with standard Python containers (such as `collections.OrderedDict`) of the sort that are expected as parameters to `loss_fn`. model_type: An instance of `tff.Type` that represents the type of the model. Similarly to `batch_size`, this type should be constructed with standard Python containers (such as `collections.OrderedDict`) of the sort that are expected as parameters to `loss_fn`. loss_fn: A loss function for the model. Must be a Python function that takes two parameters, one of them being the model, and the other being a single batch of data (with types matching `batch_type` and `model_type`). step_size: The step size to use during training (an `np.float32`). Returns: An instance of `tff.templates.IterativeProcess` that implements federated training in JAX. """ batch_type = tff.to_type(batch_type) model_type = tff.to_type(model_type) # py_typecheck.check_type(batch_type, computation_types.Type) # py_typecheck.check_type(model_type, computation_types.Type) # py_typecheck.check_callable(loss_fn) # py_typecheck.check_type(step_size, np.float) def _tensor_zeros(tensor_type): return jax.numpy.zeros( tensor_type.shape.dims, dtype=tensor_type.dtype.as_numpy_dtype) @tff.jax_computation def _create_zero_model(): model_zeros = tff.structure.map_structure(_tensor_zeros, model_type) return tff.types.type_to_py_container(model_zeros, model_type) @tff.federated_computation def _create_zero_model_on_server(): return tff.federated_eval(_create_zero_model, tff.SERVER) def _apply_update(model_param, param_delta): return model_param - step_size * param_delta @tff.jax_computation(model_type, batch_type) def _train_on_one_batch(model, batch): params = tff.structure.flatten( tff.structure.from_container(model, recursive=True)) grads = tff.structure.flatten( tff.structure.from_container(jax.grad(loss_fn)(model, batch))) updated_params = [_apply_update(x, y) for (x, y) in zip(params, grads)] trained_model = tff.structure.pack_sequence_as(model_type, updated_params) return tff.types.type_to_py_container(trained_model, model_type) local_dataset_type = tff.SequenceType(batch_type) @tff.federated_computation(model_type, local_dataset_type) def _train_on_one_client(model, batches): return tff.sequence_reduce(batches, model, _train_on_one_batch) @tff.federated_computation( tff.FederatedType(model_type, tff.SERVER), tff.FederatedType(local_dataset_type, tff.CLIENTS)) def _train_one_round(model, federated_data): locally_trained_models = tff.federated_map( _train_on_one_client, collections.OrderedDict([('model', tff.federated_broadcast(model)), ('batches', federated_data)])) return tff.federated_mean(locally_trained_models) return tff.templates.IterativeProcess( initialize_fn=_create_zero_model_on_server, next_fn=_train_one_round)