Exemple #1
0
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        # TODO(b/131429028): The federated_zip should be automatic.
        from_server = tff.federated_zip(
            gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights))
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        if gan.dp_averaging_fn is None:
            # Not using differential privacy.
            new_dp_averaging_state = server_state.dp_averaging_state
            averaged_discriminator_weights_delta = tff.federated_mean(
                client_outputs.discriminator_weights_delta,
                weight=client_outputs.update_weight)
        else:
            # Using differential privacy. Note that the weight argument is set to None
            # here. This is because the DP aggregation code explicitly does not do
            # weighted aggregation. (If weighted aggregation is desired, differential
            # privacy needs to be turned off.)
            new_dp_averaging_state, averaged_discriminator_weights_delta = (
                gan.dp_averaging_fn(server_state.dp_averaging_state,
                                    client_outputs.discriminator_weights_delta,
                                    weight=None))

        # TODO(b/131085687): Perhaps reconsider the choice to also use
        # ClientOutput to hold the aggregated client output.
        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight),
            counters=tff.federated_sum(client_outputs.counters))

        # TODO(b/131839522): This federated_zip shouldn't be needed.
        aggregated_client_output = tff.federated_zip(aggregated_client_output)

        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_dp_averaging_state))
        return server_state
Exemple #2
0
    def __attrs_post_init__(self):
        self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input)
        self.real_data_type = tensor_spec_for_batch(self.dummy_real_data)

        # Model-weights based types
        self._generator = self.generator_model_fn()
        _ = self._generator(self.dummy_gen_input)
        py_typecheck.check_type(self._generator, tf.keras.models.Model)
        self._discriminator = self.discriminator_model_fn()
        _ = self._discriminator(self.dummy_real_data)
        py_typecheck.check_type(self._discriminator, tf.keras.models.Model)

        def vars_to_type(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(v.read_value()),
                var_struct)

        self.discriminator_weights_type = vars_to_type(
            self._discriminator.weights)
        self.generator_weights_type = vars_to_type(self._generator.weights)

        self.from_server_type = gan_training_tf_fns.FromServer(
            generator_weights=self.generator_weights_type,
            discriminator_weights=self.discriminator_weights_type)

        self.client_gen_input_type = tff.FederatedType(
            tff.SequenceType(self.gen_input_type), tff.CLIENTS)
        self.client_real_data_type = tff.FederatedType(
            tff.SequenceType(self.real_data_type), tff.CLIENTS)
        self.server_gen_input_type = tff.FederatedType(
            tff.SequenceType(self.gen_input_type), tff.SERVER)

        # Right now, the logic in this library is effectively "if DP use stateful
        # aggregator, else don't use stateful aggregator". An alternative
        # formulation would be to always use a stateful aggregator, but when not
        # using DP default the aggregator to be a stateless mean, e.g.,
        # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283.
        # This change will be easier to make if the tff.StatefulAggregateFn is
        # modified to have a property that gives the type of the aggregation state
        # (i.e., what we're storing in self.dp_averaging_state_type).
        if self.train_discriminator_dp_average_query is not None:
            self.dp_averaging_fn, self.dp_averaging_state_type = (
                tff.utils.build_dp_aggregate(
                    query=self.train_discriminator_dp_average_query,
                    value_type_fn=lambda value: self.
                    discriminator_weights_type,
                    from_tff_result_fn=lambda record: list(record)))  # pylint: disable=unnecessary-lambda
Exemple #3
0
    def test_client_computation(self, with_dp):
        gan = _get_gan(with_dp)
        client_comp = tff_gans.build_client_computation(gan)

        generator = gan.generator_model_fn()
        discriminator = gan.discriminator_model_fn()

        from_server = gan_training_tf_fns.FromServer(
            generator_weights=generator.weights,
            discriminator_weights=discriminator.weights)
        client_output = client_comp(
            one_dim_gan.create_generator_inputs().take(10),
            one_dim_gan.create_real_data().take(10), from_server)
        self.assertDictEqual(
            client_output.counters,
            {'num_discriminator_train_examples': 10 * one_dim_gan.BATCH_SIZE})
    def test_client_and_server_computations(self):
        train_generator_fn, train_discriminator_fn = (
            _get_train_generator_and_discriminator_fns())

        # N.B. The way we are using datasets and re-using the same
        # generator and discriminator doesn't really "make sense" from an ML
        # perspective, but it's sufficient for testing. For more proper usage of
        # these functions, see training_loops.py.
        generator = one_dim_gan.create_generator()
        discriminator = one_dim_gan.create_discriminator()
        gen_inputs = one_dim_gan.create_generator_inputs()
        real_data = one_dim_gan.create_real_data()

        server_state = gan_training_tf_fns.server_initial_state(
            generator, discriminator, INIT_DP_AVERAGING_STATE)

        # DP averaging aggregation state is initialized properly in
        # server_initial_state().
        self.assertEqual(server_state.dp_averaging_state,
                         INIT_DP_AVERAGING_STATE)

        client_output = gan_training_tf_fns.client_computation(
            gen_inputs.take(3), real_data.take(3),
            gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights),
            generator, discriminator, train_discriminator_fn)

        server_disc_update_optimizer = tf.keras.optimizers.Adam()
        for _ in range(2):  # Train for 2 rounds
            server_state = gan_training_tf_fns.server_computation(
                server_state, gen_inputs.take(3), client_output, generator,
                discriminator, server_disc_update_optimizer,
                train_generator_fn, NEW_DP_AVERAGING_STATE)

        counters = self.evaluate(server_state.counters)
        self.assertDictEqual(
            counters, {
                'num_rounds': 2,
                'num_discriminator_train_examples':
                2 * 3 * one_dim_gan.BATCH_SIZE,
                'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE
            })

        # DP averaging aggregation state updates properly in server_computation().
        self.assertEqual(server_state.dp_averaging_state,
                         NEW_DP_AVERAGING_STATE)
Exemple #5
0
def simple_training_loop(generator_model_fn,
                         discriminator_model_fn,
                         real_data_fn,
                         gen_inputs_fn,
                         train_generator_fn,
                         train_discriminator_fn,
                         total_rounds=30,
                         client_disc_train_steps=16,
                         server_gen_train_steps=8,
                         rounds_per_eval=10,
                         eval_hook=lambda *args: None):
  """Trains in TF using client_computation and server_computation.

  This is not intended to be a general-purpose training loop (e.g., the
  optimizers are hard-coded), it is primarily intended for testing.

  Args:
    generator_model_fn: A no-arg function return the generator model.
    discriminator_model_fn: A no-arg function return the discriminator model.
    real_data_fn: A no-arg function returning a dataset of real data batches.
    gen_inputs_fn: A no-arg function returning a dataset of generator input
      batches.
    train_generator_fn: A function which takes the two networks and generator
      input and trains the generator.
    train_discriminator_fn: A function which takes the two networks, generator
      input, and real data and trains the discriminator.
    total_rounds: Number of rounds to train.
    client_disc_train_steps: Number of discriminator training batches per round.
    server_gen_train_steps: Number of generator training batches per round.
    rounds_per_eval: How often to call the  `eval_hook` function.
    eval_hook: A function taking arguments (generator, discriminator,
      server_state, round_num) and performs evaluation. Optional.

  Returns:
    A tuple (final `ServerState`, train_time_in_seconds).
  """
  logging.info('Starting simple_training_loop')
  # N.B. We can't use real_data.take(...) in the loops below,
  # or we would get the same examples on every round. Using window
  # essentially breaks one Dataset into a sequence of Datasets,
  # which is exactly what we need here.
  client_gen_inputs = iter(gen_inputs_fn().window(client_disc_train_steps))
  client_real_data = iter(real_data_fn().window(client_disc_train_steps))

  server_gen_inputs = iter(gen_inputs_fn().window(server_gen_train_steps))

  server_generator = generator_model_fn()
  server_discriminator = discriminator_model_fn()
  # We could probably use a single copy of the generator and discriminator, but
  # using separate copies is more faithful to how this code will be used in TFF.
  client_generator = generator_model_fn()
  client_discriminator = discriminator_model_fn()

  server_disc_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1.0)

  server_state = gan_training_tf_fns.server_initial_state(
      server_generator, server_discriminator)

  start_time = time.time()

  def do_eval(round_num):
    eval_hook(server_generator, server_discriminator, server_state, round_num)
    elapsed_minutes = (time.time() - start_time) / 60
    print(
        'Total training time {:.2f} minutes for {} rounds '
        '({:.2f} rounds per minute)'.format(elapsed_minutes, round_num,
                                            round_num / elapsed_minutes),
        flush=True)

  logging.info('Starting training')
  for round_num in range(total_rounds):
    if round_num % rounds_per_eval == 0:
      do_eval(round_num)

    client_output = gan_training_tf_fns.client_computation(
        gen_inputs_ds=next(client_gen_inputs),
        real_data_ds=next(client_real_data),
        from_server=gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights),
        generator=client_generator,
        discriminator=client_discriminator,
        train_discriminator_fn=train_discriminator_fn)

    server_state = gan_training_tf_fns.server_computation(
        server_state=server_state,
        gen_inputs_ds=next(server_gen_inputs),
        client_output=client_output,
        generator=server_generator,
        discriminator=server_discriminator,
        server_disc_update_optimizer=server_disc_update_optimizer,
        train_generator_fn=train_generator_fn)

  train_time = time.time() - start_time
  do_eval(total_rounds)
  return server_state, train_time
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""

        from_server = gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights,
            state_gen_optimizer_weights=server_state.
            state_gen_optimizer_weights,
            state_disc_optimizer_weights=server_state.
            state_disc_optimizer_weights,
            counters=server_state.counters)
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        if gan.dp_averaging_fn is None:
            # Not using differential privacy.
            new_dp_averaging_state = server_state.dp_averaging_state
            averaged_discriminator_weights_delta = tff.federated_mean(
                client_outputs.discriminator_weights_delta,
                weight=client_outputs.update_weight_G)
            averaged_generator_weights_delta = tff.federated_mean(
                client_outputs.generator_weights_delta,
                weight=client_outputs.update_weight_G)
            averaged_gen_opt_delta = tff.federated_mean(
                client_outputs.state_gen_opt_delta,
                weight=client_outputs.update_weight_G)
            averaged_disc_opt_delta = tff.federated_mean(
                client_outputs.state_disc_opt_delta,
                weight=client_outputs.update_weight_G)
        else:
            # Using differential privacy. Note that the weight argument is set to None
            # here. This is because the DP aggregation code explicitly does not do
            # weighted aggregation. (If weighted aggregation is desired, differential
            # privacy needs to be turned off.)
            new_dp_averaging_state, averaged_discriminator_weights_delta = (
                gan.dp_averaging_fn(server_state.dp_averaging_state,
                                    client_outputs.discriminator_weights_delta,
                                    weight=None))

        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            generator_weights_delta=averaged_generator_weights_delta,
            state_gen_opt_delta=averaged_gen_opt_delta,
            state_disc_opt_delta=averaged_disc_opt_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight_D),
            update_weight_D=tff.federated_sum(client_outputs.update_weight_D),
            update_weight_G=tff.federated_sum(client_outputs.update_weight_G),
            counters=tff.federated_sum(client_outputs.counters))

        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_dp_averaging_state))
        return server_state