def forward(self, _images, targets=None, return_result=False):
        bs = _images.size(0)
        assert bs == 1

        # Process images
        device = _images.device
        images = torch.zeros(1, 6, 3, 400, 400)
        for i in range(6):
            images[0, i] = self.img_transform(_images[0, i].cpu())
        del _images
        images = images.to(device)

        # Process targets
        #         label_index = targets[0]['labels'] == 2
        #         targets[0]['boxes'] = targets[0]['boxes'][label_index]
        #         targets[0]['labels'] = targets[0]['labels'][label_index]

        targets = [{k: v for k, v in t.items()} for t in targets]
        targets[0]['old_boxes'] = targets[0]['boxes'] / 2.
        min_coordinates, _ = torch.min(targets[0]['boxes'], 2)
        max_coordinates, _ = torch.max(targets[0]['boxes'], 2)
        targets[0]['boxes'] = torch.cat([min_coordinates, max_coordinates], 1)
        temp_tensor = torch.zeros(1, 3, 800, 800)
        _, targets = self.target_transform(temp_tensor, targets)

        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")
        original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))

        # images, targets = self.transform(images, targets)
        # HACK
        images = ImageList(images, ((400, 400), ) * images.size(0))
        targets = [{
            k: v.to(images.tensors.device)
            for k, v in t.items() if k != 'masks'
        } for t in targets]

        # Pass images from 6 camera angle to different backbone
        features_list = torch.stack([
            self.backbone(images.tensors[:, i])['0']
            for i in range(self.input_img_num)
        ],
                                    dim=1)

        feature_h, feature_w = features_list.size()[-2:]
        features_list = features_list.view(
            bs, self.backbone_out_channels * self.input_img_num, feature_h,
            feature_w)

        features = OrderedDict([('0', features_list)])
        #         if isinstance(features, torch.Tensor):
        #             features = OrderedDict([('0', features)])

        proposals, proposal_losses = self.rpn(images, features, targets)
        detections, detector_losses = self.roi_heads(features, proposals,
                                                     images.image_sizes,
                                                     targets)
        detections = self.transform.postprocess(detections, images.image_sizes,
                                                original_image_sizes)

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)
        losses.update(
            {'loss_mask': torch.zeros(1, device=images.tensors.device)})

        mask_ts = 0.
        mask_ts_numerator = 0
        mask_ts_denominator = 1

        with torch.no_grad():

            # Get object detection threat score
            cpu_detections = [{k: v.cpu()
                               for k, v in t.items()} for t in detections]
            # TODO: add threshold more than 0.5
            detection_ts, detection_ts_numerator, detection_ts_denominator =\
                get_detection_threat_score(cpu_detections, targets, 0.5)

        if return_result:
            # DEBUG
            masks = 0
            #             return losses, mask_ts, mask_ts_numerator,\
            #                    mask_ts_denominator, detection_ts, detection_ts_numerator,\
            #                    detection_ts_denominator, detections, masks
            return mask_ts, mask_ts_numerator,\
                   mask_ts_denominator, detection_ts, detection_ts_numerator,\
                   detection_ts_denominator, detections, masks
        else:
            #             return losses, mask_ts, mask_ts_numerator, mask_ts_denominator,\
            #                    detection_ts, detection_ts_numerator, detection_ts_denominator
            return losses
Ejemplo n.º 2
0
    def forward(self,
                _images,
                _targets=None,
                return_result=False,
                return_losses=False):
        bs = _images.size(0)
        assert bs == 1

        device = _images.device

        # Process images
        images = torch.zeros(1, 6, 3, 400, 400)
        depths = torch.zeros(1, 6, 3, 128, 416)
        for i in range(6):
            images[0, i] = self.img_transform(_images[0, i].cpu())
            depths[0, i] = self.depth_transform(_images[0, i].cpu())
        del _images
        images = images.to(device)
        depths = depths.to(device)

        # Get depth map
        depths = self.depth_estimator(depths.squeeze(0))[0]
        depths = self.depth_resize(depths.unsqueeze(1))
        depths = depths.view(1, 6, 1, 400, 400)
        images = torch.cat((images, depths), dim=2)
        del depths

        # Process targets
        dis = torch.mean(_targets[0]['boxes'], dim=2) - torch.tensor(
            [400., 400.])
        index_1 = torch.sqrt(torch.sum(torch.pow(dis, 2), dim=1)) < 300.
        index_2 = (_targets[0]['labels'] == 0) | (_targets[0]['labels'] == 2) |\
            (_targets[0]['labels'] == 4) | (_targets[0]['labels'] == 5)
        label_index = index_1 * index_2

        targets = [copy.deepcopy(_targets[0])]
        targets[0]['boxes'] = targets[0]['boxes'][label_index]
        targets[0]['labels'] = targets[0]['labels'][label_index]

        targets = [{k: v for k, v in t.items()} for t in targets]
        # targets[0]['old_boxes'] = targets[0]['boxes'] / 2.
        min_coordinates, _ = torch.min(targets[0]['boxes'], 2)
        max_coordinates, _ = torch.max(targets[0]['boxes'], 2)
        targets[0]['boxes'] = torch.cat([min_coordinates, max_coordinates], 1)
        temp_tensor = torch.zeros(1, 3, 800, 800)
        _, targets = self.target_transform(temp_tensor, targets)

        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")
        original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))

        # images, targets = self.transform(images, targets)
        device = images.device
        images = ImageList(images, ((400, 400), ) * images.size(0))
        target_masks = torch.stack(
            [t['masks'].float().to(device) for t in targets])
        targets = [{k: v.to(device)
                    for k, v in t.items() if k != 'masks'} for t in targets]

        # Mask backbone
        features_list = torch.stack([
            self.backbone(images.tensors[:, i])
            for i in range(self.input_img_num)
        ],
                                    dim=1)

        feature_h, feature_w = features_list.size()[-2:]
        combined_feature_map = features_list.view(bs, self.input_img_num,
                                                  feature_h, feature_w)

        masks, mask_losses = self.mask_net(combined_feature_map, target_masks)

        del features_list
        torch.cuda.empty_cache()

        # Detction backbone
        features_list = torch.stack([
            self.backbone_(images.tensors[:, i])
            for i in range(self.input_img_num)
        ],
                                    dim=1)

        feature_h, feature_w = features_list.size()[-2:]
        detection_combined_feature_map = features_list.view(
            bs, 64 * self.input_img_num, 400, 400)
        del features_list
        torch.cuda.empty_cache()

        road_map_features = OrderedDict([('0', combined_feature_map)])
        detection_features = OrderedDict([('0', detection_combined_feature_map)
                                          ])

        proposals, proposal_losses = self.rpn(images, road_map_features,
                                              targets)
        # try:
        #     detections, detector_losses = self.roi_heads(detection_features, proposals, images.image_sizes, targets)
        #     detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
        # except RuntimeError as e:
        #     print(e)
        #     detections = None
        #     detector_losses = {
        #         'loss_box_reg': torch.zeros(1),
        #         'loss_classifier': torch.zeros(1)}
        detections, detector_losses = self.roi_heads(detection_features,
                                                     proposals,
                                                     images.image_sizes,
                                                     targets)
        detections = self.transform.postprocess(detections, images.image_sizes,
                                                original_image_sizes)

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)
        losses.update(mask_losses)

        if return_result:
            return masks, detections
        else:
            return losses