def test_assert_wrong_element_shape_on_unknown_shape_dataset(self): def create_unknown_shape_dataset(x): return script_ops.py_func( lambda _: ( # pylint: disable=g-long-lambda np.ones(2, dtype=np.float32), np.zeros((3, 4), dtype=np.int32)), [x], [dtypes.float32, dtypes.int32]) dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) unknown_shapes = (tensor_shape.TensorShape(None), tensor_shape.TensorShape(None)) self.assertEqual(unknown_shapes, dataset.output_shapes) wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 10))) iterator = ( dataset.apply(batching.assert_element_shape(wrong_shapes)) .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next)
def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self): def create_unknown_shape_dataset(x): return script_ops.py_func( lambda _: ( # pylint: disable=g-long-lambda np.ones(2, dtype=np.float32), np.zeros((3, 4), dtype=np.int32)), [x], [dtypes.float32, dtypes.int32]) dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) unknown_shapes = (tensor_shape.TensorShape(None), tensor_shape.TensorShape(None)) self.assertEqual(unknown_shapes, dataset.output_shapes) wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((None, 10))) iterator = ( dataset.apply(batching.assert_element_shape(wrong_shapes)) .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next)
def test_assert_partial_element_shape(self): def create_dataset(_): return (array_ops.ones(2, dtype=dtypes.float32), array_ops.zeros((3, 4), dtype=dtypes.int32)) dataset = dataset_ops.Dataset.range(5).map(create_dataset) partial_expected_shape = (tensor_shape.TensorShape(None), # Unknown shape tensor_shape.TensorShape((None, 4))) # Partial shape result = dataset.apply( batching.assert_element_shape(partial_expected_shape)) # Partial shapes are merged with actual shapes: actual_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 4))) self.assertEqual(actual_shapes, result.output_shapes) iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def test_assert_element_shape(self): def create_unknown_shape_dataset(x): return script_ops.py_func( lambda _: ( # pylint: disable=g-long-lambda np.ones(2, dtype=np.float32), np.zeros((3, 4), dtype=np.int32)), [x], [dtypes.float32, dtypes.int32]) dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) unknown_shapes = (tensor_shape.TensorShape(None), tensor_shape.TensorShape(None)) self.assertEqual(unknown_shapes, dataset.output_shapes) expected_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 4))) result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def test_assert_element_shape(self): def create_dataset(_): return (array_ops.ones(2, dtype=dtypes.float32), array_ops.zeros((3, 4), dtype=dtypes.int32)) dataset = dataset_ops.Dataset.range(5).map(create_dataset) expected_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 4))) self.assertEqual(expected_shapes, dataset_ops.get_legacy_output_shapes(dataset)) result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, dataset_ops.get_legacy_output_shapes(result)) iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def test_assert_partial_element_shape(self): def create_dataset(_): return (array_ops.ones(2, dtype=dtypes.float32), array_ops.zeros((3, 4), dtype=dtypes.int32)) dataset = dataset_ops.Dataset.range(5).map(create_dataset) partial_expected_shape = ( tensor_shape.TensorShape(None), # Unknown shape tensor_shape.TensorShape((None, 4))) # Partial shape result = dataset.apply( batching.assert_element_shape(partial_expected_shape)) # Partial shapes are merged with actual shapes: actual_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 4))) self.assertEqual(actual_shapes, result.output_shapes) iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def test_assert_element_shape(self): def create_dataset(_): return (array_ops.ones(2, dtype=dtypes.float32), array_ops.zeros((3, 4), dtype=dtypes.int32)) dataset = dataset_ops.Dataset.range(5).map(create_dataset) expected_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 4))) self.assertEqual(expected_shapes, dataset_ops.get_legacy_output_shapes(dataset)) result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, dataset_ops.get_legacy_output_shapes(result)) iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)
def test_assert_wrong_element_shape(self): def create_dataset(_): return (array_ops.ones(2, dtype=dtypes.float32), array_ops.zeros((3, 4), dtype=dtypes.int32)) dataset = dataset_ops.Dataset.range(3).map(create_dataset) wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 10))) with self.assertRaises(ValueError): dataset.apply(batching.assert_element_shape(wrong_shapes))
def test_assert_wrong_partial_element_shape(self): def create_dataset(_): return (array_ops.ones(2, dtype=dtypes.float32), array_ops.zeros((3, 4), dtype=dtypes.int32)) dataset = dataset_ops.Dataset.range(3).map(create_dataset) wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((None, 10))) with self.assertRaises(ValueError): dataset.apply(batching.assert_element_shape(wrong_shapes))