def testMapDefunWithVariantTensorAsCaptured(self): st = sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant) @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def fn(x): del x return serialized x = constant_op.constant([0, 0]) map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant], [None])[0] deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32) expected = sparse_tensor.SparseTensorValue(indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]], values=[1, 2, 1, 2], dense_shape=[2, 3, 4]) actual = self.evaluate(deserialized) self.assertSparseValuesEqual(expected, actual)
def testMapDefunWithVariantTensor(self): @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.variant)]) def fn(x): return x st = sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant) serialized = array_ops.stack([serialized, serialized]) map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.variant], [None])[0] deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32) expected = sparse_tensor.SparseTensorValue(indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]], values=[1, 2, 1, 2], dense_shape=[2, 3, 4]) actual = self.evaluate(deserialized) self.assertValuesEqual(expected, actual)
def testMapDefunWithStrTensor(self): @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def fn(x): return x st = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.string) serialized = array_ops.stack([serialized, serialized]) map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.string], [None])[0] deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32) expected = sparse_tensor.SparseTensorValue( indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]], values=[1, 2, 1, 2], dense_shape=[2, 3, 4]) actual = self.evaluate(deserialized) self.assertSparseValuesEqual(expected, actual)
def testMapDefunWithVariantTensorAsCaptured(self): st = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant) @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def fn(x): del x return serialized x = constant_op.constant([0, 0]) map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant], [None])[0] deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32) expected = sparse_tensor.SparseTensorValue( indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]], values=[1, 2, 1, 2], dense_shape=[2, 3, 4]) actual = self.evaluate(deserialized) self.assertSparseValuesEqual(expected, actual)