def forward(self, anchors, box_cls, box_regression, coeffs, prototypes):
        sampled_boxes = []
        num_levels = len(box_cls)
        anchors = list(zip(*anchors))
        
        for a, c, r, co in zip(anchors, box_cls, box_regression, coeffs):
            sampled_boxes.append(self.forward_for_single_feature_map(a, c, r, co))

        boxlists = list(zip(*sampled_boxes))
        boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]

        if num_levels > 1:
            boxlists = self.select_over_all_levels(boxlists)

        results = []
        for prototypes_per_image, boxlists_per_image in zip(prototypes, boxlists):

            coeffs_per_image = boxlists_per_image.get_field("coeffs")

            # if DEBUG:
            #     print('range of prototypes_per_image:',\
            #         prototypes_per_image.min(), prototypes_per_image.max())

            # assemble mask
            masks_pred_per_image = prototypes_per_image.permute(1, 2, 0) @ coeffs_per_image.t()
            masks_pred_per_image = masks_pred_per_image.permute(2, 0, 1)
            masks_pred_per_image = self.mask_activation(masks_pred_per_image)

            # crop
            mask_h, mask_w = masks_pred_per_image.shape[1:]
            resized_pred_bbox = boxlists_per_image.resize((mask_w, mask_h))
            masks_pred_per_image = crop_zero_out(masks_pred_per_image, resized_pred_bbox.bbox)

            # binarize
            masks_pred_per_image = masks_pred_per_image > self.mask_threshold
            
            # convert mask predictions to polygon format to save memory
            if cfg.MODEL.YOLACT.CONVERT_MASK_TO_POLY:
                cpu_device = torch.device("cpu")
                masks_pred_per_image = SegmentationMask(masks_pred_per_image.to(cpu_device), \
                    (mask_w, mask_h), "mask")
                if DEBUG:
                    print(len(masks_pred_per_image), mask_w, mask_h)
                masks_pred_per_image = masks_pred_per_image.convert("poly")
            else:
                masks_pred_per_image = SegmentationMask(masks_pred_per_image, (mask_w, mask_h), "mask")
            
            if DEBUG:
                print(len(masks_pred_per_image), mask_w, mask_h)

            # resize
            img_w, img_h = boxlists_per_image.size
            masks_pred_per_image = masks_pred_per_image.resize((img_w, img_h))

            boxlists_per_image.add_field("masks", masks_pred_per_image)
            results.append(boxlists_per_image)

        return results
Exemple #2
0
class TestSegmentationMask(unittest.TestCase):
    def __init__(self, method_name='runTest'):
        super(TestSegmentationMask, self).__init__(method_name)
        poly = [[
            [
                423.0, 306.5, 406.5, 277.0, 400.0, 271.5, 389.5, 277.0, 387.5,
                292.0, 384.5, 295.0, 374.5, 220.0, 378.5, 210.0, 391.0, 200.5,
                404.0, 199.5, 414.0, 203.5, 425.5, 221.0, 438.5, 297.0, 423.0,
                306.5
            ],
            [100, 100, 200, 100, 200, 200, 100, 200],
        ]]
        width = 640
        height = 480
        size = width, height

        self.P = SegmentationMask(poly, size, 'poly')
        self.M = SegmentationMask(poly, size, 'poly').convert('mask')

    def L1(self, A, B):
        diff = A.get_mask_tensor() - B.get_mask_tensor()
        diff = torch.sum(torch.abs(diff.float())).item()
        return diff

    def test_convert(self):
        M_hat = self.M.convert('poly').convert('mask')
        P_hat = self.P.convert('mask').convert('poly')

        diff_mask = self.L1(self.M, M_hat)
        diff_poly = self.L1(self.P, P_hat)
        self.assertTrue(diff_mask == diff_poly)
        self.assertTrue(diff_mask <= 8169.)
        self.assertTrue(diff_poly <= 8169.)

    def test_crop(self):
        box = [400, 250, 500, 300]  # xyxy
        diff = self.L1(self.M.crop(box), self.P.crop(box))
        self.assertTrue(diff <= 1.)

    def test_resize(self):
        new_size = 50, 25
        M_hat = self.M.resize(new_size)
        P_hat = self.P.resize(new_size)
        diff = self.L1(M_hat, P_hat)

        self.assertTrue(self.M.size == self.P.size)
        self.assertTrue(M_hat.size == P_hat.size)
        self.assertTrue(self.M.size != M_hat.size)
        self.assertTrue(diff <= 255.)

    def test_transpose(self):
        FLIP_LEFT_RIGHT = 0
        FLIP_TOP_BOTTOM = 1
        diff_hor = self.L1(self.M.transpose(FLIP_LEFT_RIGHT),
                           self.P.transpose(FLIP_LEFT_RIGHT))

        diff_ver = self.L1(self.M.transpose(FLIP_TOP_BOTTOM),
                           self.P.transpose(FLIP_TOP_BOTTOM))

        self.assertTrue(diff_hor <= 53250.)
        self.assertTrue(diff_ver <= 42494.)