コード例 #1
0
ファイル: fast_rcnn.py プロジェクト: coldmanck/RVL-BERT
    def forward(self, images, boxes, box_mask, im_info, classes=None, segms=None, mvrc_ops=None, mask_visual_embed=None, copy_images=False, union_boxes=None, rels_cand=None):
        """
        :param images: [batch_size, 3, im_height, im_width]
        :param boxes: [batch_size, max_num_objects, 4] Padded boxes
        :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK
        :return: object reps [batch_size, max_num_objects, dim]
        """

        box_inds = box_mask.nonzero()
        obj_labels = classes[box_inds[:, 0], box_inds[:, 1]].type(torch.long) if classes is not None else None
        assert box_inds.shape[0] > 0

        try:
            if copy_images:
                images = images[0].unsqueeze(0)
            img_feats = self.backbone(images)
            if copy_images:
                img_feats['body4'] = torch.cat([img_feats['body4']] * box_inds.shape[0])
        except:
            import pdb; pdb.set_trace()
        
        if union_boxes is not None: # object pairing after roi pooling
            n_boxes, n_union_boxes = len(boxes), len(union_boxes) # n_boxes includes the first one as full image bbox
            rois = torch.cat((boxes, union_boxes))
            rois = torch.cat((
                torch.zeros_like(rois)[:, 0].view(-1, 1),
                rois
            ), 1)
        else:
            rois = torch.cat((
                box_inds[:, 0, None].type(boxes.dtype),
                boxes[box_inds[:, 0], box_inds[:, 1]],
            ), 1)
        
        roi_align_res = self.roi_align(img_feats['body4'], rois).type(images.dtype)

        if segms is not None:
            pool_layers = self.head[1:]
            post_roialign = self.roi_head_feature_extractor(roi_align_res)
            post_roialign = post_roialign * segms[box_inds[:, 0], None, box_inds[:, 1]].to(dtype=post_roialign.dtype)
            for _layer in pool_layers:
                post_roialign = _layer(post_roialign)
        else:
            # post_roialign = self.head(roi_align_res)
            if self.config.TRAIN.DEBUG and union_boxes is not None:
                print(f'cur max_nb_boxes: {self.max_nb_boxes}, n_boxes: {n_boxes}, n_union_boxes: {n_union_boxes}, rois.shape: {rois.shape}')
                if self.max_nb_boxes < n_boxes:
                    print(f'Update self.max_nb_boxes from {self.max_nb_boxes} to {n_boxes}')
                    self.max_nb_boxes = n_boxes
            try:
                post_roialign = self.head[0](roi_align_res)
                post_roialign_raw = post_roialign.clone().detach() # torch.tensor(post_roialign)
                post_roialign = self.head[1:](post_roialign)
            except:
                import pdb; pdb.set_trace()

        if union_boxes is not None:
            '''
            (Pdb) post_roialign.shape
            torch.Size([81, 2048])
            (Pdb) post_roialign_raw.shape
            torch.Size([81, 2048, 14, 14])
            '''
            full_img_feat = post_roialign[0]
            post_roialign_boxes = post_roialign[1:n_boxes]
            post_roialign_union_boxes = post_roialign[n_boxes:]

            full_img_feat_raw = post_roialign_raw[0]
            post_roialign_raw_boxes = post_roialign_raw[1:n_boxes]
            post_roialign_raw_union_boxes = post_roialign_raw[n_boxes:]
            n_boxes -= 1 # get rid of the count of the first full img bbox
            
            device = boxes.device
            try:
                new_boxes = torch.zeros([rels_cand.shape[0], 4, 4], device=device)
                post_roialign = torch.zeros((box_inds.shape[0], 2048), device=device)
                post_roialign_raw = torch.zeros((box_inds.shape[0], 2048, 14, 14), device=device)
            except:
                import pdb; pdb.set_trace()
            
            for i, (sub_id, obj_id) in enumerate(rels_cand):
                post_roialign[i*4] = full_img_feat
                post_roialign[i*4 + 1] = post_roialign_boxes[sub_id]
                post_roialign[i*4 + 2] = post_roialign_union_boxes[i]
                post_roialign[i*4 + 3] = post_roialign_boxes[obj_id]

                post_roialign_raw[i*4] = full_img_feat_raw
                post_roialign_raw[i*4 + 1] = post_roialign_raw_boxes[sub_id]
                post_roialign_raw[i*4 + 2] = post_roialign_raw_union_boxes[i]
                post_roialign_raw[i*4 + 3] = post_roialign_raw_boxes[obj_id]

                new_boxes[i][0] = boxes[0]
                new_boxes[i][1] = boxes[sub_id + 1]
                new_boxes[i][2] = union_boxes[i]
                new_boxes[i][3] = boxes[obj_id + 1]
            
            boxes = new_boxes

        '''
        (Pdb) boxes.shape
        torch.Size([32, 4, 4])
        (Pdb) post_roialign.shape
        torch.Size([128, 2048])
        (Pdb) post_roialign_raw.shape
        torch.Size([128, 2048, 14, 14])
        '''
        if self.config.TRAIN.DEBUG:
            pass # import pdb; pdb.set_trace() # pass

        # Add some regularization, encouraging the model to keep giving decent enough predictions
        if self.enable_cnn_reg_loss: # False
                obj_logits = self.regularizing_predictor(post_roialign)
                cnn_regularization = F.cross_entropy(obj_logits, obj_labels)[None]

        feats_to_downsample = post_roialign if (self.object_embed is None or obj_labels is None) else \
            torch.cat((post_roialign, self.object_embed(obj_labels)), -1)
        if mvrc_ops is not None and mask_visual_embed is not None: # False
            _to_masked = (mvrc_ops == 1)[box_inds[:, 0], box_inds[:, 1]]
            feats_to_downsample[_to_masked] = mask_visual_embed
        try:
            coord_embed = coordinate_embeddings(
                torch.cat((boxes[box_inds[:, 0], box_inds[:, 1]], im_info[box_inds[:, 0], :2]), 1),
                256
            )
        except:
            import pdb; pdb.set_trace()
        
        feats_to_downsample = torch.cat((coord_embed.view((coord_embed.shape[0], -1)), feats_to_downsample), -1)
        final_feats = self.obj_downsample(feats_to_downsample)

        # Reshape into a padded sequence - this is expensive and annoying but easier to implement and debug...
        if union_boxes is None:
            obj_reps = pad_sequence(final_feats, box_mask.sum(1).tolist())
            post_roialign = pad_sequence(post_roialign, box_mask.sum(1).tolist())
            post_roialign_raw = pad_sequence(post_roialign_raw, box_mask.sum(1).tolist())
        else:
            obj_reps = final_feats.view(-1, 4, final_feats.shape[1])
            post_roialign = post_roialign.view(-1, 4, post_roialign.shape[1])
            post_roialign_raw = post_roialign_raw.view(-1, 4, post_roialign_raw.shape[1])

        # DataParallel compatibility
        if union_boxes is None: # off for VG exps
            obj_reps_padded = obj_reps.new_zeros((obj_reps.shape[0], boxes.shape[1], obj_reps.shape[2]))
            obj_reps_padded[:, :obj_reps.shape[1]] = obj_reps
            obj_reps = obj_reps_padded

            post_roialign_padded = post_roialign.new_zeros((post_roialign.shape[0], boxes.shape[1], post_roialign.shape[2]))
            post_roialign_padded[:, :post_roialign.shape[1]] = post_roialign
            post_roialign = post_roialign_padded

            post_roialign_raw_padded = post_roialign.new_zeros((post_roialign_raw.shape[0], boxes.shape[1], post_roialign_raw.shape[2], post_roialign_raw.shape[3], post_roialign_raw.shape[4]))
            post_roialign_raw_padded[:, :post_roialign_raw.shape[1]] = post_roialign_raw
            post_roialign_raw = post_roialign_raw_padded

        # Output
        output_dict = {
            'obj_reps_raw': post_roialign,
            'obj_reps': obj_reps,
            'obj_reps_rawraw': post_roialign_raw,
        }

        if (not self.image_feat_precomputed) and self.enable_cnn_reg_loss:
            output_dict.update({'obj_logits': obj_logits,
                                'obj_labels': obj_labels,
                                'cnn_regularization_loss': cnn_regularization})

        if (not self.image_feat_precomputed) and self.output_conv5:
            image_feature = self.img_head(img_feats['body4'])
            output_dict['image_feature'] = image_feature

        if union_boxes is not None:
            return output_dict, boxes
        else:
            return output_dict
コード例 #2
0
    def forward(self,
                images,
                boxes,
                box_mask,
                im_info,
                classes=None,
                segms=None,
                mvrc_ops=None,
                mask_visual_embed=None):
        """
        :param images: [batch_size, 3, im_height, im_width]
        :param boxes: [batch_size, max_num_objects, 4] Padded boxes
        :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK
        :return: object reps [batch_size, max_num_objects, dim]
        """

        box_inds = box_mask.nonzero()
        obj_labels = classes[box_inds[:, 0], box_inds[:, 1]].type(
            torch.long) if classes is not None else None
        assert box_inds.shape[0] > 0

        if self.image_feat_precomputed:
            post_roialign = boxes[box_inds[:, 0], box_inds[:, 1]][:, 4:]
            boxes = boxes[:, :, :4]
        else:
            img_feats = self.backbone(images)
            rois = torch.cat([
                box_inds[:, 0, None].type(boxes.dtype),
                boxes[box_inds[:, 0], box_inds[:, 1]],
            ], 1)
            roi_align_res = self.roi_align(img_feats['body4'],
                                           rois).type(images.dtype)

            if segms is not None:
                pool_layers = self.head[1:]
                post_roialign = self.roi_head_feature_extractor(roi_align_res)
                post_roialign = post_roialign * segms[
                    box_inds[:, 0], None,
                    box_inds[:, 1]].to(dtype=post_roialign.dtype)
                for _layer in pool_layers:
                    post_roialign = _layer(post_roialign)
            else:
                post_roialign = self.head(roi_align_res)

            # Add some regularization, encouraging the model to keep giving decent enough predictions
            if self.enable_cnn_reg_loss:
                obj_logits = self.regularizing_predictor(post_roialign)
                cnn_regularization = F.cross_entropy(obj_logits,
                                                     obj_labels)[None]
        # import pdb; pdb.set_trace()
        feats_to_downsample = post_roialign if (self.object_embed is None or obj_labels is None) else \
            torch.cat((post_roialign, self.object_embed(obj_labels)), -1)
        if mvrc_ops is not None and mask_visual_embed is not None:
            _to_masked = (mvrc_ops == 1)[box_inds[:, 0], box_inds[:, 1]]
            import pdb
            pdb.set_trace()
            feats_to_downsample[_to_masked] = mask_visual_embed
        coord_embed = coordinate_embeddings(
            torch.cat((boxes[box_inds[:, 0],
                             box_inds[:, 1]], im_info[box_inds[:, 0], :2]), 1),
            256)
        feats_to_downsample = torch.cat((coord_embed.view(
            (coord_embed.shape[0], -1)), feats_to_downsample), -1)
        final_feats = self.obj_downsample(feats_to_downsample)

        # Reshape into a padded sequence - this is expensive and annoying but easier to implement and debug...
        obj_reps = pad_sequence(final_feats, box_mask.sum(1).tolist())
        post_roialign = pad_sequence(post_roialign, box_mask.sum(1).tolist())

        # DataParallel compatibility
        obj_reps_padded = obj_reps.new_zeros(
            (obj_reps.shape[0], boxes.shape[1], obj_reps.shape[2]))
        obj_reps_padded[:, :obj_reps.shape[1]] = obj_reps
        obj_reps = obj_reps_padded
        post_roialign_padded = post_roialign.new_zeros(
            (post_roialign.shape[0], boxes.shape[1], post_roialign.shape[2]))
        post_roialign_padded[:, :post_roialign.shape[1]] = post_roialign
        post_roialign = post_roialign_padded

        # Output
        output_dict = {
            'obj_reps_raw': post_roialign,
            'obj_reps': obj_reps,
        }
        if (not self.image_feat_precomputed) and self.enable_cnn_reg_loss:
            output_dict.update({
                'obj_logits': obj_logits,
                'obj_labels': obj_labels,
                'cnn_regularization_loss': cnn_regularization
            })

        if (not self.image_feat_precomputed) and self.output_conv5:
            image_feature = self.img_head(img_feats['body4'])
            output_dict['image_feature'] = image_feature

        return output_dict