Ejemplo n.º 1
0
    def testMapDefunWithVariantTensorAsCaptured(self):

        st = sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]],
                                        values=[1, 2],
                                        dense_shape=[3, 4])
        serialized = sparse_ops.serialize_sparse_v2(st,
                                                    out_type=dtypes.variant)

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
        def fn(x):
            del x
            return serialized

        x = constant_op.constant([0, 0])
        map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant],
                                           [None])[0]
        deserialized = sparse_ops.deserialize_sparse(map_defun_op,
                                                     dtypes.int32)
        expected = sparse_tensor.SparseTensorValue(indices=[[0, 0,
                                                             0], [0, 1, 2],
                                                            [1, 0, 0],
                                                            [1, 1, 2]],
                                                   values=[1, 2, 1, 2],
                                                   dense_shape=[2, 3, 4])
        actual = self.evaluate(deserialized)
        self.assertSparseValuesEqual(expected, actual)
    def testMapDefunWithVariantTensor(self):
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.variant)])
        def fn(x):
            return x

        st = sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]],
                                        values=[1, 2],
                                        dense_shape=[3, 4])

        serialized = sparse_ops.serialize_sparse_v2(st,
                                                    out_type=dtypes.variant)
        serialized = array_ops.stack([serialized, serialized])
        map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.variant],
                                           [None])[0]
        deserialized = sparse_ops.deserialize_sparse(map_defun_op,
                                                     dtypes.int32)
        expected = sparse_tensor.SparseTensorValue(indices=[[0, 0,
                                                             0], [0, 1, 2],
                                                            [1, 0, 0],
                                                            [1, 1, 2]],
                                                   values=[1, 2, 1, 2],
                                                   dense_shape=[2, 3, 4])
        actual = self.evaluate(deserialized)
        self.assertValuesEqual(expected, actual)
Ejemplo n.º 3
0
  def testMapDefunWithStrTensor(self):

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def fn(x):
      return x

    st = sparse_tensor.SparseTensor(
        indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])

    serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.string)
    serialized = array_ops.stack([serialized, serialized])
    map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.string],
                                       [None])[0]
    deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
    expected = sparse_tensor.SparseTensorValue(
        indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
        values=[1, 2, 1, 2],
        dense_shape=[2, 3, 4])
    actual = self.evaluate(deserialized)
    self.assertSparseValuesEqual(expected, actual)
Ejemplo n.º 4
0
  def testMapDefunWithVariantTensorAsCaptured(self):

    st = sparse_tensor.SparseTensor(
        indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
    serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant)

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
    def fn(x):
      del x
      return serialized

    x = constant_op.constant([0, 0])
    map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant], [None])[0]
    deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
    expected = sparse_tensor.SparseTensorValue(
        indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
        values=[1, 2, 1, 2],
        dense_shape=[2, 3, 4])
    actual = self.evaluate(deserialized)
    self.assertSparseValuesEqual(expected, actual)