예제 #1
0
    def train(
            self,
            num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
        """See base class."""
        self.train_loop_begin()

        if self._train_iter is None:
            self._train_iter = tf.nest.map_structure(iter, self.train_dataset)

        if self._train_loop_fn is None:
            train_fn = self.train_step
            if self._use_tf_while_loop:
                self._train_loop_fn = utils.create_tf_while_loop_fn(train_fn)
                if self._use_tpu_summary_optimization:
                    self._train_loop_fn = utils.train_function_with_summaries(
                        self._train_loop_fn)
                else:
                    self._train_loop_fn = tf.function(self._train_loop_fn)
            else:
                if self._use_tf_function:
                    train_fn = tf.function(train_fn)
                self._train_loop_fn = utils.create_loop_fn(train_fn)

        self._train_loop_fn(self._train_iter, num_steps)
        return self.train_loop_end()
예제 #2
0
  def evaluate(
      self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
    """See base class."""
    outputs = self.eval_begin()  # pylint: disable=assignment-from-no-return

    eval_iter = tf.nest.map_structure(iter, self._eval_dataset)
    if self._eval_loop_fn is None:
      eval_fn = self.eval_step
      if self._eval_use_tf_function:
        eval_fn = tf.function(eval_fn)
      self._eval_loop_fn = utils.create_loop_fn(eval_fn)

    outputs = self._eval_loop_fn(
        eval_iter, num_steps, state=outputs, reduce_fn=self.eval_reduce)
    if outputs is None:
      return self.eval_end()
    else:
      return self.eval_end(outputs)