def create_eval_loop_fn(self, has_state: bool):
        """Creates an eval loop from the current step function and options.

    Args:
      has_state: If the step function has state, state will be kept in the loop.

    Returns:
      The eval loop function, i.e. wrapper of multiple eval steps.
    """
        eval_step_fn = self.eval_step
        if self._eval_options.use_tf_while_loop:
            # TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
            # even when it is not used inside the loop. To workaround this limitation,
            # we have to build two tf.functions for it.
            if has_state:
                loop_fn = loop_fns.create_tf_while_loop_fn_with_state(
                    eval_step_fn)
            else:
                loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn)
            loop_fn = tf.function(loop_fn)
        else:
            if self._eval_options.use_tf_function:
                eval_step_fn = tf.function(eval_step_fn)
            loop_fn = loop_fns.create_loop_fn(eval_step_fn)
        return loop_fn
def _create_train_loop_fn(train_step_fn, options: StandardTrainerOptions):
  """Creates a training loop from the given step function and options."""
  if options.use_tf_while_loop:
    loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn)
    if options.use_tpu_summary_optimization:
      loop_fn = loop_fns.LoopFnWithSummaries(loop_fn)
    else:
      loop_fn = tf.function(loop_fn)
  else:
    if options.use_tf_function:
      train_step_fn = tf.function(train_step_fn)
    loop_fn = loop_fns.create_loop_fn(train_step_fn)
  return loop_fn
Example #3
0
def _create_eval_loop_fn(eval_step_fn, has_state: bool,
                         options: StandardEvaluatorOptions):
    """Create evaluation loop function."""
    if options.use_tf_while_loop:
        # TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
        # even when it is not used inside the loop. To workaround this limitation,
        # we have to build two tf.functions for it.
        if has_state:
            loop_fn = loop_fns.create_tf_while_loop_fn_with_state(eval_step_fn)
        else:
            loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn)
        loop_fn = tf.function(loop_fn)
    else:
        if options.use_tf_function:
            eval_step_fn = tf.function(eval_step_fn)
        loop_fn = loop_fns.create_loop_fn(eval_step_fn)
    return loop_fn
    def create_train_loop_fn(self):
        """Creates a training loop from the current step function and options.

    Returns:
      The train loop function, i.e. wrapper of multiple train steps.
    """
        train_step_fn = self.train_step
        if self._train_options.use_tf_while_loop:
            loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn)
            if self._train_options.use_tpu_summary_optimization:
                loop_fn = loop_fns.LoopFnWithSummaries(loop_fn)
            else:
                loop_fn = tf.function(loop_fn)
        else:
            if self._train_options.use_tf_function:
                train_step_fn = tf.function(train_step_fn)
            loop_fn = loop_fns.create_loop_fn(train_step_fn)
        return loop_fn
def _create_eval_loop_fn(eval_step_fn, options: StandardEvaluatorOptions):
  if options.use_tf_function:
    eval_step_fn = tf.function(eval_step_fn)
  return loop_fns.create_loop_fn(eval_step_fn)