def RestoreIfNeeded(self, sess): """If vars are not initialized, restore frome checkpoint. Args: sess: tf.Session. """ assert not self._save_only uninitialized_var_names = list(sess.run(self._uninitialized_vars)) if not uninitialized_var_names: return tf.logging.info('Uninitialized var list: %s ', uninitialized_var_names) if self._Restore(sess): return if (not any(task.params.train.init_from_checkpoint_rules for task in self._model_tasks) and not self._params.train.init_from_checkpoint_rules): tf.logging.info('Initialize ALL variables: %s', uninitialized_var_names) sess.run([self._initialize_vars]) tf.logging.info('Initialize variables done.') return # There was a race in local run. Another thread will get unblocked once # _initialize_all is called. OverrideVarsFromCheckpoints # might not happen at the right time. for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints( sess, self._vars, tp.init_from_checkpoint_rules) if self._params.train.init_from_checkpoint_rules: tp = self._params.train tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints(sess, self._vars, tp.init_from_checkpoint_rules) uninitialized_var_names = list(sess.run(self._uninitialized_vars)) if not uninitialized_var_names: return # uninitialized_var_names is a list of strings without ":0" suffix. # tf.report_uninitialized_variables returns binary strings. assert all( isinstance(s, six.binary_type) for s in uninitialized_var_names) # Need to retrieve vars, removing ":0" suffix from names. uninitialized_vars = [ v for v in self._vars if six.ensure_binary(v.name[:-2]) in uninitialized_var_names ] tf.logging.info('Initialize variables: %s', [v.name for v in uninitialized_vars]) sess.run(tf.variables_initializer(uninitialized_vars))
def _RestoreIfNeeded(self, sess): uninitialized_var_names = list(sess.run(self._uninitialized)) if not uninitialized_var_names: return tf.logging.info('Uninitialized var list: %s ', uninitialized_var_names) path = tf.train.latest_checkpoint(self._train_dir) if path: tf.logging.info('Load from checkpoint %s.', path) self._saver.restore(sess, path) tf.logging.info('Load checkpoint done.') return if (not any(task.params.train.init_from_checkpoint_rules for task in self._model.tasks) and not self._params.train.init_from_checkpoint_rules): tf.logging.info('Initialize ALL variables: %s', uninitialized_var_names) sess.run([self._initialize_all]) tf.logging.info('Initialize variables done.') return # There was a race in local run. Another thread will get unblocked once # _initialize_all is called. OverrideVarsFromCheckpoints # might not happen at the right time. for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints(sess, self._vars, tp.init_from_checkpoint_rules) if self._params.train.init_from_checkpoint_rules: tp = self._params.train tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints(sess, self._vars, tp.init_from_checkpoint_rules) uninitialized_var_names = list(sess.run(self._uninitialized)) if not uninitialized_var_names: return # uninitialized_var_names is a list of strings without ":0" suffix. assert all(isinstance(s, str) for s in uninitialized_var_names) # Need to retrieve vars, removing ":0" suffix from names. uninitialized_vars = [ v for v in self._vars if v.name[:-2] in uninitialized_var_names ] tf.logging.info('Initialize variables: %s', [v.name for v in uninitialized_vars]) sess.run(tf.variables_initializer(uninitialized_vars))
def Restore(self, sess, force_reinitialize=False): """Restore from latest checkpoint if available, or initialize.""" # Try and restore from the latest checkpoint. if self._RestoreFromLatestCheckpoint(sess): # Successfully restored from checkpoint. uninitialized_var_names = self._GetUninitializedVarNames(sess) assert not uninitialized_var_names, uninitialized_var_names return # Otherwise we need to initialize. uninitialized_var_names = self._GetUninitializedVarNames(sess) tf.logging.info('Uninitialized var list: %s', uninitialized_var_names) if not force_reinitialize: # There should only be uninitialized variables if all variables are # uninitialized - with the exception of global_step due to # RestoreGlobalStepIfNeeded in the _LoopEnqueue of TrainerTpu. all_var_names = [ six.ensure_binary(v.name[:-2]) for v in tf.global_variables() ] already_initialized_vars = (set(all_var_names) - set(uninitialized_var_names)) already_initialized_vars.discard(b'global_step') assert not already_initialized_vars, ( 'Already initialized vars: %s' % sorted(already_initialized_vars)) # At this point all variables are uninitialized, so it is safe to run a # global initializer. sess.run(self._init_op) tf.logging.info('Initialized all vars.') # TODO(b/160786085): Move this logic into Overriding vars logic itself, # which requires refactoring things out of py_utils to avoid circular deps. def _ResolveCkptPath(ckpt_rules): return {GetSpecificCheckpoint(k): v for k, v in ckpt_rules.items()} # Restore specific variables based on init_from_checkpoint_rules. for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: rules = _ResolveCkptPath(tp.init_from_checkpoint_rules) tf.logging.info('OverrideVarsFromCheckpoints %s', rules) py_utils.OverrideVarsFromCheckpoints(sess, tf.global_variables(), rules) if self._params.train.init_from_checkpoint_rules: tp = self._params.train rules = _ResolveCkptPath(tp.init_from_checkpoint_rules) tf.logging.info('OverrideVarsFromCheckpoints %s', rules) py_utils.OverrideVarsFromCheckpoints(sess, tf.global_variables(), rules)
def Restore(self, sess, force_reinitialize=False): """Restore from latest checkpoint if available, or initialize.""" # Try and restore from the latest checkpoint. if self._RestoreFromLatestCheckpoint(sess): # Successfully restored from checkpoint. uninitialized_var_names = self._GetUninitializedVarNames(sess) assert not uninitialized_var_names, uninitialized_var_names return # Otherwise we need to initialize. uninitialized_var_names = self._GetUninitializedVarNames(sess) tf.logging.info('Uninitialized var list: %s', uninitialized_var_names) if not force_reinitialize: # There should only be uninitialized variables if all variables are # uninitialized - with the exception of global_step due to # RestoreGlobalStepIfNeeded in the _LoopEnqueue of TrainerTpu. all_var_names = [ six.ensure_binary(v.name[:-2]) for v in tf.global_variables() ] already_initialized_vars = (set(all_var_names) - set(uninitialized_var_names)) already_initialized_vars.discard(b'global_step') assert not already_initialized_vars, ( 'Already initialized vars: %s' % sorted(already_initialized_vars)) # At this point all variables are uninitialized, so it is safe to run a # global initializer. sess.run(tf.global_variables_initializer()) tf.logging.info('Initialized all vars.') # Restore specific variables based on init_from_checkpoint_rules. for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints( sess, tf.global_variables(), tp.init_from_checkpoint_rules) if self._params.train.init_from_checkpoint_rules: tp = self._params.train tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints(sess, tf.global_variables(), tp.init_from_checkpoint_rules)
def _BuildInitFromCheckpointRules(self): """Build restore fns for init_from_checkpoint_rules.""" self._restore_fns = [] all_vars = list(_GetSaveableVariablesDict(self._models).values()) # TODO(b/160786085): Move this logic into Overriding vars logic itself, # which requires refactoring things out of py_utils to avoid circular deps. def _ResolveCkptPath(ckpt_rules): res_rules = {} for k, v in ckpt_rules.items(): new_k = GetSpecificCheckpoint(k) if not new_k: tf.logging.warning( f'Empty checkpoint path init rules are ignored, key={k}' ) else: res_rules.update({new_k: v}) return res_rules def _MergeLoadingRules(a, b): res = copy.deepcopy(a) for k, (load_rules, ignore_rules) in b.items(): if k in res: res_load, res_ignore = res[k] for load in load_rules: if load not in res_load: res_load.append(load) for ignore in ignore_rules: if ignore not in res_ignore: res_ignore.append(ignore) else: res[k] = (load_rules, ignore_rules) return res # Restore specific variables based on init_from_checkpoint_rules. rules = {} for model in self._models: for task in model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: rules = _MergeLoadingRules( rules, _ResolveCkptPath(tp.init_from_checkpoint_rules)) if self._train_params.init_from_checkpoint_rules: rules = _MergeLoadingRules( rules, _ResolveCkptPath( self._train_params.init_from_checkpoint_rules)) # Add graph nodes to restore specific variables based on # init_from_checkpoint_rules. # TODO(b/159267006): Move this back to Restore(). self._restore_fns.append( (f'OverrideVarsFromCheckpoints {rules}', py_utils.OverrideVarsFromCheckpoints(all_vars, rules)))
def __init__(self, train_dir, model, init_op=None, train_params=None, save_only=False): """Initialize Checkpointer. Args: train_dir: Training directory for saving checkpoints. model: A BaseModel instance or None. init_op: The initialize variables op. If unset, it will call tf.global_variables_initializer(). train_params: If specified, use these training params instead of those in the `model`. save_only: This checkpointer is only intended for saving checkpoints. """ self._train_dir = train_dir self._save_only = save_only if init_op: self._init_op = init_op else: self._init_op = tf.global_variables_initializer() self._save_path = os.path.join(self._train_dir, 'ckpt') if train_params: self._train_params = train_params self._model = None else: assert model self._train_params = model.params.train self._model = model if self._save_only: self._params = None else: self._params = model.params self._model_tasks = model.tasks self._model = model self._next_checkpoint_seconds = 0 self._save_interval_seconds = self._train_params.save_interval_seconds self._saver = self._GetSaver() self._uninitialized_vars = tf.report_uninitialized_variables( tf.global_variables()) # TODO(b/160786085): Move this logic into Overriding vars logic itself, # which requires refactoring things out of py_utils to avoid circular deps. def _ResolveCkptPath(ckpt_rules): return {GetSpecificCheckpoint(k): v for k, v in ckpt_rules.items()} self._restore_fns = [] # Add graph nodes to restore specific variables based on # init_from_checkpoint_rules. # TODO(b/159267006): Move this back to Restore(). if self._model: for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: rules = _ResolveCkptPath(tp.init_from_checkpoint_rules) tf.logging.info('OverrideVarsFromCheckpoints %s', rules) fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(), rules) self._restore_fns.append(fn) if self._params and self._params.train.init_from_checkpoint_rules: tp = self._params.train rules = _ResolveCkptPath(tp.init_from_checkpoint_rules) tf.logging.info('OverrideVarsFromCheckpoints %s', rules) fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(), rules) self._restore_fns.append(fn)
def RestoreIfNeeded(self, sess): """If vars are not initialized, restore from checkpoint.""" assert not self._save_only uninitialized_var_names = self.GetUninitializedVars(sess) # uninitialized_var_names is a list of strings without ":0" suffix. # tf.report_uninitialized_variables returns binary strings. assert all( isinstance(s, six.binary_type) for s in uninitialized_var_names) if not uninitialized_var_names: # All variables are already initialized. return tf.logging.info('Uninitialized var list: %s', uninitialized_var_names) # There should only be uninitialized variables if all variables are # uninitialized. all_var_names = [ six.ensure_binary(v.name[:-2]) for v in tf.global_variables() ] assert (set(uninitialized_var_names) == set(all_var_names) ), sorted(set(all_var_names) - set(uninitialized_var_names)) if self._Restore(sess): # Successfully restored from checkpoint. uninitialized_var_names = self.GetUninitializedVars(sess) assert not uninitialized_var_names, uninitialized_var_names return if (self._params.train.init_from_checkpoint_rules or any(task.params.train.init_from_checkpoint_rules for task in self._model_tasks)): for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints( sess, tf.global_variables(), tp.init_from_checkpoint_rules) if self._params.train.init_from_checkpoint_rules: tp = self._params.train tf.logging.info('OverrideVarsFromCheckpoints %s', tp.init_from_checkpoint_rules) py_utils.OverrideVarsFromCheckpoints( sess, tf.global_variables(), tp.init_from_checkpoint_rules) uninitialized_var_names = self.GetUninitializedVars(sess) if not uninitialized_var_names: return tf.logging.info('Remaining uninitialized vars: %s', uninitialized_var_names) # Need to retrieve vars, removing ":0" suffix from names. uninitialized_vars = [ v for v in tf.global_variables() if six.ensure_binary(v.name[:-2]) in uninitialized_var_names ] tf.logging.info('Initialize variables: %s', sorted([v.name[:-2] for v in uninitialized_vars])) sess.run(tf.variables_initializer(uninitialized_vars))