from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools from absl.testing import parameterized import tensorflow as tf from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations from official.utils.testing import integration from official.vision.image_classification import mnist_main mnist_main.define_mnist_flags() def eager_strategy_combinations(): return combinations.combine( distribution=[ strategy_combinations.default_strategy, strategy_combinations.cloud_tpu_strategy, strategy_combinations.one_device_strategy_gpu, ], mode="eager", ) class KerasMnistTest(tf.test.TestCase, parameterized.TestCase): """Unit tests for sample Keras MNIST model."""
def setUpClass(cls): # pylint: disable=invalid-name super(KerasMnistTest, cls).setUpClass() mnist_main.define_mnist_flags()