def test_out_of_range_with_for_loop(self): with ops.device('/job:worker/task:0'): dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0]) dataset = dataset.batch(1, drop_remainder=False) iterator = iter(dataset) v = variables.Variable(1.0) @def_function.function def train_step(iterator): i = next(iterator) v.assign_add(math_ops.reduce_mean(i)) num_steps = 3 for i in range(num_steps): try: with ops.device('/job:worker/task:0'): train_step(iterator) if i == num_steps - 1: context.async_wait() except errors.OutOfRangeError: context.async_clear_error() break self.assertAllEqual(v.numpy(), 4.0)
def testCopyBetweenDevicesAsync(self): with context.execution_mode(context.ASYNC): x = constant_op.constant([[1., 2.], [3., 4.]]) x = x.cpu() x = x.gpu() x = x.gpu() x = x.cpu() context.async_wait() # Invalid device with self.assertRaises(RuntimeError): x.gpu(context.context().num_gpus() + 1) context.async_wait() context.async_clear_error()
def testCopyBetweenDevicesAsync(self): with context.execution_mode(context.ASYNC): x = constant_op.constant([[1., 2.], [3., 4.]]) x = x.cpu() x = x.gpu() x = x.gpu() x = x.cpu() context.async_wait() # Invalid device with self.assertRaises(RuntimeError): x.gpu(context.context().num_gpus() + 1) context.async_wait() context.async_clear_error()
def testAsyncExceptionStackTrace(self): config.set_synchronous_execution(False) def exception_originated_from_here(): # Invalid shapes for matmul. return math_ops.matmul([[1]], [[2], [3]]) # In sync mode, an exception would have been raised here but since this is # in async, the exception will be raised next. x = exception_originated_from_here() with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 'in exception_originated_from_here'): x.numpy() context.async_clear_error() config.set_synchronous_execution(True)
def testExecuteBasicAsync(self): with context.execution_mode(context.ASYNC): three = constant_op.constant(3) five = constant_op.constant(5) product = execute(b'Mul', num_outputs=1, inputs=[three, five], attrs=('T', three.dtype.as_datatype_enum))[0] self.assertAllEqual(15, product) # Error: Invalid arguments context.set_execution_mode(context.ASYNC) with self.assertRaises(errors.InvalidArgumentError): execute(b'MatMul', num_outputs=1, inputs=[three, five], attrs=('transpose_a', False, 'transpose_b', False, 'T', three.dtype.as_datatype_enum)) context.async_wait() context.async_clear_error() context.context().execution_mode = context.SYNC
def testExecuteBasicAsync(self): with context.execution_mode(context.ASYNC): three = constant_op.constant(3) five = constant_op.constant(5) product = execute( b'Mul', num_outputs=1, inputs=[three, five], attrs=('T', three.dtype.as_datatype_enum))[0] self.assertAllEqual(15, product) # Error: Invalid arguments context.set_execution_mode(context.ASYNC) with self.assertRaises(errors.InvalidArgumentError): execute( b'MatMul', num_outputs=1, inputs=[three, five], attrs=('transpose_a', False, 'transpose_b', False, 'T', three.dtype.as_datatype_enum)) context.async_wait() context.async_clear_error() context.set_execution_mode(context.SYNC)