def _create_train_loop_fn(train_step_fn, options: StandardTrainerOptions): """Creates a training loop from the given step function and options.""" if options.use_tf_while_loop: loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn) if options.use_tpu_summary_optimization: loop_fn = loop_fns.LoopFnWithSummaries(loop_fn) else: loop_fn = tf.function(loop_fn) else: if options.use_tf_function: train_step_fn = tf.function(train_step_fn) loop_fn = loop_fns.create_loop_fn(train_step_fn) return loop_fn
def create_train_loop_fn(self): """Creates a training loop from the current step function and options. Returns: The train loop function, i.e. wrapper of multiple train steps. """ train_step_fn = self.train_step if self._train_options.use_tf_while_loop: loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn) if self._train_options.use_tpu_summary_optimization: loop_fn = loop_fns.LoopFnWithSummaries(loop_fn) else: loop_fn = tf.function(loop_fn) else: if self._train_options.use_tf_function: train_step_fn = tf.function(train_step_fn) loop_fn = loop_fns.create_loop_fn(train_step_fn) return loop_fn