Beispiel #1
0
    def test_rise_can_be_applied_to_classification_model(self):
        class TestLauncher(Launcher):
            def __init__(self, class_count, roi, **kwargs):
                self.class_count = class_count
                self.roi = roi

            def launch(self, inputs):
                for inp in inputs:
                    yield self._process(inp)

            def _process(self, image):
                roi = self.roi
                roi_area = (roi[1] - roi[0]) * (roi[3] - roi[2])
                if 0.5 * roi_area < np.sum(image[roi[0]:roi[1], roi[2]:roi[3],
                                                 0]):
                    cls = 0
                else:
                    cls = 1

                cls_conf = 0.5
                other_conf = (1.0 - cls_conf) / (self.class_count - 1)

                return [
                    LabelObject(i, attributes={
                        'score': cls_conf if cls == i else other_conf }) \
                    for i in range(self.class_count)
                ]

        roi = [70, 90, 7, 90]
        model = TestLauncher(class_count=3, roi=roi)

        rise = RISE(model, max_samples=(7 * 7)**2, mask_width=7, mask_height=7)

        image = np.ones((100, 100, 3))
        heatmaps = next(rise.apply(image))

        self.assertEqual(1, len(heatmaps))

        heatmap = heatmaps[0]
        self.assertEqual(image.shape[:2], heatmap.shape)

        h_sum = np.sum(heatmap)
        h_area = np.prod(heatmap.shape)
        roi_sum = np.sum(heatmap[roi[0]:roi[1], roi[2]:roi[3]])
        roi_area = (roi[1] - roi[0]) * (roi[3] - roi[2])
        roi_den = roi_sum / roi_area
        hrest_den = (h_sum - roi_sum) / (h_area - roi_area)
        self.assertLess(hrest_den, roi_den)
Beispiel #2
0
    def DISABLED_test_roi_nms():
        ROI = namedtuple('ROI', ['conf', 'x', 'y', 'w', 'h', 'label'])

        class_count = 3
        noisy_count = 3
        rois = [
            ROI(0.3, 10, 40, 30, 10, 0),
            ROI(0.5, 70, 90, 7, 10, 0),
            ROI(0.7, 5, 20, 40, 60, 2),
            ROI(0.9, 30, 20, 10, 40, 1),
        ]
        pixel_jitter = 10

        detections = []
        for i, roi in enumerate(rois):
            detections.append(
                Bbox(roi.x,
                     roi.y,
                     roi.w,
                     roi.h,
                     label=roi.label,
                     attributes={'score': roi.conf}))

            for j in range(noisy_count):
                cls_conf = roi.conf * j / noisy_count
                cls = (i + j) % class_count
                box = [roi.x, roi.y, roi.w, roi.h]
                offset = (np.random.rand(4) - 0.5) * pixel_jitter
                detections.append(
                    Bbox(*(box + offset),
                         label=cls,
                         attributes={'score': cls_conf}))

        import cv2
        image = np.zeros((100, 100, 3))
        for i, det in enumerate(detections):
            roi = ROI(det.attributes['score'], *det.get_bbox(), det.label)
            p1 = (int(roi.x), int(roi.y))
            p2 = (int(roi.x + roi.w), int(roi.y + roi.h))
            c = (0, 1 * (i % (1 + noisy_count) == 0), 1)
            cv2.rectangle(image, p1, p2, c)
            cv2.putText(image, 'd%s-%s-%.2f' % (i, roi.label, roi.conf), p1,
                        cv2.FONT_HERSHEY_SIMPLEX, 0.25, c)
        cv2.imshow('nms_image', image)
        cv2.waitKey(0)

        nms_boxes = RISE.nms(detections, iou_thresh=0.25)
        print(len(detections), len(nms_boxes))

        for i, det in enumerate(nms_boxes):
            roi = ROI(det.attributes['score'], *det.get_bbox(), det.label)
            p1 = (int(roi.x), int(roi.y))
            p2 = (int(roi.x + roi.w), int(roi.y + roi.h))
            c = (0, 1, 0)
            cv2.rectangle(image, p1, p2, c)
            cv2.putText(image, 'p%s-%s-%.2f' % (i, roi.label, roi.conf), p1,
                        cv2.FONT_HERSHEY_SIMPLEX, 0.25, c)
        cv2.imshow('nms_image', image)
        cv2.waitKey(0)
Beispiel #3
0
def explain_command(args):
    project_path = args.project_dir
    if is_project_path(project_path):
        project = Project.load(project_path)
    else:
        project = None
    args.target = target_selector(
        ProjectTarget(is_default=True, project=project),
        SourceTarget(project=project), ImageTarget())(args.target)
    if args.target[0] == TargetKinds.project:
        if is_project_path(args.target[1]):
            args.project_dir = osp.dirname(osp.abspath(args.target[1]))

    import cv2
    from matplotlib import cm

    project = load_project(args.project_dir)

    model = project.make_executable_model(args.model)

    if str(args.algorithm).lower() != 'rise':
        raise NotImplementedError()

    from datumaro.components.algorithms.rise import RISE
    rise = RISE(model,
                max_samples=args.max_samples,
                mask_width=args.mask_width,
                mask_height=args.mask_height,
                prob=args.prob,
                iou_thresh=args.iou_thresh,
                nms_thresh=args.nms_iou_thresh,
                det_conf_thresh=args.det_conf_thresh,
                batch_size=args.batch_size)

    if args.target[0] == TargetKinds.image:
        image_path = args.target[1]
        image = load_image(image_path)

        log.info("Running inference explanation for '%s'" % image_path)
        heatmap_iter = rise.apply(image, progressive=args.display)

        image = image / 255.0
        file_name = osp.splitext(osp.basename(image_path))[0]
        if args.display:
            for i, heatmaps in enumerate(heatmap_iter):
                for j, heatmap in enumerate(heatmaps):
                    hm_painted = cm.jet(heatmap)[:, :, 2::-1]
                    disp = (image + hm_painted) / 2
                    cv2.imshow('heatmap-%s' % j, hm_painted)
                    cv2.imshow(file_name + '-heatmap-%s' % j, disp)
                cv2.waitKey(10)
                print("Iter", i, "of", args.max_samples, end='\r')
        else:
            heatmaps = next(heatmap_iter)

        if args.save_dir is not None:
            log.info("Saving inference heatmaps at '%s'" % args.save_dir)
            os.makedirs(args.save_dir, exist_ok=True)

            for j, heatmap in enumerate(heatmaps):
                save_path = osp.join(args.save_dir,
                                     file_name + '-heatmap-%s.png' % j)
                save_image(save_path, heatmap * 255.0)
        else:
            for j, heatmap in enumerate(heatmaps):
                disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2
                cv2.imshow(file_name + '-heatmap-%s' % j, disp)
            cv2.waitKey(0)
    elif args.target[0] == TargetKinds.source or \
         args.target[0] == TargetKinds.project:
        if args.target[0] == TargetKinds.source:
            source_name = args.target[1]
            dataset = project.make_source_project(source_name).make_dataset()
            log.info("Running inference explanation for '%s'" % source_name)
        else:
            project_name = project.config.project_name
            dataset = project.make_dataset()
            log.info("Running inference explanation for '%s'" % project_name)

        for item in dataset:
            image = item.image.data
            if image is None:
                log.warn(
                    "Dataset item %s does not have image data. Skipping." % \
                    (item.id))
                continue

            heatmap_iter = rise.apply(image)

            image = image / 255.0
            heatmaps = next(heatmap_iter)

            if args.save_dir is not None:
                log.info("Saving inference heatmaps to '%s'" % args.save_dir)
                os.makedirs(args.save_dir, exist_ok=True)

                for j, heatmap in enumerate(heatmaps):
                    save_image(osp.join(args.save_dir,
                                        item.id + '-heatmap-%s.png' % j),
                               heatmap * 255.0,
                               create_dir=True)

            if not args.save_dir or args.display:
                for j, heatmap in enumerate(heatmaps):
                    disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2
                    cv2.imshow(item.id + '-heatmap-%s' % j, disp)
                cv2.waitKey(0)
    else:
        raise NotImplementedError()

    return 0
Beispiel #4
0
    def test_rise_can_be_applied_to_detection_model(self):
        ROI = namedtuple('ROI', ['threshold', 'x', 'y', 'w', 'h', 'label'])

        class TestLauncher(Launcher):
            def __init__(self,
                         rois,
                         class_count,
                         fp_count=4,
                         pixel_jitter=20,
                         **kwargs):
                self.rois = rois
                self.roi_base_sums = [
                    None,
                ] * len(rois)
                self.class_count = class_count
                self.fp_count = fp_count
                self.pixel_jitter = pixel_jitter

            @staticmethod
            def roi_value(roi, image):
                return np.sum(image[roi.y:roi.y + roi.h,
                                    roi.x:roi.x + roi.w, :])

            def launch(self, inputs):
                for inp in inputs:
                    yield self._process(inp)

            def _process(self, image):
                detections = []
                for i, roi in enumerate(self.rois):
                    roi_sum = self.roi_value(roi, image)
                    roi_base_sum = self.roi_base_sums[i]
                    first_run = roi_base_sum is None
                    if first_run:
                        roi_base_sum = roi_sum
                        self.roi_base_sums[i] = roi_base_sum

                    cls_conf = roi_sum / roi_base_sum

                    if roi.threshold < roi_sum / roi_base_sum:
                        cls = roi.label
                        detections.append(
                            BboxObject(roi.x,
                                       roi.y,
                                       roi.w,
                                       roi.h,
                                       label=cls,
                                       attributes={'score': cls_conf}))

                    if first_run:
                        continue
                    for j in range(self.fp_count):
                        if roi.threshold < cls_conf:
                            cls = roi.label
                        else:
                            cls = (i + j) % self.class_count
                        box = [roi.x, roi.y, roi.w, roi.h]
                        offset = (np.random.rand(4) - 0.5) * self.pixel_jitter
                        detections.append(
                            BboxObject(*(box + offset),
                                       label=cls,
                                       attributes={'score': cls_conf}))

                return detections

        rois = [
            ROI(0.3, 10, 40, 30, 10, 0),
            ROI(0.5, 70, 90, 7, 10, 0),
            ROI(0.7, 5, 20, 40, 60, 2),
            ROI(0.9, 30, 20, 10, 40, 1),
        ]
        model = model = TestLauncher(class_count=3, rois=rois)

        rise = RISE(model, max_samples=(7 * 7)**2, mask_width=7, mask_height=7)

        image = np.ones((100, 100, 3))
        heatmaps = next(rise.apply(image))
        heatmaps_class_count = len(set([roi.label for roi in rois]))
        self.assertEqual(heatmaps_class_count + len(rois), len(heatmaps))

        # roi_image = image.copy()
        # for i, roi in enumerate(rois):
        #     cv2.rectangle(roi_image, (roi.x, roi.y), (roi.x + roi.w, roi.y + roi.h), (32 * i) * 3)
        # cv2.imshow('img', roi_image)

        for c in range(heatmaps_class_count):
            class_roi = np.zeros(image.shape[:2])
            for i, roi in enumerate(rois):
                if roi.label != c:
                    continue
                class_roi[roi.y:roi.y + roi.h, roi.x:roi.x + roi.w] \
                    += roi.threshold

            heatmap = heatmaps[c]

            roi_pixels = heatmap[class_roi != 0]
            h_sum = np.sum(roi_pixels)
            h_area = np.sum(roi_pixels != 0)
            h_den = h_sum / h_area

            rest_pixels = heatmap[class_roi == 0]
            r_sum = np.sum(rest_pixels)
            r_area = np.sum(rest_pixels != 0)
            r_den = r_sum / r_area

            # print(r_den, h_den)
            # cv2.imshow('class %s' % c, heatmap)
            self.assertLess(r_den, h_den)

        for i, roi in enumerate(rois):
            heatmap = heatmaps[heatmaps_class_count + i]
            h_sum = np.sum(heatmap)
            h_area = np.prod(heatmap.shape)
            roi_sum = np.sum(heatmap[roi.y:roi.y + roi.h, roi.x:roi.x + roi.w])
            roi_area = roi.h * roi.w
            roi_den = roi_sum / roi_area
            hrest_den = (h_sum - roi_sum) / (h_area - roi_area)
            # print(hrest_den, h_den)
            # cv2.imshow('roi %s' % i, heatmap)
            self.assertLess(hrest_den, roi_den)
Beispiel #5
0
def explain_command(args):
    from matplotlib import cm
    import cv2

    project = scope_add(load_project(args.project_dir))

    model = project.working_tree.models.make_executable_model(args.model)

    if str(args.algorithm).lower() != 'rise':
        raise NotImplementedError()

    from datumaro.components.algorithms.rise import RISE
    rise = RISE(model,
                max_samples=args.max_samples,
                mask_width=args.mask_width,
                mask_height=args.mask_height,
                prob=args.prob,
                iou_thresh=args.iou_thresh,
                nms_thresh=args.nms_iou_thresh,
                det_conf_thresh=args.det_conf_thresh,
                batch_size=args.batch_size)

    if args.target and is_image(args.target):
        image_path = args.target
        image = load_image(image_path)

        log.info("Running inference explanation for '%s'" % image_path)
        heatmap_iter = rise.apply(image, progressive=args.display)

        image = image / 255.0
        file_name = osp.splitext(osp.basename(image_path))[0]
        if args.display:
            for i, heatmaps in enumerate(heatmap_iter):
                for j, heatmap in enumerate(heatmaps):
                    hm_painted = cm.jet(heatmap)[:, :, 2::-1]
                    disp = (image + hm_painted) / 2
                    cv2.imshow('heatmap-%s' % j, hm_painted)
                    cv2.imshow(file_name + '-heatmap-%s' % j, disp)
                cv2.waitKey(10)
                print("Iter", i, "of", args.max_samples, end='\r')
        else:
            heatmaps = next(heatmap_iter)

        if args.save_dir is not None:
            log.info("Saving inference heatmaps at '%s'" % args.save_dir)
            os.makedirs(args.save_dir, exist_ok=True)

            for j, heatmap in enumerate(heatmaps):
                save_path = osp.join(args.save_dir,
                                     file_name + '-heatmap-%s.png' % j)
                save_image(save_path, heatmap * 255.0)
        else:
            for j, heatmap in enumerate(heatmaps):
                disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2
                cv2.imshow(file_name + '-heatmap-%s' % j, disp)
            cv2.waitKey(0)

    else:
        dataset, target_project = \
            parse_full_revpath(args.target or 'project', project)
        if target_project:
            scope_add(target_project)

        log.info("Running inference explanation for '%s'" % args.target)

        for item in dataset:
            image = item.image.data
            if image is None:
                log.warning("Item %s does not have image data. Skipping.",
                            item.id)
                continue

            heatmap_iter = rise.apply(image)

            image = image / 255.0
            heatmaps = next(heatmap_iter)

            if args.save_dir is not None:
                log.info("Saving inference heatmaps to '%s'" % args.save_dir)
                os.makedirs(args.save_dir, exist_ok=True)

                for j, heatmap in enumerate(heatmaps):
                    save_image(osp.join(args.save_dir,
                                        item.id + '-heatmap-%s.png' % j),
                               heatmap * 255.0,
                               create_dir=True)

            if not args.save_dir or args.display:
                for j, heatmap in enumerate(heatmaps):
                    disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2
                    cv2.imshow(item.id + '-heatmap-%s' % j, disp)
                cv2.waitKey(0)

    return 0