예제 #1
0
class FasterRCNN_Encoder(nn.Module):
    def __init__(self, out_dim=None, fine_tune=False):
        super(FasterRCNN_Encoder, self).__init__()
        backbone = resnet_fpn_backbone('resnet50', False)
        self.faster_rcnn = FasterRCNN(backbone,
                                      num_classes=91,
                                      rpn_post_nms_top_n_train=200,
                                      rpn_post_nms_top_n_test=100)
        state_dict = load_state_dict_from_url(
            model_urls['fasterrcnn_resnet50_fpn_coco'], progress=True)
        self.faster_rcnn.load_state_dict(state_dict)

        # modify the last linear layer of the ROI pooling if there is
        # a special requirement of output size
        if out_dim is not None:
            self.faster_rcnn.roi_heads.box_head.fc7 = nn.Linear(
                in_features=1024, out_features=out_dim)

        # in captioning task, we may not want fine-tune faster-rcnn model
        if not fine_tune:
            for param in self.faster_rcnn.parameters():
                param.requires_grad = False

    def forward(self, images, targets=None):
        '''
        Forward propagation of faster-rcnn encoder
        Args:
            images: List[Tensor], a list of image data
            targets: List[Tensor], a list of ground-truth bounding box data,
                     used only in fine-tune
        Returns:
            proposal features after ROI pooling and RPN loss
        '''
        images, targets = self.faster_rcnn.transform(images, targets)
        # the base features produced by backbone network, i.e. resnet50
        features = self.faster_rcnn.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
            features = OrderedDict([(0, features)])
        # proposals produced by RPN, i.e. the coordinates of bounding box
        # which contain foreground objects
        proposals, proposal_losses = self.faster_rcnn.rpn(
            images, features, targets)
        # get the corresponding features of the proposals produced by RPN and perform roi pooling
        box_features = self.faster_rcnn.roi_heads.box_roi_pool(
            features, proposals, images.image_sizes)
        # project the features to shape (batch_size, num_boxes, feature_dim)
        box_features = self.faster_rcnn.roi_heads.box_head(box_features)
        return box_features, proposal_losses
예제 #2
0
class RelModelBase(nn.Module):
    """
    RELATIONSHIPS
    """
    def __init__(self,
                 train_data,
                 mode='sgcls',
                 require_overlap_det=True,
                 use_bias=False,
                 test_bias=False,
                 backbone='vgg16',
                 RELS_PER_IMG=1024,
                 min_size=None,
                 max_size=None,
                 edge_model='motifs'):
        """
        Base class for an SGG model
        :param mode: (sgcls, predcls, or sgdet)
        :param require_overlap_det: Whether two objects must intersect
        """
        super(RelModelBase, self).__init__()
        self.classes = train_data.ind_to_classes
        self.rel_classes = train_data.ind_to_predicates
        self.mode = mode
        self.backbone = backbone
        self.RELS_PER_IMG = RELS_PER_IMG
        self.pool_sz = 7
        self.stride = 16

        self.use_bias = use_bias
        self.test_bias = test_bias

        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        if self.backbone == 'resnet50':
            self.obj_dim = 1024
            self.fmap_sz = 21

            if min_size is None:
                min_size = 1333
            if max_size is None:
                max_size = 1333

            print('\nLoading COCO pretrained model maskrcnn_resnet50_fpn...\n')
            # See https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
            self.detector = torchvision.models.detection.maskrcnn_resnet50_fpn(
                pretrained=True,
                min_size=min_size,
                max_size=max_size,
                box_detections_per_img=50,
                box_score_thresh=0.2)
            in_features = self.detector.roi_heads.box_predictor.cls_score.in_features
            # replace the pre-trained head with a new one
            self.detector.roi_heads.box_predictor = FastRCNNPredictor(
                in_features, len(self.classes))
            self.detector.roi_heads.mask_predictor = None

            layers = list(self.detector.roi_heads.children())[:2]
            self.roi_fmap_obj = copy.deepcopy(layers[1])
            self.roi_fmap = copy.deepcopy(layers[1])
            self.roi_pool = copy.deepcopy(layers[0])

        elif self.backbone == 'vgg16':
            self.obj_dim = 4096
            self.fmap_sz = 38

            if min_size is None:
                min_size = IM_SCALE
            if max_size is None:
                max_size = IM_SCALE

            vgg = load_vgg(use_dropout=False,
                           use_relu=False,
                           use_linear=True,
                           pretrained=False)
            vgg.features.out_channels = 512
            anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256,
                                                       512), ),
                                               aspect_ratios=((0.5, 1.0,
                                                               2.0), ))

            roi_pooler = torchvision.ops.MultiScaleRoIAlign(
                featmap_names=['0'],
                output_size=self.pool_sz,
                sampling_ratio=2)

            self.detector = FasterRCNN(vgg.features,
                                       min_size=min_size,
                                       max_size=max_size,
                                       rpn_anchor_generator=anchor_generator,
                                       box_head=TwoMLPHead(
                                           vgg.features.out_channels *
                                           self.pool_sz**2, self.obj_dim),
                                       box_predictor=FastRCNNPredictor(
                                           self.obj_dim,
                                           len(train_data.ind_to_classes)),
                                       box_roi_pool=roi_pooler,
                                       box_detections_per_img=50,
                                       box_score_thresh=0.2)

            self.roi_fmap = nn.Sequential(nn.Flatten(), vgg.classifier)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier
            self.roi_pool = copy.deepcopy(
                list(self.detector.roi_heads.children())[0])

        else:
            raise NotImplementedError(self.backbone)

        self.edge_dim = self.detector.backbone.out_channels

        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pool_sz,
                                              stride=self.stride,
                                              dim=self.edge_dim,
                                              edge_model=edge_model)
        if self.use_bias:
            self.freq_bias = FrequencyBias(train_data)

    @property
    def num_classes(self):
        return len(self.classes)

    @property
    def num_rels(self):
        return len(self.rel_classes)

    def predict(self, node_feat, edge_feat, rel_inds, rois, im_sizes):
        raise NotImplementedError('predict')

    def forward(self, batch):
        raise NotImplementedError('forward')

    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        # Get the relationship candidates
        if self.training:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                # amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat(
                (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)

        return rel_inds

    def set_box_score_thresh(self, box_score_thresh):
        if self.backbone != 'vgg16_old':
            self.detector.roi_heads.score_thresh = box_score_thresh
        else:
            self.detector.thresh = box_score_thresh

    def faster_rcnn(self, x, gt_boxes, gt_classes, gt_rels):
        targets, x_lst, original_image_sizes = [], [], []
        device = self.rel_fc.weight.get_device(
        ) if self.rel_fc.weight.is_cuda else 'cpu'
        for i, s, e in enumerate_by_image(gt_classes[:, 0].long().data):
            targets.append({
                'boxes': copy.deepcopy(gt_boxes[s:e]),
                'labels': gt_classes[s:e, 1].long()
            })
            x_lst.append(x[i].to(device).squeeze())
            original_image_sizes.append(x[i].shape[-2:])

        images, targets = self.detector.transform(x_lst, targets)
        fmap_multiscale = self.detector.backbone(images.tensors)
        if isinstance(fmap_multiscale, torch.Tensor):
            fmap_multiscale = OrderedDict([('0', fmap_multiscale)])

        if self.mode != 'sgdet':
            rois, obj_labels, rel_labels = self.gt_labels(
                gt_boxes, gt_classes, gt_rels)
            rm_box_priors, rm_box_priors_org = [], []
            for i, s, e in enumerate_by_image(gt_classes[:, 0].long().data):
                rm_box_priors.append(targets[i]['boxes'])
                rm_box_priors_org.append(gt_boxes[s:e])

            im_inds = rois[:, 0]
            result = Result(
                od_box_targets=None,
                rm_box_targets=None,
                od_obj_labels=obj_labels,
                rm_box_priors=torch.cat(rm_box_priors),
                rm_obj_labels=obj_labels,
                rpn_scores=None,
                rpn_box_deltas=None,
                rel_labels=rel_labels,
                im_inds=im_inds.long(),
            )
            result.rm_box_priors_org = torch.cat(rm_box_priors_org)

        else:
            proposals, _ = self.detector.rpn(images, fmap_multiscale, targets)
            detections, _ = self.detector.roi_heads(fmap_multiscale, proposals,
                                                    images.image_sizes,
                                                    targets)
            boxes = copy.deepcopy(detections)
            boxes_all_dict = self.detector.transform.postprocess(
                detections, images.image_sizes, original_image_sizes)
            rm_box_priors, rm_box_priors_org, im_inds, obj_labels = [], [], [], []
            for i in range(len(proposals)):
                if len(boxes[i]['boxes']) <= 1:
                    raise ValueError(
                        'at least two objects must be detected to build relationships, make sure the detector is properly pretrained',
                        boxes)
                rm_box_priors.append(boxes[i]['boxes'])
                rm_box_priors_org.append(boxes_all_dict[i]['boxes'])
                obj_labels.append(boxes_all_dict[i]['labels'])
                im_inds.append(torch.zeros(len(detections[i]['boxes'])) + i)

            im_inds = torch.cat(im_inds).to(device)
            result = Result(rm_obj_labels=torch.cat(obj_labels).view(-1),
                            rm_box_priors=torch.cat(rm_box_priors),
                            rel_labels=None,
                            im_inds=im_inds.long())
            result.rm_box_priors_org = torch.cat(rm_box_priors_org)

            if len(result.rm_box_priors) <= 1:
                raise ValueError(
                    'at least two objects must be detected to build relationships'
                )

        result.im_sizes_org = original_image_sizes
        result.im_sizes = images.image_sizes
        result.fmap = fmap_multiscale[list(
            fmap_multiscale.keys())[-1]]  # last scale for global feature maps
        result.rois = torch.cat(
            (im_inds.float()[:, None], result.rm_box_priors), 1)

        return result

    def node_edge_features(self, fmap, rois, union_inds, im_sizes):

        assert union_inds.shape[1] == 2, union_inds.shape
        union_rois = torch.cat((rois[:, 0][union_inds[:, 0]][:, None],
                                torch.min(rois[:, 1:3][union_inds[:, 0]],
                                          rois[:, 1:3][union_inds[:, 1]]),
                                torch.max(rois[:, 3:5][union_inds[:, 0]],
                                          rois[:, 3:5][union_inds[:, 1]])), 1)

        if self.backbone == 'vgg16_old':
            return get_node_edge_features(fmap,
                                          rois,
                                          union_rois=union_rois,
                                          pooling_size=self.pool_sz,
                                          stride=self.stride)
        else:
            if isinstance(fmap, torch.Tensor):
                fmap = OrderedDict([('0', fmap)])
            node_feat = self.roi_pool(fmap, convert_roi_to_list(rois),
                                      im_sizes)  # images.image_sizes
            edge_feat = self.roi_pool(fmap, convert_roi_to_list(union_rois),
                                      im_sizes)
            return node_feat, edge_feat

    def get_scaled_boxes(self, boxes, im_inds, im_sizes):
        if self.backbone == 'vgg16_old':
            boxes_scaled = boxes / IM_SCALE
        else:
            boxes_scaled = boxes.clone()
            for im_ind, s, e in enumerate_by_image(im_inds.long().data):
                boxes_scaled[s:e, [0, 2]] = boxes_scaled[
                    s:e, [0, 2]] / im_sizes[im_ind][1]  # width
                boxes_scaled[s:e, [1, 3]] = boxes_scaled[
                    s:e, [1, 3]] / im_sizes[im_ind][0]  # height

        assert boxes_scaled.max() <= 1 + 1e-3, (boxes_scaled.max(),
                                                boxes.max(), im_sizes)

        return boxes_scaled

    def gt_labels(self, gt_boxes, gt_classes, gt_rels=None, sample_factor=-1):
        """
        Gets GT boxes!
        :param fmap:
        :param im_sizes:
        :param image_offset:
        :param gt_boxes:
        :param gt_classes:
        :param gt_rels:
        :param train_anchor_inds:
        :return:
        """
        assert gt_boxes is not None
        im_inds = gt_classes[:, 0]
        rois = torch.cat((im_inds.float()[:, None], gt_boxes), 1)
        if gt_rels is not None and self.training:
            rois, obj_labels, rel_labels = proposal_assignments_gtbox(
                rois.data,
                gt_boxes.data,
                gt_classes.data,
                gt_rels.data,
                0,
                self.RELS_PER_IMG,
                sample_factor=sample_factor)
        else:
            obj_labels = gt_classes[:, 1]
            rel_labels = None

        return rois, obj_labels, rel_labels