def __init__(self, train_dir, model, train_params=None, save_only=False): """Initialize Checkpointer. Args: train_dir: Training directory for saving checkpoints. model: A BaseModel instance or None. train_params: If specified, use these training params instead of those in the `model`. save_only: This checkpointer is only intended for saving checkpoints. """ self._train_dir = train_dir self._save_only = save_only self._save_path = os.path.join(self._train_dir, 'ckpt') if train_params: self._train_params = train_params self._model = None else: assert model self._train_params = model.params.train self._model = model if not self._save_only: self._params = model.params self._model_tasks = model.tasks self._model = model self._next_checkpoint_seconds = 0 self._save_interval_seconds = self._train_params.save_interval_seconds self._saver = self._GetSaver() self._uninitialized_vars = tf.report_uninitialized_variables( tf.global_variables())
def __init__(self, train_dir, models, init_op=None, train_params=None, save_only=False): """Initialize Checkpointer. Args: train_dir: Training directory for saving checkpoints. models: One or a list of BaseModel instances. Cannot be empty. If there are more than one models and `train_params` is None, the save intervals will be only determined by the first model. init_op: The initialize variables op. If unset, it will call tf.global_variables_initializer(). train_params: If specified, use these training params instead of those in the `model`. save_only: This checkpointer is only intended for saving checkpoints. """ self._train_dir = train_dir self._save_only = save_only if init_op: self._init_op = init_op else: self._init_op = tf.global_variables_initializer() self._save_path = os.path.join(self._train_dir, 'ckpt') if not isinstance(models, (list, tuple)): models = [models] self._models = models if train_params: self._train_params = train_params else: self._train_params = models[0].params.train self._next_checkpoint_seconds = 0 self._save_interval_seconds = self._train_params.save_interval_seconds self._save_interval_steps = self._train_params.save_interval_steps self._prev_ckpt_step = None self._saver = self._GetSaver() if not py_utils.IsEagerMode(): self._uninitialized_vars = tf.report_uninitialized_variables( tf.global_variables()) self._BuildInitFromCheckpointRules()
def __init__(self, train_dir, model): """Initialize Checkpointer. Args: train_dir: Training directory for saving checkpoints. model: Model. """ self._train_dir = train_dir self._model = model self._params = model.params self._vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) self._uninitialized_vars = tf.report_uninitialized_variables( self._vars) self._initialize_vars = tf.global_variables_initializer() self._save_path = os.path.join(self._train_dir, 'ckpt') self._model_tasks = model.tasks tp = self._params.train self._save_interval_seconds = tp.save_interval_seconds self._next_checkpoint_seconds = 0 self._saver = self._GetSaver()
def __init__(self, train_dir, model, init_op=None, train_params=None, save_only=False): """Initialize Checkpointer. Args: train_dir: Training directory for saving checkpoints. model: A BaseModel instance or None. init_op: The initialize variables op. If unset, it will call tf.global_variables_initializer(). train_params: If specified, use these training params instead of those in the `model`. save_only: This checkpointer is only intended for saving checkpoints. """ self._train_dir = train_dir self._save_only = save_only if init_op: self._init_op = init_op else: self._init_op = tf.global_variables_initializer() self._save_path = os.path.join(self._train_dir, 'ckpt') if train_params: self._train_params = train_params self._model = None else: assert model self._train_params = model.params.train self._model = model if self._save_only: self._params = None else: self._params = model.params self._model_tasks = model.tasks self._model = model self._next_checkpoint_seconds = 0 self._save_interval_seconds = self._train_params.save_interval_seconds self._saver = self._GetSaver() self._uninitialized_vars = tf.report_uninitialized_variables( tf.global_variables()) # TODO(b/160786085): Move this logic into Overriding vars logic itself, # which requires refactoring things out of py_utils to avoid circular deps. def _ResolveCkptPath(ckpt_rules): return {GetSpecificCheckpoint(k): v for k, v in ckpt_rules.items()} self._restore_fns = [] # Add graph nodes to restore specific variables based on # init_from_checkpoint_rules. # TODO(b/159267006): Move this back to Restore(). if self._model: for task in self._model.tasks: tp = task.params.train if tp.init_from_checkpoint_rules: rules = _ResolveCkptPath(tp.init_from_checkpoint_rules) tf.logging.info('OverrideVarsFromCheckpoints %s', rules) fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(), rules) self._restore_fns.append(fn) if self._params and self._params.train.init_from_checkpoint_rules: tp = self._params.train rules = _ResolveCkptPath(tp.init_from_checkpoint_rules) tf.logging.info('OverrideVarsFromCheckpoints %s', rules) fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(), rules) self._restore_fns.append(fn)