Exemplo n.º 1
0
def _model_setup():
  """Set up a MNIST Keras model for testing purposes.

  Builds a MNIST Keras model and returns model information.

  Returns:
    A tuple of (batch_size, steps, train_dataset, mode)
  """
  context.set_log_device_placement(True)
  batch_size = 64
  steps = 2
  with collective_strategy.CollectiveAllReduceStrategy().scope():
    # TODO(b/142509827): In rare cases this errors out at C++ level with the
    # "Connect failed" error message.
    train_ds, _ = mnist_testing_utils.mnist_synthetic_dataset(batch_size, steps)
    model = mnist_testing_utils.get_mnist_model((28, 28, 1))
  return batch_size, steps, train_ds, model
Exemplo n.º 2
0
    def testLogDevicePlacement(self):
        self.assertEqual(context.get_log_device_placement(), False)

        context.set_log_device_placement(True)
        self.assertEqual(context.get_log_device_placement(), True)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.set_log_device_placement(False)
        self.assertEqual(context.get_log_device_placement(), False)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        constant_op.constant(1)
        with self.assertRaises(RuntimeError):
            context.set_log_device_placement(True)
        with self.assertRaises(RuntimeError):
            context.set_log_device_placement(False)
Exemplo n.º 3
0
    def testLogDevicePlacement(self):
        self.assertFalse(context.get_log_device_placement())

        context.set_log_device_placement(True)
        self.assertEqual(context.get_log_device_placement(), True)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.set_log_device_placement(False)
        self.assertEqual(context.get_log_device_placement(), False)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.ensure_initialized()

        with self.assertRaises(RuntimeError):
            context.set_log_device_placement(True)
        with self.assertRaises(RuntimeError):
            context.set_log_device_placement(False)
Exemplo n.º 4
0
  def testLogDevicePlacement(self):
    self.assertEqual(context.get_log_device_placement(), False)

    context.set_log_device_placement(True)
    self.assertEqual(context.get_log_device_placement(), True)
    self.assertEqual(
        context.get_log_device_placement(),
        context.context().log_device_placement)

    context.set_log_device_placement(False)
    self.assertEqual(context.get_log_device_placement(), False)
    self.assertEqual(
        context.get_log_device_placement(),
        context.context().log_device_placement)

    constant_op.constant(1)
    with self.assertRaises(RuntimeError):
      context.set_log_device_placement(True)
    with self.assertRaises(RuntimeError):
      context.set_log_device_placement(False)
Exemplo n.º 5
0
    def testLogDevicePlacement(self):
        self.assertFalse(context.get_log_device_placement())

        context.set_log_device_placement(True)
        self.assertEqual(context.get_log_device_placement(), True)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.set_log_device_placement(False)
        self.assertEqual(context.get_log_device_placement(), False)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.ensure_initialized()

        with self.assertRaises(RuntimeError):
            context.set_log_device_placement(True)

        # If the setting the device placement is a no-op, do not throw a runtime
        # exception.
        context.set_log_device_placement(False)
Exemplo n.º 6
0
  def testLogDevicePlacement(self):
    self.assertFalse(context.get_log_device_placement())

    context.set_log_device_placement(True)
    self.assertEqual(context.get_log_device_placement(), True)
    self.assertEqual(
        context.get_log_device_placement(),
        context.context().log_device_placement)

    context.set_log_device_placement(False)
    self.assertEqual(context.get_log_device_placement(), False)
    self.assertEqual(
        context.get_log_device_placement(),
        context.context().log_device_placement)

    context.ensure_initialized()

    with self.assertRaises(RuntimeError):
      context.set_log_device_placement(True)
    with self.assertRaises(RuntimeError):
      context.set_log_device_placement(False)
Exemplo n.º 7
0
  def testLogDevicePlacement(self):
    self.assertFalse(context.get_log_device_placement())

    context.set_log_device_placement(True)
    self.assertEqual(context.get_log_device_placement(), True)
    self.assertEqual(
        context.get_log_device_placement(),
        context.context().log_device_placement)

    context.set_log_device_placement(False)
    self.assertEqual(context.get_log_device_placement(), False)
    self.assertEqual(
        context.get_log_device_placement(),
        context.context().log_device_placement)

    context.ensure_initialized()

    with self.assertRaises(RuntimeError):
      context.set_log_device_placement(True)

    # If the setting the device placement is a no-op, do not throw a runtime
    # exception.
    context.set_log_device_placement(False)
Exemplo n.º 8
0
    def testLogDevicePlacement(self):
        self.assertFalse(context.get_log_device_placement())

        context.set_log_device_placement(True)
        self.assertEqual(context.get_log_device_placement(), True)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.set_log_device_placement(False)
        self.assertEqual(context.get_log_device_placement(), False)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.ensure_initialized()

        # Changing the device placement should not throw an exception
        context.set_log_device_placement(True)
Exemplo n.º 9
0
        with ops.device('GPU:0'):
            t0 = array_ops.identity(1.0)
            self._send(t0, 't0', self.cpu_device)
        with ops.device('cpu:0'):
            self.assertAllEqual(
                self._recv(dtypes.float32, 't0', gpu_device_name), 1.0)
            self._send(constant_op.constant(2.0), 't1', gpu_device_name)
        with ops.device('GPU:0'):
            self.assertAllEqual(
                self._recv(dtypes.float32, 't1', self.cpu_device), 2.0)


class EagerTensorCacheTest(test_util.TensorFlowTestCase):
    def setUp(self):
        super(EagerTensorCacheTest, self).setUp()
        context._reset_context()
        configure_virtual_cpus()

    def testCacheSkipsTensorsTooLarge(self):
        cache = context._EagerTensorCache(max_items=100, max_tensor_size=3)
        cache.put('1', array_ops.zeros((2, 2)))
        self.assertIsNone(cache.get('1'))

        cache.put('2', array_ops.zeros((2)))
        self.assertIsNotNone(cache.get('2'))


if __name__ == '__main__':
    context.set_log_device_placement(True)
    test.main()