def main(): parser = argparse.ArgumentParser() parser.add_argument("--local_rank", type=int, default=0) parser.add_argument("--iter", "-i", type=int, default=-1) args = parser.parse_args() num_gpus = int( os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 distributed = num_gpus > 1 if distributed: torch.cuda.set_device(args.local_rank) dist.init_process_group(backend="nccl", init_method="env://") synchronize() if is_main_process() and not os.path.exists(cfg.TEST_DIR): os.mkdir(cfg.TEST_DIR) logger = get_logger(cfg.DATASET.NAME, cfg.TEST_DIR, args.local_rank, 'test_log.txt') if args.iter == -1: logger.info("Please designate one iteration.") model = MSPN(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(cfg.MODEL.DEVICE) model_file = os.path.join(cfg.OUTPUT_DIR, "iter-{}.pth".format(args.iter)) if os.path.exists(model_file): state_dict = torch.load(model_file, map_location=lambda storage, loc: storage) state_dict = state_dict['model'] model.load_state_dict(state_dict) data_loader = get_test_loader(cfg, num_gpus, args.local_rank, 'val', is_dist=distributed) results = inference(model, data_loader, logger, device) synchronize() if is_main_process(): logger.info("Dumping results ...") results.sort(key=lambda res: (res['image_id'], res['score']), reverse=True) results_path = os.path.join(cfg.TEST_DIR, 'results.json') with open(results_path, 'w') as f: json.dump(results, f) logger.info("Get all results.") data_loader.ori_dataset.evaluate(results_path)
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--test_mode", "-t", type=str, default="run_inference", choices=['generate_train', 'generate_result', 'run_inference'], help= 'Type of test. One of "generate_train": generate refineNet datasets, ' '"generate_result": save inference result and groundtruth, ' '"run_inference": save inference result for input images.') parser.add_argument( "--data_mode", "-d", type=str, default="test", choices=['test', 'generation'], help= 'Only used for "generate_train" test_mode, "generation" for refineNet train dataset,' '"test" for refineNet test dataset.') parser.add_argument("--SMAP_path", "-p", type=str, default='log/SMAP.pth', help='Path to SMAP model') parser.add_argument( "--RefineNet_path", "-rp", type=str, default='', help='Path to RefineNet model, empty means without RefineNet') parser.add_argument("--batch_size", type=int, default=1, help='Batch_size of test') parser.add_argument("--do_flip", type=float, default=0, help='Set to 1 if do flip when test') parser.add_argument("--dataset_path", type=str, default="", help='Image dir path of "run_inference" test mode') parser.add_argument("--json_name", type=str, default="", help='Add a suffix to the result json.') args = parser.parse_args() cfg.TEST_MODE = args.test_mode cfg.DATA_MODE = args.data_mode cfg.REFINE = len(args.RefineNet_path) > 0 cfg.DO_FLIP = args.do_flip cfg.JSON_SUFFIX_NAME = args.json_name cfg.TEST.IMG_PER_GPU = args.batch_size os.makedirs(cfg.TEST_DIR, exist_ok=True) logger = get_logger(cfg.DATASET.NAME, cfg.TEST_DIR, 0, 'test_log_{}.txt'.format(args.test_mode)) model = SMAP(cfg, run_efficient=cfg.RUN_EFFICIENT) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if args.test_mode == "run_inference": test_dataset = CustomDataset(cfg, args.dataset_path) data_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) else: data_loader = get_test_loader(cfg, num_gpu=1, local_rank=0, stage=args.data_mode) if cfg.REFINE: refine_model = RefineNet() refine_model.to(device) refine_model_file = args.RefineNet_path else: refine_model = None refine_model_file = "" model_file = args.SMAP_path if os.path.exists(model_file): state_dict = torch.load(model_file, map_location=lambda storage, loc: storage) state_dict = state_dict['model'] model.load_state_dict(state_dict) if os.path.exists(refine_model_file): refine_model.load_state_dict(torch.load(refine_model_file)) elif refine_model is not None: logger.info("No such RefineNet checkpoint of {}".format( args.RefineNet_path)) return generate_3d_point_pairs(model, refine_model, data_loader, cfg, logger, device, output_dir=os.path.join( cfg.OUTPUT_DIR, "result")) else: logger.info("No such checkpoint of SMAP {}".format(args.SMAP_path))