Пример #1
0
    def _obj_det_loss_comp(self, cls_score, cls_prob, bbox_pred, rois_label,
                           rois_target, rois_inside_ws, rois_outside_ws):
        # classification loss
        if cfg.TRAIN.COMMON.USE_FOCAL_LOSS:
            RCNN_loss_cls = F.cross_entropy(cls_score,
                                            rois_label,
                                            reduce=False)
            focal_loss_factor = torch.pow(
                (1 - cls_prob[range(int(cls_prob.size(0))), rois_label]),
                cfg.TRAIN.COMMON.FOCAL_LOSS_GAMMA)
            RCNN_loss_cls = torch.mean(RCNN_loss_cls * focal_loss_factor)
        else:
            RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)

        if not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0),
                                            int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(
                bbox_pred_view, 1,
                rois_label.view(rois_label.size(0), 1,
                                1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)
        # bounding box regression L1 loss
        RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target,
                                         rois_inside_ws, rois_outside_ws)
        return RCNN_loss_cls, RCNN_loss_bbox
Пример #2
0
    def forward(self, base_feat, im_info, gt_boxes):
        batch_size = base_feat.size(0)

        # return feature map after convrelu layer
        rpn_conv1 = F.relu(self.RPN_Conv(base_feat), inplace=True)
        # get rpn classification score
        rpn_cls_score = self.RPN_cls_score(rpn_conv1)

        rpn_cls_score_reshape = self.reshape(rpn_cls_score, 2)
        rpn_cls_prob_reshape = F.softmax(rpn_cls_score_reshape, 1)
        rpn_cls_prob = self.reshape(rpn_cls_prob_reshape, self.nc_score_out)

        # get rpn offsets to the anchor boxes
        rpn_bbox_pred = self.RPN_bbox_pred(rpn_conv1)

        rois = self.RPN_proposal(
            (rpn_cls_prob.data, rpn_bbox_pred.data, im_info, self.is_training))
        self.rpn_loss_cls = 0
        self.rpn_loss_box = 0

        if self.training:
            assert gt_boxes is not None

            rpn_data = self.RPN_anchor_target(
                (rpn_cls_score.data, gt_boxes, im_info))

            # compute classification loss
            rpn_cls_score = rpn_cls_score_reshape.permute(
                0, 2, 3, 1).contiguous().view(batch_size, -1, 2)
            rpn_label = rpn_data[0].view(batch_size, -1)

            rpn_keep = Variable(rpn_label.view(-1).ne(-1).nonzero().view(-1))
            rpn_cls_score = torch.index_select(rpn_cls_score.view(-1, 2), 0,
                                               rpn_keep)
            rpn_label = torch.index_select(rpn_label.view(-1), 0,
                                           rpn_keep.data)
            rpn_label = Variable(rpn_label.long())
            self.rpn_loss_cls = F.cross_entropy(rpn_cls_score, rpn_label)
            #fg_cnt = torch.sum(rpn_label.data.ne(0))

            rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[
                1:]

            rpn_bbox_inside_weights = Variable(rpn_bbox_inside_weights)
            rpn_bbox_outside_weights = Variable(rpn_bbox_outside_weights)
            rpn_bbox_targets = Variable(rpn_bbox_targets)

            self.rpn_loss_box = _smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights,\
                                                rpn_bbox_outside_weights, sigma=3, dim=[1,2,3])

        return rois, self.rpn_loss_cls, self.rpn_loss_box
Пример #3
0
    def forward_rcnn_batch(self, base_feat, branch, rois, wgt_boxes, wnum_boxes, gt_boxes, num_boxes, im_info, image_classes, output_refine=False):
        batch_size = base_feat.size(0)

        # if it is training phrase, then use ground truth bboxes for refining
        if self.training:
            roi_data = self.RCNN_proposal_target(
                rois, wgt_boxes, wnum_boxes, gt_boxes, num_boxes)
            out_rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data

            rois_label = Variable(rois_label.view(-1).long())
            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(
                rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(
                rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:
            out_rois = rois
            rois_label = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None

        out_rois = Variable(out_rois)

        # do roi pooling based on predicted rois
        if cfg.POOLING_MODE == 'crop':
            # pdb.set_trace()
            # pooled_feat_anchor = _crop_pool_layer(base_feat, rois.view(-1, 5))
            grid_xy = _affine_grid_gen(
                out_rois.view(-1, 5), base_feat.size()[2:], self.grid_size)
            grid_yx = torch.stack(
                [grid_xy.data[:, :, :, 1], grid_xy.data[:, :, :, 0]], 3).contiguous()
            pooled_feat = self.RCNN_roi_crop(
                base_feat, Variable(grid_yx).detach())
            if cfg.CROP_RESIZE_WITH_MAX_POOL:
                pooled_feat = F.max_pool2d(pooled_feat, 2, 2)
        elif cfg.POOLING_MODE == 'align':
            pooled_feat = self.RCNN_roi_align(base_feat, out_rois.view(-1, 5))
        elif cfg.POOLING_MODE == 'pool':
            pooled_feat = self.RCNN_roi_pool(base_feat, out_rois.view(-1, 5))

        # feed pooled features to top model
        pooled_feat = self._head_to_tail(pooled_feat, branch)

        # compute bbox offset
        bbox_pred = branch.RCNN_bbox_pred(pooled_feat)
        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(
                bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(bbox_pred_view, 1, rois_label.view(
                rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        # compute object classification probability
        cls_score = branch.RCNN_cls_score(pooled_feat)
        cls_prob = F.softmax(cls_score, 1)

        RCNN_loss_cls = 0
        RCNN_loss_bbox = 0

        if self.training:

            # classification loss
            RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)

            # bounding box regression L1 loss
            RCNN_loss_bbox = _smooth_l1_loss(
                bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)

            # add image-level label regularization
            rois_batch_size = out_rois.size(1)
            rois_prob = F.softmax(cls_score, 1).view(batch_size, rois_batch_size, -1)

            valid_rois_prob = (rois_label > 0).view(batch_size, rois_batch_size, -1).float()
            rois_attention = F.softmax(cls_score, 1).view(batch_size, rois_batch_size, -1)
            rois_attention = rois_attention * valid_rois_prob

            # ignore background
            rois_prob = rois_prob[:, :, 1:]
            rois_attention = rois_attention[:, :, 1:]

            # rois_attention_prob = torch.sum(rois_prob * rois_attention, dim=1) / (torch.sum(rois_attention, dim=1) + 1e-10)
            rois_attention_prob, _ = torch.max(rois_prob, dim=1)
            image_loss_cls = F.binary_cross_entropy(rois_attention_prob, image_classes[:, 1:])
        else:
            image_loss_cls = None

        if self.training:
            cls_prob = cls_prob.view(batch_size, out_rois.size(1), -1)
            bbox_pred = bbox_pred.view(batch_size, out_rois.size(1), -1)
        else:
            cls_prob = cls_prob.view(1, out_rois.size(1), -1)
            bbox_pred = bbox_pred.view(1, out_rois.size(1), -1)

        if self.training and output_refine:
            # get transformation for wgt_boxes
            wgt_rois = wgt_boxes.new(wgt_boxes.size()).zero_()
            wgt_rois[:, :, 1:5] = wgt_boxes[:, :, :4]
            batch_size = base_feat.size(0)
            for i in range(batch_size):
                wgt_rois[:, :, 0] = i

            # do roi pooling based on predicted rois
            if cfg.POOLING_MODE == 'crop':
                # pdb.set_trace()
                # pooled_feat_anchor = _crop_pool_layer(base_feat, rois.view(-1, 5))
                grid_xy = _affine_grid_gen(
                    wgt_rois.view(-1, 5), base_feat.size()[2:], self.grid_size)
                grid_yx = torch.stack(
                    [grid_xy.data[:, :, :, 1], grid_xy.data[:, :, :, 0]], 3).contiguous()
                gt_pooled_feat = self.RCNN_roi_crop(
                    base_feat, Variable(grid_yx).detach())
                if cfg.CROP_RESIZE_WITH_MAX_POOL:
                    gt_pooled_feat = F.max_pool2d(gt_pooled_feat, 2, 2)
            elif cfg.POOLING_MODE == 'align':
                gt_pooled_feat = self.RCNN_roi_align(
                    base_feat, wgt_rois.view(-1, 5))
            elif cfg.POOLING_MODE == 'pool':
                gt_pooled_feat = self.RCNN_roi_pool(
                    base_feat, wgt_rois.view(-1, 5))

            # feed pooled features to top model
            gt_pooled_feat = self._head_to_tail(gt_pooled_feat, branch)

            # compute bbox offset
            wgt_bbox_delta = branch.RCNN_bbox_pred(gt_pooled_feat)
            wgt_bbox_delta = wgt_bbox_delta.view(-1, 4) * torch.FloatTensor(
                cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
            wgt_bbox_delta = wgt_bbox_delta.view(batch_size, -1, 4 * 21)
            wgt_bbox_out_rois = bbox_transform_inv(
                wgt_boxes, wgt_bbox_delta, batch_size)

            wgt_bbox_out_rois = clip_boxes(
                wgt_bbox_out_rois, im_info.data, batch_size)

            wgt_bbox_out = wgt_boxes.new(wgt_boxes.size()).zero_()

            wgt_cls = Variable(
                wgt_boxes[:, :, 4].data, requires_grad=False).long()
            for i in range(batch_size):
                for j in range(20):
                    cls_ind = wgt_cls[i, j]
                    wgt_bbox_out[i, j, :4] = wgt_bbox_out_rois[i,
                                                               j, cls_ind * 4:cls_ind * 4 + 4]

            wgt_bbox_out[:, :, 4] = wgt_boxes[:, :, 4]

            wgt_boxes_x = (wgt_boxes[:, :, 2] - wgt_boxes[:, :, 0] + 1)
            wgt_boxes_y = (wgt_boxes[:, :, 3] - wgt_boxes[:, :, 1] + 1)
            wgt_area_zero = (wgt_boxes_x == 1) & (wgt_boxes_y == 1)
            wgt_bbox_out.masked_fill_(wgt_area_zero.view(
                batch_size, wgt_area_zero.size(1), 1).expand(wgt_boxes.size()), 0)
            wgt_bbox_out = wgt_bbox_out.detach()
        else:
            wgt_bbox_out = None

        return (out_rois, cls_prob, bbox_pred, RCNN_loss_cls, RCNN_loss_bbox, rois_label, image_loss_cls), wgt_bbox_out
Пример #4
0
    def forward(self, im_data, im_info, gt_boxes, num_boxes):
        batch_size = im_data.size(0)

        im_info = im_info.data
        gt_boxes = gt_boxes.data
        num_boxes = num_boxes.data

        # feed image data to base model to obtain base feature map
        base_feat = self.RCNN_base(im_data)

        # feed base feature map tp RPN to obtain rois
        rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(
            base_feat, im_info, gt_boxes, num_boxes)

        # if it is training phrase, then use ground trubut bboxes for refining
        if self.training:
            roi_data = self.RCNN_proposal_target(rois, gt_boxes, num_boxes)
            rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data

            rois_label = Variable(rois_label.view(-1).long())
            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(
                rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(
                rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:
            rois_label = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = 0
            rpn_loss_bbox = 0

        rois = Variable(rois)
        # do roi pooling based on predicted rois

        if cfg.POOLING_MODE == 'crop':
            # pdb.set_trace()
            # pooled_feat_anchor = _crop_pool_layer(base_feat, rois.view(-1, 5))
            grid_xy = _affine_grid_gen(rois.view(-1, 5),
                                       base_feat.size()[2:], self.grid_size)
            grid_yx = torch.stack(
                [grid_xy.data[:, :, :, 1], grid_xy.data[:, :, :, 0]],
                3).contiguous()
            pooled_feat = self.RCNN_roi_crop(base_feat,
                                             Variable(grid_yx).detach())
            if cfg.CROP_RESIZE_WITH_MAX_POOL:
                pooled_feat = F.max_pool2d(pooled_feat, 2, 2)
        elif cfg.POOLING_MODE == 'align':
            pooled_feat = self.RCNN_roi_align(base_feat, rois.view(-1, 5))
        elif cfg.POOLING_MODE == 'pool':
            pooled_feat = self.RCNN_roi_pool(base_feat, rois.view(-1, 5))

        # feed pooled features to top model
        pooled_feat = self._head_to_tail(pooled_feat)

        # compute bbox offset
        bbox_pred = self.RCNN_bbox_pred(pooled_feat)
        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0),
                                            int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(
                bbox_pred_view, 1,
                rois_label.view(rois_label.size(0), 1,
                                1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        # compute object classification probability
        cls_score = self.RCNN_cls_score(pooled_feat)
        cls_prob = F.softmax(cls_score, 1)

        RCNN_loss_cls = 0
        RCNN_loss_bbox = 0

        if self.training:
            # classification loss
            RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)

            # bounding box regression L1 loss
            RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target,
                                             rois_inside_ws, rois_outside_ws)

        cls_prob = cls_prob.view(batch_size, rois.size(1), -1)
        bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1)

        return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label
Пример #5
0
    def forward(self, data_batch):
        im_data = data_batch[0]
        im_info = data_batch[1]
        gt_boxes = data_batch[2]
        num_boxes = data_batch[3]

        batch_size = im_data.size(0)

        if self.training:
            self.iter_counter += 1

        # feed image data to base model to obtain base feature map
        base_feat = self.FeatExt(im_data)
        base_feat.append(self.conv6_layer(base_feat[-1]))

        # C2 C3 C4 C5 C6
        C256 = []
        for i, newconv in enumerate(self.RCNN_newconvs):
            C256.append(newconv(base_feat[i]))

        source = [C256[3]]
        for i, upsampleconv in enumerate(self.RCNN_upsampleconvs):
            if cfg.FPN.UPSAMPLE_CONV:
                source.append(
                    F.upsample(upsampleconv(source[i]),
                               size=(C256[2 - i].size(-2),
                                     C256[2 - i].size(-1)),
                               mode='bilinear') + C256[2 - i])
            else:
                source.append(
                    F.upsample(source[i],
                               size=(C256[2 - i].size(-2),
                                     C256[2 - i].size(-1)),
                               mode='bilinear') + C256[2 - i])
        # reverse ups list
        source = source[::-1]

        # P2 P3 P4 P5 P6
        source.append(C256[4])

        for i in range(len(source)):
            source[i] = self.RCNN_mixconvs[i](source[i])

        # feed base feature map tp RPN to obtain rois
        rois_rpn, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(
            source, im_info, gt_boxes, num_boxes)

        # if it is training phrase, then use ground trubut bboxes for refining
        if self.training:
            roi_data = self.RCNN_proposal_target(rois_rpn, gt_boxes, num_boxes)
            # outputs is a tuple of list.
            roi_data = self._assign_layer(roi_data)
            rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data

            rois_label = [i for i in rois_label if i.numel() > 0]
            rois_target = [i for i in rois_target if i.numel() > 0]
            rois_inside_ws = [i for i in rois_inside_ws if i.numel() > 0]
            rois_outside_ws = [i for i in rois_outside_ws if i.numel() > 0]

        else:
            rois = self._assign_layer(rois_rpn)
            rois_label = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = [0]
            rpn_loss_bbox = [0]

        for i in range(len(rois)):
            rois[i] = Variable(rois[i])
        # do roi pooling based on predicted rois
        pooled_feat = []
        if cfg.RCNN_COMMON.POOLING_MODE == 'align':
            for i in range(len(source)):
                if rois[i].numel() > 0:
                    pooled_feat.append(self.RCNN_roi_aligns[i](source[i],
                                                               rois[i].view(
                                                                   -1, 5)))
        elif cfg.RCNN_COMMON.POOLING_MODE == 'pool':
            for i in range(len(source)):
                if rois[i].numel() > 0:
                    pooled_feat.append(self.RCNN_roi_pools[i](source[i],
                                                              rois[i].view(
                                                                  -1, 5)))

        rois = torch.cat(rois, dim=0)
        img_inds = rois[:, 0]
        pooled_feat = torch.cat(pooled_feat, dim=0)
        if self.training:
            rois_label = torch.cat(rois_label, dim=0)
            rois_target = torch.cat(rois_target, dim=0)
            rois_inside_ws = torch.cat(rois_inside_ws, dim=0)
            rois_outside_ws = torch.cat(rois_outside_ws, dim=0)

        # put all rois belonging to the same image together
        inds = []
        for i in range(batch_size):
            # rois indexes in ith image
            rois_num_i = int(torch.sum(img_inds == i))
            _, inds_i = torch.sort(img_inds == i, descending=True)
            inds.append(inds_i[:rois_num_i])
        inds = torch.cat(inds, dim=0)

        rois = rois[inds]
        pooled_feat = pooled_feat[inds]
        if self.training:
            rois_label = rois_label[inds]
            rois_target = rois_target[inds]
            rois_inside_ws = rois_inside_ws[inds]
            rois_outside_ws = rois_outside_ws[inds]

        # feed pooled features to top model
        pooled_feat = self._head_to_tail(pooled_feat)

        # compute bbox offset
        bbox_pred = self.RCNN_bbox_pred(pooled_feat)
        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0),
                                            int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(
                bbox_pred_view, 1,
                rois_label.view(rois_label.size(0), 1,
                                1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        # compute object classification probability
        cls_score = self.RCNN_cls_score(pooled_feat)
        cls_prob = F.softmax(cls_score)

        RCNN_loss_cls = 0
        RCNN_loss_bbox = 0

        if self.training:
            # classification loss
            if cfg.TRAIN.COMMON.USE_FOCAL_LOSS:
                RCNN_loss_cls = F.cross_entropy(cls_score,
                                                rois_label,
                                                reduce=False)
                focal_loss_factor = torch.pow(
                    (1 - cls_prob[range(int(cls_prob.size(0))), rois_label]),
                    cfg.TRAIN.COMMON.FOCAL_LOSS_GAMMA)
                RCNN_loss_cls = torch.mean(RCNN_loss_cls * focal_loss_factor)
            else:
                RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)

            # bounding box regression L1 loss
            RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target,
                                             rois_inside_ws, rois_outside_ws)

        rois = rois.contiguous().view(batch_size, -1, 5)
        cls_prob = cls_prob.contiguous().view(batch_size, rois.size(1), -1)
        bbox_pred = bbox_pred.contiguous().view(batch_size, rois.size(1), -1)

        return rois, cls_prob, bbox_pred, \
               rpn_loss_cls, rpn_loss_bbox, \
               RCNN_loss_cls, RCNN_loss_bbox, \
               rois_label
Пример #6
0
    def forward(self, im_data, im_info, gt_boxes, num_boxes, rel_mat):
        # object detection
        if self.training:
            self._train_iter_conter += 1
        self.batch_size = im_data.size(0)

        im_info = im_info.data
        gt_boxes = gt_boxes.data
        num_boxes = num_boxes.data

        # feed image data to base model to obtain base feature map
        base_feat = self.VMRN_base(im_data)

        # feed base feature map tp RPN to obtain rois
        rois, rpn_loss_cls, rpn_loss_bbox = self.VMRN_obj_rpn(base_feat, im_info, gt_boxes, num_boxes)

        # rois preprocess
        if self.training:
            obj_det_rois = rois[:,:cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET]
            roi_data = self.VMRN_obj_proposal_target(rois, gt_boxes, num_boxes)
            rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data
            rois = torch.cat([obj_det_rois,rois],1)

            rois_label = Variable(rois_label.view(-1).long())
            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:
            rois_label = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = 0
            rpn_loss_bbox = 0

        rois = Variable(rois)

        pooled_feat = self._roi_pooing(base_feat, rois)

        # feed pooled features to top model
        pooled_feat = self._head_to_tail(pooled_feat)

        # compute bbox offset
        bbox_pred = self.VMRN_obj_bbox_pred(pooled_feat)
        if self.training:
            if self.class_agnostic:
                bbox_pred = bbox_pred.contiguous().view(self.batch_size, -1, 4)
            else:
                bbox_pred = bbox_pred.contiguous().view(self.batch_size, -1, 4 * self.n_classes)
            obj_det_bbox_pred = bbox_pred[:,:cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET]
            bbox_pred = bbox_pred[:,cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET:]
            if self.class_agnostic:
                obj_det_bbox_pred = obj_det_bbox_pred.contiguous().view(-1, 4)
                bbox_pred = bbox_pred.contiguous().view(-1, 4)
            else:
                obj_det_bbox_pred = obj_det_bbox_pred.contiguous().view(-1, 4 * self.n_classes)
                bbox_pred = bbox_pred.contiguous().view(-1, 4 * self.n_classes)

        # compute object classification probability
        cls_score = self.VMRN_obj_cls_score(pooled_feat)
        cls_prob = F.softmax(cls_score)
        if self.training:
            cls_score = cls_score.contiguous().view(self.batch_size, -1, self.n_classes)
            obj_det_cls_score = cls_score[:, :cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET]
            cls_score = cls_score[:, cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET:]
            obj_det_cls_score = obj_det_cls_score.contiguous().view(-1, self.n_classes)
            cls_score = cls_score.contiguous().view(-1, self.n_classes)

            cls_prob = cls_prob.contiguous().view(self.batch_size, -1, self.n_classes)
            obj_det_cls_prob = cls_prob[:, :cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET]
            cls_prob = cls_prob[:, cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET:]
            obj_det_cls_prob = obj_det_cls_prob.contiguous().view(-1, self.n_classes)
            cls_prob = cls_prob.contiguous().view(-1, self.n_classes)

        VMRN_obj_loss_cls = 0
        VMRN_obj_loss_bbox = 0

        # compute object detector loss
        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(bbox_pred_view, 1,
                                            rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        if self.training:
            # classification loss
            VMRN_obj_loss_cls = F.cross_entropy(cls_score, rois_label)
            # bounding box regression L1 loss
            VMRN_obj_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)

        # online data
        if self.training:
            if self._train_iter_conter > cfg.TRAIN.VMRN.ONLINEDATA_BEGIN_ITER:
                obj_rois, obj_num = self._obj_det(obj_det_rois,
                        obj_det_cls_prob.contiguous().view(self.batch_size, -1, self.n_classes),
                        obj_det_bbox_pred.contiguous().view(self.batch_size,
                                -1, 4 if self.class_agnostic else 4 * self.n_classes),
                        self.batch_size, im_info)
                obj_rois = obj_rois.type_as(gt_boxes)
                obj_num = obj_num.type_as(num_boxes)
            else:
                obj_rois = torch.FloatTensor([]).type_as(gt_boxes)
                obj_num = torch.LongTensor([]).type_as(num_boxes)
            obj_labels = None
        else:
            # when testing, this is object detection results
            # TODO: SUPPORT MULTI-IMAGE BATCH
            obj_rois, obj_num = self._obj_det(rois,
                    cls_prob.contiguous().view(self.batch_size, -1, self.n_classes),
                    bbox_pred.contiguous().view(self.batch_size,
                                -1, 4 if self.class_agnostic else 4 * self.n_classes),
                    self.batch_size, im_info)
            if obj_rois.numel() > 0:
                obj_labels = obj_rois[:,5]
                obj_rois = obj_rois[:,:5]
                obj_rois = obj_rois.type_as(gt_boxes)
                obj_num = obj_num.type_as(num_boxes)
            else:
                # there is no object detected
                obj_labels = torch.Tensor([]).type_as(gt_boxes).long()
                obj_rois = obj_rois.type_as(gt_boxes)
                obj_num = obj_num.type_as(num_boxes)

        # offline data

        if self.training:
            for i in range(self.batch_size):
                obj_rois = torch.cat([obj_rois,
                                  torch.cat([(i * torch.ones(num_boxes[i].item(),1)).type_as(gt_boxes),
                                             (gt_boxes[i][:num_boxes[i]][:,0:4])],1)
                                  ])
                obj_num = torch.cat([obj_num,torch.Tensor([num_boxes[i]]).type_as(obj_num)])


        obj_rois = Variable(obj_rois)

        if (obj_num > 1).sum().item() > 0:
            # filter out the detection of only one object instance
            obj_pair_feat = self.VMRN_rel_op2l(base_feat, obj_rois, self.batch_size, obj_num)
            obj_pair_feat = obj_pair_feat.detach()
            obj_pair_feat = self._rel_head_to_tail(obj_pair_feat)
            rel_cls_score = self.VMRN_rel_cls_score(obj_pair_feat)

            rel_cls_prob = F.softmax(rel_cls_score)

            VMRN_rel_loss_cls = 0
            if self.training:
                self.rel_batch_size = rel_cls_prob.size(0)

                obj_pair_rel_label = self._generate_rel_labels(obj_rois, gt_boxes, obj_num, rel_mat)
                obj_pair_rel_label = obj_pair_rel_label.type_as(gt_boxes).long()

                rel_not_keep = (obj_pair_rel_label == 0)

                # no relationship is kept
                if (rel_not_keep == 0).sum().item() > 0:
                    rel_keep = torch.nonzero(rel_not_keep == 0).view(-1)

                    rel_cls_score = rel_cls_score[rel_keep]
                    obj_pair_rel_label = obj_pair_rel_label[rel_keep]

                    obj_pair_rel_label -= 1

                    VMRN_rel_loss_cls = F.cross_entropy(rel_cls_score, obj_pair_rel_label)
            else:
                if (not cfg.TEST.VMRN.ISEX) and cfg.TRAIN.VMRN.ISEX:
                    rel_cls_prob = rel_cls_prob[::2,:]

        else:
            VMRN_rel_loss_cls = 0
            # no detected relationships
            rel_cls_prob = Variable(torch.Tensor([]).type_as(cls_prob))

        rel_result = None
        if not self.training:
            if obj_rois.numel() > 0:
                pred_boxes = obj_rois.data[:,1:5]
                pred_boxes[:, 0::2] /= im_info[0][3].item()
                pred_boxes[:, 1::2] /= im_info[0][2].item()
                rel_result = (pred_boxes, obj_labels, rel_cls_prob.data)
            else:
                rel_result = (obj_rois.data, obj_labels, rel_cls_prob.data)

        return rois, cls_prob, bbox_pred, rel_result, rpn_loss_cls, rpn_loss_bbox, \
               VMRN_obj_loss_cls, VMRN_obj_loss_bbox, VMRN_rel_loss_cls, rois_label
Пример #7
0
    def forward(self, im_data, gt_boxes, im_info):
        batch_size = im_data.size(0)
        im_info = im_info.data


        if not gt_boxes is None:
            gt_boxes = gt_boxes.data

        # feed image data to base model to obtain base feature map
        base_feat = self.RCNN_base(im_data)

        # feed base feature map to RPN to obtain rois
        rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(base_feat, im_info, gt_boxes)

        # if it is training phase, then use ground truth bboxes for refining
        if self.training:
            roi_data = self.RCNN_proposal_target(rois, gt_boxes)
            rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data

            rois_label = Variable(rois_label.view(-1).long())
            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:
            rois_label = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = 0
            rpn_loss_bbox = 0

        rois = Variable(rois)
        # do roi pooling based on predicted rois

        if cfg.pooling_mode == 'align':
            # pooled_feat = self.RCNN_roi_align(feature_map, rois.view(-1, 5))
            pooled_feat = roi_align(base_feat, rois.view(-1, 5), (cfg.pool_size, cfg.pool_size), 1.0/16)
        elif cfg.pooling_mode == 'pool':
            #pooled_feat = self.RCNN_roi_pool(feature_map, rois.view(-1, 5))
            pooled_feat = roi_pool(base_feat, rois.view(-1, 5), (cfg.pool_size, cfg.pool_size), 1.0/16)

        # feed pooled features to top model
        pooled_feat = self._head_to_tail(pooled_feat)

        # compute bbox offset
        bbox_pred = self.RCNN_bbox_pred(pooled_feat)
        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(bbox_pred_view, 1, rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        # compute object classification probability
        cls_score = self.RCNN_cls_score(pooled_feat)
        cls_prob = F.softmax(cls_score, 1)

        RCNN_loss_cls = 0
        RCNN_loss_bbox = 0

        if self.training:
            # classification loss
            RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)

            # bounding box regression L1 loss
            RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)


        cls_prob = cls_prob.view(batch_size, rois.size(1), -1)
        bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1)

        return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label
Пример #8
0
    def forward(self, im_data, gt):
        # object detection
        if self.training:
            self._train_iter_conter += 1

        self.batch_size = im_data.size(0)

        gt_boxes = gt['boxes']
        gt_grasps = gt['grasps']
        gt_grasp_inds = gt['grasp_inds']
        num_boxes = gt['num_boxes']
        num_grasps = gt['num_grasps']
        im_info = gt['im_info']
        rel_mat = gt['rel_mat']

        # feed image data to base model to obtain base feature map
        base_feat = self.VMRN_base(im_data)

        # feed base feature map tp RPN to obtain rois
        rois, rpn_loss_cls, rpn_loss_bbox = self.VMRN_obj_rpn(base_feat, im_info, gt_boxes, num_boxes)

        # rois preprocess
        if self.training:
            obj_det_rois = rois[:,:cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET]
            roi_data = self.VMRN_obj_proposal_target(rois, gt_boxes, num_boxes)
            rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data
            grasp_rois = rois.clone()
            rois = torch.cat([obj_det_rois,rois],1)

            rois_label = Variable(rois_label.view(-1).long())
            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:
            rois_label = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = 0
            rpn_loss_bbox = 0

        rois = Variable(rois)

        pooled_feat = self._roi_pooing(base_feat, rois)

        if self.training:
            pooled_feat_shape = pooled_feat.size()
            pooled_feat = pooled_feat.contiguous().view((self.batch_size, -1) + pooled_feat_shape[1:])
            grasp_feat = pooled_feat[:, cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET:].\
                contiguous().view((-1,) + pooled_feat_shape[1:])
            pooled_feat = pooled_feat.view(pooled_feat_shape)
            if self._MGN_USE_POOLED_FEATS:
                rois_overlaps = bbox_overlaps_batch(rois[:, cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET:], gt_boxes)
                # bs x N_{rois}
                _, rois_inds = torch.max(rois_overlaps, dim=2)
                rois_inds += 1
                grasp_rois_mask = rois_label.view(-1) > 0
            else:
                raise NotImplementedError

        ############################################
        # OBJECT DETECTION
        ############################################
        # feed pooled features to top model
        VMRN_feat = self._head_to_tail(pooled_feat)

        # compute bbox offset
        bbox_pred = self.VMRN_obj_bbox_pred(VMRN_feat)
        if self.training:
            if self.class_agnostic:
                bbox_pred = bbox_pred.contiguous().view(self.batch_size, -1, 4)
            else:
                bbox_pred = bbox_pred.contiguous().view(self.batch_size, -1, 4 * self.n_classes)
            obj_det_bbox_pred = bbox_pred[:,:cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET]
            bbox_pred = bbox_pred[:,cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET:]
            if self.class_agnostic:
                obj_det_bbox_pred = obj_det_bbox_pred.contiguous().view(-1, 4)
                bbox_pred = bbox_pred.contiguous().view(-1, 4)
            else:
                obj_det_bbox_pred = obj_det_bbox_pred.contiguous().view(-1, 4 * self.n_classes)
                bbox_pred = bbox_pred.contiguous().view(-1, 4 * self.n_classes)

        # compute object classification probability
        cls_score = self.VMRN_obj_cls_score(VMRN_feat)
        cls_prob = F.softmax(cls_score)
        if self.training:
            cls_score = cls_score.contiguous().view(self.batch_size, -1, self.n_classes)
            obj_det_cls_score = cls_score[:, :cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET]
            cls_score = cls_score[:, cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET:]
            obj_det_cls_score = obj_det_cls_score.contiguous().view(-1, self.n_classes)
            cls_score = cls_score.contiguous().view(-1, self.n_classes)

            cls_prob = cls_prob.contiguous().view(self.batch_size, -1, self.n_classes)
            obj_det_cls_prob = cls_prob[:, :cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET]
            cls_prob = cls_prob[:, cfg.TRAIN.VMRN.TOP_N_ROIS_FOR_OBJ_DET:]
            obj_det_cls_prob = obj_det_cls_prob.contiguous().view(-1, self.n_classes)
            cls_prob = cls_prob.contiguous().view(-1, self.n_classes)

        VMRN_obj_loss_cls = 0
        VMRN_obj_loss_bbox = 0

        # compute object detector loss
        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(bbox_pred_view, 1,
                                            rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        if self.training:
            # classification loss
            VMRN_obj_loss_cls = F.cross_entropy(cls_score, rois_label)
            # bounding box regression L1 loss
            VMRN_obj_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)

        ############################################
        # VISUAL MANIPULATION RELATIONSHIP
        ############################################
        # online data
        if self.training:
            if self._train_iter_conter > cfg.TRAIN.VMRN.ONLINEDATA_BEGIN_ITER:
                obj_rois, obj_num = self._obj_det(obj_det_rois,
                        obj_det_cls_prob.contiguous().view(self.batch_size, -1, self.n_classes),
                        obj_det_bbox_pred.contiguous().view(self.batch_size,
                                -1, 4 if self.class_agnostic else 4 * self.n_classes),
                        self.batch_size, im_info)
                obj_rois = obj_rois.type_as(gt_boxes)
                obj_num = obj_num.type_as(num_boxes)
            else:
                obj_rois = torch.FloatTensor([]).type_as(gt_boxes)
                obj_num = torch.LongTensor([]).type_as(num_boxes)
            obj_labels = None
        else:
            # when testing, this is object detection results
            # TODO: SUPPORT MULTI-IMAGE BATCH
            obj_rois, obj_num = self._obj_det(rois,
                    cls_prob.contiguous().view(self.batch_size, -1, self.n_classes),
                    bbox_pred.contiguous().view(self.batch_size,
                                -1, 4 if self.class_agnostic else 4 * self.n_classes),
                    self.batch_size, im_info)
            if obj_rois.numel() > 0:
                obj_labels = obj_rois[:,5]
                obj_rois = obj_rois[:,:5]
                obj_rois = obj_rois.type_as(gt_boxes)
                obj_num = obj_num.type_as(num_boxes)
            else:
                # there is no object detected
                obj_labels = torch.Tensor([]).type_as(gt_boxes).long()
                obj_rois = obj_rois.type_as(gt_boxes)
                obj_num = obj_num.type_as(num_boxes)

        # offline data
        if self.training:
            for i in range(self.batch_size):
                obj_rois = torch.cat([obj_rois,
                                  torch.cat([(i * torch.ones(num_boxes[i].item(),1)).type_as(gt_boxes),
                                             (gt_boxes[i][:num_boxes[i]][:,0:4])],1)
                                  ])
                obj_num = torch.cat([obj_num,torch.Tensor([num_boxes[i]]).type_as(obj_num)])

        obj_rois = Variable(obj_rois)

        if obj_rois.size(0)>1:
            # filter out the detection of only one object instance
            obj_pair_feat = self.VMRN_rel_op2l(base_feat, obj_rois, self.batch_size, obj_num)
            # obj_pair_feat = obj_pair_feat.detach()
            obj_pair_feat = self._rel_head_to_tail(obj_pair_feat)
            rel_cls_score = self.VMRN_rel_cls_score(obj_pair_feat)

            rel_cls_prob = F.softmax(rel_cls_score)

            VMRN_rel_loss_cls = 0
            if self.training:
                self.rel_batch_size = rel_cls_prob.size(0)

                obj_pair_rel_label = self._generate_rel_labels(obj_rois, gt_boxes, obj_num, rel_mat)
                obj_pair_rel_label = obj_pair_rel_label.type_as(gt_boxes).long()

                rel_not_keep = (obj_pair_rel_label == 0)
                rel_keep = torch.nonzero(rel_not_keep == 0).view(-1)

                rel_cls_score = rel_cls_score[rel_keep]
                obj_pair_rel_label = obj_pair_rel_label[rel_keep]

                obj_pair_rel_label -= 1

                VMRN_rel_loss_cls = F.cross_entropy(rel_cls_score, obj_pair_rel_label)
            else:
                if (not cfg.TEST.VMRN.ISEX) and cfg.TRAIN.VMRN.ISEX:
                    rel_cls_prob = rel_cls_prob[::2,:]

        else:
            VMRN_rel_loss_cls = 0
            # no detected relationships
            rel_cls_prob = Variable(torch.Tensor([]).type_as(obj_labels))

        rel_result = None
        if not self.training:
            if obj_rois.numel() > 0:
                pred_boxes = obj_rois.data[:,1:5]
                pred_boxes[:, 0::2] /= im_info[0][3].item()
                pred_boxes[:, 1::2] /= im_info[0][2].item()
                rel_result = (pred_boxes, obj_labels, rel_cls_prob.data)
            else:
                rel_result = (obj_rois.data, obj_labels, rel_cls_prob.data)

        ############################################
        # ROI-BASED GRASP DETECTION
        ############################################
        if self.training:
            if (grasp_rois_mask > 0).sum().item() > 0:
                grasp_feat = self._MGN_head_to_tail(grasp_feat[grasp_rois_mask])
            else:
                # when there are no one positive rois:
                grasp_loc = Variable(torch.Tensor([]).type_as(gt_grasps))
                grasp_prob = Variable(torch.Tensor([]).type_as(gt_grasps))
                grasp_bbox_loss = Variable(torch.Tensor([0]).type_as(VMRN_obj_loss_bbox))
                grasp_cls_loss = Variable(torch.Tensor([0]).type_as(VMRN_obj_loss_cls))
                grasp_conf_label = torch.Tensor([-1]).type_as(rois_label)
                grasp_all_anchors = torch.Tensor([]).type_as(gt_grasps)
                return rois, cls_prob, bbox_pred, rel_result, rpn_loss_cls, rpn_loss_bbox, \
                    VMRN_obj_loss_cls, VMRN_obj_loss_bbox, VMRN_rel_loss_cls, rois_label, \
                   grasp_loc, grasp_prob, grasp_bbox_loss , grasp_cls_loss, grasp_conf_label, grasp_all_anchors
        else:
            grasp_feat = self._MGN_head_to_tail(pooled_feat)

        grasp_pred = self.MGN_classifier(grasp_feat)
        # bs*N x K*A x 5, bs*N x K*A x 2
        grasp_loc, grasp_conf = grasp_pred

        # generate anchors
        # bs*N x K*A x 5
        if self.training:
            grasp_all_anchors = self._generate_anchors(grasp_conf.size(1), grasp_conf.size(2), grasp_rois)
        else:
            grasp_all_anchors = self._generate_anchors(grasp_conf.size(1), grasp_conf.size(2), rois)
        # filter out negative samples
        grasp_all_anchors = grasp_all_anchors.type_as(gt_grasps)
        if self.training:
            grasp_all_anchors = grasp_all_anchors[grasp_rois_mask]
            # bs*N x 1 x 1
            rois_w = (grasp_rois[:, :, 3] - grasp_rois[:, :, 1]).data.view(-1).unsqueeze(1).unsqueeze(2)
            rois_h = (grasp_rois[:, :, 4] - grasp_rois[:, :, 2]).data.view(-1).unsqueeze(1).unsqueeze(2)
            rois_w = rois_w[grasp_rois_mask]
            rois_h = rois_h[grasp_rois_mask]
            # bs*N x 1 x 1
            fsx = rois_w / grasp_conf.size(1)
            fsy = rois_h / grasp_conf.size(2)
            # bs*N x 1 x 1
            xleft = grasp_rois[:, :, 1].data.view(-1).unsqueeze(1).unsqueeze(2)
            ytop = grasp_rois[:, :, 2].data.view(-1).unsqueeze(1).unsqueeze(2)
            xleft = xleft[grasp_rois_mask]
            ytop = ytop[grasp_rois_mask]

        # reshape grasp_loc and grasp_conf
        grasp_loc = grasp_loc.contiguous().view(grasp_loc.size(0), -1, 5)
        grasp_conf = grasp_conf.contiguous().view(grasp_conf.size(0), -1, 2)
        grasp_batch_size = grasp_loc.size(0)

        # bs*N x K*A x 2
        grasp_prob = F.softmax(grasp_conf, 2)

        grasp_bbox_loss = 0
        grasp_cls_loss = 0
        grasp_conf_label = None
        if self.training:
            # inside weights indicate which bounding box should be regressed
            # outside weidhts indicate two things:
            # 1. Which bounding box should contribute for classification loss,
            # 2. Balance cls loss and bbox loss
            grasp_gt_xywhc = points2labels(gt_grasps)
            # bs*N x N_{Gr_gt} x 5
            grasp_gt_xywhc = self._assign_rois_grasps(grasp_gt_xywhc, gt_grasp_inds, rois_inds)
            # filter out negative samples
            grasp_gt_xywhc = grasp_gt_xywhc[grasp_rois_mask]

            # absolute coords to relative coords
            grasp_gt_xywhc[:, :, 0:1] -= xleft
            grasp_gt_xywhc[:, :, 0:1] = torch.clamp(grasp_gt_xywhc[:, :, 0:1], min = 0)
            grasp_gt_xywhc[:, :, 0:1] = torch.min(grasp_gt_xywhc[:, :, 0:1], rois_w)
            grasp_gt_xywhc[:, :, 1:2] -= ytop
            grasp_gt_xywhc[:, :, 1:2] = torch.clamp(grasp_gt_xywhc[:, :, 1:2], min = 0)
            grasp_gt_xywhc[:, :, 1:2] = torch.min(grasp_gt_xywhc[:, :, 1:2], rois_h)

            # grasp training data
            grasp_loc_label, grasp_conf_label, grasp_iw, grasp_ow = self.MGN_proposal_target(grasp_conf,
                                        grasp_gt_xywhc, grasp_all_anchors,xthresh = fsx/2, ythresh = fsy/2)

            grasp_keep = Variable(grasp_conf_label.view(-1).ne(-1).nonzero().view(-1))
            grasp_conf = torch.index_select(grasp_conf.view(-1, 2), 0, grasp_keep.data)
            grasp_conf_label = torch.index_select(grasp_conf_label.view(-1), 0, grasp_keep.data)
            grasp_cls_loss = F.cross_entropy(grasp_conf, grasp_conf_label)

            grasp_iw = Variable(grasp_iw)
            grasp_ow = Variable(grasp_ow)
            grasp_loc_label = Variable(grasp_loc_label)
            grasp_bbox_loss = _smooth_l1_loss(grasp_loc, grasp_loc_label, grasp_iw, grasp_ow, dim = [2,1])

        return rois, cls_prob, bbox_pred, rel_result, rpn_loss_cls, rpn_loss_bbox, \
                VMRN_obj_loss_cls, VMRN_obj_loss_bbox, VMRN_rel_loss_cls, rois_label, \
                grasp_loc, grasp_prob, grasp_bbox_loss , grasp_cls_loss, grasp_conf_label, grasp_all_anchors