Exemple #1
0
    def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
                                      gpu_compatible):
        if not gpu_compatible and test.is_gpu_available():
            self.skipTest("Test case not yet supported on GPU.")
        ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)

        if context.executing_eagerly():
            iterator = dataset_ops.make_one_shot_iterator(ds)
            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            for _ in range(3):
                next_elem = iterator_ops.get_next_as_optional(iterator)
                self.assertIsInstance(next_elem, optional_ops.Optional)
                self.assertTrue(
                    structure.are_compatible(
                        next_elem.element_spec,
                        structure.type_spec_from_value(tf_value_fn())))
                self.assertTrue(next_elem.has_value())
                self.assertValuesEqual(np_value, next_elem.get_value())
            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                next_elem = iterator_ops.get_next_as_optional(iterator)
                self.assertFalse(self.evaluate(next_elem.has_value()))
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(next_elem.get_value())
        else:
            iterator = dataset_ops.make_initializable_iterator(ds)
            next_elem = iterator_ops.get_next_as_optional(iterator)
            self.assertIsInstance(next_elem, optional_ops.Optional)
            self.assertTrue(
                structure.are_compatible(
                    next_elem.element_spec,
                    structure.type_spec_from_value(tf_value_fn())))
            # Before initializing the iterator, evaluating the optional fails with
            # a FailedPreconditionError. This is only relevant in graph mode.
            elem_has_value_t = next_elem.has_value()
            elem_value_t = next_elem.get_value()
            with self.assertRaises(errors.FailedPreconditionError):
                self.evaluate(elem_has_value_t)
            with self.assertRaises(errors.FailedPreconditionError):
                self.evaluate(elem_value_t)
            # Now we initialize the iterator.
            self.evaluate(iterator.initializer)
            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            for _ in range(3):
                elem_has_value, elem_value = self.evaluate(
                    [elem_has_value_t, elem_value_t])
                self.assertTrue(elem_has_value)
                self.assertValuesEqual(np_value, elem_value)

            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                self.assertFalse(self.evaluate(elem_has_value_t))
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(elem_value_t)
  def testSkipEagerIteratorGetNextAsOptional(self, np_value, tf_value_fn,
                                             works_on_gpu):
    if not works_on_gpu and test.is_gpu_available():
      self.skipTest("Test case not yet supported on GPU.")
    ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
    iterator = ds.make_initializable_iterator()
    next_elem = iterator_ops.get_next_as_optional(iterator)
    self.assertIsInstance(next_elem, optional_ops.Optional)
    self.assertTrue(
        next_elem.value_structure.is_compatible_with(
            structure.Structure.from_value(tf_value_fn())))
    elem_has_value_t = next_elem.has_value()
    elem_value_t = next_elem.get_value()
    with self.cached_session() as sess:
      # Before initializing the iterator, evaluating the optional fails with
      # a FailedPreconditionError.
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(elem_has_value_t)
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(elem_value_t)

      # For each element of the dataset, assert that the optional evaluates to
      # the expected value.
      sess.run(iterator.initializer)
      for _ in range(3):
        elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
        self.assertTrue(elem_has_value)
        self._assertElementValueEqual(np_value, elem_value)

      # After exhausting the iterator, `next_elem.has_value()` will evaluate to
      # false, and attempting to get the value will fail.
      for _ in range(2):
        self.assertFalse(sess.run(elem_has_value_t))
        with self.assertRaises(errors.InvalidArgumentError):
          sess.run(elem_value_t)
  def testIteratorGetNextAsOptionalOnGPU(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    host_dataset = dataset_ops.Dataset.range(3)
    device_dataset = host_dataset.apply(
        prefetching_ops.copy_to_device("/gpu:0"))
    with ops.device("/gpu:0"):
      iterator = device_dataset.make_initializable_iterator()
      next_elem = iterator_ops.get_next_as_optional(iterator)
      elem_has_value_t = next_elem.has_value()
      elem_value_t = next_elem.get_value()

    with self.cached_session() as sess:
      # Before initializing the iterator, evaluating the optional fails with
      # a FailedPreconditionError.
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(elem_has_value_t)
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(elem_value_t)

      # For each element of the dataset, assert that the optional evaluates to
      # the expected value.
      sess.run(iterator.initializer)
      for i in range(3):
        elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
        self.assertTrue(elem_has_value)
        self.assertEqual(i, elem_value)

      # After exhausting the iterator, `next_elem.has_value()` will evaluate to
      # false, and attempting to get the value will fail.
      for _ in range(2):
        self.assertFalse(sess.run(elem_has_value_t))
        with self.assertRaises(errors.InvalidArgumentError):
          sess.run(elem_value_t)
Exemple #4
0
    def testIteratorGetNextAsOptional(self):
        ds = dataset_ops.Dataset.range(3)
        iterator = ds.make_initializable_iterator()
        next_elem = iterator_ops.get_next_as_optional(iterator)
        self.assertTrue(isinstance(next_elem, optional_ops.Optional))
        self.assertEqual(ds.output_types, next_elem.output_types)
        self.assertEqual(ds.output_shapes, next_elem.output_shapes)
        self.assertEqual(ds.output_classes, next_elem.output_classes)
        elem_has_value_t = next_elem.has_value()
        elem_value_t = next_elem.get_value()
        with self.test_session() as sess:
            # Before initializing the iterator, evaluating the optional fails with
            # a FailedPreconditionError.
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(elem_has_value_t)
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(elem_value_t)

            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            sess.run(iterator.initializer)
            for i in range(3):
                elem_has_value, elem_value = sess.run(
                    [elem_has_value_t, elem_value_t])
                self.assertTrue(elem_has_value)
                self.assertEqual(i, elem_value)

            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                self.assertFalse(sess.run(elem_has_value_t))
                with self.assertRaises(errors.InvalidArgumentError):
                    sess.run(elem_value_t)
Exemple #5
0
    def aug_body():
        """Main body passed to _tf_while_stmt."""
        opt_iterate = iterator_ops.get_next_as_optional(iter_)
        has_next.value = opt_iterate.has_value()
        loop_vars = aug_get_state(
        )  # updated by set_state() in _tf_while_loop.

        def main_path():
            body(opt_iterate.get_value())
            new_loop_vars = aug_get_state()
            # Note: this verification duplicates the one performed in tf_while_stmt,
            # but needs to be done earlier to prevent the tf.cond from blowing up
            # first.
            _verify_tf_loop_vars(init_vars, loop_vars, new_loop_vars,
                                 symbol_names, opts)
            return new_loop_vars

        def noop_path():
            return loop_vars

        # TODO(mdan): If tf.while_loop supported Optional, this could be avoided.
        # Calling set_state so that get_state() _tf_while_loop sees the conditional
        # tensors.
        aug_set_state(
            control_flow_ops.cond(has_next.value, main_path, noop_path))
Exemple #6
0
 def get_next_as_optional(self):
   result = []
   for i, device in enumerate(self._devices):
     with ops.device(device):
       result.append(iterator_ops.get_next_as_optional(
           self._device_iterators[i]))
   return result
  def testIteratorGetNextAsOptional(self):
    ds = dataset_ops.Dataset.range(3)
    iterator = ds.make_initializable_iterator()
    next_elem = iterator_ops.get_next_as_optional(iterator)
    self.assertTrue(isinstance(next_elem, optional_ops.Optional))
    self.assertEqual(ds.output_types, next_elem.output_types)
    self.assertEqual(ds.output_shapes, next_elem.output_shapes)
    self.assertEqual(ds.output_classes, next_elem.output_classes)
    elem_has_value_t = next_elem.has_value()
    elem_value_t = next_elem.get_value()
    with self.test_session() as sess:
      # Before initializing the iterator, evaluating the optional fails with
      # a FailedPreconditionError.
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(elem_has_value_t)
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(elem_value_t)

      # For each element of the dataset, assert that the optional evaluates to
      # the expected value.
      sess.run(iterator.initializer)
      for i in range(3):
        elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
        self.assertTrue(elem_has_value)
        self.assertEqual(i, elem_value)

      # After exhausting the iterator, `next_elem.has_value()` will evaluate to
      # false, and attempting to get the value will fail.
      for _ in range(2):
        self.assertFalse(sess.run(elem_has_value_t))
        with self.assertRaises(errors.InvalidArgumentError):
          sess.run(elem_value_t)
  def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu):
    if not works_on_gpu and test.is_gpu_available():
      self.skipTest("Test case not yet supported on GPU.")
    ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
    iterator = ds.make_initializable_iterator()
    next_elem = iterator_ops.get_next_as_optional(iterator)
    self.assertIsInstance(next_elem, optional_ops.Optional)
    self.assertTrue(
        next_elem.value_structure.is_compatible_with(
            structure.Structure.from_value(tf_value_fn())))
    elem_has_value_t = next_elem.has_value()
    elem_value_t = next_elem.get_value()
    with self.cached_session() as sess:
      # Before initializing the iterator, evaluating the optional fails with
      # a FailedPreconditionError.
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(elem_has_value_t)
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(elem_value_t)

      # For each element of the dataset, assert that the optional evaluates to
      # the expected value.
      sess.run(iterator.initializer)
      for _ in range(3):
        elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
        self.assertTrue(elem_has_value)
        self._assertElementValueEqual(np_value, elem_value)

      # After exhausting the iterator, `next_elem.has_value()` will evaluate to
      # false, and attempting to get the value will fail.
      for _ in range(2):
        self.assertFalse(sess.run(elem_has_value_t))
        with self.assertRaises(errors.InvalidArgumentError):
          sess.run(elem_value_t)
 def get_next_as_optional(self):
   result = []
   i = 0
   for device in self._devices:
     with ops.device(device):
       result.append(iterator_ops.get_next_as_optional(
           self._device_iterators[i]))
     i += 1
   return result
def next_tf_iterator(iterator, default=UNSPECIFIED):
    if default is UNSPECIFIED:
        # Without a default, fall back to the "normal" behavior which raises
        # a runtime exception.
        return next(iterator)
    opt_iterate = iterator_ops.get_next_as_optional(iterator)
    _verify_structure_compatible('the default argument', 'the iterate',
                                 default, iterator.element_spec)
    return control_flow_ops.cond(opt_iterate.has_value(),
                                 opt_iterate.get_value, lambda: default)
Exemple #11
0
  def while_body(has_next, state):
    """Main loop body."""
    opt_iterate = iterator_ops.get_next_as_optional(itr)
    has_next = opt_iterate.has_value()

    if not init_state:
      # cond_v2 requires at least one state tensor in V1.
      dummy_state = (constant_op.constant(()),)
    else:
      dummy_state = ()

    # TODO(mdan): If tf.while_loop supported Optional, this could be avoided.
    new_state = control_flow_ops.cond(
        has_next,
        lambda: dummy_state + while_body_actual(opt_iterate, *state),
        lambda: dummy_state + state)

    if dummy_state:
      new_state = new_state[1:]

    return has_next, new_state
Exemple #12
0
  def while_body(has_next, state):
    """Main loop body."""
    opt_iterate = iterator_ops.get_next_as_optional(itr)
    has_next = opt_iterate.has_value()

    if not init_state:
      # cond_v2 requires at least one state tensor in V1.
      dummy_state = (constant_op.constant(()),)
    else:
      dummy_state = ()

    # TODO(mdan): If tf.while_loop supported Optional, this could be avoided.
    new_state = control_flow_ops.cond(
        has_next,
        lambda: dummy_state + while_body_actual(opt_iterate, *state),
        lambda: dummy_state + state)

    if dummy_state:
      new_state = new_state[1:]

    return has_next, new_state