Beispiel #1
0
    def test_out_of_range_with_for_loop(self):

        with ops.device('/job:worker/task:0'):
            dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
            dataset = dataset.batch(1, drop_remainder=False)
            iterator = iter(dataset)
            v = variables.Variable(1.0)

        @def_function.function
        def train_step(iterator):
            i = next(iterator)
            v.assign_add(math_ops.reduce_mean(i))

        num_steps = 3
        for i in range(num_steps):
            try:
                with ops.device('/job:worker/task:0'):
                    train_step(iterator)
                if i == num_steps - 1:
                    context.async_wait()
            except errors.OutOfRangeError:
                context.async_clear_error()
                break

        self.assertAllEqual(v.numpy(), 4.0)
Beispiel #2
0
    def testCopyBetweenDevicesAsync(self):
        with context.execution_mode(context.ASYNC):
            x = constant_op.constant([[1., 2.], [3., 4.]])
            x = x.cpu()
            x = x.gpu()
            x = x.gpu()
            x = x.cpu()
            context.async_wait()

        # Invalid device
        with self.assertRaises(RuntimeError):
            x.gpu(context.context().num_gpus() + 1)
            context.async_wait()
        context.async_clear_error()
Beispiel #3
0
  def testCopyBetweenDevicesAsync(self):
    with context.execution_mode(context.ASYNC):
      x = constant_op.constant([[1., 2.], [3., 4.]])
      x = x.cpu()
      x = x.gpu()
      x = x.gpu()
      x = x.cpu()
      context.async_wait()

    # Invalid device
    with self.assertRaises(RuntimeError):
      x.gpu(context.context().num_gpus() + 1)
      context.async_wait()
    context.async_clear_error()
Beispiel #4
0
    def testAsyncExceptionStackTrace(self):
        config.set_synchronous_execution(False)

        def exception_originated_from_here():
            # Invalid shapes for matmul.
            return math_ops.matmul([[1]], [[2], [3]])

        # In sync mode, an exception would have been raised here but since this is
        # in async, the exception will be raised next.
        x = exception_originated_from_here()

        with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
                                    'in exception_originated_from_here'):
            x.numpy()

        context.async_clear_error()
        config.set_synchronous_execution(True)
Beispiel #5
0
 def testExecuteBasicAsync(self):
     with context.execution_mode(context.ASYNC):
         three = constant_op.constant(3)
         five = constant_op.constant(5)
         product = execute(b'Mul',
                           num_outputs=1,
                           inputs=[three, five],
                           attrs=('T', three.dtype.as_datatype_enum))[0]
         self.assertAllEqual(15, product)
     # Error: Invalid arguments
     context.set_execution_mode(context.ASYNC)
     with self.assertRaises(errors.InvalidArgumentError):
         execute(b'MatMul',
                 num_outputs=1,
                 inputs=[three, five],
                 attrs=('transpose_a', False, 'transpose_b', False, 'T',
                        three.dtype.as_datatype_enum))
         context.async_wait()
     context.async_clear_error()
     context.context().execution_mode = context.SYNC
Beispiel #6
0
 def testExecuteBasicAsync(self):
   with context.execution_mode(context.ASYNC):
     three = constant_op.constant(3)
     five = constant_op.constant(5)
     product = execute(
         b'Mul',
         num_outputs=1,
         inputs=[three, five],
         attrs=('T', three.dtype.as_datatype_enum))[0]
     self.assertAllEqual(15, product)
   # Error: Invalid arguments
   context.set_execution_mode(context.ASYNC)
   with self.assertRaises(errors.InvalidArgumentError):
     execute(
         b'MatMul',
         num_outputs=1,
         inputs=[three, five],
         attrs=('transpose_a', False, 'transpose_b', False, 'T',
                three.dtype.as_datatype_enum))
     context.async_wait()
   context.async_clear_error()
   context.set_execution_mode(context.SYNC)