示例#1
0
 def GetNext(self):
   """Returns the next element from the dataset."""
   # Use `init_scope()` to ensure that the datasets and iterators are created
   # outside of the function-building graph. This ensures that these creation
   # operations are not repeated in subsequent `tf.function` calls.
   with tf.init_scope():
     self._InitIterator()
   if py_utils.GetUnitTestSession():
     self.Initialize(py_utils.GetUnitTestSession())
   return self._iterator[self.host_id].get_next()
示例#2
0
 def GetNext(self):
     """Override of the root's GetNext to support checking repeat sentinel."""
     self._InitIterator()
     if py_utils.GetUnitTestSession():
         self.Initialize(py_utils.GetUnitTestSession())
     batch = self._iterator[self.host_id].get_next()
     # Sentinel check.
     if self._repeat_with_sentinel and not self._repeat_steps:
         assert_op = tf.debugging.assert_none_equal(
             batch[self.params.sentinel_key],
             tf.constant(self.params.sentinel_value),
             summarize=1,
             message='REPEAT_SENTINEL_')
         tf.logging.info('sentinel constant dtype %r',
                         tf.constant(self.params.sentinel_value))
         with tf.control_dependencies([assert_op]):
             # This identity transform will throw tf.errors.InvalidArgumentError
             # if assert_op fails (sentinel_key takes on sentinel_value).
             batch = batch.Transform(tf.identity)
     return batch
示例#3
0
 def GetNext(self):
   """Override of the root's GetNext to support checking repeat sentinel."""
   # Use `init_scope()` to ensure that the datasets and iterators are created
   # outside of the function-building graph. This ensures that these creation
   # operations are not repeated in subsequent `tf.function` calls.
   with tf.init_scope():
     self._InitIterator()
   if py_utils.GetUnitTestSession():
     self.Initialize(py_utils.GetUnitTestSession())
   batch = self._iterator[self.host_id].get_next()
   # Sentinel check.
   if self._repeat_with_sentinel and not self._repeat_steps:
     assert_op = tf.debugging.assert_none_equal(
         batch[self.params.sentinel_key],
         tf.constant(self.params.sentinel_value),
         summarize=1,
         message='REPEAT_SENTINEL_')
     tf.logging.info('sentinel constant dtype %r',
                     tf.constant(self.params.sentinel_value))
     with tf.control_dependencies([assert_op]):
       # This identity transform will throw tf.errors.InvalidArgumentError
       # if assert_op fails (sentinel_key takes on sentinel_value).
       batch = batch.Transform(tf.identity)
   return batch
示例#4
0
 def GetNext(self):
     """Returns the next element from the dataset."""
     self._InitIterator()
     if py_utils.GetUnitTestSession():
         self.Initialize(py_utils.GetUnitTestSession())
     return self._iterator[self.host_id].get_next()