def inference(config_file, coco_to_kitti_dict): cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file(config_file)) cfg.DATALOADER.NUM_WORKERS = 4 cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file) cfg.DATASETS.TRAIN = ("kitti_mots_train", ) cfg.DATASETS.TEST = ("kitti_mots_test", ) cfg.SOLVER.IMS_PER_BATCH = 8 cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) trainer = DefaultTrainer(cfg) trainer.resume_or_load(resume=False) evaluator = COCOEvaluator("kitti_mots_test", cfg, False, output_dir="./output/") val_loader = build_detection_test_loader(cfg, "kitti_mots_test") inference_on_dataset(trainer.model, val_loader, evaluator) preds = evaluator._predictions filtered_preds = filter_preds(preds, coco_to_kitti_dict) evaluator._predictions = filtered_preds evaluator.evaluate()
def inference(config_file, correspondences): # test_set = 'kitti_mots_test' test_set = 'mots_challenge_train' cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file(config_file)) cfg.DATALOADER.NUM_WORKERS = 4 cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file) cfg.DATASETS.TRAIN = (test_set, ) cfg.DATASETS.TEST = (test_set, ) cfg.SOLVER.IMS_PER_BATCH = 16 cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 cfg.OUTPUT_DIR = "../week4/output/r50_fpn_cityscapes/" cfg.MODEL.WEIGHTS = "../week4/output/r50_fpn_cityscapes/model_final.pth" os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) trainer = DefaultTrainer(cfg) trainer.resume_or_load(resume=True) evaluator = COCOEvaluator(test_set, cfg, False, output_dir="./output/") print(evaluator._metadata.get("thing_classes")) val_loader = build_detection_test_loader(cfg, test_set) inference_on_dataset(trainer.model, val_loader, evaluator) preds = evaluator._predictions filtered_preds = filter_preds(preds, correspondences) evaluator._predictions = filtered_preds evaluator.evaluate() predictor = DefaultPredictor(cfg) motschallenge = DatasetCatalog.get(test_set) show_results(cfg, motschallenge, predictor)