def get_all_checkpoint_files(self): """ Returns: list: All available checkpoint files (.pth files) in target directory. """ all_model_checkpoints = [ os.path.join(self.save_dir, file) for file in PathManager.ls(self.save_dir) if PathManager.isfile(os.path.join(self.save_dir, file)) and file.endswith(".pth") ] return all_model_checkpoints
def setup_cfg(args): # load config from file and command-line arguments cfg = get_config(args.config, None) cfg.merge_from_list(args.opts) if cfg.MODEL.WEIGHTS == "": oss_prefix = os.path.join(cfg.OSS.MODEL_PREFIX, "model_zoo") file_path = os.path.join(oss_prefix, args.config, "model_final.pth") logger.warning( f"No checkpoint file specified, " f"trying to get it from Model Zoo (OSS URI: {file_path}).") assert PathManager.isfile( file_path), f"No checkpoint file found in {file_path}." cfg.MODEL.WEIGHTS = file_path return cfg
def load(self, path: str): """ Load from the given checkpoint. When path points to network file, this function has to be called on all ranks. Args: path (str): path or url to the checkpoint. If empty, will not load anything. Returns: dict: extra data loaded from the checkpoint that has not been processed. For example, those saved with :meth:`.save(**extra_data)`. """ if not path: # no checkpoint provided self.logger.info( "No checkpoint found. Initializing model from scratch") return {} self.logger.info("Loading checkpoint from {}".format(path)) if not os.path.isfile(path): path = PathManager.get_local_path(path) assert PathManager.isfile(path), "Checkpoint {} not found!".format( path) checkpoint = self._load_file(path) self._load_model(checkpoint) if self.resume: for key, obj in self.checkpointables.items(): if key in checkpoint: self.logger.info("Loading {} from {}".format(key, path)) obj.load_state_dict(checkpoint.pop(key)) # return any further checkpoint data return checkpoint else: return {}
def main(args): config.merge_from_list(args.opts) cfg, logger = default_setup(config, args) if args.debug: batches = int(cfg.SOLVER.IMS_PER_BATCH / 8 * args.num_gpus) if cfg.SOLVER.IMS_PER_BATCH != batches: cfg.SOLVER.IMS_PER_BATCH = batches logger.warning( "SOLVER.IMS_PER_BATCH is changed to {}".format(batches)) if "MODEL.WEIGHTS" in args.opts: if cfg.MODEL.WEIGHTS.endswith(".pth") and not PathManager.exists( cfg.MODEL.WEIGHTS): ckpt_name = cfg.MODEL.WEIGHTS.split("/")[-1] model_prefix = cfg.OUTPUT_DIR.split("cvpods_playground")[1][1:] remote_file_path = os.path.join(cfg.OSS.DUMP_PREFIX, model_prefix, ckpt_name) logger.warning( f"The specified ckpt file ({cfg.MODEL.WEIGHTS}) was not found locally," f" try to load the corresponding dump file on OSS ({remote_file_path})." ) cfg.MODEL.WEIGHTS = remote_file_path valid_files = [cfg.MODEL.WEIGHTS] else: list_of_files = glob.glob(os.path.join(cfg.OUTPUT_DIR, '*.pth')) if len(list_of_files) == 0: oss_prefix = os.path.join(cfg.OSS.MODEL_PREFIX, "model_zoo") model_prefix = '/'.join(os.getcwd().split('/')[-5:]) file_path = os.path.join(oss_prefix, model_prefix, "model_final.pth") logger.warning( f"No checkpoint file found in the local log path ({cfg.OUTPUT_DIR}), " f"trying to get it from Model Zoo (OSS URI: {file_path}).") assert PathManager.isfile( file_path), f"No checkpoint file found in {file_path}." valid_files = [file_path] else: assert list_of_files, "No checkpoint file found in {}.".format( cfg.OUTPUT_DIR) list_of_files.sort(key=os.path.getctime) latest_file = list_of_files[-1] if not args.end_iter: valid_files = [latest_file] else: files = [ f for f in list_of_files if str(f) <= str(latest_file) ] valid_files = [] for f in files: try: model_iter = int(re.split(r'(model_|\.pth)', f)[-3]) except Exception: logger.warning("remove {}".format(f)) continue if args.start_iter <= model_iter <= args.end_iter: valid_files.append(f) assert valid_files, "No .pth files satisfy your requirement" # * means all if need specific format then *.csv for current_file in valid_files: cfg.MODEL.WEIGHTS = current_file model = build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) if cfg.TEST.AUG.ENABLED: res = Trainer.test_with_TTA(cfg, model) else: res = Trainer.test(cfg, model) if comm.is_main_process(): verify_results(cfg, res)