def testSerializeDeserializeNestedBatch(self):
    with self.test_session(use_gpu=False) as sess:
      sp_input = self._SparseTensorValue_5x6(np.arange(6))
      serialized = sparse_ops.serialize_sparse(sp_input)
      serialized = array_ops.stack([serialized, serialized])
      serialized = array_ops.stack([serialized, serialized])

      sp_deserialized = sparse_ops.deserialize_sparse(
          serialized, dtype=dtypes.int32)

      combined_indices, combined_values, combined_shape = sess.run(
          sp_deserialized)

      # minibatch 0
      self.assertAllEqual(combined_indices[:6, :2], [[0, 0]] * 6)
      self.assertAllEqual(combined_indices[:6, 2:], sp_input[0])
      self.assertAllEqual(combined_values[:6], sp_input[1])
      # minibatch 1
      self.assertAllEqual(combined_indices[6:12, :2], [[0, 1]] * 6)
      self.assertAllEqual(combined_indices[6:12, 2:], sp_input[0])
      self.assertAllEqual(combined_values[6:12], sp_input[1])
      # minibatch 2
      self.assertAllEqual(combined_indices[12:18, :2], [[1, 0]] * 6)
      self.assertAllEqual(combined_indices[12:18, 2:], sp_input[0])
      self.assertAllEqual(combined_values[12:18], sp_input[1])
      # minibatch 3
      self.assertAllEqual(combined_indices[18:, :2], [[1, 1]] * 6)
      self.assertAllEqual(combined_indices[18:, 2:], sp_input[0])
      self.assertAllEqual(combined_values[18:], sp_input[1])

      self.assertAllEqual(combined_shape, [2, 2, 5, 6])
  def testDeserializeFailsInconsistentRank(self):
    with self.test_session(use_gpu=False) as sess:
      sp_input0 = self._SparseTensorPlaceholder()
      sp_input1 = self._SparseTensorPlaceholder()
      input0_val = self._SparseTensorValue_5x6(np.arange(6))
      input1_val = self._SparseTensorValue_1x1x1()
      serialized0 = sparse_ops.serialize_sparse(sp_input0)
      serialized1 = sparse_ops.serialize_sparse(sp_input1)
      serialized_concat = array_ops.stack([serialized0, serialized1])

      sp_deserialized = sparse_ops.deserialize_many_sparse(
          serialized_concat, dtype=dtypes.int32)

      with self.assertRaisesOpError(
          r"Inconsistent rank across SparseTensors: rank prior to "
          r"SparseTensor\[1\] was: 3 but rank of SparseTensor\[1\] is: 4"):
        sess.run(sp_deserialized,
                 {sp_input0: input0_val,
                  sp_input1: input1_val})
  def testDeserializeFailsWrongType(self):
    with self.test_session(use_gpu=False) as sess:
      sp_input0 = self._SparseTensorPlaceholder()
      sp_input1 = self._SparseTensorPlaceholder()
      input0_val = self._SparseTensorValue_5x6(np.arange(6))
      input1_val = self._SparseTensorValue_3x4(np.arange(6))
      serialized0 = sparse_ops.serialize_sparse(sp_input0)
      serialized1 = sparse_ops.serialize_sparse(sp_input1)
      serialized_concat = array_ops.stack([serialized0, serialized1])

      sp_deserialized = sparse_ops.deserialize_many_sparse(
          serialized_concat, dtype=dtypes.int64)

      with self.assertRaisesOpError(
          r"Requested SparseTensor of type int64 but "
          r"SparseTensor\[0\].values.dtype\(\) == int32"):
        sess.run(sp_deserialized,
                 {sp_input0: input0_val,
                  sp_input1: input1_val})
  def testSerializeDeserializeMany(self):
    with self.test_session(use_gpu=False) as sess:
      sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
      sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
      serialized0 = sparse_ops.serialize_sparse(sp_input0)
      serialized1 = sparse_ops.serialize_sparse(sp_input1)
      serialized_concat = array_ops.stack([serialized0, serialized1])

      sp_deserialized = sparse_ops.deserialize_many_sparse(
          serialized_concat, dtype=dtypes.int32)

      combined_indices, combined_values, combined_shape = sess.run(
          sp_deserialized)

      self.assertAllEqual(combined_indices[:6, 0], [0] * 6)  # minibatch 0
      self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0])
      self.assertAllEqual(combined_indices[6:, 0], [1] * 6)  # minibatch 1
      self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0])
      self.assertAllEqual(combined_values[:6], sp_input0[1])
      self.assertAllEqual(combined_values[6:], sp_input1[1])
      self.assertAllEqual(combined_shape, [2, 5, 6])
  def testSerializeDeserialize(self):
    with self.test_session(use_gpu=False) as sess:
      sp_input = self._SparseTensorValue_5x6(np.arange(6))
      serialized = sparse_ops.serialize_sparse(sp_input)
      sp_deserialized = sparse_ops.deserialize_sparse(
          serialized, dtype=dtypes.int32)

      indices, values, shape = sess.run(sp_deserialized)

      self.assertAllEqual(indices, sp_input[0])
      self.assertAllEqual(values, sp_input[1])
      self.assertAllEqual(shape, sp_input[2])
  def testDeserializeFailsInvalidProto(self):
    with self.test_session(use_gpu=False) as sess:
      sp_input0 = self._SparseTensorPlaceholder()
      input0_val = self._SparseTensorValue_5x6(np.arange(6))
      serialized0 = sparse_ops.serialize_sparse(sp_input0)
      serialized1 = ["a", "b", "c"]
      serialized_concat = array_ops.stack([serialized0, serialized1])

      sp_deserialized = sparse_ops.deserialize_many_sparse(
          serialized_concat, dtype=dtypes.int32)

      with self.assertRaisesOpError(
          r"Could not parse serialized_sparse\[1, 0\]"):
        sess.run(sp_deserialized, {sp_input0: input0_val})
Example #7
0
def serialize_sparse_tensors(tensors):
  """Serializes sparse tensors.

  Args:
    tensors: a tensor structure to serialize.

  Returns:
    `tensors` with any sparse tensors replaced by their serialized version.
  """

  ret = nest.pack_sequence_as(tensors, [
      sparse_ops.serialize_sparse(tensor, out_type=dtypes.variant)
      if isinstance(tensor, sparse_tensor.SparseTensor) else tensor
      for tensor in nest.flatten(tensors)
  ])
  return ret
 def testVariantSerializeDeserializeScalar(self):
   with self.session(use_gpu=False) as sess:
     indices_value = np.array([[]], dtype=np.int64)
     values_value = np.array([37], dtype=np.int32)
     shape_value = np.array([], dtype=np.int64)
     sparse_tensor = self._SparseTensorPlaceholder()
     serialized = sparse_ops.serialize_sparse(
         sparse_tensor, out_type=dtypes.variant)
     deserialized = sparse_ops.deserialize_sparse(
         serialized, dtype=dtypes.int32)
     deserialized_value = sess.run(
         deserialized,
         feed_dict={
             sparse_tensor.indices: indices_value,
             sparse_tensor.values: values_value,
             sparse_tensor.dense_shape: shape_value
         })
     self.assertAllEqual(deserialized_value.indices, indices_value)
     self.assertAllEqual(deserialized_value.values, values_value)
     self.assertAllEqual(deserialized_value.dense_shape, shape_value)
Example #9
0
 def _maybe_serialize(t):
   if not isinstance(t, ops.SparseTensor):
     return t
   return (sparse_ops.serialize_many_sparse(t) if enqueue_many
           else sparse_ops.serialize_sparse(t))
Example #10
0
 def _maybe_serialize(t):
     if not isinstance(t, ops.SparseTensor):
         return t
     return (sparse_ops.serialize_many_sparse(t)
             if enqueue_many else sparse_ops.serialize_sparse(t))
Example #11
0
 def _to_tensor_list(self, value):
     return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
Example #12
0
 def _to_tensor_list(self, value):
   return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
Example #13
0
 def _maybe_serialize(t, is_sparse):
   if not is_sparse:
     return t
   return (sparse_ops.serialize_many_sparse(t) if enqueue_many
           else sparse_ops.serialize_sparse(t))
Example #14
0
 def _maybe_serialize(t, is_sparse):
     if not is_sparse:
         return t
     return (sparse_ops.serialize_many_sparse(t)
             if enqueue_many else sparse_ops.serialize_sparse(t))