def extract_feats(self, img, img_metas): """Extract features for `img` during testing. Args: img (Tensor): of shape (1, C, H, W) encoding input image. Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image information dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmtrack/datasets/pipelines/formatting.py:VideoCollect`. Returns: list[Tensor]: Multi level feature maps of `img`. """ key_frame_interval = self.test_cfg.get('key_frame_interval', 10) frame_id = img_metas[0].get('frame_id', -1) assert frame_id >= 0 is_key_frame = False if frame_id % key_frame_interval else True if is_key_frame: self.memo = Dict() self.memo.img = img x = self.detector.extract_feat(img) self.memo.feats = x else: flow_img = torch.cat((img, self.memo.img), dim=1) flow = self.motion(flow_img, img_metas) x = [] for i in range(len(self.memo.feats)): x_single = flow_warp_feats(self.memo.feats[i], flow) x.append(x_single) return x
def forward_train(self, img, img_metas, gt_bboxes, gt_labels, ref_img, ref_img_metas, ref_gt_bboxes, ref_gt_labels, gt_instance_ids=None, gt_bboxes_ignore=None, gt_masks=None, proposals=None, ref_gt_instance_ids=None, ref_gt_bboxes_ignore=None, ref_gt_masks=None, ref_proposals=None, **kwargs): """ Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmtrack/datasets/pipelines/formatting.py:VideoCollect`. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box. ref_img (Tensor): of shape (N, 1, C, H, W) encoding input images. Typically these should be mean centered and std scaled. 1 denotes there is only one reference image for each input image. ref_img_metas (list[list[dict]]): The first list only has one element. The second list contains reference image information dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmtrack/datasets/pipelines/formatting.py:VideoCollect`. ref_gt_bboxes (list[Tensor]): The list only has one Tensor. The Tensor contains ground truth bboxes for each reference image with shape (num_all_ref_gts, 5) in [ref_img_id, tl_x, tl_y, br_x, br_y] format. The ref_img_id start from 0, and denotes the id of reference image for each key image. ref_gt_labels (list[Tensor]): The list only has one Tensor. The Tensor contains class indices corresponding to each reference box with shape (num_all_ref_gts, 2) in [ref_img_id, class_indice]. gt_instance_ids (None | list[Tensor]): specify the instance id for each ground truth bbox. gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. gt_masks (None | Tensor) : true segmentation masks for each box used if the architecture supports a segmentation task. proposals (None | Tensor) : override rpn proposals with custom proposals. Use when `with_rpn` is False. ref_gt_instance_ids (None | list[Tensor]): specify the instance id for each ground truth bboxes of reference images. ref_gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes of reference images can be ignored when computing the loss. ref_gt_masks (None | Tensor) : True segmentation masks for each box of reference image used if the architecture supports a segmentation task. ref_proposals (None | Tensor) : override rpn proposals with custom proposals of reference images. Use when `with_rpn` is False. Returns: dict[str, Tensor]: a dictionary of loss components """ assert len(img) == 1, \ 'Dff video detectors only support 1 batch size per gpu for now.' is_video_data = img_metas[0]['is_video_data'] flow_img = torch.cat((img, ref_img[:, 0]), dim=1) flow = self.motion(flow_img, img_metas) ref_x = self.detector.extract_feat(ref_img[:, 0]) x = [] for i in range(len(ref_x)): x_single = flow_warp_feats(ref_x[i], flow) if not is_video_data: x_single = 0 * x_single + ref_x[i] x.append(x_single) losses = dict() # Two stage detector if hasattr(self.detector, 'roi_head'): # RPN forward and loss if self.detector.with_rpn: proposal_cfg = self.detector.train_cfg.get( 'rpn_proposal', self.detector.test_cfg.rpn) rpn_losses, proposal_list = \ self.detector.rpn_head.forward_train( x, img_metas, gt_bboxes, gt_labels=None, gt_bboxes_ignore=gt_bboxes_ignore, proposal_cfg=proposal_cfg) losses.update(rpn_losses) else: proposal_list = proposals roi_losses = self.detector.roi_head.forward_train( x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore, gt_masks, **kwargs) losses.update(roi_losses) # Single stage detector elif hasattr(self.detector, 'bbox_head'): bbox_losses = self.detector.bbox_head.forward_train( x, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore) losses.update(bbox_losses) else: raise TypeError('detector must has roi_head or bbox_head.') return losses