def __getitem__(self, idx):
        """ Returns (id, img) tuple """
        img_idx = self.ids[idx]
        img_path = Path(self.folder) / (img_idx + self.ext)
        anno = self.annos[img_idx]

        img = bbb.draw_boxes(img_path,
                             anno,
                             show_labels=self.anno_labels,
                             color=self.colors[0])
        for i, det in enumerate(self.dets):
            color_idx = (i % (len(self.colors) - 1)) + 1
            if img_idx in det:
                ok, nok = bbb.filter_split(det[img_idx], bbb.MatchFilter(anno))
                bbb.draw_boxes(img,
                               ok,
                               show_labels=self.det_labels,
                               color=self.colors[color_idx])
                bbb.draw_boxes(
                    img,
                    nok,
                    show_labels=self.det_labels,
                    color=self.colors[color_idx],
                    faded=lambda b: True,
                )

        return img_idx, img
示例#2
0
    def test_basic_cv2(self):
        """ Test if cv2 drawing works """
        img = np.zeros((25, 25, 3), np.uint8)
        res = bbb.draw_boxes(img.copy(), [self.anno], (255, 0, 0))
        cv2.rectangle(img, (1, 5), (11, 20), (0, 0, 255), 3)

        self.assertTrue(np.array_equal(img, res))
    def test_draw_faded(self):
        """ Test drawing faded bounding boxes """
        anno2 = bbb.Annotation()
        anno2.difficult = True
        anno2.class_label = 'object'
        anno2.object_id = 2
        anno2.x_top_left = 20
        anno2.y_top_left = 20
        anno2.width = 10
        anno2.height = 10

        img = Image.new('RGB', (40, 40))
        res = bbb.draw_boxes(img.copy(), [self.anno, anno2], (255, 0, 0),
                             True,
                             faded=lambda a: a.difficult)

        imgdraw = ImageDraw.Draw(img)
        imgdraw.line([(5, 6), (15, 6), (15, 21), (5, 21), (5, 6)], (255, 0, 0),
                     3)
        imgdraw.text((5, -9), 'object', (255, 0, 0), font)
        imgdraw.line([(20, 20), (30, 20), (30, 30), (20, 30), (20, 20)],
                     (255, 0, 0), 1)
        imgdraw.text((20, 7), 'object 2', (255, 0, 0), font)

        self.assertEqual(list(res.getdata()), list(img.getdata()))
示例#4
0
 def __getitem__(self, idx):
     """ Returns (id, img) tuple """
     img_idx = self.ids[idx]
     img_path = os.path.join(self.folder, img_idx + self.ext)
     img = bbb.draw_boxes(img_path,
                          self.boxes[img_idx],
                          show_labels=self.labels,
                          faded=self.faded)
     return img_idx, img
示例#5
0
    def test_basic_pil(self):
        """ Test if Pillow drawing works """
        img = Image.new('RGB', (25, 25))
        imgdraw = ImageDraw.Draw(img)
        res = bbb.draw_boxes(img.copy(), [self.anno], (255, 0, 0))
        imgdraw.line([(1, 5), (11, 5), (11, 20), (1, 20), (1, 5)], (255, 0, 0),
                     3)

        self.assertEqual(list(img.getdata()), list(res.getdata()))
    def test_draw_pil(self):
        """ Test if Pillow drawing works """
        img = Image.new('RGB', (25, 25))
        res = bbb.draw_boxes(img.copy(), [self.anno], (255, 0, 0), True)

        imgdraw = ImageDraw.Draw(img)
        imgdraw.line([(5, 6), (15, 6), (15, 21), (5, 21), (5, 6)], (255, 0, 0),
                     3)
        imgdraw.text((5, -9), 'object', (255, 0, 0), font)

        self.assertEqual(list(res.getdata()), list(img.getdata()))
    def test_draw_cv(self):
        """ Test if cv2 drawing works """
        img = np.zeros((25, 25, 3), np.uint8)
        res = bbb.draw_boxes(img.copy(), [self.anno], (255, 0, 0), True)

        cv2.rectangle(img, (5, 6), (15, 21), (0, 0, 255), 3)
        cv2.putText(
            img,
            'object',
            (5, 1),
            cv2.FONT_HERSHEY_PLAIN,
            0.75,
            (0, 0, 255),
            1,
            cv2.LINE_AA,
        )

        self.assertTrue(np.array_equal(img, res))
    def test_draw_detection(self):
        """ Test drawing a detection and printing its confidence value """
        det = bbb.Detection()
        det.confidence = 0.66
        det.class_label = 'obj'
        det.object_id = 1
        det.x_top_left = 10
        det.y_top_left = 10
        det.width = 10
        det.height = 10

        img = Image.new('RGB', (50, 50))
        res = bbb.draw_boxes(img.copy(), [det], (100, 0, 125), True)

        imgdraw = ImageDraw.Draw(img)
        imgdraw.line([(10, 10), (20, 10), (20, 20), (10, 20), (10, 10)],
                     (100, 0, 125), 3)
        imgdraw.text((10, -5), 'obj 1|66.00%', (100, 0, 125), font)

        self.assertEqual(list(res.getdata()), list(img.getdata()))
    def test_draw_color_cycle(self):
        """ Test color cycle """
        anno1 = bbb.Annotation()
        anno1.class_label = 'a'
        anno1.x_top_left = 1
        anno1.y_top_left = 1
        anno1.width = 3
        anno1.height = 3
        anno2 = bbb.Annotation()
        anno2.class_label = 'b'
        anno2.x_top_left = 5
        anno2.y_top_left = 1
        anno2.width = 3
        anno2.height = 3
        anno3 = bbb.Annotation()
        anno3.class_label = 'c'
        anno3.x_top_left = 1
        anno3.y_top_left = 5
        anno3.width = 3
        anno3.height = 3
        anno4 = bbb.Annotation()
        anno4.class_label = 'b'
        anno4.x_top_left = 5
        anno4.y_top_left = 5
        anno4.width = 3
        anno4.height = 3

        img = Image.new('RGB', (10, 10))
        res = bbb.draw_boxes(img.copy(), [anno1, anno2, anno3, anno4])

        imgdraw = ImageDraw.Draw(img)
        imgdraw.line([(1, 1), (4, 1), (4, 4), (1, 4), (1, 1)], (31, 119, 180),
                     3)
        imgdraw.line([(5, 1), (8, 1), (8, 4), (5, 4), (5, 1)], (255, 127, 14),
                     3)
        imgdraw.line([(1, 5), (4, 5), (4, 8), (1, 8), (1, 5)], (44, 160, 44),
                     3)
        imgdraw.line([(5, 5), (8, 5), (8, 8), (5, 8), (5, 5)], (255, 127, 14),
                     3)

        self.assertEqual(list(res.getdata()), list(img.getdata()))
示例#10
0
            device = torch.device('cuda')
        else:
            log.error('CUDA not available')

    # Network
    network = create_network()
    print(network)
    print()

    # Detection
    if len(args.image) > 0:
        for img_name in args.image:
            log.info(img_name)
            image, output = detect(network, img_name)

            image = bbb.draw_boxes(image, output[0], show_labels=args.label)
            if args.save:
                cv2.imwrite('detections.png', image)
            else:
                cv2.imshow('image', image)
                cv2.waitKey(0)
                cv2.destroyAllWindows()
    else:
        while True:
            try:
                img_path = input('Enter image path: ')
            except (KeyboardInterrupt, EOFError):
                print('')
                break

            if not os.path.isfile(img_path):