Esempio n. 1
0
def inference(config_file, coco_to_kitti_dict):
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file(config_file))
    cfg.DATALOADER.NUM_WORKERS = 4
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file)
    cfg.DATASETS.TRAIN = ("kitti_mots_train", )
    cfg.DATASETS.TEST = ("kitti_mots_test", )
    cfg.SOLVER.IMS_PER_BATCH = 8
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512

    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = DefaultTrainer(cfg)
    trainer.resume_or_load(resume=False)

    evaluator = COCOEvaluator("kitti_mots_test",
                              cfg,
                              False,
                              output_dir="./output/")
    val_loader = build_detection_test_loader(cfg, "kitti_mots_test")
    inference_on_dataset(trainer.model, val_loader, evaluator)

    preds = evaluator._predictions

    filtered_preds = filter_preds(preds, coco_to_kitti_dict)
    evaluator._predictions = filtered_preds

    evaluator.evaluate()
Esempio n. 2
0
def inference(config_file, correspondences):

    # test_set = 'kitti_mots_test'
    test_set = 'mots_challenge_train'

    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file(config_file))
    cfg.DATALOADER.NUM_WORKERS = 4
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file)
    cfg.DATASETS.TRAIN = (test_set, )
    cfg.DATASETS.TEST = (test_set, )
    cfg.SOLVER.IMS_PER_BATCH = 16
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
    cfg.OUTPUT_DIR = "../week4/output/r50_fpn_cityscapes/"
    cfg.MODEL.WEIGHTS = "../week4/output/r50_fpn_cityscapes/model_final.pth"

    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = DefaultTrainer(cfg)
    trainer.resume_or_load(resume=True)

    evaluator = COCOEvaluator(test_set, cfg, False, output_dir="./output/")
    print(evaluator._metadata.get("thing_classes"))
    val_loader = build_detection_test_loader(cfg, test_set)
    inference_on_dataset(trainer.model, val_loader, evaluator)

    preds = evaluator._predictions

    filtered_preds = filter_preds(preds, correspondences)
    evaluator._predictions = filtered_preds

    evaluator.evaluate()

    predictor = DefaultPredictor(cfg)
    motschallenge = DatasetCatalog.get(test_set)
    show_results(cfg, motschallenge, predictor)