Esempio n. 1
0
    def training(self,
                 sem_depth_head,
                 ms_feat,
                 sem_feat,
                 depth_gt,
                 po_mask=None,
                 sem_class_mask=None,
                 img=None):
        # Run the semantic and instance heads here
        sem_depth_feat, sem_class_depth_pred, sem_depth_pred = sem_depth_head(
            ms_feat, sem_feat)
        # inst_depth_feat = inst_depth_head(inst_feat)
        # depth_pred = depth_fusion(sem_depth_feat, inst_depth_feat)

        # Compute the loss
        po_mask, _ = pad_packed_images(po_mask)
        po_mask = po_mask.type(torch.float)
        bts_loss, class_loss, panoptic_edge_loss, depth_stats = self.depth_loss(
            sem_depth_pred,
            sem_class_depth_pred,
            depth_gt,
            po_mask=po_mask,
            sem_class_mask=sem_class_mask,
            img=img)

        return sem_class_depth_pred, sem_depth_pred, bts_loss, class_loss, panoptic_edge_loss, depth_stats
Esempio n. 2
0
    def forward(self,
                img,
                msk=None,
                cat=None,
                iscrowd=None,
                bbx=None,
                do_loss=False,
                do_prediction=True):
        # Pad the input images
        img, valid_size = pad_packed_images(img)
        img_size = img.shape[-2:]

        # Convert ground truth to the internal format
        if do_loss:
            cat, iscrowd, bbx, ids = self._prepare_inputs(
                msk, cat, iscrowd, bbx)

        # Run network body
        x = self.body(img)

        # RPN part
        if do_loss:
            obj_loss, bbx_loss, proposals = self.rpn_algo.training(
                self.rpn_head,
                x,
                bbx,
                iscrowd,
                valid_size,
                training=self.training,
                do_inference=True)
        elif do_prediction:
            proposals = self.rpn_algo.inference(self.rpn_head, x, valid_size,
                                                self.training)
            obj_loss, bbx_loss = None, None
        else:
            obj_loss, bbx_loss, proposals = None, None, None

        # ROI part
        if do_loss:
            roi_cls_loss, roi_bbx_loss, roi_msk_loss = self.instance_seg_algo.training(
                self.roi_head, x, proposals, bbx, cat, iscrowd, ids, msk,
                img_size)
        else:
            roi_cls_loss, roi_bbx_loss, roi_msk_loss = None, None, None
        if do_prediction:
            bbx_pred, cls_pred, obj_pred, msk_pred = self.instance_seg_algo.inference(
                self.roi_head, x, proposals, valid_size, img_size)
        else:
            bbx_pred, cls_pred, obj_pred, msk_pred = None, None, None, None

        # Prepare outputs
        loss = OrderedDict([("obj_loss", obj_loss), ("bbx_loss", bbx_loss),
                            ("roi_cls_loss", roi_cls_loss),
                            ("roi_bbx_loss", roi_bbx_loss),
                            ("roi_msk_loss", roi_msk_loss)])
        pred = OrderedDict([("bbx_pred", bbx_pred), ("cls_pred", cls_pred),
                            ("obj_pred", obj_pred), ("msk_pred", msk_pred)])
        return loss, pred
Esempio n. 3
0
    def __call__(self, flow_pred_iter, flow_gt, valid_mask):
        losses = {}

        flow_gt, _ = pad_packed_images(flow_gt)
        valid_mask, _ = pad_packed_images(valid_mask)

        # Check for nan in predictions
        nan_count = 0
        for i in range(len(flow_pred_iter)):
            nan_count = nan_count + torch.sum(flow_pred_iter[i] != flow_pred_iter[i])
        if nan_count > 0:
            print("NAN FOUND!: {}".format(nan_count))

        # Remove the invalid pixels and the pixels with large displacements
        # mag = torch.sum(flow_gt ** 2, dim=1).sqrt().unsqueeze(1)
        # valid_mask = valid_mask & (mag < self.max_flow)

        flow_seq_loss = self.computeSequenceLoss(flow_pred_iter, flow_gt, valid_mask)
        # flow_edge_loss = self.computePanopticEdgeLoss(flow_pred_iter, flow_gt, valid_mask, po_mask)

        losses['total'] = flow_seq_loss

        return losses
Esempio n. 4
0
    def forward(self,
                img_pair,
                msk=None,
                cat=None,
                iscrowd=None,
                bbx=None,
                depth_gt=None,
                flow_gt=None,
                flow_mask=None,
                do_loss=False,
                do_prediction=False,
                get_test_metrics=False,
                get_depth_vis=False,
                get_sem_vis=False,
                get_flow_vis=False):
        result = OrderedDict()
        loss = OrderedDict()
        stats = OrderedDict()

        # Get some parameters
        img_1, valid_size_1 = pad_packed_images(img_pair[0])
        img_2, valid_size_2 = pad_packed_images(img_pair[1])
        img_size = img_1.shape[-2:]

        if do_loss:
            cat, iscrowd, bbx, ids, sem_gt = self._prepare_inputs(
                msk, cat, iscrowd, bbx)
            sem_class_mask = self._makeSemanticClassMask(
                sem_gt, self.sem_class_count)

        # print(img_1.shape, img_2.shape)
        # Get the image features
        ms_img_feat_1 = self.body(img_1)
        ms_img_feat_2 = self.body(img_2)

        # RPN Part
        if do_loss:
            # print("MS_IMG_FEAT", ms_img_feat_1[0.shape)
            # print("BBX", bbx.shape)
            obj_loss, bbx_loss, proposals = self.rpn_algo.training(
                self.rpn_head,
                ms_img_feat_1,
                bbx,
                iscrowd,
                valid_size_1,
                training=self.training,
                do_inference=True)
        elif do_prediction:
            proposals = self.rpn_algo.inference(self.rpn_head, ms_img_feat_1,
                                                valid_size_1, self.training)
            obj_loss, bbx_loss = None, None
        else:
            obj_loss, bbx_loss, proposals = None, None, None

        # ROI Part
        if do_loss:
            roi_cls_loss, roi_bbx_loss, roi_msk_loss = self.inst_algo.training(
                self.roi_head, ms_img_feat_1, proposals, bbx, cat, iscrowd,
                ids, msk, img_size)
        else:
            roi_cls_loss, roi_bbx_loss, roi_msk_loss = None, None, None
        if do_prediction:
            bbx_pred, cls_pred, obj_pred, msk_pred = self.inst_algo.inference(
                self.roi_head, ms_img_feat_1, proposals, valid_size_1,
                img_size)
        else:
            bbx_pred, cls_pred, obj_pred, msk_pred = None, None, None, None

        # Segmentation Part
        if do_loss:
            sem_loss, sem_conf_mat, sem_pred, sem_logits, sem_feat = self.sem_algo.training(
                self.sem_head, ms_img_feat_1, sem_gt, valid_size_1, img_size)
        elif do_prediction:
            sem_pred, sem_logits, sem_feat = self.sem_algo.inference(
                self.sem_head, ms_img_feat_1, valid_size_1, img_size)
            sem_loss, sem_conf_mat = None, None
        else:
            sem_loss, sem_conf_mat, sem_pred, sem_logits, sem_feat = None, None, None, None, None

        # Depth Part
        # if do_loss:
        #     sem_class_depth_pred, sem_depth_pred, depth_bts_loss, depth_class_loss, depth_panoptic_edge_loss, sem_depth_stats = self.sem_depth_algo.training(self.sem_depth_head, ms_img_feat_1, sem_feat, depth_gt, po_mask=msk, sem_class_mask=sem_class_mask)
        # elif do_prediction:
        #     sem_class_depth_pred, sem_depth_pred = self.sem_depth_algo.inference(self.sem_depth_head, ms_img_feat_1, sem_feat)
        #     depth_bts_loss, depth_class_loss, depth_panoptic_edge_loss, sem_depth_stats = None, None, None, None
        # else:
        #     sem_class_depth_pred, sem_depth_pred, depth_bts_loss, depth_class_loss, depth_panoptic_edge_loss, sem_depth_stats = None, None, None, None, None, None

        # Flow Part
        # if do_loss:
        #     flow_pred_iter, flow_pred, flow_pred_up, flow_loss = self.flow_algo.training(self.flow_head, ms_img_feat_1, ms_img_feat_2, flow_gt, flow_mask, img_1.shape)
        #     flow_loss = flow_loss['total']
        # elif do_prediction:
        #     flow_pred_iter, flow_pred, flow_pred_up, flow_loss = self.flow_algo.training(self.flow_head, ms_img_feat_1, ms_img_feat_2, flow_gt, flow_mask, img_1.shape)
        #     flow_loss = flow_loss['total']
        # else:
        #     flow_pred_iter, flow_pred, flow_pred_up, flow_loss = None, None, None, None

        # Multi-loss head
        if do_loss:
            depth_bts_loss = depth_class_loss = depth_panoptic_edge_loss = flow_loss = 0
            panoptic_loss = obj_loss + bbx_loss + roi_cls_loss + roi_bbx_loss + roi_msk_loss + sem_loss
            total_loss, loss_weights = self.multi_loss_algo.computeMultiLoss(
                self.multi_loss_head, [
                    panoptic_loss, depth_bts_loss, depth_class_loss,
                    depth_panoptic_edge_loss, flow_loss
                ], [True, False, False, False, False])

            # Prepare outputs
            # LOSSES
            loss['obj_loss'] = loss_weights[0] * obj_loss
            loss['bbx_loss'] = loss_weights[0] * bbx_loss
            loss['roi_cls_loss'] = loss_weights[0] * roi_cls_loss
            loss['roi_bbx_loss'] = loss_weights[0] * roi_bbx_loss
            loss['roi_msk_loss'] = loss_weights[0] * roi_msk_loss
            loss["sem_loss"] = loss_weights[0] * sem_loss
            # loss['depth_bts_loss'] = loss_weights[1] * depth_bts_loss
            # loss['depth_class_loss'] = loss_weights[2] * depth_class_loss
            # loss['depth_po_edge_loss'] = loss_weights[3] * depth_panoptic_edge_loss
            # loss['flow_loss'] = loss_weights[4] * flow_loss

            # OTHER STATISTICS
            # if sem_depth_stats is not None:
            #     for key in sem_depth_stats.keys():
            #         stats[key] = sem_depth_stats[key]
            stats['sem_conf'] = sem_conf_mat
        else:
            loss = None
            total_loss = None

        # PREDICTIONS
        result['bbx_pred'] = bbx_pred
        result['cls_pred'] = cls_pred
        result['obj_pred'] = obj_pred
        result['msk_pred'] = msk_pred
        result["sem_pred"] = sem_pred
        result['sem_logits'] = sem_logits
        # result["sem_depth_pred"] = sem_depth_pred
        # result['flow_pred'] = flow_pred_up
        # result["wt_panoptic"] = loss_weights[0]
        # result['wt_depth_bts'] = loss_weights[1]
        # result['wt_depth_class'] = loss_weights[2]
        # result['wt_depth_po_edge'] = loss_weights[3]
        # result['wt_flow'] = loss_weights[4]

        # Get the visualisation
        # if get_depth_vis:
        #     result['depth_vis'] = visualisePanopticDepth(img_1, depth_gt, sem_depth_pred, sem_class_depth_pred, sem_class_mask, self.dataset)

        if get_sem_vis and do_loss:
            result['sem_vis'] = visualiseSemanticSegmentation(
                img_1, sem_gt, sem_pred, self.dataset)

        # if get_flow_vis:
        #     result['flow_vis'] = visualiseFlowMFN([img_1, img_2], flow_gt, flow_pred_up, flow_mask, self.dataset)

        # Get the test metrics
        # if get_test_metrics:
        #     pred_depth = sem_depth_pred.detach().clone()  #.cpu().numpy()
        #     depth_gt, _ = pad_packed_images(depth_gt)
        #     gt_depth = depth_gt.detach().clone()  #.cpu().numpy()
        #     pred_depth[pred_depth < self.min_depth] = self.min_depth
        #     pred_depth[pred_depth > self.max_depth] = self.max_depth
        #     pred_depth[torch.isinf(pred_depth)] = self.max_depth
        #     pred_depth[torch.isnan(pred_depth)] = self.min_depth

        # if self.dataset == "KittiPanoptic":
        #     pred_disp = convertDepthToDisp(pred_depth)
        #     gt_disp = convertDepthToDisp(gt_depth)

        # valid_mask = (gt_depth > self.min_depth) & (gt_depth < self.max_depth)

        # if self.dataset == "KittiPanoptic":
        #     result['depth_log10'], result['depth_abs_rel'], result['depth_rms'], result['depth_sq_rel'], result['depth_log_rms'], result['depth_d1'], result['depth_d2'], result['depth_d3'], result['depth_si_log'] = computeDepthTestMetrics(gt_disp, pred_disp, valid_mask)
        # else:
        #     result['depth_log10'], result['depth_abs_rel'], result['depth_rms'], result['depth_sq_rel'], result[
        #         'depth_log_rms'], result['depth_d1'], result['depth_d2'], result['depth_d3'], result[
        #         'depth_si_log'] = computeDepthTestMetrics(gt_depth, pred_depth, valid_mask)

        # flow_gt, _ = pad_packed_images(flow_gt)
        # flow_mask, _ = pad_packed_images(flow_mask)
        # result['flow_epe'] = computeFlowTestMetrics(flow_gt, flow_pred_up, flow_mask, None)

        return total_loss, loss, result, stats, ms_img_feat_1
Esempio n. 5
0
    def __call__(self,
                 depth_pred,
                 class_depth_pred,
                 depth_gt,
                 po_mask=None,
                 sem_class_mask=None,
                 img=None):
        # Record all the losses here
        stats = {}
        depth_gt, input_sizes = pad_packed_images(depth_gt)

        if self.dataset in [
                "Cityscapes", "CityscapesSample", "CityscapesDepth",
                "CityscapesDepthSample", "CityscapesSeam"
        ]:
            mask = (depth_gt > 0.1) & (depth_gt < 100.)
        elif self.dataset in [
                "KittiRawDepth", "KittiDepthValSelection", "KittiPanoptic"
        ]:
            mask = (depth_gt > 0.1) & (depth_gt < 80.)
        else:
            raise NotImplementedError()

        # View the grads in the backward pass
        def print_grad(name):
            def hook(grad):
                print("{}------: {}".format(name, grad.mean()))

            return hook

        # depth_pred.register_hook(print_grad("Gradient"))
        # print("mask", torch.sum(mask))

        # Compute the BTS loss
        bts_loss = self.bts_wt * self.computeBTSLoss(depth_pred[mask],
                                                     depth_gt[mask])
        # bts_loss = self.computeDistanceAwareBerhuLoss(depth_pred, depth_gt, max_range)
        stats['depth_bts'] = bts_loss

        # Compute the Class loss
        class_loss = self.bts_wt * self.computeClassDepthLoss(
            class_depth_pred,
            sem_class_mask=sem_class_mask,
            depth_gt=depth_gt,
            valid_mask=mask)
        stats["depth_class"] = class_loss

        # Compute the panoptic loss
        panoptic_edge_loss = self.po_edge_wt * self.computePanopticEdgeLoss(
            depth_pred, depth_gt, po_mask, valid_mask=mask)
        stats['depth_po_edge'] = panoptic_edge_loss

        # Compute the normal scale invariant loss
        si_loss = self.computeScaleInvariantLoss(depth_pred[mask],
                                                 depth_gt[mask])
        stats["depth_si"] = si_loss

        # depth_loss = self.bts_wt * bts_loss + self.bts_wt * class_loss + self.po_edge_wt * panoptic_edge_loss
        # depth_loss = depth_loss.reshape(1)

        for stat_type, stat_value in stats.items():
            stats[stat_type] = stat_value.reshape(1)

        return bts_loss, class_loss, panoptic_edge_loss, stats
Esempio n. 6
0
    def forward(self,
                img,
                msk=None,
                cat=None,
                iscrowd=None,
                bbx=None,
                depth_gt=None,
                do_train=False,
                do_validate=False,
                do_test=False,
                get_test_metrics=False,
                get_vis=False,
                get_sem_vis=False):
        result = OrderedDict()
        loss = OrderedDict()
        stats = OrderedDict()

        # Get some parameters
        img, valid_size = pad_packed_images(img)
        img_size = img.shape[-2:]

        if do_train or do_validate:
            cat, iscrowd, bbx, ids, sem_gt = self._prepare_inputs(
                msk, cat, iscrowd, bbx)

        # Get the image features
        ms_img_feat = self.body(img)

        # RPN Part
        if do_train or do_validate:
            obj_loss, bbx_loss, proposals = self.rpn_algo.training(
                self.rpn_head,
                ms_img_feat,
                bbx,
                iscrowd,
                valid_size,
                training=self.training,
                do_inference=True)
        elif do_test:
            proposals = self.rpn_algo.inference(self.rpn_head, ms_img_feat,
                                                valid_size, self.training)
            obj_loss, bbx_loss = None, None
        else:
            obj_loss, bbx_loss, proposals = None, None, None

        # ROI Part
        if do_train or do_validate:
            roi_cls_loss, roi_bbx_loss, roi_msk_loss = self.inst_algo.training(
                self.roi_head, ms_img_feat, proposals, bbx, cat, iscrowd, ids,
                msk, img_size)
        else:
            roi_cls_loss, roi_bbx_loss, roi_msk_loss = None, None, None
        if do_test:
            bbx_pred, cls_pred, obj_pred, msk_pred = self.inst_algo.inference(
                self.roi_head, ms_img_feat, proposals, valid_size, img_size)
        else:
            bbx_pred, cls_pred, obj_pred, msk_pred = None, None, None, None

        # Segmentation Part
        if do_train or do_validate:
            sem_loss, sem_conf_mat, sem_pred, sem_logits, sem_feat = self.sem_algo.training(
                self.sem_head, ms_img_feat, sem_gt, valid_size, img_size)
        elif do_test:
            sem_pred, sem_logits, sem_feat = self.sem_algo.inference(
                self.sem_head, ms_img_feat, valid_size, img_size)
            sem_loss, sem_conf_mat = None, None
        else:
            sem_loss, sem_conf_mat, sem_pred, sem_logits, sem_feat = None, None, None, None, None

        # Depth Part
        if do_train:
            sem_class_mask = self._makeSemanticClassMask(sem_gt)
            sem_class_depth_pred, sem_depth_pred, depth_bts_loss, depth_class_loss, depth_panoptic_edge_loss, sem_depth_stats = self.sem_depth_algo.training(
                self.sem_depth_head,
                ms_img_feat,
                sem_feat,
                depth_gt,
                po_mask=msk,
                sem_class_mask=sem_class_mask)
        elif do_validate:
            sem_class_mask = self._makeSemanticClassMask(sem_gt)
            sem_class_depth_pred, sem_depth_pred, depth_bts_loss, depth_class_loss, depth_panoptic_edge_loss, sem_depth_stats = self.sem_depth_algo.training(
                self.sem_depth_head,
                ms_img_feat,
                sem_feat,
                depth_gt,
                po_mask=msk,
                sem_class_mask=sem_class_mask)
        elif do_test:
            sem_class_depth_pred, sem_depth_pred = self.sem_depth_algo.inference(
                self.sem_depth_head, ms_img_feat, sem_feat)
            depth_bts_loss, depth_class_loss, depth_panoptic_edge_loss, sem_depth_stats = None, None, None, None
        else:
            sem_class_depth_pred, sem_depth_pred, depth_bts_loss, depth_class_loss, depth_panoptic_edge_loss, sem_depth_stats = None, None, None, None, None, None

        # Multi-loss head
        panpotic_loss = obj_loss + bbx_loss + roi_cls_loss + roi_bbx_loss + roi_msk_loss + sem_loss
        total_loss, loss_weights = self.multi_loss_algo.computeMultiLoss(
            self.multi_loss_head, [
                panpotic_loss, depth_bts_loss, depth_class_loss,
                depth_panoptic_edge_loss
            ], [True, True, True, True])

        # Prepare outputs
        # LOSSES
        loss['obj_loss'] = loss_weights[0] * obj_loss
        loss['bbx_loss'] = loss_weights[0] * bbx_loss
        loss['roi_cls_loss'] = loss_weights[0] * roi_cls_loss
        loss['roi_bbx_loss'] = loss_weights[0] * roi_bbx_loss
        loss['roi_msk_loss'] = loss_weights[0] * roi_msk_loss
        loss["sem_loss"] = loss_weights[0] * sem_loss
        loss['depth_bts_loss'] = loss_weights[1] * depth_bts_loss
        loss['depth_class_loss'] = loss_weights[2] * depth_class_loss
        loss['depth_po_edge_loss'] = loss_weights[3] * depth_panoptic_edge_loss

        # OTHER STATISTICS
        if sem_depth_stats is not None:
            for key in sem_depth_stats.keys():
                stats[key] = sem_depth_stats[key]
        stats['sem_conf'] = sem_conf_mat

        # PREDICTIONS
        result['bbx_pred'] = bbx_pred
        result['cls_pred'] = cls_pred
        result['obj_pred'] = obj_pred
        result['msk_pred'] = msk_pred
        result["sem_pred"] = sem_pred
        result['sem_logits'] = sem_logits
        result["sem_depth_pred"] = sem_depth_pred
        result["wt_obj"] = loss_weights[0]
        result["wt_bbx"] = loss_weights[0]
        result["wt_roi_cls"] = loss_weights[0]
        result["wt_roi_bbx"] = loss_weights[0]
        result['wt_roi_msk'] = loss_weights[0]
        result['wt_sem'] = loss_weights[0]
        result['wt_depth_bts'] = loss_weights[1]
        result['wt_depth_class'] = loss_weights[2]
        result['wt_depth_po_edge'] = loss_weights[3]

        # Get the visualisation
        if get_vis:
            result['vis'] = visualisePanopticDepth(img, depth_gt,
                                                   sem_depth_pred,
                                                   sem_class_depth_pred,
                                                   sem_class_mask,
                                                   self.dataset)

        if get_sem_vis:
            result['sem_vis'] = visualiseSemanticSegmentation(
                img, sem_gt, sem_pred, self.dataset)

        # Get the test metrics
        if get_test_metrics:
            pred_depth = sem_depth_pred.detach().clone()  #.cpu().numpy()
            depth_gt, _ = pad_packed_images(depth_gt)
            gt_depth = depth_gt.detach().clone()  #.cpu().numpy()
            pred_depth[pred_depth < self.min_depth] = self.min_depth
            pred_depth[pred_depth > self.max_depth] = self.max_depth
            pred_depth[torch.isinf(pred_depth)] = self.max_depth
            pred_depth[torch.isnan(pred_depth)] = self.min_depth

            valid_mask = (gt_depth > self.min_depth) & (gt_depth <
                                                        self.max_depth)

            result['depth_log10'], result['depth_abs_rel'], result[
                'depth_rms'], result['depth_sq_rel'], result[
                    'depth_log_rms'], result['depth_d1'], result[
                        'depth_d2'], result['depth_d3'], result[
                            'depth_si_log'] = computeTestMetrics(
                                gt_depth, pred_depth, valid_mask)

        return total_loss, loss, result, stats