class FPNSSD512(nn.Module):
    num_anchors = 9

    def __init__(self, num_classes, pretrained=True, **kwargs):
        super(FPNSSD512, self).__init__()
        self.fpn = FPN50()
        self.num_classes = num_classes + 1  # Dummy class
        self.loc_head = self._make_head(self.num_anchors * 4)
        self.cls_head = self._make_head(self.num_anchors * self.num_classes)
        self.box_coder = FPNSSDBoxCoder()

        resnet_state = resnet50(pretrained=pretrained).state_dict()
        self.fpn.load_state_dict(resnet_state, strict=False)
        # new_state_dict = OrderedDict()
        # for k, v in resnet_state.items():
        #     if str.startswith(k, 'conv1'):
        #         continue
        #     new_state_dict[k] = v

    def forward(self, image):
        loc_preds = []
        cls_preds = []
        fms = self.fpn(image)
        for fm in fms:
            loc_pred = self.loc_head(fm)
            cls_pred = self.cls_head(fm)
            loc_pred = loc_pred.permute(0, 2, 3, 1).contiguous().view(
                image.size(0), -1,
                4)  # [N, 9*4,H,W] -> [N,H,W, 9*4] -> [N,H*W*9, 4]
            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(
                image.size(0), -1, self.num_classes
            )  # [N,9*NC,H,W] -> [N,H,W,9*NC] -> [N,H*W*9,NC]
            loc_preds.append(loc_pred)
            cls_preds.append(cls_pred)

        bboxes = torch.cat(loc_preds, 1)
        labels = torch.cat(cls_preds, 1)

        return bboxes, labels
        # return {
        #     SSD_BBOXES_KEY: bboxes,
        #     SSD_LABELS_KEY: labels,
        # }

    def _make_head(self, out_planes):
        layers = []
        for _ in range(4):
            layers.append(
                nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1))
            layers.append(nn.ReLU(True))
        layers.append(
            nn.Conv2d(256, out_planes, kernel_size=3, stride=1, padding=1))
        return nn.Sequential(*layers)

    def predict(self, image):
        import albumentations as A
        self.eval()

        normalize = A.Normalize()
        image = normalize(image=image)['image']

        slicer = ImageSlicer(image.shape, 512, 512 // 2)
        patches = [
            tensor_from_rgb_image(patch)
            for patch in slicer.split(image, borderType=cv2.BORDER_CONSTANT)
        ]
        offsets = torch.tensor([[crop[0], crop[1], crop[0], crop[1]]
                                for crop in slicer.bbox_crops],
                               dtype=torch.float32)

        all_bboxes = []
        all_labels = []

        with torch.set_grad_enabled(False):
            for patch, patch_loc in DataLoader(list(zip(patches, offsets)),
                                               batch_size=8,
                                               pin_memory=True):
                patch = patch.to(self.fpn.conv1.weight.device)
                bboxes, labels = self(patch)

                all_bboxes.extend(bboxes.cpu())
                all_labels.extend(labels.cpu())

        boxes, labels, scores = self.box_coder.decode_multi(
            all_bboxes, all_labels, offsets)
        return to_numpy(boxes), to_numpy(labels), to_numpy(scores)
Beispiel #2
0
class MobilenetSSD512(nn.Module):
    num_anchors = 9

    def __init__(self, num_classes, **kwargs):
        super(MobilenetSSD512, self).__init__()
        encoder = MobileNetV2()

        self.layer0 = encoder.layer0
        self.layer1 = encoder.layer1
        self.layer2 = encoder.layer2
        self.layer3 = encoder.layer3
        self.layer4 = encoder.layer4
        self.layer5 = encoder.layer5
        self.layer6 = encoder.layer6
        self.layer7 = encoder.layer7

        self.conv6 = nn.Conv2d(320, 64, kernel_size=3, stride=2, padding=1)
        self.conv7 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
        self.conv8 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
        self.conv9 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)

        # Top-down layers
        self.toplayer = nn.Conv2d(320, 64, kernel_size=1, stride=1, padding=0)

        # Lateral layers
        self.latlayer1 = nn.Conv2d(96, 64, kernel_size=1, stride=1, padding=0)
        self.latlayer2 = nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0)

        # Smooth layers
        self.smooth1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.smooth2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)

        self.num_classes = num_classes + 1  # Dummy class
        self.loc_head = self._make_head(self.num_anchors * 4)
        self.cls_head = self._make_head(self.num_anchors * self.num_classes)
        self.box_coder = FPNSSDBoxCoder()

    def forward(self, image):

        # Extract features
        c0 = self.layer0(image)
        c1 = self.layer1(c0)
        c2 = self.layer2(c1)
        c3 = self.layer3(c2)
        c4 = self.layer4(c3)
        c5 = self.layer5(c4)
        c6 = self.layer6(c5)
        c7 = self.layer7(c6)

        # print(c0.size())
        # print(c2.size())
        # print(c3.size())
        # print(c4.size())
        # print(c5.size())
        # print(c6.size())
        # print(c7.size())

        p6 = self.conv6(c7)
        p7 = self.conv7(F.relu(p6))
        p8 = self.conv8(F.relu(p7))
        p9 = self.conv9(F.relu(p8))

        # Top-down
        p5 = self.toplayer(c7)
        p4 = self._upsample_add(p5, self.latlayer1(c5))
        p3 = self._upsample_add(p4, self.latlayer2(c3))
        p4 = self.smooth1(p4)
        p3 = self.smooth2(p3)

        features = [p3, p4, p5, p6, p7, p8, p9]

        loc_preds = []
        cls_preds = []

        for fm in features:
            loc_pred = self.loc_head(fm)
            cls_pred = self.cls_head(fm)
            loc_pred = loc_pred.permute(0, 2, 3, 1).contiguous().view(
                image.size(0), -1,
                4)  # [N, 9*4,H,W] -> [N,H,W, 9*4] -> [N,H*W*9, 4]
            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(
                image.size(0), -1, self.num_classes
            )  # [N,9*NC,H,W] -> [N,H,W,9*NC] -> [N,H*W*9,NC]
            loc_preds.append(loc_pred)
            cls_preds.append(cls_pred)

        bboxes = torch.cat(loc_preds, 1)
        labels = torch.cat(cls_preds, 1)

        return bboxes, labels

    def _upsample_add(self, x, y, scale_factor=2):
        '''Upsample and add two feature maps.

        Args:
          x: (Variable) top feature map to be upsampled.
          y: (Variable) lateral feature map.

        Returns:
          (Variable) added feature map.

        Note in PyTorch, when input size is odd, the upsampled feature map
        with `F.upsample(..., scale_factor=2, mode='nearest')`
        maybe not equal to the lateral feature map size.

        e.g.
        original input size: [N,_,15,15] ->
        conv2d feature map size: [N,_,8,8] ->
        upsampled feature map size: [N,_,16,16]

        So we choose bilinear upsample which supports arbitrary output sizes.
        '''
        _, _, H, W = y.size()
        # print(x.size(), y.size())
        return F.interpolate(x, scale_factor=2, mode='nearest') + y

    def _make_head(self, out_planes):
        layers = []
        for _ in range(4):
            layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1,
                                    padding=1))
            layers.append(nn.ReLU(True))
        layers.append(
            nn.Conv2d(64, out_planes, kernel_size=3, stride=1, padding=1))
        return nn.Sequential(*layers)

    def predict(self, image):
        import albumentations as A
        self.eval()

        normalize = A.Normalize()
        image = normalize(image=image)['image']

        slicer = ImageSlicer(image.shape, 512, 512 // 2)
        patches = [
            tensor_from_rgb_image(patch)
            for patch in slicer.split(image, borderType=cv2.BORDER_CONSTANT)
        ]
        offsets = torch.tensor([[crop[0], crop[1], crop[0], crop[1]]
                                for crop in slicer.bbox_crops],
                               dtype=torch.float32)

        all_bboxes = []
        all_labels = []

        with torch.set_grad_enabled(False):
            for patch, patch_loc in DataLoader(list(zip(patches, offsets)),
                                               batch_size=8,
                                               pin_memory=True):
                patch = patch.to(self.conv6.weight.device)
                bboxes, labels = self(patch)

                all_bboxes.extend(bboxes.cpu())
                all_labels.extend(labels.cpu())

        boxes, labels, scores = self.box_coder.decode_multi(
            all_bboxes, all_labels, offsets)
        return to_numpy(boxes), to_numpy(labels), to_numpy(scores)