예제 #1
0
    def test_build_gan_training_process(self, with_dp):
        gan = _get_gan(with_dp)
        process = tff_gans.build_gan_training_process(gan)
        server_state = process.initialize()

        if with_dp:
            # Check that initial DP averaging aggregator state is correct.
            dp_averaging_state = server_state.dp_averaging_state
            self.assertAlmostEqual(
                dp_averaging_state.numerator_state.sum_state.l2_norm_clip,
                BEFORE_DP_L2_NORM_CLIP,
                places=5)
            self.assertAlmostEqual(
                dp_averaging_state.numerator_state.sum_state.stddev,
                BEFORE_DP_STD_DEV,
                places=5)

        client_dataset_sizes = [1, 3]
        client_gen_inputs = [
            one_dim_gan.create_generator_inputs().take(i)
            for i in client_dataset_sizes
        ]

        client_real_inputs = [
            one_dim_gan.create_real_data().take(i)
            for i in client_dataset_sizes
        ]

        num_rounds = 2
        for _ in range(num_rounds):
            server_state = process.next(
                server_state,
                one_dim_gan.create_generator_inputs().take(1),
                client_gen_inputs, client_real_inputs)

        # Check that server counters have incremented.
        counters = server_state.counters
        self.assertDictEqual(
            counters, {
                'num_rounds':
                num_rounds,
                'num_generator_train_examples':
                one_dim_gan.BATCH_SIZE * num_rounds,
                'num_discriminator_train_examples':
                num_rounds * one_dim_gan.BATCH_SIZE *
                sum(client_dataset_sizes),
            })

        if with_dp:
            # Check that DP averaging aggregator state has updated properly over the
            # above rounds.
            dp_averaging_state = server_state.dp_averaging_state
            self.assertAlmostEqual(
                dp_averaging_state.numerator_state.sum_state.l2_norm_clip,
                AFTER_2_RDS_DP_L2_NORM_CLIP,
                places=5)
            self.assertAlmostEqual(
                dp_averaging_state.numerator_state.sum_state.stddev,
                AFTER_2_RDS_DP_STD_DEV,
                places=5)
예제 #2
0
def _get_gan(with_dp=False):
    gan_loss_fns = gan_losses.get_gan_loss_fns('wasserstein')
    server_gen_optimizer = tf.keras.optimizers.Adam()
    client_disc_optimizer = tf.keras.optimizers.Adam()
    train_generator_fn = gan_training_tf_fns.create_train_generator_fn(
        gan_loss_fns, server_gen_optimizer)
    train_discriminator_fn = gan_training_tf_fns.create_train_discriminator_fn(
        gan_loss_fns, client_disc_optimizer)

    if with_dp:
        dp_average_query = tensorflow_privacy.QuantileAdaptiveClipAverageQuery(
            initial_l2_norm_clip=BEFORE_DP_L2_NORM_CLIP,
            noise_multiplier=0.3,
            target_unclipped_quantile=3,
            learning_rate=0.1,
            clipped_count_stddev=0.0,
            expected_num_records=10,
            denominator=10.0)
    else:
        dp_average_query = None

    return tff_gans.GanFnsAndTypes(
        generator_model_fn=one_dim_gan.create_generator,
        discriminator_model_fn=one_dim_gan.create_discriminator,
        dummy_gen_input=next(iter(one_dim_gan.create_generator_inputs())),
        dummy_real_data=next(iter(one_dim_gan.create_real_data())),
        train_generator_fn=train_generator_fn,
        train_discriminator_fn=train_discriminator_fn,
        server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=1.0
                                                                        ),
        train_discriminator_dp_average_query=dp_average_query)
예제 #3
0
    def test_client_computation(self, with_dp):
        gan = _get_gan(with_dp)
        client_comp = tff_gans.build_client_computation(gan)

        generator = gan.generator_model_fn()
        discriminator = gan.discriminator_model_fn()

        from_server = gan_training_tf_fns.FromServer(
            generator_weights=generator.weights,
            discriminator_weights=discriminator.weights)
        client_output = client_comp(
            one_dim_gan.create_generator_inputs().take(10),
            one_dim_gan.create_real_data().take(10), from_server)
        self.assertDictEqual(
            client_output.counters,
            {'num_discriminator_train_examples': 10 * one_dim_gan.BATCH_SIZE})
예제 #4
0
    def test_client_and_server_computations(self):
        train_generator_fn, train_discriminator_fn = (
            _get_train_generator_and_discriminator_fns())

        # N.B. The way we are using datasets and re-using the same
        # generator and discriminator doesn't really "make sense" from an ML
        # perspective, but it's sufficient for testing. For more proper usage of
        # these functions, see training_loops.py.
        generator = one_dim_gan.create_generator()
        discriminator = one_dim_gan.create_discriminator()
        gen_inputs = one_dim_gan.create_generator_inputs()
        real_data = one_dim_gan.create_real_data()

        server_state = gan_training_tf_fns.server_initial_state(
            generator, discriminator, INIT_DP_AVERAGING_STATE)

        # DP averaging aggregation state is initialized properly in
        # server_initial_state().
        self.assertEqual(server_state.dp_averaging_state,
                         INIT_DP_AVERAGING_STATE)

        client_output = gan_training_tf_fns.client_computation(
            gen_inputs.take(3), real_data.take(3),
            gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights),
            generator, discriminator, train_discriminator_fn)

        server_disc_update_optimizer = tf.keras.optimizers.Adam()
        for _ in range(2):  # Train for 2 rounds
            server_state = gan_training_tf_fns.server_computation(
                server_state, gen_inputs.take(3), client_output, generator,
                discriminator, server_disc_update_optimizer,
                train_generator_fn, NEW_DP_AVERAGING_STATE)

        counters = self.evaluate(server_state.counters)
        self.assertDictEqual(
            counters, {
                'num_rounds': 2,
                'num_discriminator_train_examples':
                2 * 3 * one_dim_gan.BATCH_SIZE,
                'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE
            })

        # DP averaging aggregation state updates properly in server_computation().
        self.assertEqual(server_state.dp_averaging_state,
                         NEW_DP_AVERAGING_STATE)
예제 #5
0
    def test_server_computation(self, with_dp):
        gan = _get_gan(with_dp)
        initial_state_comp = tff_gans.build_server_initial_state_comp(gan)

        # TODO(b/131700944): Remove this workaround, and directly instantiate a
        # ClientOutput instance (once TFF has a utility to infer TFF types of
        # objects directly).
        @tff.tf_computation
        def client_output_fn():
            discriminator = gan.discriminator_model_fn()
            return gan_training_tf_fns.ClientOutput(
                discriminator_weights_delta=[
                    tf.zeros(shape=v.shape, dtype=v.dtype)
                    for v in discriminator.weights
                ],
                update_weight=1.0,
                counters={'num_discriminator_train_examples': 13})

        def _update_dp_averaging_state(with_dp, dp_averaging_state):
            if not with_dp:
                return dp_averaging_state

            new_sum_state = tff.utils.update_state(
                dp_averaging_state.numerator_state.sum_state,
                l2_norm_clip=UPDATE_DP_L2_NORM_CLIP)
            new_numerator_state = tff.utils.update_state(
                dp_averaging_state.numerator_state, sum_state=new_sum_state)
            new_dp_averaging_state = tff.utils.update_state(
                dp_averaging_state, numerator_state=new_numerator_state)
            return new_dp_averaging_state

        server_comp = tff_gans.build_server_computation(
            gan, initial_state_comp.type_signature.result,
            client_output_fn.type_signature.result)

        server_state = initial_state_comp()

        client_output = client_output_fn()
        new_dp_averaging_state = _update_dp_averaging_state(
            with_dp, server_state.dp_averaging_state)
        final_server_state = server_comp(
            server_state,
            one_dim_gan.create_generator_inputs().take(7), client_output,
            new_dp_averaging_state)

        # Check that server counters have incremented (compare before and after).
        self.assertDictEqual(
            server_state.counters, {
                'num_rounds': 0,
                'num_generator_train_examples': 0,
                'num_discriminator_train_examples': 0
            })
        self.assertDictEqual(
            final_server_state.counters, {
                'num_rounds': 1,
                'num_discriminator_train_examples': 13,
                'num_generator_train_examples': 7 * one_dim_gan.BATCH_SIZE
            })

        if with_dp:
            # Check that DP averaging aggregator state reflects the new state that was
            # passed as argument to server computation (compare before and after).
            initial_dp_averaging_state = server_state.dp_averaging_state
            self.assertAlmostEqual(
                initial_dp_averaging_state.numerator_state.sum_state.
                l2_norm_clip, BEFORE_DP_L2_NORM_CLIP)
            new_dp_averaging_state = final_server_state.dp_averaging_state
            self.assertAlmostEqual(
                new_dp_averaging_state.numerator_state.sum_state.l2_norm_clip,
                UPDATE_DP_L2_NORM_CLIP)
예제 #6
0
    def test_build_gan_training_process(self, with_dp):
        gan = _get_gan(with_dp)
        process = tff_gans.build_gan_training_process(gan)
        server_state = gan_training_tf_fns.ServerState.from_tff_result(
            process.initialize())

        if with_dp:
            # Check that initial DP averaging aggregator state is correct.
            dp_averaging_state = server_state.dp_averaging_state
            self.assertAlmostEqual(
                dp_averaging_state['numerator_state']['l2_norm_clip'],
                BEFORE_DP_L2_NORM_CLIP)
            self.assertAlmostEqual(
                dp_averaging_state['numerator_state']['sum_state']
                ['l2_norm_clip'], BEFORE_DP_L2_NORM_CLIP)
            self.assertAlmostEqual(
                dp_averaging_state['numerator_state']['sum_state']['stddev'],
                BEFORE_DP_STD_DEV)

        client_dataset_sizes = [1, 3]
        client_gen_inputs = [
            one_dim_gan.create_generator_inputs().take(i)
            for i in client_dataset_sizes
        ]

        client_real_inputs = [
            one_dim_gan.create_real_data().take(i)
            for i in client_dataset_sizes
        ]

        num_rounds = 2
        for _ in range(num_rounds):
            server_state = process.next(
                server_state,
                one_dim_gan.create_generator_inputs().take(1),
                client_gen_inputs, client_real_inputs)

        # TODO(b/123092620): Won't need to convert from AnonymousTuple, eventually.
        server_state = gan_training_tf_fns.ServerState.from_tff_result(
            server_state)

        # Check that server counters have incremented.
        counters = server_state.counters
        self.assertDictEqual(
            counters, {
                'num_rounds':
                num_rounds,
                'num_generator_train_examples':
                one_dim_gan.BATCH_SIZE * num_rounds,
                'num_discriminator_train_examples':
                num_rounds * one_dim_gan.BATCH_SIZE *
                sum(client_dataset_sizes),
            })

        if with_dp:
            # Check that DP averaging aggregator state has updated properly over the
            # above rounds.
            dp_averaging_state = server_state.dp_averaging_state
            self.assertAlmostEqual(
                dp_averaging_state['numerator_state']['l2_norm_clip'],
                AFTER_2_RDS_DP_L2_NORM_CLIP)
            self.assertAlmostEqual(
                dp_averaging_state['numerator_state']['sum_state']
                ['l2_norm_clip'], AFTER_2_RDS_DP_L2_NORM_CLIP)
            self.assertAlmostEqual(
                dp_averaging_state['numerator_state']['sum_state']['stddev'],
                AFTER_2_RDS_DP_STD_DEV)
    def test_tff_training_loop(self, dp_average_query, checkpoint):
        if checkpoint:
            root_checkpoint_dir = os.path.join(self.get_temp_dir(),
                                               'checkpoints')
        else:
            root_checkpoint_dir = None

        train_generator_fn, train_discriminator_fn = (
            _get_train_generator_and_discriminator_fns())

        gan = tff_gans.GanFnsAndTypes(
            generator_model_fn=one_dim_gan.create_generator,
            discriminator_model_fn=one_dim_gan.create_discriminator,
            dummy_gen_input=next(iter(one_dim_gan.create_generator_inputs())),
            dummy_real_data=next(iter(one_dim_gan.create_real_data())),
            train_generator_fn=train_generator_fn,
            train_discriminator_fn=train_discriminator_fn,
            server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD(
                lr=1.0),
            train_discriminator_dp_average_query=dp_average_query)

        gen_inputs = one_dim_gan.create_generator_inputs()
        real_data = one_dim_gan.create_real_data()

        client_disc_train_steps = 2
        server_gen_train_steps = 3

        server_gen_inputs = iter(gen_inputs.window(server_gen_train_steps))
        client_gen_inputs = iter(gen_inputs.window(client_disc_train_steps))
        client_real_data = iter(real_data.window(client_disc_train_steps))

        def server_gen_inputs_fn(_):
            return next(server_gen_inputs)

        num_clients = 2

        def client_datasets_fn(_):
            return [(next(client_gen_inputs), next(client_real_data))
                    for _ in range(num_clients)]

        server_state, _ = training_loops.federated_training_loop(
            gan,
            server_gen_inputs_fn=server_gen_inputs_fn,
            client_datasets_fn=client_datasets_fn,
            total_rounds=2,
            rounds_per_checkpoint=1,
            root_checkpoint_dir=root_checkpoint_dir)

        self.assertDictEqual(
            server_state.counters, {
                'num_rounds':
                2,
                'num_generator_train_examples':
                2 * 3 * one_dim_gan.BATCH_SIZE,
                'num_discriminator_train_examples':
                (2 * 2 * one_dim_gan.BATCH_SIZE * num_clients)
            })
        if checkpoint:
            # TODO(b/141112101): We shouldn't need to re-create the gan, should be
            # able to reuse the instance from above. See comment inside tff_gans.py.
            train_generator_fn, train_discriminator_fn = (
                _get_train_generator_and_discriminator_fns())
            gan = tff_gans.GanFnsAndTypes(
                generator_model_fn=one_dim_gan.create_generator,
                discriminator_model_fn=one_dim_gan.create_discriminator,
                dummy_gen_input=next(
                    iter(one_dim_gan.create_generator_inputs())),
                dummy_real_data=next(iter(one_dim_gan.create_real_data())),
                train_generator_fn=train_generator_fn,
                train_discriminator_fn=train_discriminator_fn,
                server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.
                SGD(lr=1.0),
                train_discriminator_dp_average_query=dp_average_query)
            # Train one more round, which should resume from the checkpoint.
            server_state, _ = training_loops.federated_training_loop(
                gan,
                server_gen_inputs_fn=server_gen_inputs_fn,
                client_datasets_fn=client_datasets_fn,
                total_rounds=3,
                rounds_per_checkpoint=1,
                root_checkpoint_dir=root_checkpoint_dir)
            # Note: It would be better to return something from
            # federated_training_loop indicating the number of rounds trained in this
            # invocation, so we could verify the checkpoint was read.
            self.assertDictEqual(
                server_state.counters, {
                    'num_rounds':
                    3,
                    'num_generator_train_examples':
                    3 * 3 * one_dim_gan.BATCH_SIZE,
                    'num_discriminator_train_examples':
                    (3 * 2 * one_dim_gan.BATCH_SIZE * num_clients)
                })