def fn(*args):
      if len(args) == 1 and not isinstance(args[0], tuple):
        return batching.padded_batch_window(args[0], padded_shape)

      return tuple([
          fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
              arg, padded_shape) for arg in args
      ])
  def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
                                                     padded_shape):
    """Tests padded batching of dynamically shaped sparse tensor windows.

    Args:
      shapes: the input shapes
      padded_shape: the shape to pad the output to
    """

    shapes_t = array_ops.placeholder(dtypes.int32)
    dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
        lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
            self._make_dense_to_sparse_fn(False)
        ).apply(grouping.window_dataset(len(shapes))).apply(
            grouping._map_x_dataset(
                lambda x: batching.padded_batch_window(x, padded_shape)))
    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with self.cached_session() as sess:
      sess.run(init_op, {shapes_t: shapes})
      expected = sess.run(
          self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
                                              padded_shape))
      actual = sess.run(get_next)
      self._assertEqual(expected, actual)
  def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
    """Tests invalid padded batching of dense tensor windows.

    Args:
      shapes: the input shapes
      padded_shape: the shape to pad the output to
    """

    dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
        lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
            grouping.window_dataset(len(shapes))).apply(
                grouping._map_x_dataset(
                    lambda x: batching.padded_batch_window(x, padded_shape)))
    get_next = dataset.make_one_shot_iterator().get_next()
    with self.cached_session() as sess:
      with self.assertRaises(errors.InvalidArgumentError):
        sess.run(get_next)
    def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
        """Tests invalid padded batching of dense tensor windows.

    Args:
      shapes: the input shapes
      padded_shape: the shape to pad the output to
    """

        dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
            lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)
        ).apply(grouping.window_dataset(len(shapes))).apply(
            grouping._map_x_dataset(
                lambda x: batching.padded_batch_window(x, padded_shape)))
        get_next = dataset.make_one_shot_iterator().get_next()
        with self.test_session() as sess:
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(get_next)