示例#1
0
    def test_image_false_when_no_file(self):
        path = '%s.jpg' % current_function_name()
        target = ImageTarget()

        status = target.test(path)

        self.assertFalse(status)
示例#2
0
    def test_image_true_when_true(self):
        with TestDir() as test_dir:
            path = osp.join(test_dir, 'test.jpg')
            save_image(path, np.ones([10, 7, 3]))

            target = ImageTarget()

            status = target.test(path)

            self.assertTrue(status)
示例#3
0
    def test_image_true_when_true(self):
        with TestDir() as test_dir:
            path = osp.join(test_dir.path, 'test.jpg')
            image = np.random.random_sample([10, 10, 3])
            cv2.imwrite(path, image)

            target = ImageTarget()

            status = target.test(path)

            self.assertTrue(status)
示例#4
0
    def test_image_false_when_false(self):
        with TestDir() as test_dir:
            path = osp.join(test_dir.path, 'test.jpg')
            with open(path, 'w+') as f:
                f.write('qwerty123')

            target = ImageTarget()

            status = target.test(path)

            self.assertFalse(status)
示例#5
0
def main(args=None):
    parser = build_parser()
    args = parser.parse_args(args)

    project_path = args.project_dir
    if is_project_path(project_path):
        project = Project.load(project_path)
    else:
        project = None
    try:
        args.target = target_selector(
            ProjectTarget(is_default=True, project=project),
            SourceTarget(project=project),
            ExternalDatasetTarget(),
            ImageTarget()
        )(args.target)
        if args.target[0] == TargetKinds.project:
            if is_project_path(args.target[1]):
                args.project_dir = osp.dirname(osp.abspath(args.target[1]))
    except argparse.ArgumentTypeError as e:
        print(e)
        parser.print_help()
        return 1

    return process_command(args.target, args.params, args)
示例#6
0
def explain_command(args):
    project_path = args.project_dir
    if is_project_path(project_path):
        project = Project.load(project_path)
    else:
        project = None
    args.target = target_selector(
        ProjectTarget(is_default=True, project=project),
        SourceTarget(project=project), ImageTarget())(args.target)
    if args.target[0] == TargetKinds.project:
        if is_project_path(args.target[1]):
            args.project_dir = osp.dirname(osp.abspath(args.target[1]))

    import cv2
    from matplotlib import cm

    project = load_project(args.project_dir)

    model = project.make_executable_model(args.model)

    if str(args.algorithm).lower() != 'rise':
        raise NotImplementedError()

    from datumaro.components.algorithms.rise import RISE
    rise = RISE(model,
                max_samples=args.max_samples,
                mask_width=args.mask_width,
                mask_height=args.mask_height,
                prob=args.prob,
                iou_thresh=args.iou_thresh,
                nms_thresh=args.nms_iou_thresh,
                det_conf_thresh=args.det_conf_thresh,
                batch_size=args.batch_size)

    if args.target[0] == TargetKinds.image:
        image_path = args.target[1]
        image = load_image(image_path)

        log.info("Running inference explanation for '%s'" % image_path)
        heatmap_iter = rise.apply(image, progressive=args.display)

        image = image / 255.0
        file_name = osp.splitext(osp.basename(image_path))[0]
        if args.display:
            for i, heatmaps in enumerate(heatmap_iter):
                for j, heatmap in enumerate(heatmaps):
                    hm_painted = cm.jet(heatmap)[:, :, 2::-1]
                    disp = (image + hm_painted) / 2
                    cv2.imshow('heatmap-%s' % j, hm_painted)
                    cv2.imshow(file_name + '-heatmap-%s' % j, disp)
                cv2.waitKey(10)
                print("Iter", i, "of", args.max_samples, end='\r')
        else:
            heatmaps = next(heatmap_iter)

        if args.save_dir is not None:
            log.info("Saving inference heatmaps at '%s'" % args.save_dir)
            os.makedirs(args.save_dir, exist_ok=True)

            for j, heatmap in enumerate(heatmaps):
                save_path = osp.join(args.save_dir,
                                     file_name + '-heatmap-%s.png' % j)
                save_image(save_path, heatmap * 255.0)
        else:
            for j, heatmap in enumerate(heatmaps):
                disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2
                cv2.imshow(file_name + '-heatmap-%s' % j, disp)
            cv2.waitKey(0)
    elif args.target[0] == TargetKinds.source or \
         args.target[0] == TargetKinds.project:
        if args.target[0] == TargetKinds.source:
            source_name = args.target[1]
            dataset = project.make_source_project(source_name).make_dataset()
            log.info("Running inference explanation for '%s'" % source_name)
        else:
            project_name = project.config.project_name
            dataset = project.make_dataset()
            log.info("Running inference explanation for '%s'" % project_name)

        for item in dataset:
            image = item.image.data
            if image is None:
                log.warn(
                    "Dataset item %s does not have image data. Skipping." % \
                    (item.id))
                continue

            heatmap_iter = rise.apply(image)

            image = image / 255.0
            heatmaps = next(heatmap_iter)

            if args.save_dir is not None:
                log.info("Saving inference heatmaps to '%s'" % args.save_dir)
                os.makedirs(args.save_dir, exist_ok=True)

                for j, heatmap in enumerate(heatmaps):
                    save_image(osp.join(args.save_dir,
                                        item.id + '-heatmap-%s.png' % j),
                               heatmap * 255.0,
                               create_dir=True)

            if not args.save_dir or args.display:
                for j, heatmap in enumerate(heatmaps):
                    disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2
                    cv2.imshow(item.id + '-heatmap-%s' % j, disp)
                cv2.waitKey(0)
    else:
        raise NotImplementedError()

    return 0
示例#7
0
    def test_image_false_when_no_file(self):
        target = ImageTarget()

        status = target.test('somepath.jpg')

        self.assertFalse(status)