Beispiel #1
0
def _initialize_optimizer_vars(model: tff.learning.Model,
                               optimizer: tf.keras.optimizers.Optimizer):
    """Ensures variables holding the state of `optimizer` are created."""
    delta = tf.nest.map_structure(tf.zeros_like, _get_weights(model).trainable)
    model_weights = _get_weights(model)
    grads_and_vars = tf.nest.map_structure(lambda x, v: (x, v), delta,
                                           model_weights.trainable)
    print([delta_.graph for delta_ in delta])
    print([var.graph for var in model.trainable_variables])
    print([var.graph for var in optimizer.variables()])
    optimizer.apply_gradients(grads_and_vars, name='server_update')
    print([delta_.graph for delta_ in delta])
    print([var.graph for var in model.trainable_variables])
    print([var.graph for var in optimizer.variables()])
    assert optimizer.variables()
Beispiel #2
0
def update(model: tff.learning.Model, optimizer: tf.keras.optimizers.Optimizer,
           state: State, weights_delta: list) -> State:
    state.model.assign_weights_to(model)
    tf.nest.map_structure(lambda v, t: v.assign(t), optimizer.variables(),
                          state.optimizer_state)

    neg_weights_delta = [-1.0 * x for x in weights_delta]
    optimizer.apply_gradients(zip(neg_weights_delta,
                                  model.trainable_variables),
                              name='server_update')

    return tff.structure.update_struct(
        state,
        model=tff.learning.ModelWeights.from_model(model),
        optimizer_state=optimizer.variables(),
        round_num=state.round_num + 1)
Beispiel #3
0
def __initialize_optimizer(model: utils.PersonalizationLayersDecorator,
                           optimizer: tf.keras.optimizers.Optimizer):
    zero_gradient = tf.nest.map_structure(tf.zeros_like,
                                          model.base_model.trainable_variables)
    optimizer.apply_gradients(
        zip(zero_gradient, model.base_model.trainable_variables))
    assert optimizer.variables()
def _initialize_optimizer_vars(model: tff.learning.Model,
                               optimizer: tf.keras.optimizers.Optimizer):
    """Ensures variables holding the state of `optimizer` are created."""
    delta = tf.nest.map_structure(tf.zeros_like, _get_weights(model).trainable)
    model_weights = _get_weights(model)
    grads_and_vars = tf.nest.map_structure(lambda x, v: (x, v), delta,
                                           model_weights.trainable)
    optimizer.apply_gradients(grads_and_vars)
    assert optimizer.variables()
def initialize_optimizer_vars(model: tf.keras.Model,
                              optimizer: tf.keras.optimizers.Optimizer):
    """Ensures variables holding the state of `optimizer` are created."""
    delta = tf.nest.map_structure(tf.zeros_like, model.trainable_variables)
    grads_and_vars = tf.nest.map_structure(lambda x, v: (x, v), delta,
                                           model.trainable_weights)
    optimizer.apply_gradients(grads_and_vars, name='server_update')

    assert optimizer.variables()
def create_optimizer_vars(
        model: model_lib.Model,
        optimizer: tf.keras.optimizers.Optimizer) -> Iterable[tf.Variable]:
    """Applies a placeholder update to optimizer to enable getting its variables."""
    delta = tf.nest.map_structure(tf.zeros_like,
                                  get_global_variables(model).trainable)
    grads_and_vars = tf.nest.map_structure(
        lambda x, v: (-1.0 * x, v), tf.nest.flatten(delta),
        tf.nest.flatten(get_global_variables(model).trainable))
    optimizer.apply_gradients(grads_and_vars, name='server_update')
    return optimizer.variables()
def _initialize_optimizer_vars(model: tff.learning.Model,
                               optimizer: tf.keras.optimizers.Optimizer):
    """Creates optimizer variables to assign the optimizer's state."""
    # Create zero gradients to force an update that doesn't modify.
    # Force eagerly constructing the optimizer variables. Normally Keras lazily
    # creates the variables on first usage of the optimizer. Optimizers such as
    # Adam, Adagrad, or using momentum need to create a new set of variables shape
    # like the model weights.
    model_weights = tff.learning.ModelWeights.from_model(model)
    zero_gradient = [tf.zeros_like(t) for t in model_weights.trainable]
    optimizer.apply_gradients(zip(zero_gradient, model_weights.trainable))
    assert optimizer.variables()
Beispiel #8
0
def create_train_discriminator_fn(
    gan_loss_fns: gan_losses.AbstractGanLossFns,
    disc_optimizer: tf.keras.optimizers.Optimizer):
  """Create a function that trains discriminator, binding loss and optimizer.

  Args:
    gan_loss_fns: Instance of gan_losses.AbstractGanLossFns interface,
      specifying the generator/discriminator training losses.
    disc_optimizer: Optimizer for training the discriminator.

  Returns:
    Function that executes one step of discriminator training.
  """
  # We assert that the optimizer has not been used previously, which ensures
  # that when it is bound the train fn isn't holding onto a different copy of
  # the optimizer variables then the copy that is being exchanged b/w server and
  # clients.
  if disc_optimizer.variables():
    raise ValueError(
        'Expected disc_optimizer to not have been used previously, but '
        'variables were already initialized.')

  @tf.function
  def train_discriminator_fn(generator: tf.keras.Model,
                             discriminator: tf.keras.Model, generator_inputs,
                             real_data):
    """Trains the discriminator on a single batch.

    Args:
      generator:  The generator.
      discriminator: The discriminator.
      generator_inputs: A batch of inputs (usually noise) for the generator.
      real_data: A batch of real data for the discriminator.

    Returns:
      The size of the batch.
    """

    def disc_loss():
      """Does the forward pass and computes losses for the discriminator."""
      # N.B. The complete pass must be inside loss() for gradient tracing.
      return gan_loss_fns.discriminator_loss(generator, discriminator,
                                             generator_inputs, real_data)

    disc_optimizer.minimize(
        disc_loss, var_list=discriminator.trainable_variables)
    return tf.shape(real_data)[0]

  return train_discriminator_fn
Beispiel #9
0
def _build_server_optimizer(
    model: model_lib.Model, optimizer: tf.keras.optimizers.Optimizer
) -> Tuple[Callable[..., tf.Tensor], List[tf.Variable]]:
    """A helper for server computations that constructs  the optimizer.

  This code is needed both in server_init (to introduce variables so
  we can read their initial values) and in server_update_model.

  Args:
    model: A `tff.learning.Model`.
    optimizer: A `tf.keras.optimizers.Optimizer`.

  Returns:
    A tuple of (apply_delta_fn, optimizer_vars), where:
      *  apply_delta_fn is a TensorFlow function that takes a model delta and
         updates the trainable weights of `model` as well as possibly
         optimizer_state variables introduced by the optimizer.
      *  optimizer_vars is a list of optimizer variables.
  """
    @tf.function
    def apply_delta(delta):
        """Applies `delta` to `model.weights`."""
        tf.nest.assert_same_structure(delta, model.weights.trainable)
        grads_and_vars = tf.nest.map_structure(
            lambda x, v: (-1.0 * x, v), tf.nest.flatten(delta),
            tf.nest.flatten(model.weights.trainable))
        # N.B. This may create variables.
        optimizer.apply_gradients(grads_and_vars, name='server_update')
        return tf.constant(1)  # We have to return something.

    # Create a dummy input and trace apply_delta so that
    # we can determine the optimizer's variables.
    weights_delta = tf.nest.map_structure(tf.zeros_like,
                                          model.weights.trainable)

    # TODO(b/109733734): We would like to call get_concrete_function,
    # but that does not currently work with structured inputs.
    # For now, we just call the function on dummy input, which
    # still ensures the function is traced (so variables are created).
    apply_delta(delta=weights_delta)

    # N.B. Using to_var_dict doesn't work here, because we
    # may get non-unique names, so we just use a flat list.
    optimizer_vars = optimizer.variables()

    return (apply_delta, optimizer_vars)
Beispiel #10
0
def _eagerly_create_optimizer_variables(
    *, model: model_lib.Model,
    optimizer: tf.keras.optimizers.Optimizer) -> List[tf.Variable]:
  """Forces eager construction of the optimizer variables.

  This code is needed both in `server_init` and `server_update` (to introduce
  variables so we can read their initial values for the initial state).

  Args:
    model: A `tff.learning.Model`.
    optimizer: A `tf.keras.optimizers.Optimizer`.

  Returns:
    A list of optimizer variables.
  """
  delta_tensor_spec = tf.nest.map_structure(
      lambda v: tf.TensorSpec.from_tensor(v.read_value()),
      model_utils.ModelWeights.from_model(model).trainable)
  # Trace the function, which forces eager variable creation.
  tf.function(_apply_delta).get_concrete_function(
      optimizer=optimizer, model=model, delta=delta_tensor_spec)
  return optimizer.variables()
Beispiel #11
0
def __initialize_optimizer(model: tff.learning.Model,
                           optimizer: tf.keras.optimizers.Optimizer):
    zero_gradient = tf.nest.map_structure(tf.zeros_like,
                                          model.trainable_variables)
    optimizer.apply_gradients(zip(zero_gradient, model.trainable_variables))
    assert optimizer.variables()
def server_computation(
    # Tensor/Dataset arguments that will be supplied by TFF:
    server_state: ServerState,
    gen_inputs_ds: tf.data.Dataset,
    client_output: ClientOutput,
    # Python arguments to be bound at TFF computation construction time:
    generator: tf.keras.Model,
    discriminator: tf.keras.Model,
    state_gen_optimizer: tf.keras.optimizers.Optimizer,
    state_disc_optimizer: tf.keras.optimizers.Optimizer,
    # Not an argument bound at TFF computation construction time, but placed
    # last so that it can be defaulted to empty tuple (for non-DP use cases).
    new_dp_averaging_state=()
) -> ServerState:
    """The computation to run on the server, training the generator.

  Args:
    server_state: The initial `ServerState` for the round.
    gen_inputs_ds: An infinite `tf.data.Dataset` of inputs to the `generator`.
    client_output: The (possibly aggregated) `ClientOutput`.
    generator:  The generator.
    discriminator: The discriminator.
    server_disc_update_optimizer: Optimizer used to `apply_gradients` based on
      the client_output delta.
    train_generator_fn: A function which takes the two networks and generator
      input and trains the generator.
    new_dp_averaging_state: The updated state of the DP averaging aggregator.

  Returns:
    An updated `ServerState` object.
  """
    # A tf.function can't modify the structure of its input arguments,
    # so we make a semi-shallow copy:
    server_state = attr.evolve(server_state,
                               counters=dict(server_state.counters))

    tf.nest.map_structure(conditioned_assign, state_gen_optimizer.variables(),
                          server_state.state_gen_optimizer_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), _weights(generator),
                          server_state.generator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), _weights(discriminator),
                          server_state.discriminator_weights)
    server_gen_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1)
    server_disc_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1)

    delta = client_output.discriminator_weights_delta
    tf.nest.assert_same_structure(delta, discriminator.trainable_weights)
    grads_and_vars = tf.nest.map_structure(lambda x, v: (-1.0 * x, v), delta,
                                           discriminator.trainable_weights)
    server_disc_update_optimizer.apply_gradients(grads_and_vars,
                                                 name='server_update_disc')

    for k, v in client_output.counters.items():
        server_state.counters[k] += v

    # Update the state of the DP averaging aggregator.
    server_state.dp_averaging_state = new_dp_averaging_state

    gen_examples_this_round = tf.constant(0)

    loss_fns = gan_losses.WassersteinGanLossFns()
    for gen_inputs in gen_inputs_ds:  # Compiled by autograph.
        with tf.GradientTape() as tape2:
            loss2 = loss_fns.generator_loss(generator, discriminator,
                                            gen_inputs)
        grads2 = tape2.gradient(loss2, generator.trainable_variables)
        grads_and_vars2 = zip(grads2, generator.trainable_variables)
        state_gen_optimizer.apply_gradients(grads_and_vars2)
        gen_examples_this_round += tf.shape(gen_inputs)[0]
    # update discriminator optimizer
    delta_opt_D = client_output.state_disc_opt_delta
    updated_opt_D = tf.nest.map_structure(
        lambda a, b: a + b, server_state.state_disc_optimizer_weights,
        delta_opt_D)

    G_change = tf.nest.map_structure(tf.subtract, generator.trainable_weights,
                                     server_state.generator_weights.trainable)
    D_change = tf.nest.map_structure(
        tf.subtract, discriminator.trainable_weights,
        server_state.discriminator_weights.trainable)
    server_state.state_gen_optimizer_weights = tf.nest.map_structure(
        lambda x: tf.cast(x, tf.float32), state_gen_optimizer.variables())
    server_state.state_disc_optimizer_weights = updated_opt_D
    server_state.counters[
        'num_generator_train_examples'] += gen_examples_this_round
    server_state.generator_weights = _weights(generator)
    server_state.discriminator_weights = _weights(discriminator)
    server_state.counters['num_rounds'] += 1
    server_state.generator_diff = G_change
    server_state.discriminator_diff = D_change
    return server_state
def client_computation_fedadam(
        # Tensor/Dataset arguments that will be supplied by TFF:
        gen_inputs_ds: tf.data.Dataset,
        real_data_ds: tf.data.Dataset,
        from_server: FromServer,
        # Python arguments bound to be bound at TFF computation construction time:
        generator: tf.keras.Model,
        discriminator: tf.keras.Model,
        state_gen_optimizer: tf.keras.optimizers.Optimizer,
        state_disc_optimizer: tf.keras.optimizers.Optimizer) -> ClientOutput:
    """The computation to run on the client, training the discriminator.

  Args:
    gen_inputs_ds: A `tf.data.Dataset` of generator_inputs.
    real_data_ds: A `tf.data.Dataset` of data from the real distribution.
    from_server: A `FromServer` object, including the current model weights.
    generator:  The generator.
    discriminator: The discriminator.
    train_discriminator_fn: A function which takes the two networks, generator
      input, and real data and trains the discriminator.

  Returns:
    A `ClientOutput` object.
  """
    tf.nest.map_structure(lambda a, b: a.assign(b), _weights(generator),
                          from_server.generator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), _weights(discriminator),
                          from_server.discriminator_weights)
    tf.nest.map_structure(conditioned_assign, state_disc_optimizer.variables(),
                          from_server.state_disc_optimizer_weights)
    tf.nest.map_structure(conditioned_assign, state_gen_optimizer.variables(),
                          from_server.state_gen_optimizer_weights)
    num_examples = tf.constant(0)
    loss_fns = gan_losses.WassersteinGanLossFns()
    gen_inputs_and_real_data = tf.data.Dataset.zip(
        (gen_inputs_ds, real_data_ds))

    sgd_optimizer = tf.keras.optimizers.SGD(learning_rate=1)
    for gen_inputs, real_data in gen_inputs_and_real_data:
        # It's possible that real_data and gen_inputs have different batch sizes.
        # For calculating the discriminator loss, it's desirable to have equal-sized
        # contributions from both the real and fake data. Also, it's necessary if
        # using the Wasserstein gradient penalty (where a difference is taken b/w
        # the real and fake data). So here we reduce to the min batch size. This
        # also ensures num_examples properly reflects the amount of data trained on.
        min_batch_size = tf.minimum(
            tf.shape(real_data)[0],
            tf.shape(gen_inputs)[0])
        real_data = real_data[0:min_batch_size]
        gen_inputs = gen_inputs[0:min_batch_size]
        with tf.GradientTape() as tape:
            loss = loss_fns.discriminator_loss(generator, discriminator,
                                               gen_inputs, real_data)
        grads = tape.gradient(loss, discriminator.trainable_variables)
        grads_and_vars = zip(grads, discriminator.trainable_variables)
        sgd_optimizer.apply_gradients(grads_and_vars)
        num_examples += min_batch_size

    state_disc_opt_delta = tf.nest.map_structure(
        tf.subtract,
        tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                              state_disc_optimizer.variables()),
        from_server.state_disc_optimizer_weights)
    # should be zero
    state_gen_opt_delta = tf.nest.map_structure(
        tf.subtract,
        tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                              state_gen_optimizer.variables()),
        from_server.state_gen_optimizer_weights)
    weights_delta = tf.nest.map_structure(
        tf.subtract, discriminator.trainable_weights,
        from_server.discriminator_weights.trainable)
    weights_delta, has_non_finite_delta = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))
    update_weight = tf.cast(num_examples, tf.float32)
    # Zero out the weight if there are any non-finite values.
    # TODO(b/122071074): federated_mean might not do the right thing if
    # all clients have zero weight.
    update_weight = tf.cond(tf.equal(has_non_finite_delta, 0),
                            lambda: update_weight, lambda: tf.constant(0.0))
    return ClientOutput(
        discriminator_weights_delta=weights_delta,
        generator_weights_delta=weights_delta,
        state_disc_opt_delta=state_disc_opt_delta,
        state_gen_opt_delta=state_gen_opt_delta,
        update_weight_D=update_weight,
        update_weight_G=update_weight,
        update_weight=update_weight,
        counters={'num_discriminator_train_examples': num_examples})