コード例 #1
0
def main():
    args, cfg = parse_args()

    checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, args.checkpoint)
    model = utils.load_is_model(checkpoint_path, args.device, num_max_clicks=args.n_clicks)

    eval_exp_name = get_eval_exp_name(args)
    eval_exp_path = args.logs_path / eval_exp_name
    eval_exp_path.mkdir(parents=True, exist_ok=True)

    print_header = True
    for dataset_name in args.datasets.split(','):
        dataset = utils.get_dataset(dataset_name, cfg)

        zoom_in_target_size = 600 if dataset_name == 'DAVIS' else 400
        predictor = get_predictor(model, args.mode, args.device,
                                  prob_thresh=args.thresh,
                                  predictor_params={'num_max_points': args.n_clicks},
                                  zoom_in_params={'target_size': zoom_in_target_size})

        dataset_results = evaluate_dataset(dataset, predictor, pred_thr=args.thresh,
                                           max_iou_thr=args.target_iou,
                                           max_clicks=args.n_clicks)

        save_results(args, dataset_name, eval_exp_path, dataset_results,
                     print_header=print_header)
        print_header = False
コード例 #2
0
def main():
    args, cfg = parse_args()

    checkpoints_list, logs_path, logs_prefix = get_checkpoints_list_and_logs_path(
        args, cfg)
    logs_path.mkdir(parents=True, exist_ok=True)

    single_model_eval = len(checkpoints_list) == 1
    assert not args.iou_analysis if not single_model_eval else True, \
        "Can't perform IoU analysis for multiple checkpoints"
    print_header = single_model_eval
    for dataset_name in args.datasets.split(','):
        dataset = utils.get_dataset(dataset_name, cfg)

        for checkpoint_path in checkpoints_list:
            model = utils.load_is_model(checkpoint_path, args.device)

            predictor_params, zoomin_params = get_predictor_and_zoomin_params(
                args, dataset_name)
            predictor = get_predictor(model,
                                      args.mode,
                                      args.device,
                                      prob_thresh=args.thresh,
                                      predictor_params=predictor_params,
                                      zoom_in_params=zoomin_params)

            vis_callback = get_prediction_vis_callback(
                logs_path, dataset_name,
                args.thresh) if args.vis_preds else None
            dataset_results = evaluate_dataset(dataset,
                                               predictor,
                                               pred_thr=args.thresh,
                                               max_iou_thr=args.target_iou,
                                               min_clicks=args.min_n_clicks,
                                               max_clicks=args.n_clicks,
                                               callback=vis_callback)

            row_name = args.mode if single_model_eval else checkpoint_path.stem
            if args.iou_analysis:
                save_iou_analysis_data(args,
                                       dataset_name,
                                       logs_path,
                                       logs_prefix,
                                       dataset_results,
                                       model_name=args.model_name)

            save_results(args,
                         row_name,
                         dataset_name,
                         logs_path,
                         logs_prefix,
                         dataset_results,
                         save_ious=single_model_eval and args.save_ious,
                         single_model_eval=single_model_eval,
                         print_header=print_header)
            print_header = False