Esempio n. 1
0
    def begin(self):
        """ Creates StepTimer and SummaryWriter.

        Calls `_prepare()` method implemented by derived classes.
        """
        self._prepare()
        self._timer = StepTimer(every_steps=self._eval_steps,
                                start_at=self._start_at)
        if self._do_summary:
            self._summary_writer = SummaryWriter(self._checkpoint_dir)
Esempio n. 2
0
 def after_create_session(self, session, coord):
     checkpoint_path = saver_lib.latest_checkpoint(self._checkpoint_dir)
     if checkpoint_path:
         # reloading model
         self._saver.restore(session, checkpoint_path)
         gs = session.run(self._global_step)
         tf.logging.info(
             "CheckpointSaverHook (after_create_session): reloading models and reset global_step={}".format(gs))
         StepTimer.reset_init_triggered_step(gs)
     elif self._reload_var_ops:
         tf.logging.info("Assign all variables with pretrained variables.")
         session.run(self._reload_var_ops)
Esempio n. 3
0
 def begin(self):
     """ Creates StepTimer and SummaryWriter. """
     self._first_call = True
     self._timer = StepTimer(every_steps=self._save_checkpoint_steps)
     if self._do_summary:
         self._summary_writer = SummaryWriter(self._checkpoint_dir)
     self._reload_var_ops = None
     if not saver_lib.latest_checkpoint(
             self._checkpoint_dir) and self._pretrain_model:
         self._reload_var_ops = load_pretrain_model(
             model_name=self._model_name,
             pretrain_model_dir=self._pretrain_model,
             problem_name=self._problem_name)
Esempio n. 4
0
 def after_create_session(self, session, coord):
     checkpoint_path = saver_lib.latest_checkpoint(self._checkpoint_dir)
     if checkpoint_path:
         # reloading model
         self._saver.restore(session, checkpoint_path)
         gs = session.run(self._global_step)
         tf.logging.info(
             "CheckpointSaverHook (after_create_session): reloading models and reset global_step={}"
             .format(gs))
         StepTimer.reset_init_triggered_step(gs)
     elif self._reload_var_ops:
         tf.logging.info("Assign all variables with pretrained variables.")
         session.run(self._reload_var_ops)
Esempio n. 5
0
    def before_run(self, run_context):
        """ Dumps graphs and loads checkpoint if there exits.

        Called before each call to run().

        Args:
            run_context: A `SessionRunContext` object.

        Returns: A `SessionRunArgs` object containing global_step.
        """
        # We do write graph and saver_def at the first call of before_run.
        # We cannot do this in begin, since we let other hooks to change graph and
        # add variables in begin. Graph is finalized after all begin calls.
        if self._is_chief and self._first_call:
            training_util.write_graph(
                ops.get_default_graph().as_graph_def(add_shapes=True),
                self._checkpoint_dir, "graph.pbtxt")
            # dump model details "model_analysis.txt"
            dump_model_analysis(self._checkpoint_dir)  # dump model configs
            graph = ops.get_default_graph()
            meta_graph_def = meta_graph.create_meta_graph_def(
                graph_def=graph.as_graph_def(add_shapes=True),
                saver_def=self._saver.saver_def)
            if self._summary_writer is not None:
                self._summary_writer.add_graph(graph)
                self._summary_writer.add_meta_graph(meta_graph_def)
            tf.logging.info("CheckpointSaverHook (before_run): dump graph...")
        checkpoint_path = saver_lib.latest_checkpoint(self._checkpoint_dir)
        if self._first_call:
            if checkpoint_path:
                # reloading model
                self._saver.restore(run_context.session, checkpoint_path)
                gs = run_context.session.run(self._global_step)
                tf.logging.info(
                    "CheckpointSaverHook (before_run): reloading models and reset global_step={}"
                    .format(gs))
                StepTimer.reset_init_triggered_step(gs)
            elif self._reload_var_ops:
                tf.logging.info(
                    "Assign all variables with pretrained variables.")
                run_context.session.run(self._reload_var_ops)
        self._first_call = False
        self._timer.register_before_run()
        return tf.train.SessionRunArgs(self._global_step)
Esempio n. 6
0
    def begin(self):
        """ Creates StepTimer and SummaryWriter.

        Calls `_prepare()` method implemented by derived classes.
        """
        self._prepare()
        self._timer = StepTimer(every_steps=self._eval_steps,
                                start_at=self._start_at)
        if self._do_summary:
            self._summary_writer = SummaryWriter(self._checkpoint_dir)
Esempio n. 7
0
 def begin(self):
     """ Creates StepTimer and SummaryWriter. """
     self._first_call = True
     self._timer = StepTimer(every_steps=self._save_checkpoint_steps)
     if self._do_summary:
         self._summary_writer = SummaryWriter(self._checkpoint_dir)
     self._reload_var_ops = None
     if not saver_lib.latest_checkpoint(self._checkpoint_dir) and self._pretrain_model:
         self._reload_var_ops = load_pretrain_model(
             model_name=self._model_name,
             pretrain_model_dir=self._pretrain_model,
             problem_name=self._problem_name)
Esempio n. 8
0
class CheckpointSaverHook(tf.train.SessionRunHook):
    """ Define the hook that saves checkpoints every N steps."""

    def __init__(self,
                 checkpoint_dir,
                 save_checkpoint_steps=1000,
                 saver=None,
                 pretrain_model=None,
                 problem_name=None,
                 model_name="njunmt.models.SequenceToSequence",
                 do_summary=True,
                 is_chief=True):
        """ Initializes the hook.

        Args:
            checkpoint_dir: A string, base directory for the checkpoint files.
            save_checkpoint_steps: A python integer, save every N steps.
            saver: `Saver` object, used for saving.
            pretrain_model: The pretrained model dir.
            problem_name: A string.
            model_name: The model name.
            do_summary: Whether to save summaries.
            is_chief: Whether this is the chief process.
        """
        tf.logging.info("Create CheckpointSaverHook.")
        if saver is None:
            saver = get_saver_or_default(max_to_keep=8)  # pylint: disable=protected-access
        self._saver = saver
        self._checkpoint_dir = checkpoint_dir
        self._save_path = os.path.join(checkpoint_dir, Constants.MODEL_CKPT_FILENAME)
        self._pretrain_model = pretrain_model
        self._problem_name = problem_name
        self._model_name = model_name
        # save every n steps
        self._save_checkpoint_steps = save_checkpoint_steps
        # variable for session.run
        self._global_step = training_util.get_global_step()
        # for after create session
        self._do_summary = do_summary
        self._is_chief = is_chief
        # timer & summary writer
        self._timer = None
        self._summary_writer = None

    def begin(self):
        """ Creates StepTimer and SummaryWriter. """
        self._first_call = True
        self._timer = StepTimer(every_steps=self._save_checkpoint_steps)
        if self._do_summary:
            self._summary_writer = SummaryWriter(self._checkpoint_dir)
        self._reload_var_ops = None
        if not saver_lib.latest_checkpoint(self._checkpoint_dir) and self._pretrain_model:
            self._reload_var_ops = load_pretrain_model(
                model_name=self._model_name,
                pretrain_model_dir=self._pretrain_model,
                problem_name=self._problem_name)

    def after_create_session(self, session, coord):
        checkpoint_path = saver_lib.latest_checkpoint(self._checkpoint_dir)
        if checkpoint_path:
            # reloading model
            self._saver.restore(session, checkpoint_path)
            gs = session.run(self._global_step)
            tf.logging.info(
                "CheckpointSaverHook (after_create_session): reloading models and reset global_step={}".format(gs))
            StepTimer.reset_init_triggered_step(gs)
        elif self._reload_var_ops:
            tf.logging.info("Assign all variables with pretrained variables.")
            session.run(self._reload_var_ops)

    def before_run(self, run_context):
        """ Dumps graphs and loads checkpoint if there exits.

        Called before each call to run().

        Args:
            run_context: A `SessionRunContext` object.

        Returns: A `SessionRunArgs` object containing global_step.
        """
        # We do write graph and saver_def at the first call of before_run.
        # We cannot do this in begin, since we let other hooks to change graph and
        # add variables in begin. Graph is finalized after all begin calls.
        if self._is_chief and self._first_call:
            training_util.write_graph(
                ops.get_default_graph().as_graph_def(add_shapes=True),
                self._checkpoint_dir,
                "graph.pbtxt")
            # dump model details "model_analysis.txt"
            dump_model_analysis(self._checkpoint_dir)  # dump model configs
            graph = ops.get_default_graph()
            meta_graph_def = meta_graph.create_meta_graph_def(
                graph_def=graph.as_graph_def(add_shapes=True),
                saver_def=self._saver.saver_def)
            if self._summary_writer is not None:
                self._summary_writer.add_graph(graph)
                self._summary_writer.add_meta_graph(meta_graph_def)
            tf.logging.info("CheckpointSaverHook (before_run): dump graph...")
        self._first_call = False
        return tf.train.SessionRunArgs(self._global_step)

    def after_run(self, run_context, run_values):
        """ Checks running steps and save checkpoints.

        Args:
            run_context: A `SessionRunContext` object.
            run_values: A SessionRunValues object.
        """
        global_step = run_values.results
        if self._is_chief and self._timer.should_trigger_for_step(global_step):
            self._timer.update_last_triggered_step(global_step)
            self._save(global_step, run_context.session)

    def _save(self, step, session):
        """ Saves checkpoints.

        Args:
            step: A python integer, running step.
            session: A TensorFlow Session.
        """
        """Saves the latest checkpoint."""
        self._saver.save(session, self._save_path, global_step=step)
        tf.logging.info("Saving checkpoints for {} into {}".format(step, self._save_path))
        if self._summary_writer is not None:
            self._summary_writer.add_session_log(
                SessionLog(
                    status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
                step)
Esempio n. 9
0
 def begin(self):
     """ Creates StepTimer and SummaryWriter. """
     self._timer = StepTimer(every_steps=self._display_steps)
     self._logging_timer = LoggingTimer()
     if self._do_summary:
         self._summary_writer = SummaryWriter(self._checkpoint_dir)
Esempio n. 10
0
class DisplayHook(tf.train.SessionRunHook):
    """ Define the hook to display training loss, training speed and
    learning rate every n steps and determine when to stop. """

    def __init__(self,
                 checkpoint_dir,
                 display_steps=100,
                 maximum_train_steps=None,
                 do_summary=True,
                 is_chief=True):
        """ Initializes the hook.

        Args:
            checkpoint_dir: A string, base directory for the checkpoint files.
            display_steps: A python integer, display every N steps.
            maximum_train_steps: A python integer, the maximum training steps.
            do_summary: Whether to save summaries when display.
            is_chief: Whether this is the chief process.do_summary:
        """

        tf.logging.info("Create DisplayHook.")
        self._checkpoint_dir = checkpoint_dir
        # display steps
        self._display_steps = display_steps
        self._maximum_train_steps = maximum_train_steps
        self._do_summary = do_summary
        self._is_chief = is_chief  # not used now

        # display values
        global_step = training_util.get_global_step()
        display_keys = ops.get_collection(Constants.DISPLAY_KEY_COLLECTION_NAME)
        display_values = ops.get_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME)
        self._display_args = dict(zip(display_keys, display_values))
        self._display_args["global_step"] = global_step
        # timer & summary writer
        self._timer = None
        self._logging_timer = None
        self._summary_writer = None

    def begin(self):
        """ Creates StepTimer and SummaryWriter. """
        self._timer = StepTimer(every_steps=self._display_steps)
        self._logging_timer = LoggingTimer()
        if self._do_summary:
            self._summary_writer = SummaryWriter(self._checkpoint_dir)

    def after_create_session(self, session, coord):
        self._logging_timer.update_last_triggered_time()

    def before_run(self, run_context):
        """ Dumps graphs and loads checkpoint if there exits.

        Called before each call to run().

        Args:
            run_context: A `SessionRunContext` object.

        Returns: A `SessionRunArgs` object containing global_step and
          arguments to be displayed.
        """
        return tf.train.SessionRunArgs(self._display_args)

    def after_run(self, run_context, run_values):
        """ Checks running steps and print args.

        Also checks the maximum training steps to raise stop request.

        Args:
            run_context: A `SessionRunContext` object.
            run_values: A SessionRunValues object.
        """
        global_step = run_values.results.pop("global_step")
        if self._timer.should_trigger_for_step(global_step):
            training_loss = run_values.results[Constants.TRAIN_LOSS_KEY_NAME]
            elapsed_steps, _ = self._timer.update_last_triggered_step(global_step)
            session_run_time = self._logging_timer.update_last_triggered_time()
            steps_per_sec = elapsed_steps * 1. / session_run_time
            secs_per_step = session_run_time * 1. / elapsed_steps

            tf.logging.info("Update %d \t TrainingLoss=%f   UD %f secs/step"
                            % (global_step, training_loss, secs_per_step))
            if self._summary_writer is not None:
                self._summary_writer.add_summary("global_step/sec", steps_per_sec, global_step)
                self._summary_writer.add_summary("global_step/secs_per_step", secs_per_step, global_step)
                for k, v in run_values.results.items():
                    self._summary_writer.add_summary(k, v, global_step)
            self._logging_timer.update_last_triggered_time()
        # hit maximum training steps
        if self._maximum_train_steps and global_step >= self._maximum_train_steps:
            tf.logging.info("Training maximum steps. maximum_train_step={}".format(self._maximum_train_steps))
            run_context.request_stop()
Esempio n. 11
0
class TextMetricSpec(tf.train.SessionRunHook):
    """ Define base class for metric hook.  """
    def __init__(self,
                 model_configs,
                 dataset,
                 start_at=0,
                 eval_steps=100,
                 do_summary=True,
                 model_name=None):
        """ Initializes base metric hook.

        Args:
            model_configs: A dictionary of all configurations.
            dataset: A `Dataset` object.
            start_at: A python integer, start to evaluate model at this step.
            eval_steps: A python integer, evaluate model every N steps.
            do_summary: Whether to save summaries.
            model_name: A string, the top scope name of all variables.
        """
        self._model_configs = copy.deepcopy(model_configs)
        self._dataset = dataset
        self._checkpoint_dir = model_configs["model_dir"]
        self._start_at = start_at
        self._eval_steps = eval_steps
        self._do_summary = do_summary
        self._global_step = tf.train.get_global_step()
        self._summary_writer = None
        self._timer = None
        self._model_name = model_name

    def begin(self):
        """ Creates StepTimer and SummaryWriter.

        Calls `_prepare()` method implemented by derived classes.
        """
        self._prepare()
        self._timer = StepTimer(every_steps=self._eval_steps,
                                start_at=self._start_at)
        if self._do_summary:
            self._summary_writer = SummaryWriter(self._checkpoint_dir)

    def before_run(self, run_context):
        """  Called before each call to run().

        Args:
            run_context: A `SessionRunContext` object.

        Returns: A `SessionRunArgs` object containing global_step.
        """
        self._timer.register_before_run()
        return tf.train.SessionRunArgs(self._global_step)

    def after_run(self, run_context, run_values):
        """ Checks running steps and do evaluation.

        Calls `_do_evaluation()` method implemented by derived classes.
        Args:
            run_context: A `SessionRunContext` object.
            run_values: A SessionRunValues object.
        """
        global_step = run_values.results
        if self._timer.should_trigger_for_step(global_step):
            self._do_evaluation(run_context, global_step)
            self._timer.update_last_triggered_step(global_step)

    @abstractmethod
    def _prepare(self):
        """ Prepares for evaluation, e.g. building the model (reusing variables)
        """
        raise NotImplementedError

    @abstractmethod
    def _do_evaluation(self, run_context, global_step):
        """ Evaluates the model.

        Args:
            run_context: A `SessionRunContext` object.
            global_step: A python integer, the current training step.
        """
        raise NotImplementedError
Esempio n. 12
0
 def begin(self):
     """ Creates StepTimer and SummaryWriter. """
     self._timer = StepTimer(every_steps=self._display_steps)
     if self._do_summary:
         self._summary_writer = SummaryWriter(self._checkpoint_dir)
Esempio n. 13
0
class DisplayHook(tf.train.SessionRunHook):
    """ Define the hook to display training loss, training speed and
    learning rate every n steps and determine when to stop. """
    def __init__(self,
                 checkpoint_dir,
                 display_steps=100,
                 maximum_train_steps=None,
                 do_summary=True,
                 is_chief=True):
        """ Initializes the hook.

        Args:
            checkpoint_dir: A string, base directory for the checkpoint files.
            display_steps: A python integer, display every N steps.
            maximum_train_steps: A python integer, the maximum training steps.
            do_summary: Whether to save summaries when display.
            is_chief: Whether this is the chief process.do_summary:
        """

        tf.logging.info("Create DisplayHook.")
        self._checkpoint_dir = checkpoint_dir
        # display steps
        self._display_steps = display_steps
        self._maximum_train_steps = maximum_train_steps
        self._do_summary = do_summary
        self._is_chief = is_chief  # not used now

        # display values
        global_step = training_util.get_global_step()
        display_keys = ops.get_collection(
            Constants.DISPLAY_KEY_COLLECTION_NAME)
        display_values = ops.get_collection(
            Constants.DISPLAY_VALUE_COLLECTION_NAME)
        self._display_args = dict(zip(display_keys, display_values))
        self._display_args["global_step"] = global_step
        # timer & summary writer
        self._timer = None
        self._summary_writer = None

    def begin(self):
        """ Creates StepTimer and SummaryWriter. """
        self._timer = StepTimer(every_steps=self._display_steps)
        if self._do_summary:
            self._summary_writer = SummaryWriter(self._checkpoint_dir)

    def before_run(self, run_context):
        """ Dumps graphs and loads checkpoint if there exits.

        Called before each call to run().

        Args:
            run_context: A `SessionRunContext` object.

        Returns: A `SessionRunArgs` object containing global_step and
          arguments to be displayed.
        """
        self._timer.register_before_run()
        return tf.train.SessionRunArgs(self._display_args)

    def after_run(self, run_context, run_values):
        """ Checks running steps and print args.

        Also checks the maximum training steps to raise stop request.

        Args:
            run_context: A `SessionRunContext` object.
            run_values: A SessionRunValues object.
        """
        global_step = run_values.results.pop("global_step")
        if self._timer.should_trigger_for_step(global_step):

            training_loss = run_values.results[Constants.TRAIN_LOSS_KEY_NAME]
            elapsed_steps, _ = self._timer.update_last_triggered_step(
                global_step)
            session_run_time = self._timer.get_session_run_time()
            steps_per_sec = elapsed_steps * 1. / session_run_time
            secs_per_step = session_run_time * 1. / elapsed_steps

            tf.logging.info("Update %d \t TrainingLoss=%f   UD %f secs/step" %
                            (global_step, training_loss, secs_per_step))
            if self._summary_writer is not None:
                self._summary_writer.add_summary("global_step/sec",
                                                 steps_per_sec, global_step)
                self._summary_writer.add_summary("global_step/secs_per_step",
                                                 secs_per_step, global_step)
                for k, v in run_values.results.items():
                    self._summary_writer.add_summary(k, v, global_step)
        # hit maximum training steps
        if self._maximum_train_steps and global_step >= self._maximum_train_steps:
            tf.logging.info(
                "Training maximum steps. maximum_train_step={}".format(
                    self._maximum_train_steps))
            run_context.request_stop()
Esempio n. 14
0
class CheckpointSaverHook(tf.train.SessionRunHook):
    """ Define the hook that saves checkpoints every N steps."""
    def __init__(self,
                 checkpoint_dir,
                 save_checkpoint_steps=1000,
                 saver=None,
                 pretrain_model=None,
                 problem_name=None,
                 model_name="njunmt.models.SequenceToSequence",
                 do_summary=True,
                 is_chief=True):
        """ Initializes the hook.

        Args:
            checkpoint_dir: A string, base directory for the checkpoint files.
            save_checkpoint_steps: A python integer, save every N steps.
            saver: `Saver` object, used for saving.
            pretrain_model: The pretrained model dir.
            problem_name: A string.
            model_name: The model name.
            do_summary: Whether to save summaries.
            is_chief: Whether this is the chief process.
        """
        tf.logging.info("Create CheckpointSaverHook.")
        if saver is None:
            saver = get_saver_or_default(max_to_keep=8)  # pylint: disable=protected-access
        self._saver = saver
        self._checkpoint_dir = checkpoint_dir
        self._save_path = os.path.join(checkpoint_dir,
                                       Constants.MODEL_CKPT_FILENAME)
        self._pretrain_model = pretrain_model
        self._problem_name = problem_name
        self._model_name = model_name
        # save every n steps
        self._save_checkpoint_steps = save_checkpoint_steps
        # variable for session.run
        self._global_step = training_util.get_global_step()
        # for after create session
        self._do_summary = do_summary
        self._is_chief = is_chief
        # timer & summary writer
        self._timer = None
        self._summary_writer = None

    def begin(self):
        """ Creates StepTimer and SummaryWriter. """
        self._first_call = True
        self._timer = StepTimer(every_steps=self._save_checkpoint_steps)
        if self._do_summary:
            self._summary_writer = SummaryWriter(self._checkpoint_dir)
        self._reload_var_ops = None
        if not saver_lib.latest_checkpoint(
                self._checkpoint_dir) and self._pretrain_model:
            self._reload_var_ops = load_pretrain_model(
                model_name=self._model_name,
                pretrain_model_dir=self._pretrain_model,
                problem_name=self._problem_name)

    def before_run(self, run_context):
        """ Dumps graphs and loads checkpoint if there exits.

        Called before each call to run().

        Args:
            run_context: A `SessionRunContext` object.

        Returns: A `SessionRunArgs` object containing global_step.
        """
        # We do write graph and saver_def at the first call of before_run.
        # We cannot do this in begin, since we let other hooks to change graph and
        # add variables in begin. Graph is finalized after all begin calls.
        if self._is_chief and self._first_call:
            training_util.write_graph(
                ops.get_default_graph().as_graph_def(add_shapes=True),
                self._checkpoint_dir, "graph.pbtxt")
            # dump model details "model_analysis.txt"
            dump_model_analysis(self._checkpoint_dir)  # dump model configs
            graph = ops.get_default_graph()
            meta_graph_def = meta_graph.create_meta_graph_def(
                graph_def=graph.as_graph_def(add_shapes=True),
                saver_def=self._saver.saver_def)
            if self._summary_writer is not None:
                self._summary_writer.add_graph(graph)
                self._summary_writer.add_meta_graph(meta_graph_def)
            tf.logging.info("CheckpointSaverHook (before_run): dump graph...")
        checkpoint_path = saver_lib.latest_checkpoint(self._checkpoint_dir)
        if self._first_call:
            if checkpoint_path:
                # reloading model
                self._saver.restore(run_context.session, checkpoint_path)
                gs = run_context.session.run(self._global_step)
                tf.logging.info(
                    "CheckpointSaverHook (before_run): reloading models and reset global_step={}"
                    .format(gs))
                StepTimer.reset_init_triggered_step(gs)
            elif self._reload_var_ops:
                tf.logging.info(
                    "Assign all variables with pretrained variables.")
                run_context.session.run(self._reload_var_ops)
        self._first_call = False
        self._timer.register_before_run()
        return tf.train.SessionRunArgs(self._global_step)

    def after_run(self, run_context, run_values):
        """ Checks running steps and save checkpoints.

        Args:
            run_context: A `SessionRunContext` object.
            run_values: A SessionRunValues object.
        """
        global_step = run_values.results
        if self._is_chief and self._timer.should_trigger_for_step(global_step):
            self._timer.update_last_triggered_step(global_step)
            self._save(global_step, run_context.session)

    def _save(self, step, session):
        """ Saves checkpoints.

        Args:
            step: A python integer, running step.
            session: A TensorFlow Session.
        """
        """Saves the latest checkpoint."""
        self._saver.save(session, self._save_path, global_step=step)
        tf.logging.info("Saving checkpoints for {} into {}".format(
            step, self._save_path))
        if self._summary_writer is not None:
            self._summary_writer.add_session_log(
                SessionLog(status=SessionLog.CHECKPOINT,
                           checkpoint_path=self._save_path), step)
Esempio n. 15
0
class TextMetricSpec(tf.train.SessionRunHook):
    """ Define base class for metric hook.  """

    def __init__(self,
                 model_configs,
                 dataset,
                 start_at=0,
                 eval_steps=100,
                 do_summary=True,
                 model_name=None):
        """ Initializes base metric hook.

        Args:
            model_configs: A dictionary of all configurations.
            dataset: A `Dataset` object.
            start_at: A python integer, start to evaluate model at this step.
            eval_steps: A python integer, evaluate model every N steps.
            do_summary: Whether to save summaries.
            model_name: A string, the top scope name of all variables.
        """
        self._model_configs = copy.deepcopy(model_configs)
        self._dataset = dataset
        self._checkpoint_dir = model_configs["model_dir"]
        self._start_at = start_at
        self._eval_steps = eval_steps
        self._do_summary = do_summary
        self._global_step = tf.train.get_global_step()
        self._summary_writer = None
        self._timer = None
        self._model_name = model_name

    def begin(self):
        """ Creates StepTimer and SummaryWriter.

        Calls `_prepare()` method implemented by derived classes.
        """
        self._prepare()
        self._timer = StepTimer(every_steps=self._eval_steps,
                                start_at=self._start_at)
        if self._do_summary:
            self._summary_writer = SummaryWriter(self._checkpoint_dir)

    def before_run(self, run_context):
        """  Called before each call to run().

        Args:
            run_context: A `SessionRunContext` object.

        Returns: A `SessionRunArgs` object containing global_step.
        """
        return tf.train.SessionRunArgs(self._global_step)

    def after_run(self, run_context, run_values):
        """ Checks running steps and do evaluation.

        Calls `_do_evaluation()` method implemented by derived classes.
        Args:
            run_context: A `SessionRunContext` object.
            run_values: A SessionRunValues object.
        """
        global_step = run_values.results
        if self._timer.should_trigger_for_step(global_step):
            self._do_evaluation(run_context, global_step)
            self._timer.update_last_triggered_step(global_step)

    @abstractmethod
    def _prepare(self):
        """ Prepares for evaluation, e.g. building the model (reusing variables)
        """
        raise NotImplementedError

    @abstractmethod
    def _do_evaluation(self, run_context, global_step):
        """ Evaluates the model.

        Args:
            run_context: A `SessionRunContext` object.
            global_step: A python integer, the current training step.
        """
        raise NotImplementedError
Esempio n. 16
0
 def begin(self):
     """ Creates StepTimer and SummaryWriter. """
     self._first_call = True
     self._timer = StepTimer(every_steps=self._save_checkpoint_steps)
     if self._do_summary:
         self._summary_writer = SummaryWriter(self._checkpoint_dir)