def testSerializeDeserializeNestedBatch(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorValue_5x6(np.arange(6)) serialized = sparse_ops.serialize_sparse(sp_input) serialized = array_ops.stack([serialized, serialized]) serialized = array_ops.stack([serialized, serialized]) sp_deserialized = sparse_ops.deserialize_sparse( serialized, dtype=dtypes.int32) combined_indices, combined_values, combined_shape = sess.run( sp_deserialized) # minibatch 0 self.assertAllEqual(combined_indices[:6, :2], [[0, 0]] * 6) self.assertAllEqual(combined_indices[:6, 2:], sp_input[0]) self.assertAllEqual(combined_values[:6], sp_input[1]) # minibatch 1 self.assertAllEqual(combined_indices[6:12, :2], [[0, 1]] * 6) self.assertAllEqual(combined_indices[6:12, 2:], sp_input[0]) self.assertAllEqual(combined_values[6:12], sp_input[1]) # minibatch 2 self.assertAllEqual(combined_indices[12:18, :2], [[1, 0]] * 6) self.assertAllEqual(combined_indices[12:18, 2:], sp_input[0]) self.assertAllEqual(combined_values[12:18], sp_input[1]) # minibatch 3 self.assertAllEqual(combined_indices[18:, :2], [[1, 1]] * 6) self.assertAllEqual(combined_indices[18:, 2:], sp_input[0]) self.assertAllEqual(combined_values[18:], sp_input[1]) self.assertAllEqual(combined_shape, [2, 2, 5, 6])
def testDeserializeFailsInconsistentRank(self): with self.test_session(use_gpu=False) as sess: sp_input0 = self._SparseTensorPlaceholder() sp_input1 = self._SparseTensorPlaceholder() input0_val = self._SparseTensorValue_5x6(np.arange(6)) input1_val = self._SparseTensorValue_1x1x1() serialized0 = sparse_ops.serialize_sparse(sp_input0) serialized1 = sparse_ops.serialize_sparse(sp_input1) serialized_concat = array_ops.stack([serialized0, serialized1]) sp_deserialized = sparse_ops.deserialize_many_sparse( serialized_concat, dtype=dtypes.int32) with self.assertRaisesOpError( r"Inconsistent rank across SparseTensors: rank prior to " r"SparseTensor\[1\] was: 3 but rank of SparseTensor\[1\] is: 4"): sess.run(sp_deserialized, {sp_input0: input0_val, sp_input1: input1_val})
def testDeserializeFailsWrongType(self): with self.test_session(use_gpu=False) as sess: sp_input0 = self._SparseTensorPlaceholder() sp_input1 = self._SparseTensorPlaceholder() input0_val = self._SparseTensorValue_5x6(np.arange(6)) input1_val = self._SparseTensorValue_3x4(np.arange(6)) serialized0 = sparse_ops.serialize_sparse(sp_input0) serialized1 = sparse_ops.serialize_sparse(sp_input1) serialized_concat = array_ops.stack([serialized0, serialized1]) sp_deserialized = sparse_ops.deserialize_many_sparse( serialized_concat, dtype=dtypes.int64) with self.assertRaisesOpError( r"Requested SparseTensor of type int64 but " r"SparseTensor\[0\].values.dtype\(\) == int32"): sess.run(sp_deserialized, {sp_input0: input0_val, sp_input1: input1_val})
def testSerializeDeserializeMany(self): with self.test_session(use_gpu=False) as sess: sp_input0 = self._SparseTensorValue_5x6(np.arange(6)) sp_input1 = self._SparseTensorValue_3x4(np.arange(6)) serialized0 = sparse_ops.serialize_sparse(sp_input0) serialized1 = sparse_ops.serialize_sparse(sp_input1) serialized_concat = array_ops.stack([serialized0, serialized1]) sp_deserialized = sparse_ops.deserialize_many_sparse( serialized_concat, dtype=dtypes.int32) combined_indices, combined_values, combined_shape = sess.run( sp_deserialized) self.assertAllEqual(combined_indices[:6, 0], [0] * 6) # minibatch 0 self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0]) self.assertAllEqual(combined_indices[6:, 0], [1] * 6) # minibatch 1 self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0]) self.assertAllEqual(combined_values[:6], sp_input0[1]) self.assertAllEqual(combined_values[6:], sp_input1[1]) self.assertAllEqual(combined_shape, [2, 5, 6])
def testSerializeDeserialize(self): with self.test_session(use_gpu=False) as sess: sp_input = self._SparseTensorValue_5x6(np.arange(6)) serialized = sparse_ops.serialize_sparse(sp_input) sp_deserialized = sparse_ops.deserialize_sparse( serialized, dtype=dtypes.int32) indices, values, shape = sess.run(sp_deserialized) self.assertAllEqual(indices, sp_input[0]) self.assertAllEqual(values, sp_input[1]) self.assertAllEqual(shape, sp_input[2])
def testDeserializeFailsInvalidProto(self): with self.test_session(use_gpu=False) as sess: sp_input0 = self._SparseTensorPlaceholder() input0_val = self._SparseTensorValue_5x6(np.arange(6)) serialized0 = sparse_ops.serialize_sparse(sp_input0) serialized1 = ["a", "b", "c"] serialized_concat = array_ops.stack([serialized0, serialized1]) sp_deserialized = sparse_ops.deserialize_many_sparse( serialized_concat, dtype=dtypes.int32) with self.assertRaisesOpError( r"Could not parse serialized_sparse\[1, 0\]"): sess.run(sp_deserialized, {sp_input0: input0_val})
def serialize_sparse_tensors(tensors): """Serializes sparse tensors. Args: tensors: a tensor structure to serialize. Returns: `tensors` with any sparse tensors replaced by their serialized version. """ ret = nest.pack_sequence_as(tensors, [ sparse_ops.serialize_sparse(tensor, out_type=dtypes.variant) if isinstance(tensor, sparse_tensor.SparseTensor) else tensor for tensor in nest.flatten(tensors) ]) return ret
def testVariantSerializeDeserializeScalar(self): with self.session(use_gpu=False) as sess: indices_value = np.array([[]], dtype=np.int64) values_value = np.array([37], dtype=np.int32) shape_value = np.array([], dtype=np.int64) sparse_tensor = self._SparseTensorPlaceholder() serialized = sparse_ops.serialize_sparse( sparse_tensor, out_type=dtypes.variant) deserialized = sparse_ops.deserialize_sparse( serialized, dtype=dtypes.int32) deserialized_value = sess.run( deserialized, feed_dict={ sparse_tensor.indices: indices_value, sparse_tensor.values: values_value, sparse_tensor.dense_shape: shape_value }) self.assertAllEqual(deserialized_value.indices, indices_value) self.assertAllEqual(deserialized_value.values, values_value) self.assertAllEqual(deserialized_value.dense_shape, shape_value)
def _maybe_serialize(t): if not isinstance(t, ops.SparseTensor): return t return (sparse_ops.serialize_many_sparse(t) if enqueue_many else sparse_ops.serialize_sparse(t))
def _to_tensor_list(self, value): return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
def _maybe_serialize(t, is_sparse): if not is_sparse: return t return (sparse_ops.serialize_many_sparse(t) if enqueue_many else sparse_ops.serialize_sparse(t))