コード例 #1
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools

from absl.testing import parameterized
import tensorflow as tf

from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.utils.testing import integration
from official.vision.image_classification import mnist_main

mnist_main.define_mnist_flags()


def eager_strategy_combinations():
    return combinations.combine(
        distribution=[
            strategy_combinations.default_strategy,
            strategy_combinations.cloud_tpu_strategy,
            strategy_combinations.one_device_strategy_gpu,
        ],
        mode="eager",
    )


class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
    """Unit tests for sample Keras MNIST model."""
コード例 #2
0
 def setUpClass(cls):  # pylint: disable=invalid-name
   super(KerasMnistTest, cls).setUpClass()
   mnist_main.define_mnist_flags()