def testAddManyTakeManyRoundTripBatched(self):
        with self.test_session(use_gpu=False) as sess:
            # N == 4 because shape_value == [4, 5]
            indices_value_1 = np.array([[0, 0], [0, 1], [2, 0]],
                                       dtype=np.int64)
            values_value_1 = np.array([b"a", b"b", b"c"])
            shape_value_1 = np.array([4, 5], dtype=np.int64)
            sparse_tensor_1 = sparse_tensor.SparseTensor(
                array_ops.placeholder(dtypes.int64),
                array_ops.placeholder(dtypes.string),
                array_ops.placeholder(dtypes.int64))
            dict1 = {"key": sparse_tensor_1}
            indices_value_2 = np.array([[1, 4], [2, 3]], dtype=np.int64)
            values_value_2 = np.array([b"d", b"e"])
            shape_value_2 = np.array([4, 5], dtype=np.int64)
            sparse_tensor_2 = sparse_tensor.SparseTensor(
                array_ops.placeholder(dtypes.int64),
                array_ops.placeholder(dtypes.string),
                array_ops.placeholder(dtypes.int64))
            dict2 = {"key": sparse_tensor_2}

            input_seq1, keys1, tensor_list1 = sqss._deconstruct_sparse_tensor_seq(
                dict1, shared_name="a")
            handles_1 = input_seq1["key"]
            input_seq2, _, _ = sqss._deconstruct_sparse_tensor_seq(
                dict2, shared_name="a")
            handles_2 = input_seq2["key"]

            combined_handles = array_ops.stack([
                handles_1[1], handles_1[2], handles_1[3], handles_2[1],
                handles_2[2], handles_2[3]
            ])
            batched_dict = {"key": combined_handles}
            sqss._reconstruct_sparse_tensor_seq(batched_dict,
                                                keys1,
                                                tensor_list1,
                                                batch_size=2,
                                                num_unroll=3)

            roundtrip_value, = sess.run(
                [batched_dict["key"]],
                feed_dict={
                    sparse_tensor_1.indices: indices_value_1,
                    sparse_tensor_1.values: values_value_1,
                    sparse_tensor_1.dense_shape: shape_value_1,
                    sparse_tensor_2.indices: indices_value_2,
                    sparse_tensor_2.values: values_value_2,
                    sparse_tensor_2.dense_shape: shape_value_2
                })

            self.assertAllEqual(
                roundtrip_value.indices,
                np.array([[0, 1, 0], [1, 0, 4], [1, 1, 3]], dtype=np.int64))
            self.assertAllEqual(roundtrip_value.values,
                                np.array([b"c", b"d", b"e"]))
            self.assertAllEqual(roundtrip_value.dense_shape,
                                np.array([2, 3, 5], dtype=np.int64))
  def testAddManyTakeManyRoundTripBatched(self):
    with self.test_session(use_gpu=False) as sess:
      # N == 4 because shape_value == [4, 5]
      indices_value_1 = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64)
      values_value_1 = np.array([b"a", b"b", b"c"])
      shape_value_1 = np.array([4, 5], dtype=np.int64)
      sparse_tensor_1 = sparse_tensor.SparseTensor(
          array_ops.placeholder(dtypes.int64),
          array_ops.placeholder(dtypes.string),
          array_ops.placeholder(dtypes.int64))
      dict1 = {"key": sparse_tensor_1}
      indices_value_2 = np.array([[1, 4], [2, 3]], dtype=np.int64)
      values_value_2 = np.array([b"d", b"e"])
      shape_value_2 = np.array([4, 5], dtype=np.int64)
      sparse_tensor_2 = sparse_tensor.SparseTensor(
          array_ops.placeholder(dtypes.int64),
          array_ops.placeholder(dtypes.string),
          array_ops.placeholder(dtypes.int64))
      dict2 = {"key": sparse_tensor_2}

      input_seq1, keys1, tensor_list1 = sqss._deconstruct_sparse_tensor_seq(
          dict1, shared_name="a")
      handles_1 = input_seq1["key"]
      input_seq2, _, _ = sqss._deconstruct_sparse_tensor_seq(
          dict2, shared_name="a")
      handles_2 = input_seq2["key"]

      combined_handles = array_ops.stack(
          [handles_1[1], handles_1[2], handles_1[3],
           handles_2[1], handles_2[2], handles_2[3]])
      batched_dict = {"key": combined_handles}
      sqss._reconstruct_sparse_tensor_seq(
          batched_dict,
          keys1,
          tensor_list1,
          batch_size=2,
          num_unroll=3)

      roundtrip_value, = sess.run(
          [batched_dict["key"]],
          feed_dict={sparse_tensor_1.indices: indices_value_1,
                     sparse_tensor_1.values: values_value_1,
                     sparse_tensor_1.dense_shape: shape_value_1,
                     sparse_tensor_2.indices: indices_value_2,
                     sparse_tensor_2.values: values_value_2,
                     sparse_tensor_2.dense_shape: shape_value_2})

      self.assertAllEqual(roundtrip_value.indices,
                          np.array([[0, 1, 0], [1, 0, 4], [1, 1, 3]],
                                   dtype=np.int64))
      self.assertAllEqual(roundtrip_value.values,
                          np.array([b"c", b"d", b"e"]))
      self.assertAllEqual(roundtrip_value.dense_shape,
                          np.array([2, 3, 5], dtype=np.int64))