示例#1
0
    def __init__(self, train_dir, model, train_params=None, save_only=False):
        """Initialize Checkpointer.

    Args:
     train_dir: Training directory for saving checkpoints.
     model: A BaseModel instance or None.
     train_params: If specified, use these training params instead of those in
       the `model`.
     save_only: This checkpointer is only intended for saving checkpoints.
    """
        self._train_dir = train_dir
        self._save_only = save_only

        self._save_path = os.path.join(self._train_dir, 'ckpt')

        if train_params:
            self._train_params = train_params
            self._model = None
        else:
            assert model
            self._train_params = model.params.train
            self._model = model

        if not self._save_only:
            self._params = model.params
            self._model_tasks = model.tasks
            self._model = model

        self._next_checkpoint_seconds = 0
        self._save_interval_seconds = self._train_params.save_interval_seconds
        self._saver = self._GetSaver()

        self._uninitialized_vars = tf.report_uninitialized_variables(
            tf.global_variables())
示例#2
0
    def __init__(self,
                 train_dir,
                 models,
                 init_op=None,
                 train_params=None,
                 save_only=False):
        """Initialize Checkpointer.

    Args:
     train_dir: Training directory for saving checkpoints.
     models: One or a list of BaseModel instances. Cannot be empty. If there are
       more than one models and `train_params` is None, the save intervals will
       be only determined by the first model.
     init_op: The initialize variables op. If unset, it will call
       tf.global_variables_initializer().
     train_params: If specified, use these training params instead of those in
       the `model`.
     save_only: This checkpointer is only intended for saving checkpoints.
    """
        self._train_dir = train_dir
        self._save_only = save_only

        if init_op:
            self._init_op = init_op
        else:
            self._init_op = tf.global_variables_initializer()

        self._save_path = os.path.join(self._train_dir, 'ckpt')

        if not isinstance(models, (list, tuple)):
            models = [models]
        self._models = models

        if train_params:
            self._train_params = train_params
        else:
            self._train_params = models[0].params.train

        self._next_checkpoint_seconds = 0
        self._save_interval_seconds = self._train_params.save_interval_seconds
        self._save_interval_steps = self._train_params.save_interval_steps
        self._prev_ckpt_step = None
        self._saver = self._GetSaver()

        if not py_utils.IsEagerMode():
            self._uninitialized_vars = tf.report_uninitialized_variables(
                tf.global_variables())

        self._BuildInitFromCheckpointRules()
示例#3
0
    def __init__(self, train_dir, model):
        """Initialize Checkpointer.

    Args:
     train_dir: Training directory for saving checkpoints.
     model: Model.
    """
        self._train_dir = train_dir
        self._model = model
        self._params = model.params

        self._vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        self._uninitialized_vars = tf.report_uninitialized_variables(
            self._vars)
        self._initialize_vars = tf.global_variables_initializer()

        self._save_path = os.path.join(self._train_dir, 'ckpt')
        self._model_tasks = model.tasks

        tp = self._params.train
        self._save_interval_seconds = tp.save_interval_seconds
        self._next_checkpoint_seconds = 0
        self._saver = self._GetSaver()
示例#4
0
  def __init__(self,
               train_dir,
               model,
               init_op=None,
               train_params=None,
               save_only=False):
    """Initialize Checkpointer.

    Args:
     train_dir: Training directory for saving checkpoints.
     model: A BaseModel instance or None.
     init_op: The initialize variables op. If unset, it will call
       tf.global_variables_initializer().
     train_params: If specified, use these training params instead of those in
       the `model`.
     save_only: This checkpointer is only intended for saving checkpoints.
    """
    self._train_dir = train_dir
    self._save_only = save_only
    if init_op:
      self._init_op = init_op
    else:
      self._init_op = tf.global_variables_initializer()

    self._save_path = os.path.join(self._train_dir, 'ckpt')

    if train_params:
      self._train_params = train_params
      self._model = None
    else:
      assert model
      self._train_params = model.params.train
      self._model = model

    if self._save_only:
      self._params = None
    else:
      self._params = model.params
      self._model_tasks = model.tasks
      self._model = model

    self._next_checkpoint_seconds = 0
    self._save_interval_seconds = self._train_params.save_interval_seconds
    self._saver = self._GetSaver()

    self._uninitialized_vars = tf.report_uninitialized_variables(
        tf.global_variables())

    # TODO(b/160786085): Move this logic into Overriding vars logic itself,
    # which requires refactoring things out of py_utils to avoid circular deps.
    def _ResolveCkptPath(ckpt_rules):
      return {GetSpecificCheckpoint(k): v for k, v in ckpt_rules.items()}

    self._restore_fns = []

    # Add graph nodes to restore specific variables based on
    # init_from_checkpoint_rules.
    # TODO(b/159267006): Move this back to Restore().
    if self._model:
      for task in self._model.tasks:
        tp = task.params.train
        if tp.init_from_checkpoint_rules:
          rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
          tf.logging.info('OverrideVarsFromCheckpoints %s', rules)
          fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(),
                                                    rules)
          self._restore_fns.append(fn)

    if self._params and self._params.train.init_from_checkpoint_rules:
      tp = self._params.train
      rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
      tf.logging.info('OverrideVarsFromCheckpoints %s', rules)
      fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(), rules)
      self._restore_fns.append(fn)