Пример #1
0
    def __init__(self,
                 checkpoint_dir,
                 save_checkpoint_steps=1000,
                 saver=None,
                 do_summary=True,
                 is_chief=True):
        """ Initializes the hook.

        Args:
            checkpoint_dir: A string, base directory for the checkpoint files.
            save_checkpoint_steps: A python integer, save every N steps.
            saver: `Saver` object, used for saving.
            do_summary: Whether to save summaries.
            is_chief: Whether this is the chief process.
        """
        tf.logging.info("Create CheckpointSaverHook.")
        if saver is None:
            saver = saver_lib._get_saver_or_default()  # pylint: disable=protected-access
        self._saver = saver
        self._checkpoint_dir = checkpoint_dir
        self._save_path = os.path.join(checkpoint_dir,
                                       GlobalNames.MODEL_CKPT_FILENAME)
        # save every n steps
        self._save_checkpoint_steps = save_checkpoint_steps
        # variable for session.run
        self._global_step = training_util.get_global_step()
        # for after create session
        self._do_summary = do_summary
        self._is_chief = is_chief
        # timer & summary writer
        self._timer = None
        self._summary_writer = None
Пример #2
0
    def __init__(self,
                 checkpoint_dir,
                 checkpoint_steps,
                 model_core,
                 dev_fetches=[],
                 firein_steps=0,
                 checkpoint_basename="model.ckpt",
                 dev_n=5,
                 dev_batch_size=128,
                 listeners=None):

        logging.info("Create NickCheckpointSaverHook.")
        if model_core.saver is None:
            model_core.saver = saver_lib._get_saver_or_default()  # pylint: disable=protected-access
        self._checkpoint_dir = checkpoint_dir
        self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
        self._saver = model_core.saver
        self._preproc_fn = model_core.preproc
        self._fetch_data_fn = model_core.fetch_data
        self._dev_fetches = dev_fetches
        self._firein_steps = firein_steps
        self._summary_tag_scope = model_core.model_kind

        self._dev_n = dev_n
        self._dev_batch_size = dev_batch_size

        self._timer = tf.train.SecondOrStepTimer(every_secs=None,
                                                 every_steps=checkpoint_steps)
        self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
        self._listeners = listeners or []
Пример #3
0
    def _update_bleu_ckpt(self, run_context, bleu, hypothesis, global_step):
        """ Updates the best checkpoints according to BLEU score and
        removes the worst model if the number of checkpoint archives
        exceeds maximum_keep_models.

        If the model does not improves BLEU score anymore (hits the
        maximum patience), request stop session.

        Args:
            run_context: A `SessionRunContext` object.
            bleu: A python float, the BLEU score derived by the model
              at this step.
            hypothesis: A list of hypothesis for validation set.
            global_step: A python integer, the current training step.
        """
        if bleu >= self._best_bleu_score:
            self._best_bleu_score = bleu
            self._bad_count = 0
        else:
            self._bad_count += 1
        if self._bad_count >= self._estop_patience_max and self._early_stop:
            tf.logging.info("early stop.")
            run_context.request_stop()
        # saving checkpoints if eval_steps and save_checkpoint_steps mismatch
        if len(self._best_checkpoint_names) == 0 or bleu > self._best_checkpoint_bleus[0]:
            with open_file(self._tmp_trans_file_prefix + str(global_step), mode="w") as fw:
                fw.write('\n'.join(hypothesis) + "\n")
            if not gfile.Exists("{}-{}.meta".format(
                    os.path.join(self._checkpoint_dir, Constants.MODEL_CKPT_FILENAME), global_step)):
                saver = saver_lib._get_saver_or_default()
                saver.save(run_context.session,
                           os.path.join(self._checkpoint_dir, Constants.MODEL_CKPT_FILENAME),
                           global_step=global_step)
            backup_dirname = os.path.join(self._model_configs["model_dir"], "../") \
                             + "{dirname_prefix}_iter{global_step}_bleu{bleu}".format(
                dirname_prefix=Constants.BACKUP_MODEL_DIRNAME_PREFIX,
                global_step=global_step,
                bleu=("%.1f" % bleu))
            tf.logging.info("Saving to directoruy: {}/".format(backup_dirname))
            os.system("mkdir {backup_dirname};"
                      "cp {ckpt_dirname}/checkpoint {backup_dirname}/;"
                      "cp {ckpt_dirname}/{model_config} {backup_dirname}/;"
                      "cp {ckpt_dirname}/{model_analysis} {backup_dirname}/;"
                      "cp {ckpt_dirname}/*{global_step}* {backup_dirname}/".format(
                backup_dirname=backup_dirname,
                ckpt_dirname=self._checkpoint_dir,
                model_config=Constants.MODEL_CONFIG_YAML_FILENAME,
                model_analysis=Constants.MODEL_ANALYSIS_FILENAME,
                global_step=global_step))
            self._best_checkpoint_bleus.append(bleu)
            self._best_checkpoint_names.append(backup_dirname)
            if len(self._best_checkpoint_bleus) > self._maximum_keep_models:
                tidx = numpy.argsort(self._best_checkpoint_bleus)
                _bleus = [self._best_checkpoint_bleus[i] for i in tidx]
                _names = [self._best_checkpoint_names[i] for i in tidx]
                self._best_checkpoint_bleus = _bleus[1:]
                self._best_checkpoint_names = _names[1:]
                os.system("rm -rf {}".format(_names[0]))
            self._write_ckpt_bleulog()
Пример #4
0
    def _update_bleu_ckpt(self, run_context, bleu, global_step):
        """ Updates the best checkpoints according to BLEU score and
        removes the worst model if the number of checkpoint archives
        exceeds maximum_keep_models.

        If the model does not improves BLEU score anymore (hits the
        maximum patience), request stop session.

        Args:
            run_context: A `SessionRunContext` object.
            bleu: A python float, the BLEU score derived by the model
              at this step.
            global_step: A python integer, the current training step.
        """
        if bleu >= self._best_bleu_score:
            self._best_bleu_score = bleu
            self._estop_patience = 0
        else:
            self._estop_patience += 1
        if self._estop_patience >= self._estop_patience_max and self._early_stop:
            tf.logging.info("early stop.")
            run_context.request_stop()
        # saving checkpoints if eval_steps and save_checkpoint_steps mismatch
        if not gfile.Exists("{}-{}.meta".format(
                os.path.join(self._checkpoint_dir,
                             GlobalNames.MODEL_CKPT_FILENAME), global_step)):
            saver = saver_lib._get_saver_or_default()
            saver.save(run_context.session,
                       os.path.join(self._checkpoint_dir,
                                    GlobalNames.MODEL_CKPT_FILENAME),
                       global_step=global_step)
        if len(self._best_checkpoint_names
               ) == 0 or bleu > self._best_checkpoint_bleus[0]:
            tarname = "{}{}.tar.gz".format(
                GlobalNames.CKPT_TGZ_FILENAME_PREFIX, global_step)
            os.system(
                "tar -zcvf {tarname} {checkpoint} {model_config} {model_analysis} {ckptdir}/*{global_step}*"
                .format(tarname=tarname,
                        checkpoint=os.path.join(self._checkpoint_dir,
                                                "checkpoint"),
                        model_config=os.path.join(
                            self._checkpoint_dir,
                            GlobalNames.MODEL_CONFIG_YAML_FILENAME),
                        model_analysis=os.path.join(
                            self._checkpoint_dir,
                            GlobalNames.MODEL_ANALYSIS_FILENAME),
                        ckptdir=self._checkpoint_dir,
                        global_step=global_step))
            self._best_checkpoint_bleus.append(bleu)
            self._best_checkpoint_names.append(tarname)
            if len(self._best_checkpoint_bleus) > self._maximum_keep_models:
                tidx = numpy.argsort(self._best_checkpoint_bleus)
                _bleus = [self._best_checkpoint_bleus[i] for i in tidx]
                _names = [self._best_checkpoint_names[i] for i in tidx]
                self._best_checkpoint_bleus = _bleus[1:]
                self._best_checkpoint_names = _names[1:]
                os.system("rm {}".format(_names[0]))
            self._write_ckpt_bleulog()
Пример #5
0
    def finalize(self):
        """Creates operations if needed and finalizes the graph."""
        if self._init_op is None:

            def default_init_op():
                return control_flow_ops.group(
                    variables.global_variables_initializer(),
                    resources.initialize_resources(
                        resources.shared_resources()))

            self._init_op = Scaffold.get_or_default('init_op',
                                                    ops.GraphKeys.INIT_OP,
                                                    default_init_op)
        if self._ready_op is None:

            def default_ready_op():
                return array_ops.concat([
                    variables.report_uninitialized_variables(),
                    resources.report_uninitialized_resources()
                ], 0)

            self._ready_op = Scaffold.get_or_default('ready_op',
                                                     ops.GraphKeys.READY_OP,
                                                     default_ready_op)
        if self._ready_for_local_init_op is None:

            def default_ready_for_local_init_op():
                return variables.report_uninitialized_variables(
                    variables.global_variables())

            self._ready_for_local_init_op = Scaffold.get_or_default(
                'ready_for_local_init_op',
                ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
                default_ready_for_local_init_op)
        if self._local_init_op is None:
            self._local_init_op = Scaffold.get_or_default(
                'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
                Scaffold._default_local_init_op)
        if self._summary_op is None:
            self._summary_op = Scaffold.get_or_default(
                'summary_op', ops.GraphKeys.SUMMARY_OP, summary.merge_all)
        # pylint: disable=g-long-lambda
        if self._saver is None:
            self._saver = training_saver._get_saver_or_default()  # pylint: disable=protected-access
        # pylint: enable=g-long-lambda
        self._saver.build()

        ops.get_default_graph().finalize()
        return self
Пример #6
0
    def _update_bleu_ckpt(self, run_context, bleu, global_step):
        """ Updates the best checkpoints according to BLEU score and
        removes the worst model if the number of checkpoint archives
        exceeds maximum_keep_models.

        If the model does not improves BLEU score anymore (hits the
        maximum patience), request stop session.

        Args:
            run_context: A `SessionRunContext` object.
            bleu: A python float, the BLEU score derived by the model
              at this step.
            global_step: A python integer, the current training step.
        """
        if bleu >= self._best_bleu_score:
            self._best_bleu_score = bleu
            self._bad_count = 0
        else:
            self._bad_count += 1
        if self._bad_count >= self._estop_patience_max and self._early_stop:
            tf.logging.info("early stop.")
            run_context.request_stop()
        # saving checkpoints if eval_steps and save_checkpoint_steps mismatch
        if not gfile.Exists("{}-{}.meta".format(
                os.path.join(self._checkpoint_dir, Constants.MODEL_CKPT_FILENAME), global_step)):
            saver = saver_lib._get_saver_or_default()
            saver.save(run_context.session,
                       os.path.join(self._checkpoint_dir, Constants.MODEL_CKPT_FILENAME),
                       global_step=global_step)
        if len(self._best_checkpoint_names) == 0 or bleu > self._best_checkpoint_bleus[0]:
            tarname = "{}{}.tar.gz".format(Constants.CKPT_TGZ_FILENAME_PREFIX, global_step)
            os.system("tar -zcvf {tarname} {checkpoint} {model_config} {model_analysis} {ckptdir}/*{global_step}*"
                      .format(tarname=tarname,
                              checkpoint=os.path.join(self._checkpoint_dir, "checkpoint"),
                              model_config=os.path.join(self._checkpoint_dir, Constants.MODEL_CONFIG_YAML_FILENAME),
                              model_analysis=os.path.join(self._checkpoint_dir, Constants.MODEL_ANALYSIS_FILENAME),
                              ckptdir=self._checkpoint_dir,
                              global_step=global_step))
            self._best_checkpoint_bleus.append(bleu)
            self._best_checkpoint_names.append(tarname)
            if len(self._best_checkpoint_bleus) > self._maximum_keep_models:
                tidx = numpy.argsort(self._best_checkpoint_bleus)
                _bleus = [self._best_checkpoint_bleus[i] for i in tidx]
                _names = [self._best_checkpoint_names[i] for i in tidx]
                self._best_checkpoint_bleus = _bleus[1:]
                self._best_checkpoint_names = _names[1:]
                os.system("rm {}".format(_names[0]))
            self._write_ckpt_bleulog()
Пример #7
0
  def finalize(self):
    """Creates operations if needed and finalizes the graph."""
    if self._init_op is None:
      def default_init_op():
        return control_flow_ops.group(
            variables.global_variables_initializer(),
            resources.initialize_resources(resources.shared_resources()))
      self._init_op = Scaffold.get_or_default(
          'init_op',
          ops.GraphKeys.INIT_OP,
          default_init_op)
    if self._ready_op is None:
      def default_ready_op():
        return array_ops.concat([
            variables.report_uninitialized_variables(),
            resources.report_uninitialized_resources()
        ], 0)
      self._ready_op = Scaffold.get_or_default(
          'ready_op', ops.GraphKeys.READY_OP,
          default_ready_op)
    if self._ready_for_local_init_op is None:
      def default_ready_for_local_init_op():
        return variables.report_uninitialized_variables(
            variables.global_variables())
      self._ready_for_local_init_op = Scaffold.get_or_default(
          'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
          default_ready_for_local_init_op)
    if self._local_init_op is None:
      self._local_init_op = Scaffold.get_or_default(
          'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
          Scaffold._default_local_init_op)
    if self._summary_op is None:
      self._summary_op = Scaffold.get_or_default('summary_op',
                                                 ops.GraphKeys.SUMMARY_OP,
                                                 summary.merge_all)
    # pylint: disable=g-long-lambda
    if self._saver is None:
      self._saver = training_saver._get_saver_or_default()  # pylint: disable=protected-access
    # pylint: enable=g-long-lambda
    self._saver.build()

    ops.get_default_graph().finalize()
    logging.info('Graph was finalized.')
    return self
    def __init__(self,
                 checkpoint_dir,
                 save_secs=None,
                 save_steps=None,
                 saver=None,
                 checkpoint_basename="model.ckpt",
                 scaffold=None,
                 listeners=None):
        """Initializes a `CheckpointSaverHook`.

    Args:
      checkpoint_dir: `str`, base directory for the checkpoint files.
      save_secs: `int`, save every N secs.
      save_steps: `int`, save every N steps.
      saver: `Saver` object, used for saving.
      checkpoint_basename: `str`, base name for the checkpoint files.
      scaffold: `Scaffold`, use to get saver object.
      listeners: List of `CheckpointSaverListener` subclass instances.
        Used for callbacks that run immediately before or after this hook saves
        the checkpoint.

    Raises:
      ValueError: One of `save_steps` or `save_secs` should be set.
      ValueError: Exactly one of saver or scaffold should be set.
    """
        logging.info("Create CheckpointSaverHook.")
        if saver is not None and scaffold is not None:
            raise ValueError("You cannot provide both saver and scaffold.")
        if saver is None and scaffold is None:
            saver = saver_lib._get_saver_or_default()  # pylint: disable=protected-access
        self._saver = saver
        self._checkpoint_dir = checkpoint_dir
        self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
        self._scaffold = scaffold
        self._timer = SecondOrStepTimer(every_secs=save_secs,
                                        every_steps=save_steps)
        self._listeners = listeners or []
  def __init__(self,
               checkpoint_dir,
               save_secs=None,
               save_steps=None,
               saver=None,
               checkpoint_basename="model.ckpt",
               scaffold=None,
               listeners=None):
    """Initializes a `CheckpointSaverHook`.

    Args:
      checkpoint_dir: `str`, base directory for the checkpoint files.
      save_secs: `int`, save every N secs.
      save_steps: `int`, save every N steps.
      saver: `Saver` object, used for saving.
      checkpoint_basename: `str`, base name for the checkpoint files.
      scaffold: `Scaffold`, use to get saver object.
      listeners: List of `CheckpointSaverListener` subclass instances.
        Used for callbacks that run immediately before or after this hook saves
        the checkpoint.

    Raises:
      ValueError: One of `save_steps` or `save_secs` should be set.
      ValueError: Exactly one of saver or scaffold should be set.
    """
    logging.info("Create CheckpointSaverHook.")
    if saver is not None and scaffold is not None:
      raise ValueError("You cannot provide both saver and scaffold.")
    if saver is None and scaffold is None:
      saver = saver_lib._get_saver_or_default()  # pylint: disable=protected-access
    self._saver = saver
    self._checkpoint_dir = checkpoint_dir
    self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
    self._scaffold = scaffold
    self._timer = SecondOrStepTimer(every_secs=save_secs,
                                    every_steps=save_steps)
    self._listeners = listeners or []