예제 #1
0
    def forward(self, images, targets=None):
        """
        Arguments:
            images: Image batch, normalized [NxCxHxW]
            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).
        """
        image_sizes = [tuple(images.shape[-2:])] * images.shape[0]

        features = self.backbone(images)

        # Might need to torch.chunk the features because it wants it to be a list for some reason.
        image_list = ImageList(images, image_sizes)
        try:
            proposals, proposal_losses = self.rpn(image_list, features,
                                                  targets)
        except Exception as e:
            print(e)  # dirty data not cleaned
        detections, detector_losses = self.roi_heads(features, proposals,
                                                     image_sizes, targets)

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)

        if targets is not None:
            return detections, features, losses
        else:
            return detections, features
예제 #2
0
 def test_incorrect_anchors(self):
     incorrect_sizes = ((2, 4, 8), (32, 8), )
     incorrect_aspects = (0.5, 1.0)
     anc = AnchorGenerator(incorrect_sizes, incorrect_aspects)
     image1 = torch.randn(3, 800, 800)
     image_list = ImageList(image1, [(800, 800)])
     feature_maps = [torch.randn(1, 50)]
     pytest.raises(ValueError, anc, image_list, feature_maps)
예제 #3
0
def demo():
    # Parse the input arguments.
    parser = argparse.ArgumentParser(description="Simple demo for real-time-panoptic model")
    parser.add_argument("--config-file", metavar="FILE", help="path to config", required=True)
    parser.add_argument("--pretrained-weight", metavar="FILE", help="path to pretrained_weight", required=True)
    parser.add_argument("--input", metavar="FILE", help="path to jpg/png file", required=True)
    parser.add_argument("--device", help="inference device", default='cuda')
    args = parser.parse_args()

    # General config object from given config files.
    cfg.merge_from_file(args.config_file)

    # Initialize model.
    model = RTPanoNet(
        backbone=cfg.model.backbone, 
        num_classes=cfg.model.panoptic.num_classes,
        things_num_classes=cfg.model.panoptic.num_thing_classes,
        pre_nms_thresh=cfg.model.panoptic.pre_nms_thresh,
        pre_nms_top_n=cfg.model.panoptic.pre_nms_top_n,
        nms_thresh=cfg.model.panoptic.nms_thresh,
        fpn_post_nms_top_n=cfg.model.panoptic.fpn_post_nms_top_n,
        instance_id_range=cfg.model.panoptic.instance_id_range)
    device = args.device
    model.to(device)
    model.load_state_dict(torch.load(args.pretrained_weight))

    # Print out mode architecture for sanity checking.
    print(model)

    # Prepare for model inference.
    model.eval()
    input_image = Image.open(args.input)
    data = {'image': input_image}
    # data pre-processing
    normalize_transform = P.Normalize(mean=cfg.input.pixel_mean, std=cfg.input.pixel_std, to_bgr255=cfg.input.to_bgr255)
    transform = P.Compose([
        P.ToTensor(),
        normalize_transform,
    ])
    data = transform(data)
    print("Done with data preparation and model configuration.")
    with torch.no_grad():
        input_image_list = ImageList([data['image'].to(device)], image_sizes=[input_image.size[::-1]])
        panoptic_result, _ = model.forward(input_image_list)
        print("Done with model inference.")
        print("Process and visualizing the outputs...")
        instance_detection = [o.to('cpu') for o in panoptic_result["instance_segmentation_result"]]
        semseg_logics = [o.to('cpu') for o in panoptic_result["semantic_segmentation_result"]]
        semseg_prob = [torch.argmax(semantic_logit , dim=0) for semantic_logit in  semseg_logics]

        seg_vis = visualize_segmentation_image(semseg_prob[0], input_image, cityscapes_colormap)
        Image.fromarray(seg_vis.astype('uint8')).save('semantic_segmentation_result.jpg')
        print("Saved semantic segmentation visualization in semantic_segmentation_result.jpg")
        det_vis = visualize_detection_image(instance_detection[0], input_image, cityscapes_instance_label_name)
        Image.fromarray(det_vis.astype('uint8')).save('instance_segmentation_result.jpg')
        print("Saved instance segmentation visualization in instance_segmentation_result.jpg")
        print("Demo finished.")
예제 #4
0
def visualize_img(images, index, backbone, rpn, boxHead):
    """
    Run inference and visualization for one image
    :param images:
    :param index:
    :param backbone:
    :param rpn:
    :param boxHead:
    :return:
    """
    with torch.no_grad():
        # Take the features from the backbone
        backout = backbone(images)

        # The RPN implementation takes as first argument the following image list
        im_lis = ImageList(images, [(800, 1088)] * images.shape[0])
        # Then we pass the image list and the backbone output through the rpn
        rpnout = rpn(im_lis, backout)

        # The final output is
        # A list of proposal tensors: list:len(bz){(keep_topK,4)}
        proposals = [proposal[0:keep_topK, :] for proposal in rpnout[0]]
        # A list of features produces by the backbone's FPN levels: list:len(FPN){(bz,256,H_feat,W_feat)}
        fpn_feat_list = list(backout.values())

        feature_vectors = boxHead.MultiScaleRoiAlign(fpn_feat_list, proposals)

        class_logits, box_pred = boxHead(feature_vectors)
        class_logits = torch.softmax(class_logits, dim=1)  # todo: check softmax is applied everywhere

        # convert proposal to xywh
        proposal_torch = torch.cat(proposals, dim=0)  # x1 y1 x2 y2
        proposal_xywh = torch.zeros_like(proposal_torch, device=proposal_torch.device)
        proposal_xywh[:, 0] = ((proposal_torch[:, 0] + proposal_torch[:, 2]) / 2)
        proposal_xywh[:, 1] = ((proposal_torch[:, 1] + proposal_torch[:, 3]) / 2)
        proposal_xywh[:, 2] = torch.abs(proposal_torch[:, 2] - proposal_torch[:, 0])
        proposal_xywh[:, 3] = torch.abs(proposal_torch[:, 3] - proposal_torch[:, 1])

        # decode output
        prob_simp, class_simp, box_simp = utils.simplifyOutputs(class_logits, box_pred)
        # box_decoded: format x1, y1, x2, y2
        box_decoded = utils.decode_output(proposal_xywh, box_simp)

        # visualization: PreNMS
        prob_selected, class_selected, box_selected = selectResult(prob_simp, class_simp, box_decoded)
        plot_prediction(images, class_selected, box_selected, index=index, result_dir=dir_prenms)

        # Do whaterver post processing you find performs best
        post_nms_prob, post_nms_class, post_nms_box = boxHead.postprocess_detections(prob_simp, class_simp, box_decoded, conf_thresh=0.8,
                                                               keep_num_preNMS=200, keep_num_postNMS=3, IOU_thresh=0.5)

        # visualization: PostNMS
        assert post_nms_class.dim() == 1
        assert post_nms_box.dim() == 2
        plot_prediction(images, post_nms_class, post_nms_box, index=index, result_dir=dir_postnms)
예제 #5
0
    def forward(self,
                image: torch.Tensor,  # (batch_size, c, h, w)
                image_sizes: torch.Tensor,  # (batch_size, 2)
                boxes: torch.Tensor = None,  # (batch_size, max_boxes_in_batch, 4)
                box_classes: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        im_sizes = [(x[1].item(), x[0].item()) for x in image_sizes]
        image_list = ImageList(image, im_sizes)
        features = self.backbone.forward(image)
        objectness, rpn_box_regression = self._rpn_head(features)
        anchors: List[torch.Tensor] = self.anchor_generator(image_list, features)
        num_anchors_per_level = [o[0].numel() for o in objectness]
        objectness, rpn_box_regression = \
            concat_box_prediction_layers(objectness, rpn_box_regression)

        out = {'features': features,
               'objectness': objectness,
               'rpn_box_regression': rpn_box_regression,
               'anchors': anchors,
               'sizes': image_sizes,
               'num_anchors_per_level': num_anchors_per_level}
        if boxes is not None:
            labels, matched_gt_boxes = self.assign_targets_to_anchors(
                    anchors, object_utils.unpad(boxes))
            regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)

            sampled_pos_inds, sampled_neg_inds = self.sampler(labels)
            sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
            sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)

            sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

            objectness = objectness.flatten()

            labels = torch.cat(labels, dim=0)
            regression_targets = torch.cat(regression_targets, dim=0)

            loss_rpn_box_reg = F.l1_loss(
                    rpn_box_regression[sampled_pos_inds],
                    regression_targets[sampled_pos_inds],
                    reduction="sum",
            ) / (sampled_inds.numel())

            loss_objectness = F.binary_cross_entropy_with_logits(
                    objectness[sampled_inds], labels[sampled_inds]
            )
            self._loss_meters['rpn_cls_loss'](loss_objectness.item())
            self._loss_meters['rpn_reg_loss'](loss_rpn_box_reg.item())
            out["loss_objectness"] = loss_objectness
            out["loss_rpn_box_reg"] = loss_rpn_box_reg
            out["loss"] = loss_objectness + 10*loss_rpn_box_reg
        return out
예제 #6
0
    def forward(
        self,
        images,  # type: List[Tensor]
        targets=None  # type: Optional[List[Dict[str, Tensor]]]
    ):
        # type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
        images = [img for img in images]
        if targets is not None:
            # make a copy of targets to avoid modifying it in-place
            # once torchscript supports dict comprehension
            # this can be simplified as as follows
            # targets = [{k: v for k,v in t.items()} for t in targets]
            targets_copy: List[Dict[str, Tensor]] = []
            for t in targets:
                #if t is not None:
                if len(t) == 5:
                    data: Dict[str, Tensor] = {}
                    for k, v in t.items():
                        data[k] = v
                    targets_copy.append(data)
                else:
                    data: Dict[str, Tensor] = {}
                    for k, v in t.items():
                        data[k] = v
                    targets_copy.append(data)
            targets = targets_copy
        for i in range(len(images)):
            image = images[i]
            target_index = targets[i] if targets is not None and {} else None

            if image.dim() != 3:
                raise ValueError(
                    "images is expected to be a list of 3d tensors "
                    "of shape [C, H, W], got {}".format(image.shape))
            image = self.normalize(image)
            image, target_index = self.resize(image, target_index)
            images[i] = image
            if targets is not None and target_index is not None:
                targets[i] = target_index

        image_sizes = [img.shape[-2:] for img in images]
        images = self.batch_images(images)
        image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], [])
        for image_size in image_sizes:
            assert len(image_size) == 2
            image_sizes_list.append((image_size[0], image_size[1]))

        image_list = ImageList(images, image_sizes_list)
        return image_list, targets
예제 #7
0
def inference(model, input, transform, device="cuda"):
    input_image = Image.open(input)
    data = {'image': input_image}
    # data pre-processing
    data = transform(data)
    with torch.no_grad():
        input_image_list = ImageList([data['image'].to(device)], image_sizes=[input_image.size[::-1]])
        panoptic_result, _ = model.forward(input_image_list)
        semseg_logics = [o.to('cpu') for o in panoptic_result["semantic_segmentation_result"]]
        # Export the result
        output = input.replace("/data/", "/output/")
        os.makedirs(parent(output), exist_ok=True)
        assert os.path.exists(parent(output))
        semseg_prob = [torch.argmax(semantic_logit, dim=0) for semantic_logit in semseg_logics]
        seg_vis = visualize_segmentation_image(semseg_prob[0], input_image, cityscapes_colormap_sky)
        Image.fromarray(seg_vis.astype('uint8')).save(output)
def forward(self, images, targets=None):
    for i in range(len(images)):
        image = images[i]
        target = targets[i] if targets is not None else targets
        if image.dim() != 3:
            raise ValueError("images is expected to be a list of 3d tensors "
                             "of shape [C, H, W], got {}".format(image.shape))
        # image = self.normalize(image)
        # image, target = self.resize(image, target)
        images[i] = image
        if targets is not None:
            targets[i] = target
    image_sizes = [img.shape[-2:] for img in images]
    images = self.batch_images(images)
    image_list = ImageList(images, image_sizes)
    return image_list, targets
예제 #9
0
def box_validation(box_head, test_loader,optimizer,epoch,backbone,rpn,keep_topK):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    epoch_loss = 0
    epoch_clas_loss = 0
    epoch_regr_loss = 0

    # TODO double check following two values, just placehoder for now
    l = 5
    effective_batch = 32  # suggestd 150 for 4 images

    for i,data in enumerate(test_loader):
        # imgs, label_list, mask_list, bbox_list, index_list = [data[key] for key in data.keys()]
        images = data['images'].to(device)
        boxes = data['bbox']
        labels = data['labels']
        with torch.no_grad():
            backout = backbone(images)
            im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
            rpnout = rpn(im_lis, backout)
            proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
            fpn_feat_list= list(backout.values())
            gt_labels, gt_regressor_target = box_head.create_ground_truth(proposals, labels, boxes)
            roi_align_result = box_head.MultiScaleRoiAlign(fpn_feat_list, proposals)  # This is the input to Box head
            clas_out, regr_out = box_head.forward(roi_align_result.to(device))
            loss, loss_c, loss_r = box_head.compute_loss(clas_out, regr_out, gt_labels, gt_regressor_target, l,
                                                         effective_batch)
        epoch_loss += loss.item()
        epoch_clas_loss += loss_c.item()
        epoch_regr_loss += loss_r.item()

        #delete variables after usage to free GPU ram, double check if these variables are needed for future!!!!!!!
        del loss ,loss_c , loss_r
        del images, labels, boxes
        del clas_out, regr_out
        del gt_labels, gt_regressor_target
        torch.cuda.empty_cache()

    epoch_loss /= i
    epoch_clas_loss /= i
    epoch_regr_loss /= i

    return epoch_loss, epoch_clas_loss, epoch_regr_loss
    def test_defaultbox_generator(self):
        images = torch.zeros(2, 3, 15, 15)
        features = [torch.zeros(2, 8, 1, 1)]
        image_shapes = [i.shape[-2:] for i in images]
        images = ImageList(images, image_shapes)

        model = self._init_test_defaultbox_generator()
        model.eval()
        dboxes = model(images, features)

        dboxes_output = torch.tensor([[6.9750, 6.9750, 8.0250, 8.0250],
                                      [6.7315, 6.7315, 8.2685, 8.2685],
                                      [6.7575, 7.1288, 8.2425, 7.8712],
                                      [7.1288, 6.7575, 7.8712, 8.2425]])

        self.assertEqual(len(dboxes), 2)
        self.assertEqual(tuple(dboxes[0].shape), (4, 4))
        self.assertEqual(tuple(dboxes[1].shape), (4, 4))
        self.assertTrue(dboxes[0].allclose(dboxes_output))
        self.assertTrue(dboxes[1].allclose(dboxes_output))
    def test_defaultbox_generator(self):
        images = torch.zeros(2, 3, 15, 15)
        features = [torch.zeros(2, 8, 1, 1)]
        image_shapes = [i.shape[-2:] for i in images]
        images = ImageList(images, image_shapes)

        model = self._init_test_defaultbox_generator()
        model.eval()
        dboxes = model(images, features)

        dboxes_output = torch.tensor([[6.3750, 6.3750, 8.6250, 8.6250],
                                      [4.7443, 4.7443, 10.2557, 10.2557],
                                      [5.9090, 6.7045, 9.0910, 8.2955],
                                      [6.7045, 5.9090, 8.2955, 9.0910]])

        self.assertEqual(len(dboxes), 2)
        self.assertEqual(tuple(dboxes[0].shape), (4, 4))
        self.assertEqual(tuple(dboxes[1].shape), (4, 4))
        self.assertTrue(dboxes[0].allclose(dboxes_output))
        self.assertTrue(dboxes[1].allclose(dboxes_output))
    def forward(self, _images, targets=None, return_result=False):
        bs = _images.size(0)
        assert bs == 1

        # Process images
        device = _images.device
        images = torch.zeros(1, 6, 3, 400, 400)
        for i in range(6):
            images[0, i] = self.img_transform(_images[0, i].cpu())
        del _images
        images = images.to(device)

        # Process targets
        #         label_index = targets[0]['labels'] == 2
        #         targets[0]['boxes'] = targets[0]['boxes'][label_index]
        #         targets[0]['labels'] = targets[0]['labels'][label_index]

        targets = [{k: v for k, v in t.items()} for t in targets]
        targets[0]['old_boxes'] = targets[0]['boxes'] / 2.
        min_coordinates, _ = torch.min(targets[0]['boxes'], 2)
        max_coordinates, _ = torch.max(targets[0]['boxes'], 2)
        targets[0]['boxes'] = torch.cat([min_coordinates, max_coordinates], 1)
        temp_tensor = torch.zeros(1, 3, 800, 800)
        _, targets = self.target_transform(temp_tensor, targets)

        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")
        original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))

        # images, targets = self.transform(images, targets)
        # HACK
        images = ImageList(images, ((400, 400), ) * images.size(0))
        targets = [{
            k: v.to(images.tensors.device)
            for k, v in t.items() if k != 'masks'
        } for t in targets]

        # Pass images from 6 camera angle to different backbone
        features_list = torch.stack([
            self.backbone(images.tensors[:, i])['0']
            for i in range(self.input_img_num)
        ],
                                    dim=1)

        feature_h, feature_w = features_list.size()[-2:]
        features_list = features_list.view(
            bs, self.backbone_out_channels * self.input_img_num, feature_h,
            feature_w)

        features = OrderedDict([('0', features_list)])
        #         if isinstance(features, torch.Tensor):
        #             features = OrderedDict([('0', features)])

        proposals, proposal_losses = self.rpn(images, features, targets)
        detections, detector_losses = self.roi_heads(features, proposals,
                                                     images.image_sizes,
                                                     targets)
        detections = self.transform.postprocess(detections, images.image_sizes,
                                                original_image_sizes)

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)
        losses.update(
            {'loss_mask': torch.zeros(1, device=images.tensors.device)})

        mask_ts = 0.
        mask_ts_numerator = 0
        mask_ts_denominator = 1

        with torch.no_grad():

            # Get object detection threat score
            cpu_detections = [{k: v.cpu()
                               for k, v in t.items()} for t in detections]
            # TODO: add threshold more than 0.5
            detection_ts, detection_ts_numerator, detection_ts_denominator =\
                get_detection_threat_score(cpu_detections, targets, 0.5)

        if return_result:
            # DEBUG
            masks = 0
            #             return losses, mask_ts, mask_ts_numerator,\
            #                    mask_ts_denominator, detection_ts, detection_ts_numerator,\
            #                    detection_ts_denominator, detections, masks
            return mask_ts, mask_ts_numerator,\
                   mask_ts_denominator, detection_ts, detection_ts_numerator,\
                   detection_ts_denominator, detections, masks
        else:
            #             return losses, mask_ts, mask_ts_numerator, mask_ts_denominator,\
            #                    detection_ts, detection_ts_numerator, detection_ts_denominator
            return losses
예제 #13
0
    def forward(self,
                _images,
                _targets=None,
                return_result=False,
                return_losses=False):
        bs = _images.size(0)
        assert bs == 1

        device = _images.device

        # Process images
        images = torch.zeros(1, 6, 3, 400, 400)
        depths = torch.zeros(1, 6, 3, 128, 416)
        for i in range(6):
            images[0, i] = self.img_transform(_images[0, i].cpu())
            depths[0, i] = self.depth_transform(_images[0, i].cpu())
        del _images
        images = images.to(device)
        depths = depths.to(device)

        # Get depth map
        depths = self.depth_estimator(depths.squeeze(0))[0]
        depths = self.depth_resize(depths.unsqueeze(1))
        depths = depths.view(1, 6, 1, 400, 400)
        images = torch.cat((images, depths), dim=2)
        del depths

        # Process targets
        dis = torch.mean(_targets[0]['boxes'], dim=2) - torch.tensor(
            [400., 400.])
        index_1 = torch.sqrt(torch.sum(torch.pow(dis, 2), dim=1)) < 300.
        index_2 = (_targets[0]['labels'] == 0) | (_targets[0]['labels'] == 2) |\
            (_targets[0]['labels'] == 4) | (_targets[0]['labels'] == 5)
        label_index = index_1 * index_2

        targets = [copy.deepcopy(_targets[0])]
        targets[0]['boxes'] = targets[0]['boxes'][label_index]
        targets[0]['labels'] = targets[0]['labels'][label_index]

        targets = [{k: v for k, v in t.items()} for t in targets]
        # targets[0]['old_boxes'] = targets[0]['boxes'] / 2.
        min_coordinates, _ = torch.min(targets[0]['boxes'], 2)
        max_coordinates, _ = torch.max(targets[0]['boxes'], 2)
        targets[0]['boxes'] = torch.cat([min_coordinates, max_coordinates], 1)
        temp_tensor = torch.zeros(1, 3, 800, 800)
        _, targets = self.target_transform(temp_tensor, targets)

        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")
        original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))

        # images, targets = self.transform(images, targets)
        device = images.device
        images = ImageList(images, ((400, 400), ) * images.size(0))
        target_masks = torch.stack(
            [t['masks'].float().to(device) for t in targets])
        targets = [{k: v.to(device)
                    for k, v in t.items() if k != 'masks'} for t in targets]

        # Mask backbone
        features_list = torch.stack([
            self.backbone(images.tensors[:, i])
            for i in range(self.input_img_num)
        ],
                                    dim=1)

        feature_h, feature_w = features_list.size()[-2:]
        combined_feature_map = features_list.view(bs, self.input_img_num,
                                                  feature_h, feature_w)

        masks, mask_losses = self.mask_net(combined_feature_map, target_masks)

        del features_list
        torch.cuda.empty_cache()

        # Detction backbone
        features_list = torch.stack([
            self.backbone_(images.tensors[:, i])
            for i in range(self.input_img_num)
        ],
                                    dim=1)

        feature_h, feature_w = features_list.size()[-2:]
        detection_combined_feature_map = features_list.view(
            bs, 64 * self.input_img_num, 400, 400)
        del features_list
        torch.cuda.empty_cache()

        road_map_features = OrderedDict([('0', combined_feature_map)])
        detection_features = OrderedDict([('0', detection_combined_feature_map)
                                          ])

        proposals, proposal_losses = self.rpn(images, road_map_features,
                                              targets)
        # try:
        #     detections, detector_losses = self.roi_heads(detection_features, proposals, images.image_sizes, targets)
        #     detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
        # except RuntimeError as e:
        #     print(e)
        #     detections = None
        #     detector_losses = {
        #         'loss_box_reg': torch.zeros(1),
        #         'loss_classifier': torch.zeros(1)}
        detections, detector_losses = self.roi_heads(detection_features,
                                                     proposals,
                                                     images.image_sizes,
                                                     targets)
        detections = self.transform.postprocess(detections, images.image_sizes,
                                                original_image_sizes)

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)
        losses.update(mask_losses)

        if return_result:
            return masks, detections
        else:
            return losses
예제 #14
0
        num_bbox_class.append(num_class2)
        num_class3 = torch.count_nonzero(bbox_list[0] == 3)
        num_bbox_class.append(num_class3)

        image = transforms.functional.normalize(img[0].cpu().detach(),
                                                [-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
                                                [1 / 0.229, 1 / 0.224, 1 / 0.225], inplace=False)

        image_vis = image.permute(1, 2, 0).cpu().detach().numpy()
        num_grnd_box = len(bbox_list)

        # Take the features from the backbone
        backout = backbone(img)

        # The RPN implementation takes as first argument the following image list
        im_lis = ImageList(img, [(800, 1088)] * img.shape[0])
        rpnout = rpn(im_lis, backout)

        # The final output is a list of proposal tensors: list:len(bz){(keep_topK,4)}
        proposals = [proposal[0:keep_topK_check, :] for proposal in rpnout[0]]
        # generate gt labels
        labels, regressor_target = box_head.create_ground_truth(proposals, label_list, bbox_list)      #tx,ty,tw,twh
        labels = labels.flatten()
        # A list of features produces by the backbone's FPN levels: list:len(FPN){(bz,256,H_feat,W_feat)}
        fpn_feat_list = list(backout.values())
        proposal_torch = torch.cat(proposals, dim=0)  # x1 y1 x2 y2
        proposal_xywh = torch.zeros_like(proposal_torch, device=proposal_torch.device)
        proposal_xywh[:, 0] = ((proposal_torch[:, 0] + proposal_torch[:, 2]) / 2)
        proposal_xywh[:, 1] = ((proposal_torch[:, 1] + proposal_torch[:, 3]) / 2)
        proposal_xywh[:, 2] = torch.abs(proposal_torch[:, 2] - proposal_torch[:, 0])
        proposal_xywh[:, 3] = torch.abs(proposal_torch[:, 3] - proposal_torch[:, 1])
def proposal_confusion_matrix(loader):
    """Returns mean loss per sample in loader"""
        
    TP, FP, TN, FN = 0, 0, 0, 0

    with torch.no_grad():

        for idx, (batch, pad_lengths) in enumerate(loader):

            images, masks, bboxes, labels = batch

            images = images.to(device=DEVICE, dtype=torch.float)
            bboxes = bboxes.to(device=DEVICE, dtype=torch.float)
            labels = labels.to(device=DEVICE, dtype=torch.float)
            
            backbone_out = BACKBONE(images)

            img_list = ImageList(
                images, list(itertools.repeat((TARGET_HEIGHT, TARGET_WIDTH), len(images))))

            rpn_proposals = RPN(img_list, backbone_out)[0]

            sel_pos_proposals = [1]*len(images)
            sel_pos_bboxes = [1]*len(images)
            sel_pos_labels = [1]*len(images)

            sel_neg_proposals = [1]*len(images)
            sel_neg_bboxes = [1]*len(images)
            sel_neg_labels = [1]*len(images)
            
            # Sample rpn proposals for positive and negative proposals
            for ix, proposals in enumerate(rpn_proposals):
                ground_truth = sample_ground_truth(
                    proposals,
                    bboxes[ix][:pad_lengths["bboxes"][ix]],
                    labels[ix][:pad_lengths["labels"][ix]],
                    iou_thresh=0.5)

                positive_proposals, positive_bboxes, positive_labels, negative_proposals, negative_bboxes, negative_labels = ground_truth
                
                # Positive samples
                sel_pos_proposals[ix] = positive_proposals
                sel_pos_bboxes[ix] = positive_bboxes
                sel_pos_labels[ix] = positive_labels
                
                # Negatives samples
                sel_neg_proposals[ix] = negative_proposals
                sel_neg_bboxes[ix] = negative_bboxes
                sel_neg_labels[ix] = negative_labels

            sel_proposals = sel_pos_proposals + filter_none(sel_neg_proposals)
            sel_bboxes = sel_pos_bboxes + filter_none(sel_neg_bboxes)
            sel_labels = sel_pos_labels + filter_none(sel_neg_labels)

            # ROI Align
            roi_aligned_proposals = torchvision.ops.roi_align(
                backbone_out[0],
                sel_proposals,
                (7,7),
                spatial_scale=1./4.,
                sampling_ratio=4)

            sel_proposals = torch.cat(sel_proposals, dim=0)
            sel_bboxes = torch.cat(sel_bboxes, dim=0)
            sel_labels = torch.cat(sel_labels, dim=0)
        
            sel_pos_proposals = torch.cat(sel_pos_proposals, dim=0)
            sel_pos_bboxes = torch.cat(sel_pos_bboxes, dim=0)
            sel_pos_labels = torch.cat(sel_pos_labels, dim=0)
            
            # the total num. of positive proposals
            n_pos_proposals = len(sel_pos_proposals) 

            # Roi Aligned into Intermediate then Regressor/Classifier
            roi_out = ROI_NET(roi_aligned_proposals)
            class_out = CLASS_NET(roi_out)

            pred_probs, pred_classes = torch.softmax(class_out, dim=1).max(dim=1)

            tp = (pred_classes[(sel_labels != 0).nonzero().squeeze()] == sel_labels[(sel_labels != 0).nonzero().squeeze()]).sum(0).item()
            fp = (pred_classes != 0).sum(0).item() - tp

            tn = (pred_classes[(sel_labels == 0).nonzero().squeeze()] == sel_labels[(sel_labels == 0).nonzero().squeeze()]).sum(0).item()
            fn = (pred_classes == 0).sum(0).item() - tn

            TP += tp
            FP += fp
            TN += tn
            FN += fn

    accuracy = (TP + TN) / (TP + FP + TN + FN)
    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    
    return accuracy, precision, recall