def main():
    args, cfg = parse_args()

    checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, args.checkpoint)
    model = utils.load_is_model(checkpoint_path, args.device, num_max_clicks=args.n_clicks)

    eval_exp_name = get_eval_exp_name(args)
    eval_exp_path = args.logs_path / eval_exp_name
    eval_exp_path.mkdir(parents=True, exist_ok=True)

    print_header = True
    for dataset_name in args.datasets.split(','):
        dataset = utils.get_dataset(dataset_name, cfg)

        zoom_in_target_size = 600 if dataset_name == 'DAVIS' else 400
        predictor = get_predictor(model, args.mode, args.device,
                                  prob_thresh=args.thresh,
                                  predictor_params={'num_max_points': args.n_clicks},
                                  zoom_in_params={'target_size': zoom_in_target_size})

        dataset_results = evaluate_dataset(dataset, predictor, pred_thr=args.thresh,
                                           max_iou_thr=args.target_iou,
                                           max_clicks=args.n_clicks)

        save_results(args, dataset_name, eval_exp_path, dataset_results,
                     print_header=print_header)
        print_header = False
def main():
    args, cfg = parse_args()

    checkpoints_list, logs_path, logs_prefix = get_checkpoints_list_and_logs_path(
        args, cfg)
    logs_path.mkdir(parents=True, exist_ok=True)

    single_model_eval = len(checkpoints_list) == 1
    assert not args.iou_analysis if not single_model_eval else True, \
        "Can't perform IoU analysis for multiple checkpoints"
    print_header = single_model_eval
    for dataset_name in args.datasets.split(','):
        dataset = utils.get_dataset(dataset_name, cfg)

        for checkpoint_path in checkpoints_list:
            model = utils.load_is_model(checkpoint_path, args.device)

            predictor_params, zoomin_params = get_predictor_and_zoomin_params(
                args, dataset_name)
            predictor = get_predictor(model,
                                      args.mode,
                                      args.device,
                                      prob_thresh=args.thresh,
                                      predictor_params=predictor_params,
                                      zoom_in_params=zoomin_params)

            vis_callback = get_prediction_vis_callback(
                logs_path, dataset_name,
                args.thresh) if args.vis_preds else None
            dataset_results = evaluate_dataset(dataset,
                                               predictor,
                                               pred_thr=args.thresh,
                                               max_iou_thr=args.target_iou,
                                               min_clicks=args.min_n_clicks,
                                               max_clicks=args.n_clicks,
                                               callback=vis_callback)

            row_name = args.mode if single_model_eval else checkpoint_path.stem
            if args.iou_analysis:
                save_iou_analysis_data(args,
                                       dataset_name,
                                       logs_path,
                                       logs_prefix,
                                       dataset_results,
                                       model_name=args.model_name)

            save_results(args,
                         row_name,
                         dataset_name,
                         logs_path,
                         logs_prefix,
                         dataset_results,
                         save_ious=single_model_eval and args.save_ious,
                         single_model_eval=single_model_eval,
                         print_header=print_header)
            print_header = False
Ejemplo n.º 3
0
def main():
    args, cfg = parse_args()

    # get model
    torch.backends.cudnn.deterministic = True
    checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH,
                                            args.checkpoint)
    model = utils.load_is_model(checkpoint_path,
                                args.device,
                                cpu_dist_maps=True,
                                norm_radius=args.norm_radius)
    predictor = get_predictor(model.to(args.device),
                              device=args.device,
                              brs_mode=args.mode)
    clicker = clicker_.Clicker()
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([.485, .456, .406], [.229, .224, .225])
    ])

    # get image
    image = cv2.cvtColor(cv2.imread(args.imgpath), cv2.COLOR_BGR2RGB)
    image_nd = input_transform(image).to(args.device)
    if image_nd is not None:
        predictor.set_input_image(image_nd)

    probs_history = []
    clicks = json.load(open('clicks.json', 'r'))['clicks']
    for c in clicks:
        click = clicker_.Click(is_positive=c['is_positive'],
                               coords=(c['y'], c['x']))
        clicker.add_click(click)
        pred = predictor.get_prediction(clicker)
        torch.cuda.empty_cache()

        if probs_history:
            probs_history.append((probs_history[-1][0], pred))
        else:
            probs_history.append((np.zeros_like(pred), pred))

    if probs_history:
        current_prob_total, current_prob_additive = probs_history[-1]
        final_pred = np.maximum(current_prob_total, current_prob_additive)
        final_mask = (final_pred > 0.5).astype(np.uint8)
    else:
        final_mask = np.ones_like(pred).astype(np.uint8)

    if args.sis:
        final_mask = structural_integrity_strategy(final_mask, clicker)

    cv2.imwrite(
        args.outpath,
        cv2.cvtColor(image, cv2.COLOR_RGB2BGR) *
        np.expand_dims(final_mask, axis=2))
def main():
    args, cfg = parse_args()

    torch.backends.cudnn.deterministic = True
    checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, args.checkpoint)
    model = utils.load_is_model(checkpoint_path, args.device, cpu_dist_maps=True, norm_radius=args.norm_radius)

    root = tk.Tk()
    root.minsize(960, 480)
    app = InteractiveDemoApp(root, args, model)
    root.deiconify()
    app.mainloop()
Ejemplo n.º 5
0
    def __init__(self):
        torch.backends.cudnn.deterministic = True
        base_dir = os.path.abspath(
            os.environ.get("MODEL_PATH", "/opt/nuclio/hrnet"))
        model_path = os.path.join(base_dir)

        self.net = None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        checkpoint_path = utils.find_checkpoint(model_path,
                                                "coco_lvis_h18_itermask.pth")
        self.net = utils.load_is_model(checkpoint_path, self.device)
Ejemplo n.º 6
0
def main():
    args, cfg = parse_args()

    torch.backends.cudnn.deterministic = True
    checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, args.checkpoint)
    model = utils.load_is_model(checkpoint_path, args.device, cpu_dist_maps=True, norm_radius=args.norm_radius)

    x_coords = (300, 300, 150)
    y_coords = (300, 400, 300)
    is_pos = (1, 1, 0)
    img_pth = '/Users/jason/Desktop/img_1205.jpg'
    image = cv2.cvtColor(cv2.imread(img_pth), cv2.COLOR_BGR2RGB)
    print(type(image))
    print(sys.getsizeof(image))

    var1 = BetterThanApp.AppReplacement(image, args, model, x_coords, y_coords, is_pos)
cfg = exp.load_config_file('./config.yml', return_edict=True)

## init dateset
# Possible choices: 'GrabCut', 'Berkeley', 'DAVIS', 'COCO_MVal', 'SBD'
DATASET = 'Berkeley'
dataset = utils.get_dataset(DATASET, cfg)

## init model
from isegm.inference.predictors import get_predictor

EVAL_MAX_CLICKS = 20
MODEL_THRESH = 0.49

checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH,
                                        'coco_lvis_h18_itermask')
model = utils.load_is_model(checkpoint_path, device)

# Possible choices: 'NoBRS', 'f-BRS-A', 'f-BRS-B', 'f-BRS-C', 'RGB-BRS', 'DistMap-BRS'
brs_mode = 'f-BRS-B'
predictor = get_predictor(model, brs_mode, device, prob_thresh=MODEL_THRESH)

## single sample test
sample_id = random.sample(range(len(dataset)), 1)[0]
TARGET_IOU = 0.95

sample = dataset.get_sample(sample_id)
gt_mask = sample.gt_mask

clicks_list, ious_arr, pred = evaluate_sample(sample.image,
                                              gt_mask,
                                              predictor,