def fn(*args): if len(args) == 1 and not isinstance(args[0], tuple): return batching.padded_batch_window(args[0], padded_shape) return tuple([ fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( arg, padded_shape) for arg in args ])
def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes, padded_shape): """Tests padded batching of dynamically shaped sparse tensor windows. Args: shapes: the input shapes padded_shape: the shape to pad the output to """ shapes_t = array_ops.placeholder(dtypes.int32) dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( self._make_dense_to_sparse_fn(False) ).apply(grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op, {shapes_t: shapes}) expected = sess.run( self._structuredRaggedSparseElement(None, shapes, dtypes.int32, padded_shape)) actual = sess.run(get_next) self._assertEqual(expected, actual)
def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): """Tests invalid padded batching of dense tensor windows. Args: shapes: the input shapes padded_shape: the shape to pad the output to """ dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next)
def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): """Tests invalid padded batching of dense tensor windows. Args: shapes: the input shapes padded_shape: the shape to pad the output to """ dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( lambda shape: array_ops.zeros(shape, dtype=dtypes.int32) ).apply(grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() with self.test_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next)