Esempio n. 1
0
  def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):

    def create_unknown_shape_dataset(x):
      return script_ops.py_func(
          lambda _: (  # pylint: disable=g-long-lambda
              np.ones(2, dtype=np.float32),
              np.zeros((3, 4), dtype=np.int32)),
          [x],
          [dtypes.float32, dtypes.int32])

    dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
    unknown_shapes = (tensor_shape.TensorShape(None),
                      tensor_shape.TensorShape(None))
    self.assertEqual(unknown_shapes, dataset.output_shapes)

    wrong_shapes = (tensor_shape.TensorShape(2),
                    tensor_shape.TensorShape((3, 10)))
    iterator = (
        dataset.apply(batching.assert_element_shape(wrong_shapes))
        .make_initializable_iterator())
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with self.test_session() as sess:
      sess.run(init_op)
      with self.assertRaises(errors.InvalidArgumentError):
        sess.run(get_next)
  def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):

    def create_unknown_shape_dataset(x):
      return script_ops.py_func(
          lambda _: (  # pylint: disable=g-long-lambda
              np.ones(2, dtype=np.float32),
              np.zeros((3, 4), dtype=np.int32)),
          [x],
          [dtypes.float32, dtypes.int32])

    dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
    unknown_shapes = (tensor_shape.TensorShape(None),
                      tensor_shape.TensorShape(None))
    self.assertEqual(unknown_shapes, dataset.output_shapes)

    wrong_shapes = (tensor_shape.TensorShape(2),
                    tensor_shape.TensorShape((None, 10)))
    iterator = (
        dataset.apply(batching.assert_element_shape(wrong_shapes))
        .make_initializable_iterator())
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with self.cached_session() as sess:
      sess.run(init_op)
      with self.assertRaises(errors.InvalidArgumentError):
        sess.run(get_next)
  def test_assert_partial_element_shape(self):

    def create_dataset(_):
      return (array_ops.ones(2, dtype=dtypes.float32),
              array_ops.zeros((3, 4), dtype=dtypes.int32))

    dataset = dataset_ops.Dataset.range(5).map(create_dataset)
    partial_expected_shape = (tensor_shape.TensorShape(None),       # Unknown shape
                              tensor_shape.TensorShape((None, 4)))  # Partial shape
    result = dataset.apply(
        batching.assert_element_shape(partial_expected_shape))
    # Partial shapes are merged with actual shapes:
    actual_shapes = (tensor_shape.TensorShape(2),
                     tensor_shape.TensorShape((3, 4)))
    self.assertEqual(actual_shapes, result.output_shapes)

    iterator = result.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with self.cached_session() as sess:
      sess.run(init_op)
      for _ in range(5):
        sess.run(get_next)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
Esempio n. 4
0
  def test_assert_element_shape(self):

    def create_unknown_shape_dataset(x):
      return script_ops.py_func(
          lambda _: (  # pylint: disable=g-long-lambda
              np.ones(2, dtype=np.float32),
              np.zeros((3, 4), dtype=np.int32)),
          [x],
          [dtypes.float32, dtypes.int32])

    dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
    unknown_shapes = (tensor_shape.TensorShape(None),
                      tensor_shape.TensorShape(None))
    self.assertEqual(unknown_shapes, dataset.output_shapes)

    expected_shapes = (tensor_shape.TensorShape(2),
                       tensor_shape.TensorShape((3, 4)))
    result = dataset.apply(batching.assert_element_shape(expected_shapes))
    self.assertEqual(expected_shapes, result.output_shapes)

    iterator = result.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with self.test_session() as sess:
      sess.run(init_op)
      for _ in range(5):
        sess.run(get_next)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
  def test_assert_element_shape(self):

    def create_dataset(_):
      return (array_ops.ones(2, dtype=dtypes.float32),
              array_ops.zeros((3, 4), dtype=dtypes.int32))

    dataset = dataset_ops.Dataset.range(5).map(create_dataset)
    expected_shapes = (tensor_shape.TensorShape(2),
                       tensor_shape.TensorShape((3, 4)))
    self.assertEqual(expected_shapes,
                     dataset_ops.get_legacy_output_shapes(dataset))

    result = dataset.apply(batching.assert_element_shape(expected_shapes))
    self.assertEqual(expected_shapes,
                     dataset_ops.get_legacy_output_shapes(result))

    iterator = dataset_ops.make_initializable_iterator(result)
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with self.cached_session() as sess:
      sess.run(init_op)
      for _ in range(5):
        sess.run(get_next)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
Esempio n. 6
0
    def test_assert_partial_element_shape(self):
        def create_dataset(_):
            return (array_ops.ones(2, dtype=dtypes.float32),
                    array_ops.zeros((3, 4), dtype=dtypes.int32))

        dataset = dataset_ops.Dataset.range(5).map(create_dataset)
        partial_expected_shape = (
            tensor_shape.TensorShape(None),  # Unknown shape
            tensor_shape.TensorShape((None, 4)))  # Partial shape
        result = dataset.apply(
            batching.assert_element_shape(partial_expected_shape))
        # Partial shapes are merged with actual shapes:
        actual_shapes = (tensor_shape.TensorShape(2),
                         tensor_shape.TensorShape((3, 4)))
        self.assertEqual(actual_shapes, result.output_shapes)

        iterator = result.make_initializable_iterator()
        init_op = iterator.initializer
        get_next = iterator.get_next()
        with self.cached_session() as sess:
            sess.run(init_op)
            for _ in range(5):
                sess.run(get_next)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)
Esempio n. 7
0
  def test_assert_element_shape(self):

    def create_dataset(_):
      return (array_ops.ones(2, dtype=dtypes.float32),
              array_ops.zeros((3, 4), dtype=dtypes.int32))

    dataset = dataset_ops.Dataset.range(5).map(create_dataset)
    expected_shapes = (tensor_shape.TensorShape(2),
                       tensor_shape.TensorShape((3, 4)))
    self.assertEqual(expected_shapes,
                     dataset_ops.get_legacy_output_shapes(dataset))

    result = dataset.apply(batching.assert_element_shape(expected_shapes))
    self.assertEqual(expected_shapes,
                     dataset_ops.get_legacy_output_shapes(result))

    iterator = result.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with self.cached_session() as sess:
      sess.run(init_op)
      for _ in range(5):
        sess.run(get_next)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
    def test_assert_wrong_element_shape(self):
        def create_dataset(_):
            return (array_ops.ones(2, dtype=dtypes.float32),
                    array_ops.zeros((3, 4), dtype=dtypes.int32))

        dataset = dataset_ops.Dataset.range(3).map(create_dataset)
        wrong_shapes = (tensor_shape.TensorShape(2),
                        tensor_shape.TensorShape((3, 10)))
        with self.assertRaises(ValueError):
            dataset.apply(batching.assert_element_shape(wrong_shapes))
  def test_assert_wrong_partial_element_shape(self):

    def create_dataset(_):
      return (array_ops.ones(2, dtype=dtypes.float32),
              array_ops.zeros((3, 4), dtype=dtypes.int32))

    dataset = dataset_ops.Dataset.range(3).map(create_dataset)
    wrong_shapes = (tensor_shape.TensorShape(2),
                    tensor_shape.TensorShape((None, 10)))
    with self.assertRaises(ValueError):
      dataset.apply(batching.assert_element_shape(wrong_shapes))