def main():
    args, cfg = parse_args()

    checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, args.checkpoint)
    model = utils.load_is_model(checkpoint_path, args.device, num_max_clicks=args.n_clicks)

    eval_exp_name = get_eval_exp_name(args)
    eval_exp_path = args.logs_path / eval_exp_name
    eval_exp_path.mkdir(parents=True, exist_ok=True)

    print_header = True
    for dataset_name in args.datasets.split(','):
        dataset = utils.get_dataset(dataset_name, cfg)

        zoom_in_target_size = 600 if dataset_name == 'DAVIS' else 400
        predictor = get_predictor(model, args.mode, args.device,
                                  prob_thresh=args.thresh,
                                  predictor_params={'num_max_points': args.n_clicks},
                                  zoom_in_params={'target_size': zoom_in_target_size})

        dataset_results = evaluate_dataset(dataset, predictor, pred_thr=args.thresh,
                                           max_iou_thr=args.target_iou,
                                           max_clicks=args.n_clicks)

        save_results(args, dataset_name, eval_exp_path, dataset_results,
                     print_header=print_header)
        print_header = False
 def reset_predictor(self, predictor_params=None):
     if predictor_params is not None:
         self.predictor_params = predictor_params
     self.predictor = get_predictor(self.net, device=self.device,
                                    **self.predictor_params)
     if self.image_nd is not None:
         self.predictor.set_input_image(self.image_nd)
示例#3
0
    def handle(self, image, pos_points, neg_points, threshold):
        image_nd = np.array(image)

        clicker = Clicker()
        for x, y in pos_points:
            click = Click(is_positive=True, coords=(y, x))
            clicker.add_click(click)

        for x, y in neg_points:
            click = Click(is_positive=False, coords=(y, x))
            clicker.add_click(click)

        predictor = get_predictor(self.net,
                                  'NoBRS',
                                  device=self.device,
                                  prob_thresh=0.49)
        predictor.set_input_image(image_nd)

        object_prob = predictor.get_prediction(clicker)
        if self.device == 'cuda':
            torch.cuda.empty_cache()
        object_mask = object_prob > threshold
        polygon = convert_mask_to_polygon(object_mask)

        return polygon
示例#4
0
    def handle(self, image, pos_points, neg_points, threshold):
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225])
        ])

        image_nd = input_transform(image).to(self.device)

        clicker = Clicker()
        for x, y in pos_points:
            click = Click(is_positive=True, coords=(y, x))
            clicker.add_click(click)

        for x, y in neg_points:
            click = Click(is_positive=False, coords=(y, x))
            clicker.add_click(click)

        predictor_params = {
            'brs_mode': 'f-BRS-B',
            'brs_opt_func_params': {'min_iou_diff': 0.001},
            'lbfgs_params': {'maxfun': 20},
            'predictor_params': {'max_size': 800, 'net_clicks_limit': 8},
            'prob_thresh': threshold,
            'zoom_in_params': {'expansion_ratio': 1.4, 'skip_clicks': 1, 'target_size': 480}}
        predictor = get_predictor(self.net, device=self.device,
            **predictor_params)
        predictor.set_input_image(image_nd)

        object_prob = predictor.get_prediction(clicker)
        if self.device == 'cuda':
            torch.cuda.empty_cache()
        object_mask = object_prob > threshold
        polygon = convert_mask_to_polygon(object_mask)

        return polygon
def main():
    args, cfg = parse_args()

    checkpoints_list, logs_path, logs_prefix = get_checkpoints_list_and_logs_path(
        args, cfg)
    logs_path.mkdir(parents=True, exist_ok=True)

    single_model_eval = len(checkpoints_list) == 1
    assert not args.iou_analysis if not single_model_eval else True, \
        "Can't perform IoU analysis for multiple checkpoints"
    print_header = single_model_eval
    for dataset_name in args.datasets.split(','):
        dataset = utils.get_dataset(dataset_name, cfg)

        for checkpoint_path in checkpoints_list:
            model = utils.load_is_model(checkpoint_path, args.device)

            predictor_params, zoomin_params = get_predictor_and_zoomin_params(
                args, dataset_name)
            predictor = get_predictor(model,
                                      args.mode,
                                      args.device,
                                      prob_thresh=args.thresh,
                                      predictor_params=predictor_params,
                                      zoom_in_params=zoomin_params)

            vis_callback = get_prediction_vis_callback(
                logs_path, dataset_name,
                args.thresh) if args.vis_preds else None
            dataset_results = evaluate_dataset(dataset,
                                               predictor,
                                               pred_thr=args.thresh,
                                               max_iou_thr=args.target_iou,
                                               min_clicks=args.min_n_clicks,
                                               max_clicks=args.n_clicks,
                                               callback=vis_callback)

            row_name = args.mode if single_model_eval else checkpoint_path.stem
            if args.iou_analysis:
                save_iou_analysis_data(args,
                                       dataset_name,
                                       logs_path,
                                       logs_prefix,
                                       dataset_results,
                                       model_name=args.model_name)

            save_results(args,
                         row_name,
                         dataset_name,
                         logs_path,
                         logs_prefix,
                         dataset_results,
                         save_ious=single_model_eval and args.save_ious,
                         single_model_eval=single_model_eval,
                         print_header=print_header)
            print_header = False
示例#6
0
def main():
    args, cfg = parse_args()

    # get model
    torch.backends.cudnn.deterministic = True
    checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH,
                                            args.checkpoint)
    model = utils.load_is_model(checkpoint_path,
                                args.device,
                                cpu_dist_maps=True,
                                norm_radius=args.norm_radius)
    predictor = get_predictor(model.to(args.device),
                              device=args.device,
                              brs_mode=args.mode)
    clicker = clicker_.Clicker()
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([.485, .456, .406], [.229, .224, .225])
    ])

    # get image
    image = cv2.cvtColor(cv2.imread(args.imgpath), cv2.COLOR_BGR2RGB)
    image_nd = input_transform(image).to(args.device)
    if image_nd is not None:
        predictor.set_input_image(image_nd)

    probs_history = []
    clicks = json.load(open('clicks.json', 'r'))['clicks']
    for c in clicks:
        click = clicker_.Click(is_positive=c['is_positive'],
                               coords=(c['y'], c['x']))
        clicker.add_click(click)
        pred = predictor.get_prediction(clicker)
        torch.cuda.empty_cache()

        if probs_history:
            probs_history.append((probs_history[-1][0], pred))
        else:
            probs_history.append((np.zeros_like(pred), pred))

    if probs_history:
        current_prob_total, current_prob_additive = probs_history[-1]
        final_pred = np.maximum(current_prob_total, current_prob_additive)
        final_mask = (final_pred > 0.5).astype(np.uint8)
    else:
        final_mask = np.ones_like(pred).astype(np.uint8)

    if args.sis:
        final_mask = structural_integrity_strategy(final_mask, clicker)

    cv2.imwrite(
        args.outpath,
        cv2.cvtColor(image, cv2.COLOR_RGB2BGR) *
        np.expand_dims(final_mask, axis=2))
示例#7
0
    def _reset_predictor(self):
        brs_mode = self.state['brs_mode'].get()
        prob_thresh = self.state['prob_thresh'].get()
        net_clicks_limit = None if brs_mode == 'NoBRS' else self.state['predictor_params']['net_clicks_limit'].get()

        self.predictor = get_predictor(self.net, brs_mode, self.device, prob_thresh=prob_thresh,
                                       zoom_in_params=self.zoomin_params,
                                       predictor_params={
                                           'net_clicks_limit': net_clicks_limit,
                                           'max_size': self.limit_longest_size
                                       },
                                       brs_opt_func_params={'min_iou_diff': 1e-3},
                                       lbfgs_params={'maxfun': self.state['lbfgs_max_iters'].get()})

        if self.state['_image_nd'] is not None:
            self.predictor.set_input_image(self.state['_image_nd'])
DATASET = 'Berkeley'
dataset = utils.get_dataset(DATASET, cfg)

## init model
from isegm.inference.predictors import get_predictor

EVAL_MAX_CLICKS = 20
MODEL_THRESH = 0.49

checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH,
                                        'coco_lvis_h18_itermask')
model = utils.load_is_model(checkpoint_path, device)

# Possible choices: 'NoBRS', 'f-BRS-A', 'f-BRS-B', 'f-BRS-C', 'RGB-BRS', 'DistMap-BRS'
brs_mode = 'f-BRS-B'
predictor = get_predictor(model, brs_mode, device, prob_thresh=MODEL_THRESH)

## single sample test
sample_id = random.sample(range(len(dataset)), 1)[0]
TARGET_IOU = 0.95

sample = dataset.get_sample(sample_id)
gt_mask = sample.gt_mask

clicks_list, ious_arr, pred = evaluate_sample(sample.image,
                                              gt_mask,
                                              predictor,
                                              pred_thr=MODEL_THRESH,
                                              max_iou_thr=TARGET_IOU,
                                              max_clicks=EVAL_MAX_CLICKS)