def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
                                           padded_shape):
        """Tests padded batching of sparse tensor windows.

    Args:
      structure: the input structure
      shapes: the input shapes
      dtype: the input data type
      padded_shape: the shape to pad the output to
    """
        def fn(*args):
            if len(args) == 1 and not isinstance(args[0], tuple):
                return batching.padded_batch_window(args[0], padded_shape)

            return tuple([
                fn(*arg) if isinstance(arg, tuple) else
                batching.padded_batch_window(arg, padded_shape) for arg in args
            ])

        dataset = self._structuredRaggedSparseDataset(
            structure, shapes,
            dtype).apply(grouping.window_dataset(len(shapes))).apply(
                grouping._map_x_dataset(fn))
        get_next = dataset.make_one_shot_iterator().get_next()
        with self.test_session() as sess:
            expected = sess.run(
                self._structuredRaggedSparseElement(structure, shapes, dtype,
                                                    padded_shape))
            actual = sess.run(get_next)
            self._assertEqual(expected, actual)
    def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
                                                       padded_shape):
        """Tests padded batching of dynamically shaped sparse tensor windows.

    Args:
      shapes: the input shapes
      padded_shape: the shape to pad the output to
    """

        shapes_t = array_ops.placeholder(dtypes.int32)
        dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
            lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
                self._make_dense_to_sparse_fn(False)
            ).apply(grouping.window_dataset(len(shapes))).apply(
                grouping._map_x_dataset(
                    lambda x: batching.padded_batch_window(x, padded_shape)))
        iterator = dataset.make_initializable_iterator()
        init_op = iterator.initializer
        get_next = iterator.get_next()
        with self.test_session() as sess:
            sess.run(init_op, {shapes_t: shapes})
            expected = sess.run(
                self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
                                                    padded_shape))
            actual = sess.run(get_next)
            self._assertEqual(expected, actual)
    def testWindowDatasetBatchSparse(self, structure, shape, dtype):
        """Tests batching of sparse tensor windows.

    Args:
      structure: the input structure
      shape: the input shape
      dtype: the input data type
    """
        def fn(*args):
            if len(args) == 1 and not isinstance(args[0], tuple):
                return batching.batch_window(args[0])

            return tuple([
                fn(*arg)
                if isinstance(arg, tuple) else batching.batch_window(arg)
                for arg in args
            ])

        dataset = self._structuredSparseDataset(
            structure, shape, dtype).repeat(5).apply(
                grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
        get_next = dataset.make_one_shot_iterator().get_next()
        with self.test_session() as sess:
            expected = sess.run(
                self._structuredSparseElement(
                    structure, np.concatenate(([5], shape), axis=0), dtype))
            actual = sess.run(get_next)
            self._assertEqual(expected, actual)
  def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
                                                     padded_shape):
    """Tests padded batching of dynamically shaped sparse tensor windows.

    Args:
      shapes: the input shapes
      padded_shape: the shape to pad the output to
    """

    shapes_t = array_ops.placeholder(dtypes.int32)
    dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
        lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
            self._make_dense_to_sparse_fn(False)
        ).apply(grouping.window_dataset(len(shapes))).apply(
            grouping._map_x_dataset(
                lambda x: batching.padded_batch_window(x, padded_shape)))
    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with self.cached_session() as sess:
      sess.run(init_op, {shapes_t: shapes})
      expected = sess.run(
          self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
                                              padded_shape))
      actual = sess.run(get_next)
      self._assertEqual(expected, actual)
  def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
                                         padded_shape):
    """Tests padded batching of sparse tensor windows.

    Args:
      structure: the input structure
      shapes: the input shapes
      dtype: the input data type
      padded_shape: the shape to pad the output to
    """

    def fn(*args):
      if len(args) == 1 and not isinstance(args[0], tuple):
        return batching.padded_batch_window(args[0], padded_shape)

      return tuple([
          fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
              arg, padded_shape) for arg in args
      ])

    dataset = self._structuredRaggedSparseDataset(
        structure, shapes, dtype).apply(grouping.window_dataset(
            len(shapes))).apply(grouping._map_x_dataset(fn))
    get_next = dataset.make_one_shot_iterator().get_next()
    with self.cached_session() as sess:
      expected = sess.run(
          self._structuredRaggedSparseElement(structure, shapes, dtype,
                                              padded_shape))
      actual = sess.run(get_next)
      self._assertEqual(expected, actual)
  def testWindowDatasetBatchSparse(self, structure, shape, dtype):
    """Tests batching of sparse tensor windows.

    Args:
      structure: the input structure
      shape: the input shape
      dtype: the input data type
    """

    def fn(*args):
      if len(args) == 1 and not isinstance(args[0], tuple):
        return batching.batch_window(args[0])

      return tuple([
          fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg)
          for arg in args
      ])

    dataset = self._structuredSparseDataset(
        structure, shape, dtype).repeat(5).apply(
            grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
    get_next = dataset.make_one_shot_iterator().get_next()
    with self.cached_session() as sess:
      expected = sess.run(
          self._structuredSparseElement(structure,
                                        np.concatenate(([5], shape), axis=0),
                                        dtype))
      actual = sess.run(get_next)
      self._assertEqual(expected, actual)
    def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
        """Tests invalid padded batching of dense tensor windows.

    Args:
      shapes: the input shapes
      padded_shape: the shape to pad the output to
    """

        dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
            lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)
        ).apply(grouping.window_dataset(len(shapes))).apply(
            grouping._map_x_dataset(
                lambda x: batching.padded_batch_window(x, padded_shape)))
        get_next = dataset.make_one_shot_iterator().get_next()
        with self.test_session() as sess:
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(get_next)
  def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
    """Tests invalid padded batching of dense tensor windows.

    Args:
      shapes: the input shapes
      padded_shape: the shape to pad the output to
    """

    dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
        lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
            grouping.window_dataset(len(shapes))).apply(
                grouping._map_x_dataset(
                    lambda x: batching.padded_batch_window(x, padded_shape)))
    get_next = dataset.make_one_shot_iterator().get_next()
    with self.cached_session() as sess:
      with self.assertRaises(errors.InvalidArgumentError):
        sess.run(get_next)
Пример #9
0
  def testWindowDatasetBatchDenseDynamicShape(self, shape):
    """Tests batching of dynamically shaped dense tensor windows.

    Args:
      shape: the input shape
    """

    shape_t = array_ops.placeholder(dtypes.int32)
    dataset = dataset_ops.Dataset.from_tensors(
        array_ops.zeros(shape_t)).repeat(5).apply(
            grouping.window_dataset(5)).apply(
                grouping._map_x_dataset(batching.batch_window))
    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with self.cached_session() as sess:
      sess.run(init_op, {shape_t: shape})
      expected = sess.run(
          self._structuredElement(None, np.concatenate(([5], shape), axis=0),
                                  dtypes.int32))
      actual = sess.run(get_next)
      self._assertEqual(expected, actual)