Ejemplo n.º 1
0
    def __attrs_post_init__(self):
        self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input)
        self.real_data_type = tensor_spec_for_batch(self.dummy_real_data)

        # Model-weights based types
        self._generator = self.generator_model_fn()
        _ = self._generator(self.dummy_gen_input)
        if not isinstance(self._generator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._generator)))
        self._discriminator = self.discriminator_model_fn()
        _ = self._discriminator(self.dummy_real_data)
        if not isinstance(self._discriminator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._discriminator)))

        def vars_to_type(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(v.read_value()),
                var_struct)

        self.discriminator_weights_type = vars_to_type(
            self._discriminator.weights)
        self.generator_weights_type = vars_to_type(self._generator.weights)

        self.from_server_type = gan_training_tf_fns.FromServer(
            generator_weights=self.generator_weights_type,
            discriminator_weights=self.discriminator_weights_type,
            meta_gen=self.generator_weights_type,
            meta_disc=self.discriminator_weights_type)

        self.client_gen_input_type = tff.type_at_clients(
            tff.SequenceType(self.gen_input_type))
        self.client_real_data_type = tff.type_at_clients(
            tff.SequenceType(self.real_data_type))
        self.server_gen_input_type = tff.type_at_server(
            tff.SequenceType(self.gen_input_type))

        if self.train_discriminator_dp_average_query is not None:
            self.aggregation_process = tff.aggregators.DifferentiallyPrivateFactory(
                query=self.train_discriminator_dp_average_query).create(
                    value_type=tff.to_type(self.discriminator_weights_type))
        else:
            self.aggregation_process = tff.aggregators.MeanFactory().create(
                value_type=tff.to_type(self.discriminator_weights_type),
                weight_type=tff.to_type(tf.float32))
Ejemplo n.º 2
0
    def __attrs_post_init__(self):
        self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input)
        self.real_data_type = tensor_spec_for_batch(self.dummy_real_data)

        # Model-weights based types
        self._generator = self.generator_model_fn()
        _ = self._generator(self.dummy_gen_input)
        if not isinstance(self._generator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._generator)))
        self._discriminator = self.discriminator_model_fn()
        _ = self._discriminator(self.dummy_real_data)
        if not isinstance(self._discriminator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._discriminator)))

        def vars_to_type(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(v.read_value()),
                var_struct)

        self.discriminator_weights_type = vars_to_type(
            self._discriminator.weights)
        self.generator_weights_type = vars_to_type(self._generator.weights)

        self.from_server_type = gan_training_tf_fns.FromServer(
            generator_weights=self.generator_weights_type,
            discriminator_weights=self.discriminator_weights_type)

        self.client_gen_input_type = tff.type_at_clients(
            tff.SequenceType(self.gen_input_type))
        self.client_real_data_type = tff.type_at_clients(
            tff.SequenceType(self.real_data_type))
        self.server_gen_input_type = tff.type_at_server(
            tff.SequenceType(self.gen_input_type))

        # Right now, the logic in this library is effectively "if DP use stateful
        # aggregator, else don't use stateful aggregator". An alternative
        # formulation would be to always use a stateful aggregator, but when not
        # using DP default the aggregator to be a stateless mean, e.g.,
        # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283.
        if self.train_discriminator_dp_average_query is not None:
            self.dp_averaging_fn = tff.utils.build_dp_aggregate_process(
                value_type=tff.to_type(self.discriminator_weights_type),
                query=self.train_discriminator_dp_average_query)
Ejemplo n.º 3
0
    def test_client_computation(self, with_dp):
        gan = _get_gan(with_dp)
        client_comp = tff_gans.build_client_computation(gan)

        generator = gan.generator_model_fn()
        discriminator = gan.discriminator_model_fn()

        from_server = gan_training_tf_fns.FromServer(
            generator_weights=generator.weights,
            discriminator_weights=discriminator.weights)
        client_output = client_comp(
            one_dim_gan.create_generator_inputs().take(10),
            one_dim_gan.create_real_data().take(10), from_server)
        self.assertDictEqual(
            client_output.counters,
            {'num_discriminator_train_examples': 10 * one_dim_gan.BATCH_SIZE})
Ejemplo n.º 4
0
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        from_server = gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights)
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        if gan.dp_averaging_fn is None:
            # Not using differential privacy.
            new_dp_averaging_state = server_state.dp_averaging_state
            averaged_discriminator_weights_delta = tff.federated_mean(
                client_outputs.discriminator_weights_delta,
                weight=client_outputs.update_weight)
        else:
            # Using differential privacy. Note that the weight argument is set to
            # a constant 1.0 here, however the underlying AggregationProcess ignores
            # the parameter and performs no weighting.
            ignored_weight = tff.federated_value(1.0, tff.CLIENTS)
            aggregation_output = gan.dp_averaging_fn.next(
                server_state.dp_averaging_state,
                client_outputs.discriminator_weights_delta,
                weight=ignored_weight)
            new_dp_averaging_state = aggregation_output.state
            averaged_discriminator_weights_delta = aggregation_output.result

        # TODO(b/131085687): Perhaps reconsider the choice to also use
        # ClientOutput to hold the aggregated client output.
        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight),
            counters=tff.federated_sum(client_outputs.counters))

        server_computation = build_server_computation(
            gan, server_state.type_signature.member, client_output_type)
        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_dp_averaging_state))
        return server_state
Ejemplo n.º 5
0
  def test_client_and_server_computations(self):
    train_generator_fn, train_discriminator_fn = (
        _get_train_generator_and_discriminator_fns())

    # N.B. The way we are using datasets and re-using the same
    # generator and discriminator doesn't really "make sense" from an ML
    # perspective, but it's sufficient for testing. For more proper usage of
    # these functions, see training_loops.py.
    generator = one_dim_gan.create_generator()
    discriminator = one_dim_gan.create_discriminator()
    gen_inputs = one_dim_gan.create_generator_inputs()
    real_data = one_dim_gan.create_real_data()

    server_state = gan_training_tf_fns.server_initial_state(
        generator, discriminator)

    # The aggregation state (e.g., used for handling DP averaging) is
    # initialized to be empty. A user of the `server_initial_state` is expected
    # to take the output `ServerState` object and populate this field, most
    # likely via an instance of tff.templates.AggregationProcess.
    self.assertEmpty(server_state.aggregation_state)

    client_output = gan_training_tf_fns.client_computation(
        gen_inputs.take(3), real_data.take(3),
        gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights),
        generator, discriminator, train_discriminator_fn)

    server_disc_update_optimizer = tf.keras.optimizers.Adam()
    for _ in range(2):  # Train for 2 rounds
      server_state = gan_training_tf_fns.server_computation(
          server_state, gen_inputs.take(3), client_output, generator,
          discriminator, server_disc_update_optimizer, train_generator_fn,
          NEW_DP_AVERAGING_STATE)

    counters = self.evaluate(server_state.counters)
    self.assertDictEqual(
        counters, {
            'num_rounds': 2,
            'num_discriminator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE,
            'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE
        })

    # DP averaging aggregation state updates properly in server_computation().
    self.assertEqual(server_state.aggregation_state, NEW_DP_AVERAGING_STATE)
Ejemplo n.º 6
0
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        from_server = gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights)
        client_input = tff.federated_broadcast(from_server)
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input))

        # Note that weight goes unused here if the aggregation is involving
        # Differential Privacy; the underlying AggregationProcess doesn't take the
        # parameter, as it just uniformly weights the clients.
        if gan.aggregation_process.is_weighted:
            aggregation_output = gan.aggregation_process.next(
                server_state.aggregation_state,
                client_outputs.discriminator_weights_delta,
                client_outputs.update_weight)
        else:
            aggregation_output = gan.aggregation_process.next(
                server_state.aggregation_state,
                client_outputs.discriminator_weights_delta)

        new_aggregation_state = aggregation_output.state
        averaged_discriminator_weights_delta = aggregation_output.result

        # TODO(b/131085687): Perhaps reconsider the choice to also use
        # ClientOutput to hold the aggregated client output.
        aggregated_client_output = gan_training_tf_fns.ClientOutput(
            discriminator_weights_delta=averaged_discriminator_weights_delta,
            # We don't actually need the aggregated update_weight, but
            # this keeps the types of the non-aggregated and aggregated
            # client_output the same, which is convenient. And I can
            # imagine wanting this.
            update_weight=tff.federated_sum(client_outputs.update_weight),
            counters=tff.federated_sum(client_outputs.counters))

        server_computation = build_server_computation(
            gan, server_state.type_signature.member, client_output_type,
            gan.aggregation_process.state_type.member)
        server_state = tff.federated_map(
            server_computation,
            (server_state, server_gen_inputs, aggregated_client_output,
             new_aggregation_state))
        return server_state
Ejemplo n.º 7
0
  def test_client_and_server_computations(self):
    train_generator_fn, train_discriminator_fn = (
        _get_train_generator_and_discriminator_fns())

    # N.B. The way we are using datasets and re-using the same
    # generator and discriminator doesn't really "make sense" from an ML
    # perspective, but it's sufficient for testing. For more proper usage of
    # these functions, see training_loops.py.
    generator = one_dim_gan.create_generator()
    discriminator = one_dim_gan.create_discriminator()
    gen_inputs = one_dim_gan.create_generator_inputs()
    real_data = one_dim_gan.create_real_data()

    server_state = gan_training_tf_fns.server_initial_state(
        generator, discriminator, INIT_DP_AVERAGING_STATE)

    # DP averaging aggregation state is initialized properly in
    # server_initial_state().
    self.assertEqual(server_state.dp_averaging_state, INIT_DP_AVERAGING_STATE)

    client_output = gan_training_tf_fns.client_computation(
        gen_inputs.take(3), real_data.take(3),
        gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights),
        generator, discriminator, train_discriminator_fn)

    server_disc_update_optimizer = tf.keras.optimizers.Adam()
    for _ in range(2):  # Train for 2 rounds
      server_state = gan_training_tf_fns.server_computation(
          server_state, gen_inputs.take(3), client_output, generator,
          discriminator, server_disc_update_optimizer, train_generator_fn,
          NEW_DP_AVERAGING_STATE)

    counters = self.evaluate(server_state.counters)
    self.assertDictEqual(
        counters, {
            'num_rounds': 2,
            'num_discriminator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE,
            'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE
        })

    # DP averaging aggregation state updates properly in server_computation().
    self.assertEqual(server_state.dp_averaging_state, NEW_DP_AVERAGING_STATE)
Ejemplo n.º 8
0
def simple_training_loop(generator_model_fn,
                         discriminator_model_fn,
                         real_data_fn,
                         gen_inputs_fn,
                         train_generator_fn,
                         train_discriminator_fn,
                         total_rounds=30,
                         client_disc_train_steps=16,
                         server_gen_train_steps=8,
                         rounds_per_eval=10,
                         eval_hook=lambda *args: None):
    """Trains in TF using client_computation and server_computation.

  This is not intended to be a general-purpose training loop (e.g., the
  optimizers are hard-coded), it is primarily intended for testing.

  Args:
    generator_model_fn: A no-arg function return the generator model.
    discriminator_model_fn: A no-arg function return the discriminator model.
    real_data_fn: A no-arg function returning a dataset of real data batches.
    gen_inputs_fn: A no-arg function returning a dataset of generator input
      batches.
    train_generator_fn: A function which takes the two networks and generator
      input and trains the generator.
    train_discriminator_fn: A function which takes the two networks, generator
      input, and real data and trains the discriminator.
    total_rounds: Number of rounds to train.
    client_disc_train_steps: Number of discriminator training batches per round.
    server_gen_train_steps: Number of generator training batches per round.
    rounds_per_eval: How often to call the  `eval_hook` function.
    eval_hook: A function taking arguments (generator, discriminator,
      server_state, round_num) and performs evaluation. Optional.

  Returns:
    A tuple (final `ServerState`, train_time_in_seconds).
  """
    logging.info('Starting simple_training_loop')
    # N.B. We can't use real_data.take(...) in the loops below,
    # or we would get the same examples on every round. Using window
    # essentially breaks one Dataset into a sequence of Datasets,
    # which is exactly what we need here.
    client_gen_inputs = iter(gen_inputs_fn().window(client_disc_train_steps))
    client_real_data = iter(real_data_fn().window(client_disc_train_steps))

    server_gen_inputs = iter(gen_inputs_fn().window(server_gen_train_steps))

    server_generator = generator_model_fn()
    server_discriminator = discriminator_model_fn()
    # We could probably use a single copy of the generator and discriminator, but
    # using separate copies is more faithful to how this code will be used in TFF.
    client_generator = generator_model_fn()
    client_discriminator = discriminator_model_fn()

    server_disc_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1.0)

    server_state = gan_training_tf_fns.server_initial_state(
        server_generator, server_discriminator)

    start_time = time.time()

    def do_eval(round_num):
        eval_hook(server_generator, server_discriminator, server_state,
                  round_num)
        elapsed_minutes = (time.time() - start_time) / 60
        print('Total training time {:.2f} minutes for {} rounds '
              '({:.2f} rounds per minute)'.format(elapsed_minutes, round_num,
                                                  round_num / elapsed_minutes),
              flush=True)

    logging.info('Starting training')
    for round_num in range(total_rounds):
        if round_num % rounds_per_eval == 0:
            do_eval(round_num)

        client_output = gan_training_tf_fns.client_computation(
            gen_inputs_ds=next(client_gen_inputs),
            real_data_ds=next(client_real_data),
            from_server=gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights),
            generator=client_generator,
            discriminator=client_discriminator,
            train_discriminator_fn=train_discriminator_fn)

        server_state = gan_training_tf_fns.server_computation(
            server_state=server_state,
            gen_inputs_ds=next(server_gen_inputs),
            client_output=client_output,
            generator=server_generator,
            discriminator=server_discriminator,
            server_disc_update_optimizer=server_disc_update_optimizer,
            train_generator_fn=train_generator_fn)

    train_time = time.time() - start_time
    do_eval(total_rounds)
    return server_state, train_time
Ejemplo n.º 9
0
    def run_one_round(server_state, server_gen_inputs, client_gen_inputs,
                      client_real_data):
        """The `tff.Computation` to be returned."""
        from_server = gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights,
            meta_gen=server_state.meta_gen,
            meta_disc=server_state.meta_disc)
        client_input = tff.federated_broadcast(from_server)

        # calculate the control variates
        control_output = tff.federated_map(
            control_computation,
            (client_gen_inputs, client_real_data, client_input))
        central_control_gen = tff.federated_broadcast(
            tff.federated_mean(control_output.generator_weights_delta,
                               weight=control_output.update_weight))
        central_control_disc = tff.federated_broadcast(
            tff.federated_mean(control_output.discriminator_weights_delta,
                               weight=control_output.update_weight))

        @tff.tf_computation(client_output_type.generator_weights_delta,
                            client_output_type.generator_weights_delta)
        def compute_control_input_gen(own_gen, server_gen):
            adj_gen = tf.nest.map_structure(lambda a, b: a - b, server_gen,
                                            own_gen)
            return tf.cond(
                tf.constant(control, dtype=tf.bool), lambda: adj_gen,
                lambda: tf.nest.map_structure(tf.zeros_like, server_gen))

        @tff.tf_computation(client_output_type.discriminator_weights_delta,
                            client_output_type.discriminator_weights_delta)
        def compute_control_input_disc(own_disc, server_disc):
            adj_disc = tf.nest.map_structure(lambda a, b: a - b, server_disc,
                                             own_disc)
            return tf.cond(
                tf.constant(control, dtype=tf.bool), lambda: adj_disc,
                lambda: tf.nest.map_structure(tf.zeros_like, server_disc))

        control_input_gen = tff.federated_map(
            compute_control_input_gen,
            (control_output.generator_weights_delta, central_control_gen))
        control_input_disc = tff.federated_map(
            compute_control_input_disc,
            (control_output.discriminator_weights_delta, central_control_disc))
        client_outputs = tff.federated_map(
            client_computation,
            (client_gen_inputs, client_real_data, client_input,
             control_input_gen, control_input_disc))

        gen_delta = tff.federated_mean(client_outputs.generator_weights_delta,
                                       weight=client_outputs.update_weight)
        disc_delta = tff.federated_mean(
            client_outputs.discriminator_weights_delta,
            weight=client_outputs.update_weight)

        server_computation = build_server_computation(
            gan, server_state.type_signature.member)
        server_state = tff.federated_map(server_computation,
                                         (server_state, gen_delta, disc_delta))
        return server_state