Exemplo n.º 1
0
    def forward(self, data_batch):
        x = data_batch[0]
        im_info = data_batch[1]
        gt_boxes = data_batch[2]
        num_boxes = data_batch[3]

        if self.training:
            self.iter_counter += 1

        # features
        x = self.FeatExt(x)
        pred = self.FCGN_classifier(x)
        loc, conf = pred
        self.batch_size = loc.size(0)

        all_anchors = self._generate_anchors(conf.size(1), conf.size(2))
        all_anchors = all_anchors.type_as(gt_boxes)
        all_anchors = all_anchors.expand(self.batch_size, all_anchors.size(1),
                                         all_anchors.size(2))

        loc = loc.contiguous().view(loc.size(0), -1, 5)
        conf = conf.contiguous().view(conf.size(0), -1, 2)
        prob = F.softmax(conf, 2)

        bbox_loss = 0
        cls_loss = 0
        conf_label = None
        if self.training:
            # inside weights indicate which bounding box should be regressed
            # outside weidhts indicate two things:
            # 1. Which bounding box should contribute for classification loss,
            # 2. Balance cls loss and bbox loss
            gt_xywhc = points2labels(gt_boxes)
            loc_label, conf_label, iw, ow = self.FCGN_proposal_target(
                conf, gt_xywhc, all_anchors)

            keep = Variable(conf_label.view(-1).ne(-1).nonzero().view(-1))
            conf = torch.index_select(conf.view(-1, 2), 0, keep.data)
            conf_label = torch.index_select(conf_label.view(-1), 0, keep.data)
            cls_loss = F.cross_entropy(conf, conf_label)

            iw = Variable(iw)
            ow = Variable(ow)
            loc_label = Variable(loc_label)
            bbox_loss = _smooth_l1_loss(loc, loc_label, iw, ow, dim=[2, 1])

        return loc, prob, bbox_loss, cls_loss, conf_label, all_anchors
    def forward(self, data_batch):
        im_data = data_batch[0]
        im_info = data_batch[1]
        gt_boxes = data_batch[2]
        gt_grasps = data_batch[3]
        num_boxes = data_batch[4]
        num_grasps = data_batch[5]
        rel_mat = data_batch[6]
        gt_grasp_inds = data_batch[7]

        # object detection
        if self.training:
            self.iter_counter += 1
        self.batch_size = im_data.size(0)

        # feed image data to base model to obtain base feature map
        base_feat = self.FeatExt(im_data)
        ### GENERATE ROIs
        rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(base_feat, im_info, gt_boxes, num_boxes)
        if self.training:
            rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = \
                self._get_header_train_data(rois, gt_boxes, num_boxes)
            pos_rois_labels = [(rois_label[i * rois.size(1): (i + 1) * rois.size(1)] > 0) for i in range(self.batch_size)]
            od_rois = [rois[i][pos_rois_labels[i]].data for i in range(self.batch_size)]
        else:
            rois_label, rois_target, rois_inside_ws, rois_outside_ws = None, None, None, None
            od_rois = rois.data
        pooled_feat = self._roi_pooling(base_feat, rois)

        ### OBJECT DETECTION
        cls_score, cls_prob, bbox_pred = self._get_obj_det_result(pooled_feat)
        RCNN_loss_bbox, RCNN_loss_cls = 0, 0
        if self.training:
            RCNN_loss_bbox, RCNN_loss_cls = self._obj_det_loss_comp(cls_score, cls_prob, bbox_pred, rois_label, rois_target,
                                                                    rois_inside_ws, rois_outside_ws)
        cls_prob = cls_prob.contiguous().view(self.batch_size, rois.size(1), -1)
        bbox_pred = bbox_pred.contiguous().view(self.batch_size, rois.size(1), -1)

        ### VISUAL MANIPULATION RELATIONSHIP DETECTION
        # for object detection before relationship detection
        if self.training:
            od_cls_prob = [cls_prob[i][pos_rois_labels[i]].data for i in range(self.batch_size)]
            od_bbox_pred = [bbox_pred[i][pos_rois_labels[i]].data for i in range(self.batch_size)]
        else:
            od_cls_prob = cls_prob.data
            od_bbox_pred = bbox_pred.data

        # generate object RoIs.
        obj_rois, obj_num = torch.Tensor([]).type_as(rois), torch.Tensor([]).type_as(num_boxes)
        # online data
        if not self.training or (cfg.TRAIN.VMRN.TRAINING_DATA == 'all' or 'online'):
            obj_rois, obj_num = self._object_detection(od_rois, od_cls_prob, od_bbox_pred, self.batch_size, im_info.data)
        # offline data
        if self.training and (cfg.TRAIN.VMRN.TRAINING_DATA == 'all' or 'offline'):
            for i in range(self.batch_size):
                img_ind = (i * torch.ones(num_boxes[i].item(),1)).type_as(gt_boxes)
                obj_rois = torch.cat([obj_rois, torch.cat([img_ind, (gt_boxes[i][:num_boxes[i]])],1)])
            obj_num = torch.cat([obj_num, num_boxes])

        obj_labels = torch.Tensor([]).type_as(gt_boxes).long()
        if obj_rois.size(0) > 0:
            obj_labels = obj_rois[:, 5]
            obj_rois = obj_rois[:, :5]

        VMRN_rel_loss_cls = 0
        if (obj_num > 1).sum().item() > 0:
            rel_cls_score, rel_cls_prob = self._get_rel_det_result(base_feat, obj_rois, obj_num)
            if self.training:
                obj_pair_rel_label = self._generate_rel_labels(obj_rois, gt_boxes, obj_num, rel_mat, rel_cls_prob.size(0))
                VMRN_rel_loss_cls = self._rel_det_loss_comp(obj_pair_rel_label.type_as(gt_boxes).long(), rel_cls_score)
            else:
                rel_cls_prob = self._rel_cls_prob_post_process(rel_cls_prob)
        else:
            rel_cls_prob = torch.Tensor([]).type_as(cls_prob)

        rel_result = None
        if not self.training:
            if obj_rois.numel() > 0:
                pred_boxes = obj_rois.data[:,1:5]
                pred_boxes[:, 0::2] /= im_info[0][3].item()
                pred_boxes[:, 1::2] /= im_info[0][2].item()
                rel_result = (pred_boxes, obj_labels, rel_cls_prob.data)
            else:
                rel_result = (obj_rois.data, obj_labels, rel_cls_prob.data)

        ### ROI-BASED GRASP DETECTION
        if self.training:
            rois_overlaps = bbox_overlaps_batch(rois, gt_boxes)
            # bs x N_{rois}
            _, rois_inds = torch.max(rois_overlaps, dim=2)
            rois_inds += 1
            grasp_rois_mask = rois_label.view(-1) > 0

            if (grasp_rois_mask > 0).sum().item() > 0:
                grasp_feat = self._MGN_head_to_tail(pooled_feat[grasp_rois_mask])
                grasp_rois = rois.view(-1, 5)[grasp_rois_mask]
                # process grasp ground truth, return: N_{gr_rois} x N_{Gr_gt} x 5
                grasp_gt_xywhc = points2labels(gt_grasps)
                grasp_gt_xywhc = self._assign_rois_grasps(grasp_gt_xywhc, gt_grasp_inds, rois_inds)
                grasp_gt_xywhc = grasp_gt_xywhc[grasp_rois_mask]
            else:
                # when there are no one positive rois, return dummy results
                grasp_loc = torch.Tensor([]).type_as(gt_grasps)
                grasp_prob = torch.Tensor([]).type_as(gt_grasps)
                grasp_bbox_loss = torch.Tensor([0]).type_as(gt_grasps)
                grasp_cls_loss = torch.Tensor([0]).type_as(gt_grasps)
                grasp_conf_label = torch.Tensor([-1]).type_as(rois_label)
                grasp_all_anchors = torch.Tensor([]).type_as(gt_grasps)
                return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label,\
                   grasp_loc, grasp_prob, grasp_bbox_loss , grasp_cls_loss, grasp_conf_label, grasp_all_anchors
        else:
            grasp_feat = self._MGN_head_to_tail(pooled_feat)

        # N_{gr_rois} x W x H x A*5, N_{gr_rois} x W x H x A*2
        grasp_loc, grasp_conf = self.FCGN_classifier(grasp_feat)
        feat_height, feat_width = grasp_conf.size(1), grasp_conf.size(2)
        # reshape grasp_loc and grasp_conf
        grasp_loc = grasp_loc.contiguous().view(grasp_loc.size(0), -1, 5)
        grasp_conf = grasp_conf.contiguous().view(grasp_conf.size(0), -1, 2)
        grasp_prob = F.softmax(grasp_conf, 2)

        # 2. calculate grasp loss
        grasp_bbox_loss, grasp_cls_loss, grasp_conf_label = 0, 0, None
        if self.training:
            # N_{gr_rois} x K*A x 5
            grasp_all_anchors = self._generate_anchors(feat_height, feat_width, grasp_rois)
            grasp_bbox_loss, grasp_cls_loss, grasp_conf_label = self._grasp_loss_comp(grasp_rois,
                grasp_conf, grasp_loc, grasp_gt_xywhc, grasp_all_anchors, feat_height, feat_width)
        else:
            # bs*N x K*A x 5
            grasp_all_anchors = self._generate_anchors(feat_height, feat_width, rois.view(-1, 5))

        return rois, cls_prob, bbox_pred, rel_result, rpn_loss_cls, rpn_loss_bbox, \
                RCNN_loss_cls, RCNN_loss_bbox, VMRN_rel_loss_cls, rois_label, \
                grasp_loc, grasp_prob, grasp_bbox_loss , grasp_cls_loss, grasp_conf_label, grasp_all_anchors
Exemplo n.º 3
0
    def forward(self, data_batch):
        im_data = data_batch[0]
        im_info = data_batch[1]
        gt_boxes = data_batch[2]
        gt_grasps = data_batch[3]
        num_boxes = data_batch[4]
        num_grasps = data_batch[5]
        gt_grasp_inds = data_batch[6]
        batch_size = im_data.size(0)
        if self.training:
            self.iter_counter += 1

        # for jacquard dataset, the bounding box labels are set to -1. For training, we set them to 1, which does not
        # affect the training process.
        if self.training:
            if gt_boxes[:, :, -1].sum().item() < 0:
                gt_boxes[:, :, -1] = 1

        for i in range(batch_size):
            if torch.sum(gt_grasp_inds[i]).item() == 0:
                gt_grasp_inds[i, :num_grasps[i].item()] = 1

        # features
        base_feat = self.FeatExt(im_data)

        # generate rois of RCNN
        rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(
            base_feat, im_info, gt_boxes, num_boxes)
        if not self.use_objdet_branch:
            rois_scores = rois[:, :, 5:].clone()
            rois = rois[:, :, :5].clone()

        if self.training:
            rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = \
                self._get_header_train_data(rois, gt_boxes, num_boxes)
        else:
            rois_label, rois_target, rois_inside_ws, rois_outside_ws = None, None, None, None
        pooled_feat = self._roi_pooling(base_feat, rois)

        cls_prob, bbox_pred, RCNN_loss_bbox, RCNN_loss_cls = \
            None, None, torch.Tensor([0]).type_as(rois), torch.Tensor([0]).type_as(rois)
        if self.use_objdet_branch:
            # object detection branch
            cls_score, cls_prob, bbox_pred = self._get_obj_det_result(
                pooled_feat)
            if self.training:
                RCNN_loss_bbox, RCNN_loss_cls = self._obj_det_loss_comp(
                    cls_score, cls_prob, bbox_pred, rois_label, rois_target,
                    rois_inside_ws, rois_outside_ws)
            cls_prob = cls_prob.view(batch_size, rois.size(1), -1)
            bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1)
        else:
            cls_prob = torch.cat([1 - rois_scores, rois_scores], dim=-1)

        # grasp detection branch
        # 1. obtaining grasp features of the positive ROIs and prepare grasp training data
        if self.training:
            rois_overlaps = bbox_overlaps_batch(rois, gt_boxes)
            # bs x N_{rois}
            _, rois_inds = torch.max(rois_overlaps, dim=2)
            rois_inds += 1
            grasp_rois_mask = rois_label.view(-1) > 0

            if (grasp_rois_mask > 0).sum().item() > 0:
                grasp_feat = self._MGN_head_to_tail(
                    pooled_feat[grasp_rois_mask])
                grasp_rois = rois.view(-1, 5)[grasp_rois_mask]
                # process grasp ground truth, return: N_{gr_rois} x N_{Gr_gt} x 5
                grasp_gt_xywhc = points2labels(gt_grasps)
                grasp_gt_xywhc = self._assign_rois_grasps(
                    grasp_gt_xywhc, gt_grasp_inds, rois_inds)
                grasp_gt_xywhc = grasp_gt_xywhc[grasp_rois_mask]
            else:
                # when there are no one positive rois, return dummy results
                grasp_loc = torch.Tensor([]).type_as(gt_grasps)
                grasp_prob = torch.Tensor([]).type_as(gt_grasps)
                grasp_bbox_loss = torch.Tensor([0]).type_as(gt_grasps)
                grasp_cls_loss = torch.Tensor([0]).type_as(gt_grasps)
                grasp_conf_label = torch.Tensor([-1]).type_as(rois_label)
                grasp_all_anchors = torch.Tensor([]).type_as(gt_grasps)
                return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label,\
                   grasp_loc, grasp_prob, grasp_bbox_loss , grasp_cls_loss, grasp_conf_label, grasp_all_anchors
        else:
            grasp_feat = self._MGN_head_to_tail(pooled_feat)

        # N_{gr_rois} x W x H x A*5, N_{gr_rois} x W x H x A*2
        grasp_loc, grasp_conf = self.FCGN_classifier(grasp_feat)
        feat_height, feat_width = grasp_conf.size(1), grasp_conf.size(2)
        # reshape grasp_loc and grasp_conf
        grasp_loc = grasp_loc.contiguous().view(grasp_loc.size(0), -1, 5)
        grasp_conf = grasp_conf.contiguous().view(grasp_conf.size(0), -1, 2)
        grasp_prob = F.softmax(grasp_conf, 2)

        # 2. calculate grasp loss
        grasp_bbox_loss, grasp_cls_loss, grasp_conf_label = 0, 0, None
        if self.training:
            # N_{gr_rois} x K*A x 5
            grasp_all_anchors = self._generate_anchors(feat_height, feat_width,
                                                       grasp_rois)
            grasp_bbox_loss, grasp_cls_loss, grasp_conf_label = self._grasp_loss_comp(
                grasp_rois, grasp_conf, grasp_loc, grasp_gt_xywhc,
                grasp_all_anchors, feat_height, feat_width)
        else:
            # bs*N x K*A x 5
            grasp_all_anchors = self._generate_anchors(feat_height, feat_width,
                                                       rois.view(-1, 5))

        return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label,\
               grasp_loc, grasp_prob, grasp_bbox_loss , grasp_cls_loss, grasp_conf_label, grasp_all_anchors
Exemplo n.º 4
0
    def forward_without_od(self, data_batch):
        im_data = data_batch[0]
        im_info = data_batch[1]
        gt_boxes = data_batch[2]
        gt_grasps = data_batch[3]
        num_boxes = data_batch[4]
        num_grasps = data_batch[5]
        rel_mat = data_batch[6]
        gt_grasp_inds = data_batch[7]

        # object detection
        if self.training:
            self.iter_counter += 1
        self.batch_size = im_data.size(0)

        # feed image data to base model to obtain base feature map
        base_feat = self.FeatExt(im_data)

        # object detection loss
        rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox = 0, 0, 0, 0

        ### VISUAL MANIPULATION RELATIONSHIP DETECTION
        # generate object RoIs.
        obj_rois, obj_num = torch.Tensor([]).type_as(gt_boxes), torch.Tensor([]).type_as(num_boxes)
        # offline data
        for i in range(self.batch_size):
            img_ind = (i * torch.ones(num_boxes[i].item(), 1)).type_as(gt_boxes)
            obj_rois = torch.cat([obj_rois, torch.cat([img_ind, (gt_boxes[i][:num_boxes[i]])], 1)])
        obj_num = torch.cat([obj_num, num_boxes])

        obj_labels = obj_rois[:, 5]
        obj_rois = obj_rois[:, :5]

        VMRN_rel_loss_cls, rel_reg_loss = 0, 0
        if (obj_num > 1).sum().item() > 0:
            rel_cls_score, rel_cls_prob, rel_reg_loss = self._get_rel_det_result(base_feat, obj_rois, obj_num, im_info)
            if self.training:
                obj_pair_rel_label = self._generate_rel_labels(obj_rois, gt_boxes, obj_num, rel_mat,
                                                               rel_cls_prob.size(0))
                VMRN_rel_loss_cls = self._rel_det_loss_comp(obj_pair_rel_label.type_as(gt_boxes).long(), rel_cls_score)
            else:
                rel_cls_prob = self._rel_cls_prob_post_process(rel_cls_prob)
        else:
            rel_cls_prob = torch.Tensor([]).type_as(gt_boxes)

        rel_result = None
        if not self.training:
            pred_boxes = obj_rois[:, 1:5].view(-1, 4)
            pred_boxes[:, 0::2] /= im_info[0][3].item()
            pred_boxes[:, 1::2] /= im_info[0][2].item()
            rel_result = (pred_boxes.data, obj_labels.data, rel_cls_prob.data)

        ### ROI-BASED GRASP DETECTION
        img_ind = torch.cat([(i * torch.ones(1, gt_boxes.shape[1], 1)).type_as(gt_boxes)
                             for i in range(self.batch_size)], dim = 0)
        rois = torch.cat([img_ind, gt_boxes[:, :, :4]], dim = -1)
        rois_inds = torch.ones((self.batch_size, rois.shape[1])).type_as(rois).long()
        for i in range(self.batch_size):
            rois_inds[i][:num_boxes[i].item()] = torch.arange(1, num_boxes[i].item() + 1)

        grasp_rois_mask = gt_boxes[:,:,4].view(-1) > 0
        grasp_rois = rois.view(-1, 5)[grasp_rois_mask]
        pooled_feat = self._roi_pooling(base_feat, grasp_rois)
        grasp_feat = self._MGN_head_to_tail(pooled_feat)
        if self.training:
            # process grasp ground truth, return: N_{gr_rois} x N_{Gr_gt} x 5
            grasp_gt_xywhc = points2labels(gt_grasps)
            grasp_gt_xywhc = self._assign_rois_grasps(grasp_gt_xywhc, gt_grasp_inds, rois_inds)
            grasp_gt_xywhc = grasp_gt_xywhc[grasp_rois_mask]

        # N_{gr_rois} x W x H x A*5, N_{gr_rois} x W x H x A*2
        grasp_loc, grasp_conf = self.FCGN_classifier(grasp_feat)
        feat_height, feat_width = grasp_conf.size(1), grasp_conf.size(2)
        # reshape grasp_loc and grasp_conf
        grasp_loc = grasp_loc.contiguous().view(grasp_loc.size(0), -1, 5)
        grasp_conf = grasp_conf.contiguous().view(grasp_conf.size(0), -1, 2)
        grasp_prob = F.softmax(grasp_conf, 2)

        # 2. calculate grasp loss
        grasp_bbox_loss, grasp_cls_loss, grasp_conf_label = 0, 0, None
        if self.training:
            # N_{gr_rois} x K*A x 5
            grasp_all_anchors = self._generate_anchors(feat_height, feat_width, grasp_rois)
            grasp_bbox_loss, grasp_cls_loss, grasp_conf_label = self._grasp_loss_comp(grasp_rois,
                            grasp_conf, grasp_loc, grasp_gt_xywhc, grasp_all_anchors, feat_height, feat_width)
        else:
            # bs*N x K*A x 5
            grasp_all_anchors = self._generate_anchors(feat_height, feat_width, rois.view(-1, 5))

        cls_prob, bbox_pred, rois_label = None, None, None
        if not self.training:
            cls_prob = torch.zeros((1, num_boxes[0].item(), self.n_classes)).type_as(gt_boxes)
            for i in range(num_boxes[0].item()):
                cls_prob[0, i, gt_boxes[0, i, -1].long().item()] = 1
            bbox_pred = torch.zeros((1, num_boxes[0].item(), 4 if self.class_agnostic
                                                                else 4 * self.n_classes)).type_as(gt_boxes)

        return rois, cls_prob, bbox_pred, rel_result, rpn_loss_cls, rpn_loss_bbox, \
               RCNN_loss_cls, RCNN_loss_bbox, VMRN_rel_loss_cls, rel_reg_loss, rois_label, \
               grasp_loc, grasp_prob, grasp_bbox_loss, grasp_cls_loss, grasp_conf_label, grasp_all_anchors