def _initialize_optimizer_vars(model: tff.learning.Model, optimizer: tf.keras.optimizers.Optimizer): """Ensures variables holding the state of `optimizer` are created.""" delta = tf.nest.map_structure(tf.zeros_like, _get_weights(model).trainable) model_weights = _get_weights(model) grads_and_vars = tf.nest.map_structure(lambda x, v: (x, v), delta, model_weights.trainable) print([delta_.graph for delta_ in delta]) print([var.graph for var in model.trainable_variables]) print([var.graph for var in optimizer.variables()]) optimizer.apply_gradients(grads_and_vars, name='server_update') print([delta_.graph for delta_ in delta]) print([var.graph for var in model.trainable_variables]) print([var.graph for var in optimizer.variables()]) assert optimizer.variables()
def update(model: tff.learning.Model, optimizer: tf.keras.optimizers.Optimizer, state: State, weights_delta: list) -> State: state.model.assign_weights_to(model) tf.nest.map_structure(lambda v, t: v.assign(t), optimizer.variables(), state.optimizer_state) neg_weights_delta = [-1.0 * x for x in weights_delta] optimizer.apply_gradients(zip(neg_weights_delta, model.trainable_variables), name='server_update') return tff.structure.update_struct( state, model=tff.learning.ModelWeights.from_model(model), optimizer_state=optimizer.variables(), round_num=state.round_num + 1)
def __initialize_optimizer(model: utils.PersonalizationLayersDecorator, optimizer: tf.keras.optimizers.Optimizer): zero_gradient = tf.nest.map_structure(tf.zeros_like, model.base_model.trainable_variables) optimizer.apply_gradients( zip(zero_gradient, model.base_model.trainable_variables)) assert optimizer.variables()
def _initialize_optimizer_vars(model: tff.learning.Model, optimizer: tf.keras.optimizers.Optimizer): """Ensures variables holding the state of `optimizer` are created.""" delta = tf.nest.map_structure(tf.zeros_like, _get_weights(model).trainable) model_weights = _get_weights(model) grads_and_vars = tf.nest.map_structure(lambda x, v: (x, v), delta, model_weights.trainable) optimizer.apply_gradients(grads_and_vars) assert optimizer.variables()
def initialize_optimizer_vars(model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer): """Ensures variables holding the state of `optimizer` are created.""" delta = tf.nest.map_structure(tf.zeros_like, model.trainable_variables) grads_and_vars = tf.nest.map_structure(lambda x, v: (x, v), delta, model.trainable_weights) optimizer.apply_gradients(grads_and_vars, name='server_update') assert optimizer.variables()
def create_optimizer_vars( model: model_lib.Model, optimizer: tf.keras.optimizers.Optimizer) -> Iterable[tf.Variable]: """Applies a placeholder update to optimizer to enable getting its variables.""" delta = tf.nest.map_structure(tf.zeros_like, get_global_variables(model).trainable) grads_and_vars = tf.nest.map_structure( lambda x, v: (-1.0 * x, v), tf.nest.flatten(delta), tf.nest.flatten(get_global_variables(model).trainable)) optimizer.apply_gradients(grads_and_vars, name='server_update') return optimizer.variables()
def _initialize_optimizer_vars(model: tff.learning.Model, optimizer: tf.keras.optimizers.Optimizer): """Creates optimizer variables to assign the optimizer's state.""" # Create zero gradients to force an update that doesn't modify. # Force eagerly constructing the optimizer variables. Normally Keras lazily # creates the variables on first usage of the optimizer. Optimizers such as # Adam, Adagrad, or using momentum need to create a new set of variables shape # like the model weights. model_weights = tff.learning.ModelWeights.from_model(model) zero_gradient = [tf.zeros_like(t) for t in model_weights.trainable] optimizer.apply_gradients(zip(zero_gradient, model_weights.trainable)) assert optimizer.variables()
def create_train_discriminator_fn( gan_loss_fns: gan_losses.AbstractGanLossFns, disc_optimizer: tf.keras.optimizers.Optimizer): """Create a function that trains discriminator, binding loss and optimizer. Args: gan_loss_fns: Instance of gan_losses.AbstractGanLossFns interface, specifying the generator/discriminator training losses. disc_optimizer: Optimizer for training the discriminator. Returns: Function that executes one step of discriminator training. """ # We assert that the optimizer has not been used previously, which ensures # that when it is bound the train fn isn't holding onto a different copy of # the optimizer variables then the copy that is being exchanged b/w server and # clients. if disc_optimizer.variables(): raise ValueError( 'Expected disc_optimizer to not have been used previously, but ' 'variables were already initialized.') @tf.function def train_discriminator_fn(generator: tf.keras.Model, discriminator: tf.keras.Model, generator_inputs, real_data): """Trains the discriminator on a single batch. Args: generator: The generator. discriminator: The discriminator. generator_inputs: A batch of inputs (usually noise) for the generator. real_data: A batch of real data for the discriminator. Returns: The size of the batch. """ def disc_loss(): """Does the forward pass and computes losses for the discriminator.""" # N.B. The complete pass must be inside loss() for gradient tracing. return gan_loss_fns.discriminator_loss(generator, discriminator, generator_inputs, real_data) disc_optimizer.minimize( disc_loss, var_list=discriminator.trainable_variables) return tf.shape(real_data)[0] return train_discriminator_fn
def _build_server_optimizer( model: model_lib.Model, optimizer: tf.keras.optimizers.Optimizer ) -> Tuple[Callable[..., tf.Tensor], List[tf.Variable]]: """A helper for server computations that constructs the optimizer. This code is needed both in server_init (to introduce variables so we can read their initial values) and in server_update_model. Args: model: A `tff.learning.Model`. optimizer: A `tf.keras.optimizers.Optimizer`. Returns: A tuple of (apply_delta_fn, optimizer_vars), where: * apply_delta_fn is a TensorFlow function that takes a model delta and updates the trainable weights of `model` as well as possibly optimizer_state variables introduced by the optimizer. * optimizer_vars is a list of optimizer variables. """ @tf.function def apply_delta(delta): """Applies `delta` to `model.weights`.""" tf.nest.assert_same_structure(delta, model.weights.trainable) grads_and_vars = tf.nest.map_structure( lambda x, v: (-1.0 * x, v), tf.nest.flatten(delta), tf.nest.flatten(model.weights.trainable)) # N.B. This may create variables. optimizer.apply_gradients(grads_and_vars, name='server_update') return tf.constant(1) # We have to return something. # Create a dummy input and trace apply_delta so that # we can determine the optimizer's variables. weights_delta = tf.nest.map_structure(tf.zeros_like, model.weights.trainable) # TODO(b/109733734): We would like to call get_concrete_function, # but that does not currently work with structured inputs. # For now, we just call the function on dummy input, which # still ensures the function is traced (so variables are created). apply_delta(delta=weights_delta) # N.B. Using to_var_dict doesn't work here, because we # may get non-unique names, so we just use a flat list. optimizer_vars = optimizer.variables() return (apply_delta, optimizer_vars)
def _eagerly_create_optimizer_variables( *, model: model_lib.Model, optimizer: tf.keras.optimizers.Optimizer) -> List[tf.Variable]: """Forces eager construction of the optimizer variables. This code is needed both in `server_init` and `server_update` (to introduce variables so we can read their initial values for the initial state). Args: model: A `tff.learning.Model`. optimizer: A `tf.keras.optimizers.Optimizer`. Returns: A list of optimizer variables. """ delta_tensor_spec = tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor(v.read_value()), model_utils.ModelWeights.from_model(model).trainable) # Trace the function, which forces eager variable creation. tf.function(_apply_delta).get_concrete_function( optimizer=optimizer, model=model, delta=delta_tensor_spec) return optimizer.variables()
def __initialize_optimizer(model: tff.learning.Model, optimizer: tf.keras.optimizers.Optimizer): zero_gradient = tf.nest.map_structure(tf.zeros_like, model.trainable_variables) optimizer.apply_gradients(zip(zero_gradient, model.trainable_variables)) assert optimizer.variables()
def server_computation( # Tensor/Dataset arguments that will be supplied by TFF: server_state: ServerState, gen_inputs_ds: tf.data.Dataset, client_output: ClientOutput, # Python arguments to be bound at TFF computation construction time: generator: tf.keras.Model, discriminator: tf.keras.Model, state_gen_optimizer: tf.keras.optimizers.Optimizer, state_disc_optimizer: tf.keras.optimizers.Optimizer, # Not an argument bound at TFF computation construction time, but placed # last so that it can be defaulted to empty tuple (for non-DP use cases). new_dp_averaging_state=() ) -> ServerState: """The computation to run on the server, training the generator. Args: server_state: The initial `ServerState` for the round. gen_inputs_ds: An infinite `tf.data.Dataset` of inputs to the `generator`. client_output: The (possibly aggregated) `ClientOutput`. generator: The generator. discriminator: The discriminator. server_disc_update_optimizer: Optimizer used to `apply_gradients` based on the client_output delta. train_generator_fn: A function which takes the two networks and generator input and trains the generator. new_dp_averaging_state: The updated state of the DP averaging aggregator. Returns: An updated `ServerState` object. """ # A tf.function can't modify the structure of its input arguments, # so we make a semi-shallow copy: server_state = attr.evolve(server_state, counters=dict(server_state.counters)) tf.nest.map_structure(conditioned_assign, state_gen_optimizer.variables(), server_state.state_gen_optimizer_weights) tf.nest.map_structure(lambda a, b: a.assign(b), _weights(generator), server_state.generator_weights) tf.nest.map_structure(lambda a, b: a.assign(b), _weights(discriminator), server_state.discriminator_weights) server_gen_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1) server_disc_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1) delta = client_output.discriminator_weights_delta tf.nest.assert_same_structure(delta, discriminator.trainable_weights) grads_and_vars = tf.nest.map_structure(lambda x, v: (-1.0 * x, v), delta, discriminator.trainable_weights) server_disc_update_optimizer.apply_gradients(grads_and_vars, name='server_update_disc') for k, v in client_output.counters.items(): server_state.counters[k] += v # Update the state of the DP averaging aggregator. server_state.dp_averaging_state = new_dp_averaging_state gen_examples_this_round = tf.constant(0) loss_fns = gan_losses.WassersteinGanLossFns() for gen_inputs in gen_inputs_ds: # Compiled by autograph. with tf.GradientTape() as tape2: loss2 = loss_fns.generator_loss(generator, discriminator, gen_inputs) grads2 = tape2.gradient(loss2, generator.trainable_variables) grads_and_vars2 = zip(grads2, generator.trainable_variables) state_gen_optimizer.apply_gradients(grads_and_vars2) gen_examples_this_round += tf.shape(gen_inputs)[0] # update discriminator optimizer delta_opt_D = client_output.state_disc_opt_delta updated_opt_D = tf.nest.map_structure( lambda a, b: a + b, server_state.state_disc_optimizer_weights, delta_opt_D) G_change = tf.nest.map_structure(tf.subtract, generator.trainable_weights, server_state.generator_weights.trainable) D_change = tf.nest.map_structure( tf.subtract, discriminator.trainable_weights, server_state.discriminator_weights.trainable) server_state.state_gen_optimizer_weights = tf.nest.map_structure( lambda x: tf.cast(x, tf.float32), state_gen_optimizer.variables()) server_state.state_disc_optimizer_weights = updated_opt_D server_state.counters[ 'num_generator_train_examples'] += gen_examples_this_round server_state.generator_weights = _weights(generator) server_state.discriminator_weights = _weights(discriminator) server_state.counters['num_rounds'] += 1 server_state.generator_diff = G_change server_state.discriminator_diff = D_change return server_state
def client_computation_fedadam( # Tensor/Dataset arguments that will be supplied by TFF: gen_inputs_ds: tf.data.Dataset, real_data_ds: tf.data.Dataset, from_server: FromServer, # Python arguments bound to be bound at TFF computation construction time: generator: tf.keras.Model, discriminator: tf.keras.Model, state_gen_optimizer: tf.keras.optimizers.Optimizer, state_disc_optimizer: tf.keras.optimizers.Optimizer) -> ClientOutput: """The computation to run on the client, training the discriminator. Args: gen_inputs_ds: A `tf.data.Dataset` of generator_inputs. real_data_ds: A `tf.data.Dataset` of data from the real distribution. from_server: A `FromServer` object, including the current model weights. generator: The generator. discriminator: The discriminator. train_discriminator_fn: A function which takes the two networks, generator input, and real data and trains the discriminator. Returns: A `ClientOutput` object. """ tf.nest.map_structure(lambda a, b: a.assign(b), _weights(generator), from_server.generator_weights) tf.nest.map_structure(lambda a, b: a.assign(b), _weights(discriminator), from_server.discriminator_weights) tf.nest.map_structure(conditioned_assign, state_disc_optimizer.variables(), from_server.state_disc_optimizer_weights) tf.nest.map_structure(conditioned_assign, state_gen_optimizer.variables(), from_server.state_gen_optimizer_weights) num_examples = tf.constant(0) loss_fns = gan_losses.WassersteinGanLossFns() gen_inputs_and_real_data = tf.data.Dataset.zip( (gen_inputs_ds, real_data_ds)) sgd_optimizer = tf.keras.optimizers.SGD(learning_rate=1) for gen_inputs, real_data in gen_inputs_and_real_data: # It's possible that real_data and gen_inputs have different batch sizes. # For calculating the discriminator loss, it's desirable to have equal-sized # contributions from both the real and fake data. Also, it's necessary if # using the Wasserstein gradient penalty (where a difference is taken b/w # the real and fake data). So here we reduce to the min batch size. This # also ensures num_examples properly reflects the amount of data trained on. min_batch_size = tf.minimum( tf.shape(real_data)[0], tf.shape(gen_inputs)[0]) real_data = real_data[0:min_batch_size] gen_inputs = gen_inputs[0:min_batch_size] with tf.GradientTape() as tape: loss = loss_fns.discriminator_loss(generator, discriminator, gen_inputs, real_data) grads = tape.gradient(loss, discriminator.trainable_variables) grads_and_vars = zip(grads, discriminator.trainable_variables) sgd_optimizer.apply_gradients(grads_and_vars) num_examples += min_batch_size state_disc_opt_delta = tf.nest.map_structure( tf.subtract, tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), state_disc_optimizer.variables()), from_server.state_disc_optimizer_weights) # should be zero state_gen_opt_delta = tf.nest.map_structure( tf.subtract, tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), state_gen_optimizer.variables()), from_server.state_gen_optimizer_weights) weights_delta = tf.nest.map_structure( tf.subtract, discriminator.trainable_weights, from_server.discriminator_weights.trainable) weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) update_weight = tf.cast(num_examples, tf.float32) # Zero out the weight if there are any non-finite values. # TODO(b/122071074): federated_mean might not do the right thing if # all clients have zero weight. update_weight = tf.cond(tf.equal(has_non_finite_delta, 0), lambda: update_weight, lambda: tf.constant(0.0)) return ClientOutput( discriminator_weights_delta=weights_delta, generator_weights_delta=weights_delta, state_disc_opt_delta=state_disc_opt_delta, state_gen_opt_delta=state_gen_opt_delta, update_weight_D=update_weight, update_weight_G=update_weight, update_weight=update_weight, counters={'num_discriminator_train_examples': num_examples})