def train( self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]: """See base class.""" self.train_loop_begin() if self._train_iter is None: self._train_iter = tf.nest.map_structure(iter, self.train_dataset) if self._train_loop_fn is None: train_fn = self.train_step if self._use_tf_while_loop: self._train_loop_fn = utils.create_tf_while_loop_fn(train_fn) if self._use_tpu_summary_optimization: self._train_loop_fn = utils.train_function_with_summaries( self._train_loop_fn) else: self._train_loop_fn = tf.function(self._train_loop_fn) else: if self._use_tf_function: train_fn = tf.function(train_fn) self._train_loop_fn = utils.create_loop_fn(train_fn) self._train_loop_fn(self._train_iter, num_steps) return self.train_loop_end()
def evaluate( self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]: """See base class.""" outputs = self.eval_begin() # pylint: disable=assignment-from-no-return eval_iter = tf.nest.map_structure(iter, self._eval_dataset) if self._eval_loop_fn is None: eval_fn = self.eval_step if self._eval_use_tf_function: eval_fn = tf.function(eval_fn) self._eval_loop_fn = utils.create_loop_fn(eval_fn) outputs = self._eval_loop_fn( eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce) if outputs is None: return self.eval_end() else: return self.eval_end(outputs)