def test_client_and_server_computations(self):
        train_generator_fn, train_discriminator_fn = (
            _get_train_generator_and_discriminator_fns())

        # N.B. The way we are using datasets and re-using the same
        # generator and discriminator doesn't really "make sense" from an ML
        # perspective, but it's sufficient for testing. For more proper usage of
        # these functions, see training_loops.py.
        generator = one_dim_gan.create_generator()
        discriminator = one_dim_gan.create_discriminator()
        gen_inputs = one_dim_gan.create_generator_inputs()
        real_data = one_dim_gan.create_real_data()

        server_state = gan_training_tf_fns.server_initial_state(
            generator, discriminator, INIT_DP_AVERAGING_STATE)

        # DP averaging aggregation state is initialized properly in
        # server_initial_state().
        self.assertEqual(server_state.dp_averaging_state,
                         INIT_DP_AVERAGING_STATE)

        client_output = gan_training_tf_fns.client_computation(
            gen_inputs.take(3), real_data.take(3),
            gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights),
            generator, discriminator, train_discriminator_fn)

        server_disc_update_optimizer = tf.keras.optimizers.Adam()
        for _ in range(2):  # Train for 2 rounds
            server_state = gan_training_tf_fns.server_computation(
                server_state, gen_inputs.take(3), client_output, generator,
                discriminator, server_disc_update_optimizer,
                train_generator_fn, NEW_DP_AVERAGING_STATE)

        counters = self.evaluate(server_state.counters)
        self.assertDictEqual(
            counters, {
                'num_rounds': 2,
                'num_discriminator_train_examples':
                2 * 3 * one_dim_gan.BATCH_SIZE,
                'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE
            })

        # DP averaging aggregation state updates properly in server_computation().
        self.assertEqual(server_state.dp_averaging_state,
                         NEW_DP_AVERAGING_STATE)
    def server_computation(server_state, gen_inputs, client_output,
                           new_dp_averaging_state):
        """The wrapped server_computation."""
        # initialize the optimizers beforehand so you don't create them within the tf.function
        steps = server_state.counters['num_rounds']
        scheduler = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            [1000], [0.001, 0.0005])
        state_gen_optimizer = gan.state_gen_optimizer_fn(
            scheduler.__call__(steps))
        generator = gan.generator_model_fn()
        gan_training_tf_fns.initialize_optimizer_vars(generator,
                                                      state_gen_optimizer)
        discriminator = gan.discriminator_model_fn(0.0002)
        state_disc_optimizer = gan.state_disc_optimizer_fn(steps)
        gan_training_tf_fns.initialize_optimizer_vars(discriminator,
                                                      state_disc_optimizer)

        if gan.disc_status == 'fedadam':
            return gan_training_tf_fns.server_computation_fedadam(
                server_state=server_state,
                gen_inputs_ds=gen_inputs,
                client_output=client_output,
                generator=generator,
                discriminator=discriminator,
                state_disc_optimizer=state_disc_optimizer,
                state_gen_optimizer=state_gen_optimizer,
                new_dp_averaging_state=new_dp_averaging_state)
        else:
            return gan_training_tf_fns.server_computation(
                server_state=server_state,
                gen_inputs_ds=gen_inputs,
                client_output=client_output,
                generator=generator,
                discriminator=discriminator,
                state_disc_optimizer=state_disc_optimizer,
                state_gen_optimizer=state_gen_optimizer,
                new_dp_averaging_state=new_dp_averaging_state)