예제 #1
0
파일: eval.py 프로젝트: Jinming-Su/STMask
def validation(net: STMask, valid_data=False, output_metrics_file=None):
    cfg.mask_proto_debug = args.mask_proto_debug
    if not valid_data:
        cfg.valid_sub_dataset.test_mode = True
        dataset = get_dataset(cfg.valid_sub_dataset)
    else:
        cfg.valid_dataset.test_mode = True
        dataset = get_dataset(cfg.valid_dataset)

    frame_times = MovingAverage()
    dataset_size = math.ceil(len(dataset) /
                             args.batch_size) if args.max_images < 0 else min(
                                 args.max_images, len(dataset))
    progress_bar = ProgressBar(30, dataset_size)

    print()
    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  shuffle=False,
                                  collate_fn=detection_collate,
                                  pin_memory=True)
    results = []

    try:
        # Main eval loop
        for it, data_batch in enumerate(data_loader):
            timer.reset()
            with timer.env('Load Data'):
                images, images_meta, ref_images, ref_images_meta = prepare_data(
                    data_batch, is_cuda=True, train_mode=False)
                pad_h, pad_w = images.size()[2:4]

            with timer.env('Network Extra'):
                preds = net(images,
                            img_meta=images_meta,
                            ref_x=ref_images,
                            ref_imgs_meta=ref_images_meta)

                if it == dataset_size - 1:
                    batch_size = len(dataset) % args.batch_size
                else:
                    batch_size = images.size(0)

                for batch_id in range(batch_size):
                    cfg.preserve_aspect_ratio = True
                    preds_cur = postprocess_ytbvis(
                        preds[batch_id],
                        pad_h,
                        pad_w,
                        images_meta[batch_id],
                        score_threshold=cfg.eval_conf_thresh)
                    segm_results = bbox2result_with_id(preds_cur, cfg.classes)
                    results.append(segm_results)

            # First couple of images take longer because we're constructing the graph.
            # Since that's technically initialization, don't include those in the FPS calculations.
            if it > 1:
                if batch_size == 0:
                    batch_size = 1
                frame_times.add(timer.total_time() / batch_size)

            if it > 1 and frame_times.get_avg() > 0:
                fps = 1 / frame_times.get_avg()
            else:
                fps = 0
            progress = (it + 1) / dataset_size * 100
            progress_bar.set_val(it + 1)
            print(
                '\rProcessing Images  %s %6d / %6d (%5.2f%%)    %5.2f fps        '
                % (repr(progress_bar), it + 1, dataset_size, progress, fps),
                end='')

        print()
        print('Dumping detections...')

        if not valid_data:
            results2json_videoseg(dataset, results, args.mask_det_file)
            print('calculate evaluation metrics ...')
            ann_file = cfg.valid_sub_dataset.ann_file
            dt_file = args.mask_det_file
            calc_metrics(ann_file, dt_file, output_file=output_metrics_file)
        else:
            results2json_videoseg(dataset, results,
                                  output_metrics_file.replace('.txt', '.json'))

    except KeyboardInterrupt:
        print('Stopping...')
예제 #2
0
파일: eval.py 프로젝트: Jinming-Su/STMask
def evaluate(net: STMask, dataset):
    net.detect.use_fast_nms = args.fast_nms
    cfg.mask_proto_debug = args.mask_proto_debug

    frame_times = MovingAverage()
    dataset_size = math.ceil(len(dataset) /
                             args.batch_size) if args.max_images < 0 else min(
                                 args.max_images, len(dataset))
    progress_bar = ProgressBar(30, dataset_size)

    print()

    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  shuffle=False,
                                  collate_fn=detection_collate,
                                  pin_memory=True)
    results = []

    try:
        # Main eval loop
        for it, data_batch in enumerate(data_loader):
            timer.reset()

            with timer.env('Load Data'):
                images, images_meta, ref_images, ref_images_meta = prepare_data(
                    data_batch, is_cuda=True, train_mode=False)
            pad_h, pad_w = images.size()[2:4]

            with timer.env('Network Extra'):
                preds = net(images,
                            img_meta=images_meta,
                            ref_x=ref_images,
                            ref_imgs_meta=ref_images_meta)

            # Perform the meat of the operation here depending on our mode.
            if it == dataset_size - 1:
                batch_size = len(dataset) % args.batch_size
            else:
                batch_size = images.size(0)

            for batch_id in range(batch_size):
                if args.display:
                    img_id = (images_meta[batch_id]['video_id'],
                              images_meta[batch_id]['frame_id'])
                    if not cfg.display_mask_single:
                        img_numpy = prep_display(
                            preds[batch_id],
                            images[batch_id],
                            pad_h,
                            pad_w,
                            img_meta=images_meta[batch_id],
                            img_ids=img_id)
                    else:
                        for p in range(
                                preds[batch_id]['detection']['box'].size(0)):
                            preds_single = {'detection': {}}
                            for k in preds[batch_id]['detection']:
                                if preds[batch_id]['detection'][
                                        k] is not None and k not in {'proto'}:
                                    preds_single['detection'][k] = preds[
                                        batch_id]['detection'][k][p]
                                else:
                                    preds_single['detection'][k] = None
                            preds_single['net'] = preds[batch_id]['net']
                            preds_single['detection'][
                                'box_ids'] = torch.tensor(-1)

                            img_numpy = prep_display(
                                preds_single,
                                images[batch_id],
                                pad_h,
                                pad_w,
                                img_meta=images_meta[batch_id],
                                img_ids=img_id)
                            plt.imshow(img_numpy)
                            plt.axis('off')
                            plt.savefig(''.join([
                                args.mask_det_file[:-12], 'out_single/',
                                str(img_id), '_',
                                str(p), '.png'
                            ]))
                            plt.clf()

                else:
                    cfg.preserve_aspect_ratio = True
                    preds_cur = postprocess_ytbvis(
                        preds[batch_id],
                        pad_h,
                        pad_w,
                        images_meta[batch_id],
                        score_threshold=cfg.eval_conf_thresh)
                    segm_results = bbox2result_with_id(preds_cur, cfg.classes)
                    results.append(segm_results)

                # First couple of images take longer because we're constructing the graph.
                # Since that's technically initialization, don't include those in the FPS calculations.
                if it > 1:
                    frame_times.add(timer.total_time() / batch_size)

                if args.display and not cfg.display_mask_single:
                    if it > 1:
                        print('Avg FPS: %.4f' % (1 / frame_times.get_avg()))
                    plt.imshow(img_numpy)
                    plt.axis('off')
                    plt.title(str(img_id))

                    root_dir = ''.join([
                        args.mask_det_file[:-12], 'out/',
                        str(images_meta[batch_id]['video_id']), '/'
                    ])
                    if not os.path.exists(root_dir):
                        os.makedirs(root_dir)
                    plt.savefig(''.join([
                        root_dir,
                        str(images_meta[batch_id]['frame_id']), '.png'
                    ]))
                    plt.clf()
                    # plt.show()
                elif not args.no_bar:
                    if it > 1: fps = 1 / frame_times.get_avg()
                    else: fps = 0
                    progress = (it + 1) / dataset_size * 100
                    progress_bar.set_val(it + 1)
                    print(
                        '\rProcessing Images  %s %6d / %6d (%5.2f%%)    %5.2f fps        '
                        % (repr(progress_bar), it + 1, dataset_size, progress,
                           fps),
                        end='')

        if not args.display and not args.benchmark:
            print()
            if args.output_json:
                print('Dumping detections...')
                results2json_videoseg(dataset, results, args.mask_det_file)

                if cfg.use_valid_sub or cfg.use_train_sub:
                    if cfg.use_valid_sub:
                        print('calculate evaluation metrics ...')
                        ann_file = cfg.valid_sub_dataset.ann_file
                    else:
                        print('calculate train_sub metrics ...')
                        ann_file = cfg.train_dataset.ann_file
                    dt_file = args.mask_det_file
                    metrics = calc_metrics(ann_file, dt_file)

                    return metrics

        elif args.benchmark:
            print()
            print()
            print('Stats for the last frame:')
            timer.print_stats()
            avg_seconds = frame_times.get_avg()
            print('Average: %5.2f fps, %5.2f ms' %
                  (1 / frame_times.get_avg(), 1000 * avg_seconds))

    except KeyboardInterrupt:
        print('Stopping...')
예제 #3
0
파일: eval.py 프로젝트: Jinming-Su/STMask
def prep_display_single(dets_out,
                        img,
                        pad_h,
                        pad_w,
                        img_ids=None,
                        img_meta=None,
                        undo_transform=True,
                        mask_alpha=0.45,
                        fps_str='',
                        display_mode=None):
    """
    Note: If undo_transform=False then im_h and im_w are allowed to be None.
    -- display_model: 'train', 'test', 'None' means groundtruth results
    """

    if undo_transform:
        img_numpy = undo_image_transformation(img, img_meta, pad_h, pad_w)
        img_gpu = torch.Tensor(img_numpy).cuda()
    else:
        img_gpu = img / 255.0
        pad_h, pad_w, _ = img.shape

    with timer.env('Postprocess'):
        cfg.mask_proto_debug = args.mask_proto_debug
        cfg.preserve_aspect_ratio = False
        dets_out = postprocess_ytbvis(dets_out,
                                      pad_h,
                                      pad_w,
                                      img_meta,
                                      display_mask=True,
                                      visualize_lincomb=args.display_lincomb,
                                      crop_masks=args.crop,
                                      score_threshold=cfg.eval_conf_thresh,
                                      img_ids=img_ids,
                                      mask_det_file=args.mask_det_file)
        torch.cuda.synchronize()
        scores = dets_out['score'][:args.top_k].detach().cpu().numpy()
        boxes = dets_out['box'][:args.top_k].detach().cpu().numpy()

    if 'segm' in dets_out:
        masks = dets_out['segm'][:args.top_k]
        args.display_masks = True
    else:
        args.display_masks = False

    classes = dets_out['class'][:args.top_k].detach().cpu().numpy()

    num_dets_to_consider = min(args.top_k, classes.shape[0])
    color_type = dets_out['box_ids']
    for j in range(num_dets_to_consider):
        if scores[j] < args.score_threshold:
            num_dets_to_consider = j
            break

    if num_dets_to_consider == 0:
        # No detections found so just output the original image
        return (img_gpu * 255).byte().cpu().numpy()

    # First, draw the masks on the GPU where we can do it really fast
    # Beware: very fast but possibly unintelligible mask-drawing code ahead
    # I wish I had access to OpenGL or Vulkan but alas, I guess Pytorch tensor operations will have to suffice
    if args.display_masks and cfg.eval_mask_branch:
        # After this, mask is of size [num_dets, h, w, 1]
        masks = masks[:num_dets_to_consider, :, :, None]

        # Prepare the RGB images for each mask given their color (size [num_dets, h, w, 1])
        colors = torch.cat([
            get_color(j,
                      color_type,
                      on_gpu=img_gpu.device.index,
                      undo_transform=undo_transform).view(1, 1, 1, 3)
            for j in range(num_dets_to_consider)
        ],
                           dim=0)
        masks_color = masks.repeat(1, 1, 1, 3) * colors * mask_alpha

        # This is 1 everywhere except for 1-mask_alpha where the mask is
        inv_alph_masks = masks * (-mask_alpha) + 1

        # I did the math for this on pen and paper. This whole block should be equivalent to:
        #    for j in range(num_dets_to_consider):
        #        img_gpu = img_gpu * inv_alph_masks[j] + masks_color[j]
        masks_color_summand = masks_color[0]
        if num_dets_to_consider > 1:
            inv_alph_cumul = inv_alph_masks[:(num_dets_to_consider -
                                              1)].cumprod(dim=0)
            masks_color_cumul = masks_color[1:] * inv_alph_cumul
            masks_color_summand += masks_color_cumul.sum(dim=0)
        img_gpu = img_gpu * inv_alph_masks.prod(dim=0) + masks_color_summand

    if args.display_fps:
        # Draw the box for the fps on the GPU
        font_face = cv2.FONT_HERSHEY_DUPLEX
        font_scale = 0.6
        font_thickness = 1

        text_w, text_h = cv2.getTextSize(fps_str, font_face, font_scale,
                                         font_thickness)[0]

        img_gpu[0:text_h + 8, 0:text_w + 8] *= 0.6  # 1 - Box alpha

    # Then draw the stuff that needs to be done on the cpu
    # Note, make sure this is a uint8 tensor or opencv will not anti alias text for whatever reason
    img_numpy = (img_gpu * 255).byte().cpu().numpy()

    if args.display_fps:
        # Draw the text on the CPU
        text_pt = (4, text_h + 2)
        text_color = [255, 255, 255]

        cv2.putText(img_numpy, fps_str, text_pt, font_face, font_scale,
                    text_color, font_thickness, cv2.LINE_AA)

    if args.display_text or args.display_bboxes:
        for j in reversed(range(num_dets_to_consider)):
            x1, y1, x2, y2 = boxes[j, :]
            color = get_color(j, color_type)
            # plot priors
            h, w, _ = img_meta['img_shape']
            priors = dets_out['priors'].detach().cpu().numpy()
            if j < dets_out['priors'].size(0):
                cpx, cpy, pw, ph = priors[j, :] * [w, h, w, h]
                px1, py1 = cpx - pw / 2.0, cpy - ph / 2.0
                px2, py2 = cpx + pw / 2.0, cpy + ph / 2.0
                px1, py1, px2, py2 = int(px1), int(py1), int(px2), int(py2)
                pcolor = [255, 0, 255]

            # plot the range of features for classification and regression
            pred_scales = [24, 48, 96, 192, 384]
            x = torch.clamp(torch.tensor([x1, x2]), min=2, max=638).tolist(),
            y = torch.clamp(torch.tensor([y1, y2]), min=2, max=358).tolist(),
            x, y = x[0], y[0]

            if display_mode is not None:
                score = scores[j]

            if args.display_bboxes:
                cv2.rectangle(img_numpy, (x[0], y[0]), (x[1], y[1]), color, 1)
                if j < dets_out['priors'].size(0):
                    cv2.rectangle(img_numpy, (px1, py1), (px2, py2),
                                  pcolor,
                                  2,
                                  lineType=8)
                # cv2.rectangle(img_numpy, (x[4], y[4]), (x[5], y[5]), fcolor, 2)

            if args.display_text:
                if classes[j] - 1 < 0:
                    _class = 'None'
                else:
                    _class = cfg.classes[classes[j] - 1]

                if display_mode == 'test':
                    # if cfg.use_maskiou and not cfg.rescore_bbox:
                    train_centerness = False
                    if train_centerness:
                        rescore = dets_out['DIoU_score'][j] * score
                        text_str = '%s: %.2f: %.2f: %s' % (_class, score, rescore, str(color_type[j].cpu().numpy())) \
                            if args.display_scores else _class
                    else:

                        text_str = '%s: %.2f: %s' % (
                            _class, score, str(color_type[j].cpu().numpy())
                        ) if args.display_scores else _class
                else:
                    text_str = '%s' % _class

                font_face = cv2.FONT_HERSHEY_DUPLEX
                font_scale = 0.5
                font_thickness = 1

                text_w, text_h = cv2.getTextSize(text_str, font_face,
                                                 font_scale, font_thickness)[0]

                text_pt = (x1, y1 - 3)
                text_color = [255, 255, 255]
                cv2.rectangle(img_numpy, (x1, y1),
                              (x1 + text_w, y1 - text_h - 4), color, -1)
                cv2.putText(img_numpy, text_str, text_pt, font_face,
                            font_scale, text_color, font_thickness,
                            cv2.LINE_AA)

    return img_numpy