Exemplo n.º 1
0
class ResNetRoIHead(chainer.Chain):

    mask_size = 14  # Size of the predicted mask.

    def __init__(self, n_layers, n_class, roi_size, spatial_scale,
                 pretrained_model='auto',
                 res_initialW=None, loc_initialW=None, score_initialW=None,
                 mask_initialW=None, pooling_func=functions.roi_align_2d,
                 n_mask_class=3):
        # n_class includes the background
        super(ResNetRoIHead, self).__init__()
        with self.init_scope():
            self.res5 = BuildingBlock(
                3, 1024, 512, 2048, stride=roi_size // 7,
                initialW=res_initialW)
            self.cls_loc = L.Linear(2048, n_class * 4, initialW=loc_initialW)
            self.score = L.Linear(2048, n_class, initialW=score_initialW)

            # 7 x 7 x 2048 -> 14 x 14 x 256
            self.deconv6 = L.Deconvolution2D(
                2048, 256, 2, stride=2, initialW=mask_initialW)
            # 14 x 14 x 256 -> 14 x 14 x 20
            n_fg_class = n_class - 1
            self.mask = L.Convolution2D(
                256, n_fg_class * n_mask_class, 1, initialW=mask_initialW)

        self.n_class = n_class
        self.roi_size = roi_size
        self.spatial_scale = spatial_scale
        self.pooling_func = pooling_func

        _convert_bn_to_affine(self)

        if pretrained_model == 'auto':
            self._copy_imagenet_pretrained_resnet(n_layers)
        else:
            assert pretrained_model is None, \
                'Unsupported pretrained_model: {}'.format(pretrained_model)

    def _copy_imagenet_pretrained_resnet(self, n_layers):
        if n_layers == 50:
            pretrained_model = ResNet50Extractor(pretrained_model='auto')
        elif n_layers == 101:
            pretrained_model = ResNet101Extractor(pretrained_model='auto')
        else:
            raise ValueError
        self.res5.copyparams(pretrained_model.res5)
        _copy_persistent_chain(self.res5, pretrained_model.res5)

    def __call__(self, x, rois, roi_indices, pred_bbox=True, pred_mask=True):
        roi_indices = roi_indices.astype(np.float32)
        indices_and_rois = self.xp.concatenate(
            (roi_indices[:, None], rois), axis=1)
        pool = self.pooling_func(
            x,
            indices_and_rois,
            outh=self.roi_size,
            outw=self.roi_size,
            spatial_scale=self.spatial_scale,
            axes='yx',
        )

        res5 = self.res5(pool)

        roi_cls_locs = None
        roi_scores = None
        roi_masks = None

        if pred_bbox:
            pool5 = F.average_pooling_2d(res5, 7, stride=7)
            roi_cls_locs = self.cls_loc(pool5)
            roi_scores = self.score(pool5)

        if pred_mask:
            deconv6 = F.relu(self.deconv6(res5))
            roi_masks = self.mask(deconv6)

        return roi_cls_locs, roi_scores, roi_masks
Exemplo n.º 2
0
class ResNetRoIHead(chainer.Chain):
    def __init__(self,
                 n_layers,
                 n_class,
                 roi_size,
                 spatial_scale,
                 pretrained_model='auto',
                 res_initialW=None,
                 loc_initialW=None,
                 score_initialW=None,
                 mask_initialW=None,
                 pooling_func=functions.roi_align_2d,
                 mask_loss='softmax'):
        # n_class includes the background
        super(ResNetRoIHead, self).__init__()
        with self.init_scope():
            self.res5 = BuildingBlock(3,
                                      1024,
                                      512,
                                      2048,
                                      stride=roi_size // 7,
                                      initialW=res_initialW)
            self.cls_loc = L.Linear(2048, n_class * 4, initialW=loc_initialW)
            self.score = L.Linear(2048, n_class, initialW=score_initialW)

            self.mask_loss = mask_loss

            # 7 x 7 x 2048 -> 14 x 14 x 256
            self.deconv6 = L.Deconvolution2D(2048,
                                             256,
                                             2,
                                             stride=2,
                                             initialW=mask_initialW)
            # 14 x 14 x 256
            n_fg_class = n_class - 1
            if self.mask_loss in ['softmax', 'softmax_x2']:
                # -> 14 x 14 x (n_fg_class * 3) (bg vs. vis vs. inv)
                self.mask = L.Convolution2D(256,
                                            n_fg_class * 3,
                                            1,
                                            initialW=mask_initialW)
            elif self.mask_loss == 'sigmoid_softmax':
                # -> 14 x 14 x n_fg_class: vis
                self.mask = L.Convolution2D(256,
                                            n_fg_class,
                                            1,
                                            initialW=mask_initialW)
                # -> 14 x 14 x (n_fg_class * 2): bg vs. inv
                self.mask_bginv = L.Convolution2D(256,
                                                  n_fg_class * 2,
                                                  1,
                                                  initialW=mask_initialW)
            elif self.mask_loss in ['sigmoid_sigmoid', 'sigmoid_sigmoid+']:
                # -> 14 x 14 x n_fg_class: vis
                self.mask = L.Convolution2D(256,
                                            n_fg_class,
                                            1,
                                            initialW=mask_initialW)
                # -> 14 x 14 x n_fg_class: vis + inv
                self.mask_visinv = L.Convolution2D(256,
                                                   n_fg_class,
                                                   1,
                                                   initialW=mask_initialW)
            elif self.mask_loss in [
                    'softmax_relook_softmax',
                    'softmax_relook_softmax+',
                    'softmax_relook_softmax+_res',
                    'softmax_relook_softmax_cls',
                    'softmax_relook_softmax+_cls',
                    'softmax_relook_softmax_tt',
                    'softmax_relook_softmax+_tt',
                    'softmax_relook_softmax+_tt2',
                    'softmax_relook_softmax_cls_tt',
                    'softmax_relook_softmax+_cls_tt',
                    'softmax_relook_softmax_bbox',
                    'softmax_relook_softmax+_bbox',
            ]:
                self.mask = L.Convolution2D(
                    in_channels=256,
                    out_channels=n_fg_class * 3,
                    ksize=1,
                    initialW=mask_initialW,
                )
                if '_cls' in self.mask_loss:
                    self.conv5 = L.Convolution2D(
                        in_channels=n_fg_class * 3 + 1024,
                        out_channels=1024,
                        ksize=3,
                        pad=1,
                        initialW=mask_initialW,
                    )
                elif '_res' in self.mask_loss:
                    self.conv5 = L.Convolution2D(
                        in_channels=3,
                        out_channels=1024,
                        ksize=3,
                        pad=1,
                        initialW=mask_initialW,
                    )
                else:
                    self.conv5 = L.Convolution2D(
                        in_channels=3 + 1024,
                        out_channels=1024,
                        ksize=3,
                        pad=1,
                        initialW=mask_initialW,
                    )
                self.mask2 = L.Convolution2D(
                    in_channels=256,
                    out_channels=3,
                    ksize=1,
                    initialW=mask_initialW,
                )
            else:
                raise ValueError

        self.n_class = n_class
        self.roi_size = roi_size
        self.spatial_scale = spatial_scale
        self.pooling_func = pooling_func

        _convert_bn_to_affine(self)
        self._copy_imagenet_pretrained_resnet(n_layers)

    def _copy_imagenet_pretrained_resnet(self, n_layers):
        if n_layers == 50:
            pretrained_model = ResNet50Extractor(pretrained_model='auto')
        elif n_layers == 101:
            pretrained_model = ResNet101Extractor(pretrained_model='auto')
        else:
            raise ValueError
        self.res5.copyparams(pretrained_model.res5)
        _copy_persistent_chain(self.res5, pretrained_model.res5)

    def __call__(self,
                 x,
                 rois,
                 roi_indices,
                 pred_bbox=True,
                 pred_mask=True,
                 pred_bbox2=False,
                 pred_mask2=True,
                 labels=None):
        roi_indices = roi_indices.astype(np.float32)
        indices_and_rois = self.xp.concatenate((roi_indices[:, None], rois),
                                               axis=1)
        pool = _roi_pooling_2d_yx(x, indices_and_rois, self.roi_size,
                                  self.roi_size, self.spatial_scale,
                                  self.pooling_func)

        with chainer.using_config('train', False):
            res5 = self.res5(pool)

        roi_cls_locs = None
        roi_scores = None
        roi_masks = None

        if pred_bbox:
            pool5 = F.average_pooling_2d(res5, 7, stride=7)
            roi_cls_locs = self.cls_loc(pool5)
            roi_scores = self.score(pool5)

        if pred_mask:
            deconv6 = F.relu(self.deconv6(res5))
            if self.mask_loss in ['softmax', 'softmax_x2']:
                roi_masks = self.mask(deconv6)
            elif self.mask_loss == 'sigmoid_softmax':
                roi_masks = self.mask(deconv6)
                roi_masks_bginv = self.mask_bginv(deconv6)
                roi_masks = (roi_masks, roi_masks_bginv)
            elif self.mask_loss == 'sigmoid_sigmoid':
                roi_masks = self.mask(deconv6)
                roi_masks_visinv = self.mask_visinv(deconv6)
                roi_masks = (roi_masks, roi_masks_visinv)
            elif self.mask_loss == 'sigmoid_sigmoid+':
                roi_masks = self.mask(deconv6)
                roi_masks_visinv = self.mask_visinv(deconv6)
                roi_masks = (roi_masks, roi_masks + roi_masks_visinv)
            elif self.mask_loss in [
                    'softmax_relook_softmax',
                    'softmax_relook_softmax+',
                    'softmax_relook_softmax+_res',
                    'softmax_relook_softmax_cls',
                    'softmax_relook_softmax+_cls',
                    'softmax_relook_softmax_tt',
                    'softmax_relook_softmax+_tt',
                    'softmax_relook_softmax+_tt2',
                    'softmax_relook_softmax_cls_tt',
                    'softmax_relook_softmax+_cls_tt',
                    'softmax_relook_softmax_bbox',
                    'softmax_relook_softmax+_bbox',
            ]:
                assert labels is not None

                # roi_masks: (n_roi, n_fg_class, 14, 14) -> (n_roi, 14, 14)
                # print('deconv6', deconv6.shape)
                roi_masks = self.mask(deconv6)
                # print('roi_masks', roi_masks.shape)

                n_roi = rois.shape[0]
                # print('labels', labels.shape)

                n_positive = int((labels > 0).sum())
                # print('n_positive', n_positive)

                labels = labels[:n_positive]
                rois = rois[:n_positive]
                # indices_and_rois = indices_and_rois[:n_positive]
                # print('labels', labels.shape)
                # print('rois', rois.shape)

                roi_masks = F.reshape(
                    roi_masks,
                    (n_roi, -1, 3, roi_masks.shape[2], roi_masks.shape[3]))
                # print('roi_masks', roi_masks.shape)
                roi_masks = roi_masks[np.arange(n_positive), labels - 1]
                assert (labels == 0).sum() == 0
                # print('roi_masks', roi_masks.shape)

                if '_cls' in self.mask_loss:
                    whole_masks = roi_mask_to_whole_mask(
                        F.softmax(roi_masks).array,
                        rois,
                        x.shape[2:4],
                        self.spatial_scale,
                        fg_labels=labels - 1,
                        n_fg_class=self.n_class - 1)
                else:
                    whole_masks = roi_mask_to_whole_mask(
                        F.softmax(roi_masks).array, rois, x.shape[2:4],
                        self.spatial_scale)
                # print('whole_masks', whole_masks.shape)
                whole_masks = F.reshape(
                    whole_masks,
                    (1, -1, whole_masks.shape[2], whole_masks.shape[3]))
                # print('whole_masks', whole_masks.shape)

                if '_res' in self.mask_loss:
                    h = self.conv5(whole_masks)
                    h = F.relu(h + x)
                else:
                    h = F.concat([whole_masks, x], axis=1)
                    # print('h', h.shape)

                    h = F.relu(self.conv5(h))  # 1/16, whole
                    # print('h', h.shape)

                h = _roi_pooling_2d_yx(h, indices_and_rois, self.roi_size,
                                       self.roi_size, self.spatial_scale,
                                       self.pooling_func)
                # print('h', h.shape)  # 1/16, roi

                with chainer.using_config('train', False):
                    res5 = self.res5(h)
                # print('h', h.shape)  # 1/16, roi

                if pred_bbox2:
                    pool5 = F.average_pooling_2d(res5, 7, stride=7)
                    roi_cls_locs2 = self.cls_loc(pool5)
                    roi_scores2 = self.score(pool5)
                    roi_cls_locs = (roi_cls_locs, roi_cls_locs2)
                    roi_scores = (roi_scores, roi_scores2)

                roi_masks2 = None
                if pred_mask2:
                    h = F.relu(self.deconv6(res5))
                    h = h[:n_positive, :, :, :]
                    # print('h', h.shape)  # 1/8, roi

                    roi_masks2 = self.mask2(h)  # 1/8, roi
                    # print('roi_masks2', roi_masks2.shape)

                roi_masks = (roi_masks, roi_masks2)
            else:
                raise ValueError

        return roi_cls_locs, roi_scores, roi_masks