def has_checkpoint(self): """ Returns: bool: whether a checkpoint exists in the target directory. """ save_file = os.path.join(self.save_dir, "last_checkpoint") return PathManager.exists(save_file)
def convert_to_coco_json(dataset_name, output_file, allow_cached=True): """ Converts dataset into COCO format and saves it to a json file. dataset_name must be registered in DatasetCatalog and in cvpods's standard format. Args: dataset_name: reference from the config file to the catalogs must be registered in DatasetCatalog and in cvpods's standard format output_file: path of json file that will be saved to allow_cached: if json file is already present then skip conversion """ # TODO: The dataset or the conversion script *may* change, # a checksum would be useful for validating the cached data PathManager.mkdirs(os.path.dirname(output_file)) with file_lock(output_file): if PathManager.exists(output_file) and allow_cached: logger.info( f"Cached annotations in COCO format already exist: {output_file}" ) else: logger.info( f"Converting dataset annotations in '{dataset_name}' to COCO format ...)" ) coco_dict = convert_to_coco_dict(dataset_name) with PathManager.open(output_file, "w") as json_file: logger.info( f"Caching annotations in COCO format: {output_file}") json.dump(coco_dict, json_file)
def get_valid_files(args, cfg, logger): if "MODEL.WEIGHTS" in args.opts: model_weights = cfg.MODEL.WEIGHTS assert PathManager.exists(model_weights), "{} not exist!!!".format( model_weights) return [model_weights] file_list = glob.glob(os.path.join(cfg.OUTPUT_DIR, "model_*.pth")) if len(file_list) == 0: # local file invalid, get it from oss model_prefix = cfg.OUTPUT_DIR.split("cvpods_playground")[-1][1:] remote_file_path = os.path.join(cfg.OSS.DUMP_PREFIX, model_prefix) logger.warning( "No checkpoint file was found locally, try to " f"load the corresponding dump file on OSS site: {remote_file_path}." ) file_list = [ str(filename) for filename in PathManager.ls(remote_file_path) if re.match(r"model_.+\.pth", filename.name) is not None ] assert len(file_list) != 0, "No valid file found on OSS" file_list = filter_by_iters(file_list, args.start_iter, args.end_iter) assert file_list, "No checkpoint valid in {}.".format(cfg.OUTPUT_DIR) logger.info("All files below will be tested in order:\n{}".format( pformat(file_list))) return file_list
def main(args): config.merge_from_list(args.opts) cfg, logger = default_setup(config, args) if args.debug: batches = int(cfg.SOLVER.IMS_PER_BATCH / 8 * args.num_gpus) if cfg.SOLVER.IMS_PER_BATCH != batches: cfg.SOLVER.IMS_PER_BATCH = batches logger.warning( "SOLVER.IMS_PER_BATCH is changed to {}".format(batches)) if "MODEL.WEIGHTS" in args.opts: if cfg.MODEL.WEIGHTS.endswith(".pth") and not PathManager.exists( cfg.MODEL.WEIGHTS): ckpt_name = cfg.MODEL.WEIGHTS.split("/")[-1] model_prefix = cfg.OUTPUT_DIR.split("cvpods_playground")[1][1:] remote_file_path = os.path.join(cfg.OSS.DUMP_PREFIX, model_prefix, ckpt_name) logger.warning( f"The specified ckpt file ({cfg.MODEL.WEIGHTS}) was not found locally," f" try to load the corresponding dump file on OSS ({remote_file_path})." ) cfg.MODEL.WEIGHTS = remote_file_path valid_files = [cfg.MODEL.WEIGHTS] else: list_of_files = glob.glob(os.path.join(cfg.OUTPUT_DIR, '*.pth')) assert list_of_files, "No checkpoint file found in {}.".format( cfg.OUTPUT_DIR) list_of_files.sort(key=os.path.getctime) latest_file = list_of_files[-1] if not args.end_iter: valid_files = [latest_file] else: files = [f for f in list_of_files if str(f) <= str(latest_file)] valid_files = [] for f in files: try: model_iter = int(re.split(r'(model_|\.pth)', f)[-3]) except Exception: logger.warning("remove {}".format(f)) continue if args.start_iter <= model_iter <= args.end_iter: valid_files.append(f) assert valid_files, "No .pth files satisfy your requirement" # * means all if need specific format then *.csv for current_file in valid_files: cfg.MODEL.WEIGHTS = current_file model = build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume) res = Trainer.test(cfg, model) if comm.is_main_process(): verify_results(cfg, res) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model))