コード例 #1
0
    def RestoreIfNeeded(self, sess):
        """If vars are not initialized, restore frome checkpoint.

    Args:
      sess: tf.Session.
    """
        assert not self._save_only
        uninitialized_var_names = list(sess.run(self._uninitialized_vars))
        if not uninitialized_var_names:
            return

        tf.logging.info('Uninitialized var list: %s ', uninitialized_var_names)
        if self._Restore(sess):
            return

        if (not any(task.params.train.init_from_checkpoint_rules
                    for task in self._model_tasks)
                and not self._params.train.init_from_checkpoint_rules):
            tf.logging.info('Initialize ALL variables: %s',
                            uninitialized_var_names)
            sess.run([self._initialize_vars])
            tf.logging.info('Initialize variables done.')
            return

        # There was a race in local run. Another thread will get unblocked once
        # _initialize_all is called. OverrideVarsFromCheckpoints
        # might not happen at the right time.
        for task in self._model.tasks:
            tp = task.params.train
            if tp.init_from_checkpoint_rules:
                tf.logging.info('OverrideVarsFromCheckpoints %s',
                                tp.init_from_checkpoint_rules)
                py_utils.OverrideVarsFromCheckpoints(
                    sess, self._vars, tp.init_from_checkpoint_rules)

        if self._params.train.init_from_checkpoint_rules:
            tp = self._params.train
            tf.logging.info('OverrideVarsFromCheckpoints %s',
                            tp.init_from_checkpoint_rules)
            py_utils.OverrideVarsFromCheckpoints(sess, self._vars,
                                                 tp.init_from_checkpoint_rules)

        uninitialized_var_names = list(sess.run(self._uninitialized_vars))
        if not uninitialized_var_names:
            return

        # uninitialized_var_names is a list of strings without ":0" suffix.
        # tf.report_uninitialized_variables returns binary strings.
        assert all(
            isinstance(s, six.binary_type) for s in uninitialized_var_names)

        # Need to retrieve vars, removing ":0" suffix from names.
        uninitialized_vars = [
            v for v in self._vars
            if six.ensure_binary(v.name[:-2]) in uninitialized_var_names
        ]
        tf.logging.info('Initialize variables: %s',
                        [v.name for v in uninitialized_vars])
        sess.run(tf.variables_initializer(uninitialized_vars))
コード例 #2
0
ファイル: trainer.py プロジェクト: j-luo93/lingvo
  def _RestoreIfNeeded(self, sess):
    uninitialized_var_names = list(sess.run(self._uninitialized))
    if not uninitialized_var_names:
      return

    tf.logging.info('Uninitialized var list: %s ', uninitialized_var_names)
    path = tf.train.latest_checkpoint(self._train_dir)
    if path:
      tf.logging.info('Load from checkpoint %s.', path)
      self._saver.restore(sess, path)
      tf.logging.info('Load checkpoint done.')
      return

    if (not any(task.params.train.init_from_checkpoint_rules
                for task in self._model.tasks) and
        not self._params.train.init_from_checkpoint_rules):
      tf.logging.info('Initialize ALL variables: %s', uninitialized_var_names)
      sess.run([self._initialize_all])
      tf.logging.info('Initialize variables done.')
      return

    # There was a race in local run. Another thread will get unblocked once
    # _initialize_all is called. OverrideVarsFromCheckpoints
    # might not happen at the right time.
    for task in self._model.tasks:
      tp = task.params.train
      if tp.init_from_checkpoint_rules:
        tf.logging.info('OverrideVarsFromCheckpoints %s',
                        tp.init_from_checkpoint_rules)
        py_utils.OverrideVarsFromCheckpoints(sess, self._vars,
                                             tp.init_from_checkpoint_rules)

    if self._params.train.init_from_checkpoint_rules:
      tp = self._params.train
      tf.logging.info('OverrideVarsFromCheckpoints %s',
                      tp.init_from_checkpoint_rules)
      py_utils.OverrideVarsFromCheckpoints(sess, self._vars,
                                           tp.init_from_checkpoint_rules)

    uninitialized_var_names = list(sess.run(self._uninitialized))
    if not uninitialized_var_names:
      return

    # uninitialized_var_names is a list of strings without ":0" suffix.
    assert all(isinstance(s, str) for s in uninitialized_var_names)

    # Need to retrieve vars, removing ":0" suffix from names.
    uninitialized_vars = [
        v for v in self._vars if v.name[:-2] in uninitialized_var_names
    ]
    tf.logging.info('Initialize variables: %s',
                    [v.name for v in uninitialized_vars])
    sess.run(tf.variables_initializer(uninitialized_vars))
コード例 #3
0
    def Restore(self, sess, force_reinitialize=False):
        """Restore from latest checkpoint if available, or initialize."""
        # Try and restore from the latest checkpoint.
        if self._RestoreFromLatestCheckpoint(sess):
            # Successfully restored from checkpoint.
            uninitialized_var_names = self._GetUninitializedVarNames(sess)
            assert not uninitialized_var_names, uninitialized_var_names
            return

        # Otherwise we need to initialize.
        uninitialized_var_names = self._GetUninitializedVarNames(sess)
        tf.logging.info('Uninitialized var list: %s', uninitialized_var_names)
        if not force_reinitialize:
            # There should only be uninitialized variables if all variables are
            # uninitialized - with the exception of global_step due to
            # RestoreGlobalStepIfNeeded in the _LoopEnqueue of TrainerTpu.
            all_var_names = [
                six.ensure_binary(v.name[:-2]) for v in tf.global_variables()
            ]
            already_initialized_vars = (set(all_var_names) -
                                        set(uninitialized_var_names))
            already_initialized_vars.discard(b'global_step')
            assert not already_initialized_vars, (
                'Already initialized vars: %s' %
                sorted(already_initialized_vars))

        # At this point all variables are uninitialized, so it is safe to run a
        # global initializer.
        sess.run(self._init_op)
        tf.logging.info('Initialized all vars.')

        # TODO(b/160786085): Move this logic into Overriding vars logic itself,
        # which requires refactoring things out of py_utils to avoid circular deps.
        def _ResolveCkptPath(ckpt_rules):
            return {GetSpecificCheckpoint(k): v for k, v in ckpt_rules.items()}

        # Restore specific variables based on init_from_checkpoint_rules.
        for task in self._model.tasks:
            tp = task.params.train
            if tp.init_from_checkpoint_rules:
                rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
                tf.logging.info('OverrideVarsFromCheckpoints %s', rules)
                py_utils.OverrideVarsFromCheckpoints(sess,
                                                     tf.global_variables(),
                                                     rules)

        if self._params.train.init_from_checkpoint_rules:
            tp = self._params.train
            rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
            tf.logging.info('OverrideVarsFromCheckpoints %s', rules)
            py_utils.OverrideVarsFromCheckpoints(sess, tf.global_variables(),
                                                 rules)
コード例 #4
0
    def Restore(self, sess, force_reinitialize=False):
        """Restore from latest checkpoint if available, or initialize."""
        # Try and restore from the latest checkpoint.
        if self._RestoreFromLatestCheckpoint(sess):
            # Successfully restored from checkpoint.
            uninitialized_var_names = self._GetUninitializedVarNames(sess)
            assert not uninitialized_var_names, uninitialized_var_names
            return

        # Otherwise we need to initialize.
        uninitialized_var_names = self._GetUninitializedVarNames(sess)
        tf.logging.info('Uninitialized var list: %s', uninitialized_var_names)
        if not force_reinitialize:
            # There should only be uninitialized variables if all variables are
            # uninitialized - with the exception of global_step due to
            # RestoreGlobalStepIfNeeded in the _LoopEnqueue of TrainerTpu.
            all_var_names = [
                six.ensure_binary(v.name[:-2]) for v in tf.global_variables()
            ]
            already_initialized_vars = (set(all_var_names) -
                                        set(uninitialized_var_names))
            already_initialized_vars.discard(b'global_step')
            assert not already_initialized_vars, (
                'Already initialized vars: %s' %
                sorted(already_initialized_vars))

        # At this point all variables are uninitialized, so it is safe to run a
        # global initializer.
        sess.run(tf.global_variables_initializer())
        tf.logging.info('Initialized all vars.')

        # Restore specific variables based on init_from_checkpoint_rules.
        for task in self._model.tasks:
            tp = task.params.train
            if tp.init_from_checkpoint_rules:
                tf.logging.info('OverrideVarsFromCheckpoints %s',
                                tp.init_from_checkpoint_rules)
                py_utils.OverrideVarsFromCheckpoints(
                    sess, tf.global_variables(), tp.init_from_checkpoint_rules)

        if self._params.train.init_from_checkpoint_rules:
            tp = self._params.train
            tf.logging.info('OverrideVarsFromCheckpoints %s',
                            tp.init_from_checkpoint_rules)
            py_utils.OverrideVarsFromCheckpoints(sess, tf.global_variables(),
                                                 tp.init_from_checkpoint_rules)
コード例 #5
0
ファイル: checkpointer.py プロジェクト: huaxz1986/lingvo
    def _BuildInitFromCheckpointRules(self):
        """Build restore fns for init_from_checkpoint_rules."""
        self._restore_fns = []
        all_vars = list(_GetSaveableVariablesDict(self._models).values())

        # TODO(b/160786085): Move this logic into Overriding vars logic itself,
        # which requires refactoring things out of py_utils to avoid circular deps.
        def _ResolveCkptPath(ckpt_rules):
            res_rules = {}
            for k, v in ckpt_rules.items():
                new_k = GetSpecificCheckpoint(k)
                if not new_k:
                    tf.logging.warning(
                        f'Empty checkpoint path init rules are ignored, key={k}'
                    )
                else:
                    res_rules.update({new_k: v})
            return res_rules

        def _MergeLoadingRules(a, b):
            res = copy.deepcopy(a)
            for k, (load_rules, ignore_rules) in b.items():
                if k in res:
                    res_load, res_ignore = res[k]
                    for load in load_rules:
                        if load not in res_load:
                            res_load.append(load)
                    for ignore in ignore_rules:
                        if ignore not in res_ignore:
                            res_ignore.append(ignore)
                else:
                    res[k] = (load_rules, ignore_rules)
            return res

        # Restore specific variables based on init_from_checkpoint_rules.
        rules = {}
        for model in self._models:
            for task in model.tasks:
                tp = task.params.train
                if tp.init_from_checkpoint_rules:
                    rules = _MergeLoadingRules(
                        rules, _ResolveCkptPath(tp.init_from_checkpoint_rules))

        if self._train_params.init_from_checkpoint_rules:
            rules = _MergeLoadingRules(
                rules,
                _ResolveCkptPath(
                    self._train_params.init_from_checkpoint_rules))

        # Add graph nodes to restore specific variables based on
        # init_from_checkpoint_rules.
        # TODO(b/159267006): Move this back to Restore().
        self._restore_fns.append(
            (f'OverrideVarsFromCheckpoints {rules}',
             py_utils.OverrideVarsFromCheckpoints(all_vars, rules)))
コード例 #6
0
  def __init__(self,
               train_dir,
               model,
               init_op=None,
               train_params=None,
               save_only=False):
    """Initialize Checkpointer.

    Args:
     train_dir: Training directory for saving checkpoints.
     model: A BaseModel instance or None.
     init_op: The initialize variables op. If unset, it will call
       tf.global_variables_initializer().
     train_params: If specified, use these training params instead of those in
       the `model`.
     save_only: This checkpointer is only intended for saving checkpoints.
    """
    self._train_dir = train_dir
    self._save_only = save_only
    if init_op:
      self._init_op = init_op
    else:
      self._init_op = tf.global_variables_initializer()

    self._save_path = os.path.join(self._train_dir, 'ckpt')

    if train_params:
      self._train_params = train_params
      self._model = None
    else:
      assert model
      self._train_params = model.params.train
      self._model = model

    if self._save_only:
      self._params = None
    else:
      self._params = model.params
      self._model_tasks = model.tasks
      self._model = model

    self._next_checkpoint_seconds = 0
    self._save_interval_seconds = self._train_params.save_interval_seconds
    self._saver = self._GetSaver()

    self._uninitialized_vars = tf.report_uninitialized_variables(
        tf.global_variables())

    # TODO(b/160786085): Move this logic into Overriding vars logic itself,
    # which requires refactoring things out of py_utils to avoid circular deps.
    def _ResolveCkptPath(ckpt_rules):
      return {GetSpecificCheckpoint(k): v for k, v in ckpt_rules.items()}

    self._restore_fns = []

    # Add graph nodes to restore specific variables based on
    # init_from_checkpoint_rules.
    # TODO(b/159267006): Move this back to Restore().
    if self._model:
      for task in self._model.tasks:
        tp = task.params.train
        if tp.init_from_checkpoint_rules:
          rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
          tf.logging.info('OverrideVarsFromCheckpoints %s', rules)
          fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(),
                                                    rules)
          self._restore_fns.append(fn)

    if self._params and self._params.train.init_from_checkpoint_rules:
      tp = self._params.train
      rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
      tf.logging.info('OverrideVarsFromCheckpoints %s', rules)
      fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(), rules)
      self._restore_fns.append(fn)
コード例 #7
0
    def RestoreIfNeeded(self, sess):
        """If vars are not initialized, restore from checkpoint."""
        assert not self._save_only
        uninitialized_var_names = self.GetUninitializedVars(sess)
        # uninitialized_var_names is a list of strings without ":0" suffix.
        # tf.report_uninitialized_variables returns binary strings.
        assert all(
            isinstance(s, six.binary_type) for s in uninitialized_var_names)
        if not uninitialized_var_names:
            # All variables are already initialized.
            return

        tf.logging.info('Uninitialized var list: %s', uninitialized_var_names)

        # There should only be uninitialized variables if all variables are
        # uninitialized.
        all_var_names = [
            six.ensure_binary(v.name[:-2]) for v in tf.global_variables()
        ]
        assert (set(uninitialized_var_names) == set(all_var_names)
                ), sorted(set(all_var_names) - set(uninitialized_var_names))

        if self._Restore(sess):
            # Successfully restored from checkpoint.
            uninitialized_var_names = self.GetUninitializedVars(sess)
            assert not uninitialized_var_names, uninitialized_var_names
            return

        if (self._params.train.init_from_checkpoint_rules
                or any(task.params.train.init_from_checkpoint_rules
                       for task in self._model_tasks)):
            for task in self._model.tasks:
                tp = task.params.train
                if tp.init_from_checkpoint_rules:
                    tf.logging.info('OverrideVarsFromCheckpoints %s',
                                    tp.init_from_checkpoint_rules)
                    py_utils.OverrideVarsFromCheckpoints(
                        sess, tf.global_variables(),
                        tp.init_from_checkpoint_rules)

            if self._params.train.init_from_checkpoint_rules:
                tp = self._params.train
                tf.logging.info('OverrideVarsFromCheckpoints %s',
                                tp.init_from_checkpoint_rules)
                py_utils.OverrideVarsFromCheckpoints(
                    sess, tf.global_variables(), tp.init_from_checkpoint_rules)

            uninitialized_var_names = self.GetUninitializedVars(sess)
            if not uninitialized_var_names:
                return

            tf.logging.info('Remaining uninitialized vars: %s',
                            uninitialized_var_names)

        # Need to retrieve vars, removing ":0" suffix from names.
        uninitialized_vars = [
            v for v in tf.global_variables()
            if six.ensure_binary(v.name[:-2]) in uninitialized_var_names
        ]
        tf.logging.info('Initialize variables: %s',
                        sorted([v.name[:-2] for v in uninitialized_vars]))
        sess.run(tf.variables_initializer(uninitialized_vars))