예제 #1
0
파일: bench.py 프로젝트: maxizi/THOR
def run_bench(delete_after=False, args_ext=None):
    if args_ext == None:
        args = parser.parse_args()
    else:
        args = args_ext

    cfg = load_cfg(args)
    cfg['THOR']['viz'] = args.viz
    cfg['THOR']['verbose'] = args.verbose

    # setup tracker and dataset
    if args.tracker == 'SiamFC':
        tracker = SiamFC_Tracker(cfg)
    elif args.tracker == 'SiamRPN':
        tracker = SiamRPN_Tracker(cfg)
    elif args.tracker == 'SiamMask':
        tracker = SiamMask_Tracker(cfg)
    else:
        raise ValueError(f"Tracker {args.tracker} does not exist.")

    dataset = load_dataset(args.dataset)
    # optionally filter for a specific videos
    if args.spec_video:
        dataset = {args.spec_video: dataset[args.spec_video]}

    if args.dataset == "VOT2018":
        test_bench, eval_bench = test_vot, eval_vot
    elif args.dataset == "OTB2015":
        test_bench, eval_bench = test_otb, eval_otb
    else:
        raise NotImplementedError(
            f"Procedure for {args.dataset} does not exist.")

    # testing
    total_lost = 0
    speed_list = []

    for v_id, video in enumerate(dataset.keys(), start=1):
        tracker.temp_mem.do_full_init = True
        speed = test_bench(v_id, tracker, dataset[video], args)
        speed_list.append(speed)

    # evaluation
    mean_fps = np.mean(np.array(speed_list))
    bench_res = eval_bench(args.save_path, delete_after)
    bench_res['mean_fps'] = mean_fps
    print(bench_res)

    return bench_res
예제 #2
0
def load_cfg(args):
    json_path = f"configs/{args.tracker}/VOT2018_"
    if args.vanilla:
        json_path += "vanilla.json"
    else:
        json_path += f"THOR_{args.lb_type}.json"
    cfg = json.load(open(json_path))
    return cfg


if __name__ == '__main__':
    args = parser.parse_args()

    cfg = load_cfg(args)
    cfg['THOR']['viz'] = args.viz
    cfg['THOR']['verbose'] = args.verbose

    print("[INFO] Initializing the tracker.")
    if args.tracker == 'SiamFC':
        tracker = SiamFC_Tracker(cfg)
    elif args.tracker == 'SiamRPN':
        tracker = SiamRPN_Tracker(cfg)
    elif args.tracker == 'SiamMask':
        tracker = SiamMask_Tracker(cfg)
#    elif args.tracker == 'SiamRPN_PP':
#        tracker = SiamRPN_PP_Tracker(cfg)
    else:
        raise ValueError(f"Tracker {args.tracker} does not exist.")

    print("[INFO] Starting video stream.")
    show_webcam(tracker, mirror=True, viz=args.viz)
예제 #3
0
def run_bench(delete_after=False):
    args = parser.parse_args()

    cfg = load_cfg(args)
    cfg['THOR']['viz'] = args.viz
    cfg['THOR']['verbose'] = args.verbose

    # setup tracker and dataset
    if args.tracker == 'SiamFC':
        tracker = SiamFC_Tracker(cfg)
    elif args.tracker == 'SiamRPN':
        tracker = SiamRPN_Tracker(cfg)
    elif args.tracker == 'SiamMask':
        tracker = SiamMask_Tracker(cfg)
    else:
        raise ValueError(f"Tracker {args.tracker} does not exist.")

    dataset = load_dataset(args.dataset)
    # optionally filter for a specific videos
    if args.spec_video:

        # pdb.set_trace()
        dataset = {args.spec_video: dataset[args.spec_video]}

    if args.dataset == "VOT2018":
        test_bench, eval_bench = test_vot, eval_vot
    elif args.dataset == "OTB2015":
        test_bench, eval_bench = test_otb, eval_otb
    elif args.dataset == "GOT10k":
        test_bench = test_got
    elif args.dataset == "GOT10k_train_val":
        test_bench = test_gottrainval
    elif args.dataset == "LaSOT":
        test_bench, eval_bench = test_lasot, eval_lasot
    elif args.dataset == "UAV20L":
        test_bench, eval_bench = test_uav20l, eval_uav20l
    elif args.dataset == "UAV123":
        test_bench, eval_bench = test_uav123, eval_uav123
    elif args.dataset == "OXUVA":
        test_bench = test_oxuva
    elif args.dataset == "TC128":
        test_bench, eval_bench = test_tc128, eval_tc128
    else:
        raise NotImplementedError(
            f"Procedure for {args.dataset} does not exist.")

    # testing
    total_lost = 0
    speed_list = []

    if args.dataset == "OTB2015":
        print("==>> No processing for the json file ... ")
    else:
        dataset = ast.literal_eval(dataset)
    # pdb.set_trace()

    for v_id, video in enumerate(dataset.keys(), start=1):
        tracker.temp_mem.do_full_init = True
        speed = test_bench(v_id, tracker, dataset[video], args)
        speed_list.append(speed)

    if args.dataset == "GOT10k":
        print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
        print("==>> Please evaluate online for GOT10k dataset ... ")
        print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
    elif args.dataset == "OxUvA":
        print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
        print("==>> Please evaluate online for OxUvA dataset ... ")
        print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
    else:
        # evaluation
        # pdb.set_trace()
        bench_res = eval_bench(args.save_path, delete_after)
        print(bench_res)
        mean_fps = np.mean(np.array(speed_list))
        bench_res['mean_fps'] = mean_fps
        print(bench_res)

        return bench_res