def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype, padded_shape): """Tests padded batching of sparse tensor windows. Args: structure: the input structure shapes: the input shapes dtype: the input data type padded_shape: the shape to pad the output to """ def fn(*args): if len(args) == 1 and not isinstance(args[0], tuple): return batching.padded_batch_window(args[0], padded_shape) return tuple([ fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(arg, padded_shape) for arg in args ]) dataset = self._structuredRaggedSparseDataset( structure, shapes, dtype).apply(grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() with self.test_session() as sess: expected = sess.run( self._structuredRaggedSparseElement(structure, shapes, dtype, padded_shape)) actual = sess.run(get_next) self._assertEqual(expected, actual)
def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes, padded_shape): """Tests padded batching of dynamically shaped sparse tensor windows. Args: shapes: the input shapes padded_shape: the shape to pad the output to """ shapes_t = array_ops.placeholder(dtypes.int32) dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( self._make_dense_to_sparse_fn(False) ).apply(grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op, {shapes_t: shapes}) expected = sess.run( self._structuredRaggedSparseElement(None, shapes, dtypes.int32, padded_shape)) actual = sess.run(get_next) self._assertEqual(expected, actual)
def testWindowDatasetBatchSparse(self, structure, shape, dtype): """Tests batching of sparse tensor windows. Args: structure: the input structure shape: the input shape dtype: the input data type """ def fn(*args): if len(args) == 1 and not isinstance(args[0], tuple): return batching.batch_window(args[0]) return tuple([ fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) for arg in args ]) dataset = self._structuredSparseDataset( structure, shape, dtype).repeat(5).apply( grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() with self.test_session() as sess: expected = sess.run( self._structuredSparseElement( structure, np.concatenate(([5], shape), axis=0), dtype)) actual = sess.run(get_next) self._assertEqual(expected, actual)
def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype, padded_shape): """Tests padded batching of sparse tensor windows. Args: structure: the input structure shapes: the input shapes dtype: the input data type padded_shape: the shape to pad the output to """ def fn(*args): if len(args) == 1 and not isinstance(args[0], tuple): return batching.padded_batch_window(args[0], padded_shape) return tuple([ fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( arg, padded_shape) for arg in args ]) dataset = self._structuredRaggedSparseDataset( structure, shapes, dtype).apply(grouping.window_dataset( len(shapes))).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() with self.cached_session() as sess: expected = sess.run( self._structuredRaggedSparseElement(structure, shapes, dtype, padded_shape)) actual = sess.run(get_next) self._assertEqual(expected, actual)
def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes, padded_shape): """Tests padded batching of dynamically shaped sparse tensor windows. Args: shapes: the input shapes padded_shape: the shape to pad the output to """ shapes_t = array_ops.placeholder(dtypes.int32) dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( self._make_dense_to_sparse_fn(False) ).apply(grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op, {shapes_t: shapes}) expected = sess.run( self._structuredRaggedSparseElement(None, shapes, dtypes.int32, padded_shape)) actual = sess.run(get_next) self._assertEqual(expected, actual)
def testWindowDatasetBatchSparse(self, structure, shape, dtype): """Tests batching of sparse tensor windows. Args: structure: the input structure shape: the input shape dtype: the input data type """ def fn(*args): if len(args) == 1 and not isinstance(args[0], tuple): return batching.batch_window(args[0]) return tuple([ fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) for arg in args ]) dataset = self._structuredSparseDataset( structure, shape, dtype).repeat(5).apply( grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() with self.cached_session() as sess: expected = sess.run( self._structuredSparseElement(structure, np.concatenate(([5], shape), axis=0), dtype)) actual = sess.run(get_next) self._assertEqual(expected, actual)
def testWindowDatasetFlatMap(self, structure, shape, dtype): """Tests windowing by chaining it with flat map. Args: structure: the input structure shape: the input shape dtype: the input data type """ def fn(*args): if len(args) == 1 and not isinstance(args[0], tuple): return args[0] return dataset_ops.Dataset.zip( tuple([ fn(*arg) if isinstance(arg, tuple) else arg for arg in args ])) dataset = self._structuredDataset( structure, shape, dtype).repeat(5).apply(grouping.window_dataset(5)).flat_map(fn) get_next = dataset.make_one_shot_iterator().get_next() with self.cached_session() as sess: expected = sess.run( self._structuredElement(structure, shape, dtype)) for _ in range(5): actual = sess.run(get_next) self._assertEqual(expected, actual)
def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): """Tests invalid padded batching of dense tensor windows. Args: shapes: the input shapes padded_shape: the shape to pad the output to """ dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( lambda shape: array_ops.zeros(shape, dtype=dtypes.int32) ).apply(grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() with self.test_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next)
def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): """Tests invalid padded batching of dense tensor windows. Args: shapes: the input shapes padded_shape: the shape to pad the output to """ dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next)
def testWindowDatasetBatchDenseDynamicShape(self, shape): """Tests batching of dynamically shaped dense tensor windows. Args: shape: the input shape """ shape_t = array_ops.placeholder(dtypes.int32) dataset = dataset_ops.Dataset.from_tensors( array_ops.zeros(shape_t)).repeat(5).apply( grouping.window_dataset(5)).apply( grouping._map_x_dataset(batching.batch_window)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op, {shape_t: shape}) expected = sess.run( self._structuredElement(None, np.concatenate(([5], shape), axis=0), dtypes.int32)) actual = sess.run(get_next) self._assertEqual(expected, actual)
def testWindowDatasetFlatMap(self, structure, shape, dtype): """Tests windowing by chaining it with flat map. Args: structure: the input structure shape: the input shape dtype: the input data type """ def fn(*args): if len(args) == 1 and not isinstance(args[0], tuple): return args[0] return dataset_ops.Dataset.zip( tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args])) dataset = self._structuredDataset(structure, shape, dtype).apply( grouping.window_dataset(5)).flat_map(fn) get_next = dataset.make_one_shot_iterator().get_next() with self.test_session() as sess: expected = sess.run(self._structuredElement(structure, shape, dtype)) actual = sess.run(get_next) self._assertEqual(expected, actual)