def test_keras_model_using_batch_norm(self):
        model = model_examples.build_conv_batch_norm_keras_model()

        def loss_fn(y_true, y_pred):
            loss_per_example = tf.keras.losses.sparse_categorical_crossentropy(
                y_true=y_true, y_pred=y_pred)
            return tf.reduce_mean(loss_per_example)

        model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
                      loss=loss_fn,
                      metrics=[NumBatchesCounter(),
                               NumExamplesCounter()])

        dummy_batch = collections.OrderedDict([
            ('x', np.zeros([1, 28 * 28], dtype=np.float32)),
            ('y', np.zeros([1, 1], dtype=np.int64)),
        ])
        tff_model = keras_utils.from_compiled_keras_model(
            keras_model=model, dummy_batch=dummy_batch)

        batch_size = 2
        batch = {
            'x':
            np.random.uniform(low=0.0, high=1.0,
                              size=[batch_size, 28 * 28]).astype(np.float32),
            'y':
            np.random.random_integers(low=0, high=9,
                                      size=[batch_size, 1]).astype(np.int64),
        }

        num_iterations = 2
        for _ in range(num_iterations):
            self.evaluate(tff_model.train_on_batch(batch))

        m = self.evaluate(tff_model.report_local_outputs())
        self.assertEqual(m['num_batches'], [num_iterations])
        self.assertEqual(m['num_examples'], [batch_size * num_iterations])
        self.assertGreater(m['loss'][0], 0.0)
        self.assertEqual(m['loss'][1], batch_size * num_iterations)

        # Ensure we can assign the FL trained model weights to a new model.
        tff_weights = model_utils.ModelWeights.from_model(tff_model)
        keras_model = model_examples.build_conv_batch_norm_keras_model()
        tff_weights.assign_weights_to(keras_model)

        def assert_all_weights_close(keras_weights, tff_weights):
            for keras_w, tff_w in zip(keras_weights,
                                      six.itervalues(tff_weights)):
                self.assertAllClose(self.evaluate(keras_w),
                                    self.evaluate(tff_w),
                                    atol=1e-4,
                                    msg='Variable [{}]'.format(keras_w.name))

        assert_all_weights_close(keras_model.trainable_weights,
                                 tff_weights.trainable)
        assert_all_weights_close(keras_model.non_trainable_weights,
                                 tff_weights.non_trainable)
    def test_keras_model_using_batch_norm(self):
        model = model_examples.build_conv_batch_norm_keras_model()
        input_spec = collections.OrderedDict(x=tf.TensorSpec(
            shape=[None, 28 * 28], dtype=tf.float32),
                                             y=tf.TensorSpec(shape=[None, 1],
                                                             dtype=tf.int64))
        tff_model = keras_utils.from_keras_model(
            keras_model=model,
            input_spec=input_spec,
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics=[NumBatchesCounter(),
                     NumExamplesCounter()])

        batch_size = 2
        batch = collections.OrderedDict(
            x=np.random.uniform(low=0.0, high=1.0,
                                size=[batch_size, 28 * 28]).astype(np.float32),
            y=np.random.random_integers(low=0, high=9,
                                        size=[batch_size, 1]).astype(np.int64))

        num_train_steps = 2
        for _ in range(num_train_steps):
            self.evaluate(tff_model.forward_pass(batch))

        m = self.evaluate(tff_model.report_local_outputs())
        self.assertEqual(m['num_batches'], [num_train_steps])
        self.assertEqual(m['num_examples'], [batch_size * num_train_steps])
        self.assertGreater(m['loss'][0], 0.0)
        self.assertEqual(m['loss'][1], batch_size * num_train_steps)

        # Ensure we can assign the FL trained model weights to a new model.
        tff_weights = model_utils.ModelWeights.from_model(tff_model)
        keras_model = model_examples.build_conv_batch_norm_keras_model()
        tff_weights.assign_weights_to(keras_model)

        def assert_all_weights_close(keras_weights, tff_weights):
            for keras_w, tff_w in zip(keras_weights, tff_weights):
                self.assertAllClose(self.evaluate(keras_w),
                                    self.evaluate(tff_w),
                                    atol=1e-4,
                                    msg='Variable [{}]'.format(keras_w.name))

        assert_all_weights_close(keras_model.trainable_weights,
                                 tff_weights.trainable)
        assert_all_weights_close(keras_model.non_trainable_weights,
                                 tff_weights.non_trainable)