Пример #1
0
    def forward(self, features, rois):
        """
        Args:
            features: NCHW images
            rois: Bx5 boxes. First column is the index into N. The other 4
            columns are xyxy.
        """
        assert rois.dim() == 2 and rois.size(1) == 5

        return ps_roi_align(features, rois, self.out_size,
                                self.spatial_scale, self.sample_num)
Пример #2
0
 def script_fn(input, rois, pool_size):
     # type: (Tensor, Tensor, int) -> Tensor
     return ops.ps_roi_align(input, rois, pool_size, 1.0)[0]
Пример #3
0
    def forward(self, images, targets=None):
        """
        Arguments:
            -images: Tensor[N,C,H,W], here H=W=416
            -targets: Tensor[m, 6], m is the number of annotated bboxes in a batch.  
            6 for [index_in_a_batch, class_num, x_center, y_center, w, h]
            Here (x_center, y_center, w, h) is scale to (0, 1)

        Returns:
            -loss: float
            -output: the new score for every bbox, Tensor[m, 8].
                8 for [image_i, x1, y1, x2, y2, object_conf, class_score, class_pred]
                Here (x1, y1, x2, y2) is scale to image size (e.g., 416)
        """
        ########################################
        # Get candidate boxes from base detector
        ########################################
        # for tiny yolov3, feature_map: (1,256,26,26); output_tensor: (1, 2535, 85). No gradient.
        feature_map, output_tensor = self.base_detector(images)

        # NMS in a batch manner, CPP version
        detections = non_max_suppression_cpp(output_tensor.cpu(),
                                             conf_thresh=self.conf_thresh)

        # detections: List[tensor[n,7+c]] -> boxes: tensor[N,8+c], (8+c) for
        # (image_i, x1, y1, x2, y2, object_conf, class_score, class_pred, scores_of_c_classes), xyxy are scaled to img size
        boxes = []
        for image_i, detection_i in enumerate(
                detections):  # iter through a batch
            if detection_i is not None:
                boxes_i = torch.zeros((len(detection_i), 8 + self.class_num))
                boxes_i[:, 0] = image_i
                boxes_i[:, 1:] = detection_i
                boxes.append(boxes_i)
        if len(boxes) > 0:
            boxes = torch.cat(boxes, 0).to(self.device)
        else:
            boxes = torch.empty((0, 8 + self.class_num)).to(self.device)

        ##################
        # Generate outputs
        ##################
        # obatin the roi score maps
        roi_score_map = self.fcn_layers(feature_map)

        # RoI cropping in a batch manner
        cropped_img_feature = ps_roi_align(roi_score_map,
                                           boxes[:, :5], (7, 7),
                                           spatial_scale=1. / 16)

        # combine yolo outputs and refinement vector and obtain the masks, masks[:, 0]: p(background), masks[:, 1]: p(foreground)
        regress_param, refinement_vector = self.refinement_head(
            cropped_img_feature)
        yolo_vector = torch.cat((boxes[:, 5:6], boxes[:, 8:]),
                                1)  # (obj_conf, scores_of_c_classes)
        yolo_vector.requires_grad = False
        masks = self.ensemble_head(refinement_vector, yolo_vector)

        # generate masks, i.e., the new confidence score
        positive_masks = (masks[:, 1] > self.refine_threshold)
        output = torch.cat(
            (boxes[positive_masks, :1],
             box_regress(regress_param[positive_masks], boxes[positive_masks,
                                                              1:5]),
             masks[positive_masks, 1:], boxes[positive_masks, 6:8]), -1)

        # sort according to the confidence
        output = output[torch.sort(output[:, 5],
                                   descending=True).indices].cpu()

        # print the grad
        """
        def save_grad():
            def hook(grad):
                print(grad)
            return hook
        # register gradient hook for masks
        if masks.requires_grad == True:
            masks.register_hook(save_grad())
        """

        ########################################
        # get labels and loss, only for training
        ########################################
        if targets is not None:

            # transform targets from xywh to x1y1x2y2; scale from (0,1) to (0,image_size)
            targets[:, 2:] = xywh2xyxy(targets[:, 2:])
            targets[:, 2:] *= images.shape[-1]  # restore to image size

            # rebuild box
            boxes_cpu = torch.cat((boxes[:, :1], boxes[:, 7:8], boxes[:, 1:5]),
                                  1).cpu()

            # use single process to obtain binary label for each box
            t = time.time()
            iou_labels, target_location = obtain_iou_labels(
                boxes_cpu, targets)  # on cpu is faster than gpu
            pos_filter = (iou_labels > self.iou_thresh[1])
            neg_filter = (iou_labels < self.iou_thresh[0])

            # log infos
            conf_1 = boxes[:, 5].cpu()
            conf_2 = masks[:, 1].cpu()

            conf_1_pos, conf_2_pos = conf_1[
                iou_labels.flatten() > 0.5], conf_2[iou_labels.flatten() > 0.5]
            conf_1_neg, conf_2_neg = conf_1[
                iou_labels.flatten() < 0.5], conf_2[iou_labels.flatten() < 0.5]
            confs = dict(conf_1_pos=conf_1_pos,
                         conf_1_neg=conf_1_neg,
                         conf_2_pos=conf_2_pos,
                         conf_2_neg=conf_2_neg)

            total_sample = len(iou_labels)
            refined = positive_masks.sum()
            true = pos_filter.sum()
            tps = (positive_masks.cpu() * (pos_filter.flatten())).sum().float()
            print(f"preocess time: {time.time()-t}", f"RoIs: {total_sample}",
                  f"refined_RoIs: {refined}", f"true_RoIs: {true}",
                  f"tps: {tps}")
            metric = dict(total=total_sample,
                          true=true,
                          positive=refined,
                          tp=tps,
                          conf=confs)

            # balance posittive : negative = 1 : balance_fac
            pos_idx = np.where(
                pos_filter.flatten())[0]  # np.where(cond) returns a tuple
            neg_idx = np.where(neg_filter.flatten())[0]
            top_k = min(len(pos_idx) * self.balance_fac, len(neg_idx))

            # one-hot encoding for labels
            label_onehot = torch.tensor([1.0, 0.0]).repeat(masks.shape[0], 1)
            for i in pos_idx:
                label_onehot[i] = torch.tensor(
                    [0.0, 1.0])  # true label: [0, 1], false label: [1, 0]

            sample_filter = pos_filter.flatten().clone(
            )  # Tensor([True, False, ...])
            selected_neg_idx = neg_idx[random.sample(range(len(neg_idx)),
                                                     k=top_k)]
            sample_filter[selected_neg_idx] = True

            label_onehot = label_onehot[sample_filter]
            masks = masks[sample_filter]

            # focal loss for masks
            masks_loss = FocalLoss(self.device,
                                   self.alpha)(masks,
                                               label_onehot.to(self.device))
            # loss =  nn.BCELoss()(masks, label_onehot)

            # confidence loss for positive and negative samples
            conf_label = torch.zeros(len(boxes_cpu))
            for i in pos_idx:
                conf_label[i] = 1.0
            conf_loss = nn.BCELoss(reduction="sum")(
                refinement_vector[sample_filter, 0],
                conf_label[sample_filter].to(self.device))

            # box regression loss for positive samples
            loss_xy, loss_wh = regression_loss(
                regress_param[pos_filter.flatten()],
                target_location[pos_filter.flatten()].to(self.device),
                boxes[pos_filter.flatten(), 1:5])

            # category loss for positive samples
            class_label = torch.zeros((len(boxes_cpu), self.class_num))
            for i, idx in enumerate(pos_idx):
                class_label[i, int(boxes_cpu[idx, 1])] = 1.0
            category_loss = nn.BCELoss(reduction="sum")(
                refinement_vector[pos_filter.flatten(), 1:],
                class_label[pos_filter.flatten()].to(self.device))

            loss = masks_loss + (conf_loss +
                                 category_loss) / self.loss_lambda[0] + (
                                     loss_xy + loss_wh) / self.loss_lambda[1]
            print(masks_loss.data, conf_loss.data/self.loss_lambda[0], category_loss.data/self.loss_lambda[0], \
                loss_xy.data/self.loss_lambda[1], loss_wh.data/self.loss_lambda[1])

        return output if targets is None else (output, loss, metric)
Пример #4
0
    def forward(self, x, boxes, image_shapes):
        # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]])
        """
        Arguments:
            x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
                all the same number of channels, but they can have different sizes.
            boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
                (x1, y1, x2, y2) format and in the image reference size, not the feature map
                reference.
            image_shapes (List[Tuple[height, width]]): the sizes of each image before they
                have been fed to a CNN to obtain feature maps. This allows us to infer the
                scale factor for each one of the levels to be pooled.
        Returns:
            result (Tensor)
        """
        x_filtered = []
        for k, v in x.items():
            if k in self.featmap_names:
                x_filtered.append(v)
        num_levels = len(x_filtered)
        rois = self.convert_to_roi_format(boxes)
        if self.scales is None:
            self.setup_scales(x_filtered, image_shapes)

        scales = self.scales
        assert scales is not None

        if num_levels == 1:
            return ps_roi_align(x_filtered[0],
                                rois,
                                output_size=self.output_size,
                                spatial_scale=scales[0],
                                sampling_ratio=self.sampling_ratio)

        mapper = self.map_levels
        assert mapper is not None

        levels = mapper(boxes)

        num_rois = len(rois)
        num_channels = x_filtered[0].shape[1]

        dtype, device = x_filtered[0].dtype, x_filtered[0].device
        result = torch.zeros(
            (
                num_rois,
                num_channels,
            ) + self.output_size,
            dtype=dtype,
            device=device,
        )

        tracing_results = []
        for level, (per_level_feature,
                    scale) in enumerate(zip(x_filtered, scales)):
            idx_in_level = torch.nonzero(levels == level).squeeze(1)
            rois_per_level = rois[idx_in_level]

            result_idx_in_level = ps_roi_align(
                per_level_feature,
                rois_per_level,
                output_size=self.output_size,
                spatial_scale=scale,
                sampling_ratio=self.sampling_ratio)

            if torchvision._is_tracing():
                tracing_results.append(result_idx_in_level.to(dtype))
            else:
                result[idx_in_level] = result_idx_in_level

        if torchvision._is_tracing():
            result = _onnx_merge_levels(levels, tracing_results)

        return result