Example #1
0
 def testRoundTrip(self, dtype, shape):
     np.random.seed(42)
     np_array = np.random.randint(0, 10, shape)
     tf_tensor = constant_op.constant(np_array, dtype=dtype)
     dlcapsule = dlpack.to_dlpack(tf_tensor)
     del tf_tensor  # should still work
     tf_tensor2 = dlpack.from_dlpack(dlcapsule)
     self.assertAllClose(np_array, tf_tensor2)
    def testTensorsCanBeConsumedOnceOnly(self):
        np.random.seed(42)
        np_array = np.random.randint(0, 10, (2, 3, 4))
        tf_tensor = constant_op.constant(np_array, dtype=np.float32)
        dlcapsule = dlpack.to_dlpack(tf_tensor)
        del tf_tensor  # should still work
        _ = dlpack.from_dlpack(dlcapsule)

        def ConsumeDLPackTensor():
            dlpack.from_dlpack(dlcapsule)  # Should can be consumed only once

        self.assertRaisesRegex(
            Exception, ".*a DLPack tensor may be consumed at most once.*",
            ConsumeDLPackTensor)
 def testRoundTrip(self, dtype, shape):
     np.random.seed(42)
     np_array = np.random.randint(0, 10, shape)
     # copy to gpu if available
     tf_tensor = array_ops.identity(
         constant_op.constant(np_array, dtype=dtype))
     tf_tensor_device = tf_tensor.device
     tf_tensor_dtype = tf_tensor.dtype
     dlcapsule = dlpack.to_dlpack(tf_tensor)
     del tf_tensor  # should still work
     tf_tensor2 = dlpack.from_dlpack(dlcapsule)
     self.assertAllClose(np_array, tf_tensor2)
     if tf_tensor_dtype == dtypes.int32:
         # int32 tensor is always on cpu for now
         self.assertEqual(tf_tensor2.device,
                          "/job:localhost/replica:0/task:0/device:CPU:0")
     else:
         self.assertEqual(tf_tensor_device, tf_tensor2.device)
 def testDLPackFromWithoutContextInitialization(self):
     tf_tensor = constant_op.constant(1)
     dlcapsule = dlpack.to_dlpack(tf_tensor)
     # Resetting the context doesn't cause an error.
     context._reset_context()
     _ = dlpack.from_dlpack(dlcapsule)
 def testMustPassTensorArgumentToDLPack(self):
     with self.assertRaisesRegex(
             errors.InvalidArgumentError,
             "The argument to `to_dlpack` must be a TF tensor, not Python object"
     ):
         dlpack.to_dlpack([1])
 def UnsupportedComplex64():
     tf_tensor = constant_op.constant([[1, 4], [5, 2]],
                                      dtype=dtypes.complex64)
     _ = dlpack.to_dlpack(tf_tensor)
 def UnsupportedQint16():
     tf_tensor = constant_op.constant([[1, 4], [5, 2]],
                                      dtype=dtypes.qint16)
     _ = dlpack.to_dlpack(tf_tensor)