def load_weights(self, log_dir=None, type='latest', **kwargs): if not log_dir: log_dir = self.log_dir if type == 'latest': init_checkpoint_path = tfutils.get_latest_model_checkpoint_path( log_dir, 'model.ckpt') elif type == 'best_dice': init_checkpoint_path = tfutils.get_latest_model_checkpoint_path( log_dir, 'model_best_dice.ckpt') elif type == 'best_loss': init_checkpoint_path = tfutils.get_latest_model_checkpoint_path( log_dir, 'model_best_loss.ckpt') elif type == 'best_ged': init_checkpoint_path = tfutils.get_latest_model_checkpoint_path( log_dir, 'model_best_ged.ckpt') elif type == 'iter': assert 'iteration' in kwargs, "argument 'iteration' must be provided for type='iter'" iteration = kwargs['iteration'] init_checkpoint_path = os.path.join(log_dir, 'model.ckpt-%d' % iteration) else: raise ValueError( 'Argument type=%s is unknown. type can be latest/iter.' % type) self.saver.restore(self.sess, init_checkpoint_path)
def load_weights(self, log_dir=None, type='latest', **kwargs): """ Load weights into the model :param log_dir: experiment directory into which all the checkpoints have been written :param type: can be 'latest', 'best_wasserstein' (highest validation Wasserstein distance), or 'iter' (specific iteration, requires passing the iteration argument with a valid step number from the checkpoint files) """ if not log_dir: log_dir = self.log_dir if type == 'latest': init_checkpoint_path = tf_utils.get_latest_model_checkpoint_path( log_dir, 'model.ckpt') elif type == 'best_wasserstein': init_checkpoint_path = tf_utils.get_latest_model_checkpoint_path( log_dir, 'model_best_wasserstein.ckpt') elif type == 'iter': assert 'iteration' in kwargs, "argument 'iteration' must be provided for type='iter'" iteration = kwargs['iteration'] init_checkpoint_path = os.path.join(log_dir, 'model.ckpt-%d' % iteration) else: raise ValueError( 'Argument type=%s is unknown. type can be latest/best_wasserstein/iter.' % type) self.saver.restore(self.sess, init_checkpoint_path)
def _setup_log_dir_and_continue_mode(self): # Default values self.log_dir = os.path.join(sys_config.log_root, 'classifier', self.exp_config.experiment_name) self.init_checkpoint_path = None self.continue_run = False self.init_step = 0 # If a checkpoint file already exists enable continue mode if tf.gfile.Exists(self.log_dir): init_checkpoint_path = tf_utils.get_latest_model_checkpoint_path(self.log_dir, 'model.ckpt') if init_checkpoint_path is not False: self.init_checkpoint_path = init_checkpoint_path self.continue_run = True self.init_step = int(self.init_checkpoint_path.split('/')[-1].split('-')[-1]) self.log_dir += '_cont' logging.info('--------------------------- Continuing previous run --------------------------------') logging.info('Checkpoint path: %s' % self.init_checkpoint_path) logging.info('Latest step was: %d' % self.init_step) logging.info('------------------------------------------------------------------------------------') tf.gfile.MakeDirs(self.log_dir) self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) # Copy experiment config file to log_dir for future reference shutil.copy(self.exp_config.__file__, self.log_dir)