def create_eval_loop_fn(self, has_state: bool): """Creates an eval loop from the current step function and options. Args: has_state: If the step function has state, state will be kept in the loop. Returns: The eval loop function, i.e. wrapper of multiple eval steps. """ eval_step_fn = self.eval_step if self._eval_options.use_tf_while_loop: # TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input # even when it is not used inside the loop. To workaround this limitation, # we have to build two tf.functions for it. if has_state: loop_fn = loop_fns.create_tf_while_loop_fn_with_state( eval_step_fn) else: loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn) loop_fn = tf.function(loop_fn) else: if self._eval_options.use_tf_function: eval_step_fn = tf.function(eval_step_fn) loop_fn = loop_fns.create_loop_fn(eval_step_fn) return loop_fn
def _create_train_loop_fn(train_step_fn, options: StandardTrainerOptions): """Creates a training loop from the given step function and options.""" if options.use_tf_while_loop: loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn) if options.use_tpu_summary_optimization: loop_fn = loop_fns.LoopFnWithSummaries(loop_fn) else: loop_fn = tf.function(loop_fn) else: if options.use_tf_function: train_step_fn = tf.function(train_step_fn) loop_fn = loop_fns.create_loop_fn(train_step_fn) return loop_fn
def _create_eval_loop_fn(eval_step_fn, has_state: bool, options: StandardEvaluatorOptions): """Create evaluation loop function.""" if options.use_tf_while_loop: # TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input # even when it is not used inside the loop. To workaround this limitation, # we have to build two tf.functions for it. if has_state: loop_fn = loop_fns.create_tf_while_loop_fn_with_state(eval_step_fn) else: loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn) loop_fn = tf.function(loop_fn) else: if options.use_tf_function: eval_step_fn = tf.function(eval_step_fn) loop_fn = loop_fns.create_loop_fn(eval_step_fn) return loop_fn
def create_train_loop_fn(self): """Creates a training loop from the current step function and options. Returns: The train loop function, i.e. wrapper of multiple train steps. """ train_step_fn = self.train_step if self._train_options.use_tf_while_loop: loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn) if self._train_options.use_tpu_summary_optimization: loop_fn = loop_fns.LoopFnWithSummaries(loop_fn) else: loop_fn = tf.function(loop_fn) else: if self._train_options.use_tf_function: train_step_fn = tf.function(train_step_fn) loop_fn = loop_fns.create_loop_fn(train_step_fn) return loop_fn
def _create_eval_loop_fn(eval_step_fn, options: StandardEvaluatorOptions): if options.use_tf_function: eval_step_fn = tf.function(eval_step_fn) return loop_fns.create_loop_fn(eval_step_fn)