コード例 #1
0
    def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg_0,
                                   gt_semantic_seg_1, gt_semantic_seg_2):
        """Run forward function and calculate loss for decode head in
        training."""
        losses = dict()
        # class check

        #temp_gt = gt_semantic_seg
        #temp_gt[temp_gt==255]=-1

        rank, world_size = get_dist_info()
        if rank == 100:
            print(gt_semantic_seg_2.max(), gt_semantic_seg_1.max(),
                  gt_semantic_seg_0.max())
        loss_decode_dir = self.lane_dir_head.forward_train(
            x, img_metas, gt_semantic_seg_2, self.train_cfg)

        loss_decode_sty = self.lane_sty_head.forward_train(
            x, img_metas, gt_semantic_seg_1, self.train_cfg)
        loss_decode_typ = self.lane_typ_head.forward_train(
            x, img_metas, gt_semantic_seg_0, self.train_cfg)
        rank, world_size = get_dist_info()
        if rank == 10:
            print('pin_size', gt_semantic_seg.size(), 'pin_size')
            print('\npin 0', gt_semantic_seg[:, :, :, 0].max(), '\npin 0')
            print('\npin 1', gt_semantic_seg[:, :, :, 1].max(), '\npin 0')
            print('\npin 2', gt_semantic_seg[:, :, :, 2].max(), '\npin 0')

        losses.update(add_prefix(loss_decode_dir, 'decode_lane_dir'))
        losses.update(add_prefix(loss_decode_sty, 'decode_lane_sty'))
        losses.update(add_prefix(loss_decode_typ, 'decode_lane_typ'))
        return losses
コード例 #2
0
    def _loss_regularization_forward_train(self):
        """Calculate regularization loss for model weight in training."""
        losses = dict()
        if isinstance(self.loss_regularization, nn.ModuleList):
            for idx, regularize_loss in enumerate(self.loss_regularization):
                loss_regularize = dict(
                    loss_regularize=regularize_loss(self.modules()))
                losses.update(add_prefix(loss_regularize, f'regularize_{idx}'))
        else:
            loss_regularize = dict(
                loss_regularize=self.loss_regularization(self.modules()))
            losses.update(add_prefix(loss_regularize, 'regularize'))

        return losses
コード例 #3
0
 def losses(self, seg_logit, seg_label):
     """Compute ``pam_cam``, ``pam``, ``cam`` loss."""
     pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
     loss = dict()
     loss.update(
         add_prefix(
             super(DAHead, self).losses(pam_cam_seg_logit, seg_label),
             'pam_cam'))
     loss.update(
         add_prefix(
             super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam'))
     loss.update(
         add_prefix(
             super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam'))
     return loss
コード例 #4
0
    def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):
        """Run forward function and calculate loss for auxiliary head in
        training."""
        losses = dict()
        if isinstance(self.auxiliary_head, nn.ModuleList):
            for idx, aux_head in enumerate(self.auxiliary_head):
                loss_aux = aux_head.forward_train(x, img_metas,
                                                  gt_semantic_seg,
                                                  self.train_cfg)
                losses.update(add_prefix(loss_aux, f'aux_{idx}'))
        else:
            loss_aux = self.auxiliary_head.forward_train(
                x, img_metas, gt_semantic_seg, self.train_cfg)
            losses.update(add_prefix(loss_aux, 'aux'))

        return losses
コード例 #5
0
 def _decode_head_forward_train_r(self, x, img_metas, gt_semantic_seg):
     """Run forward function and calculate loss for decode head in
     training."""
     losses = dict()
     loss_decode = self.decode_head_r.forward_train(x, img_metas,
                                                    gt_semantic_seg,
                                                    self.train_cfg)
     losses.update(add_prefix(loss_decode, 'decode_r'))
     return losses
コード例 #6
0
    def _decode_head_forward_train_r(self, x, img_metas, gt_semantic_seg):
        """Run forward function and calculate loss for decode head in
        training."""
        losses = dict()

        loss_decode = self.decode_head_r[0].forward_train(
            x, img_metas, gt_semantic_seg, self.train_cfg)

        losses.update(add_prefix(loss_decode, 'decode_r_0'))

        for i in range(1, self.num_stages):
            # forward test again, maybe unnecessary for most methods.
            prev_outputs = self.decode_head_r[i - 1].forward_test(
                x, img_metas, self.test_cfg)
            loss_decode = self.decode_head_r[i].forward_train(
                x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)
            losses.update(add_prefix(loss_decode, f'decode_r_{i}'))

        return losses
コード例 #7
0
ファイル: kd_encoder_decoder.py プロジェクト: jiamingNo1/skd
    def _auxiliary_head_forward_train_kd(self, x, T_x, img_metas, gt_semantic_seg):
        """Run forward function and calculate loss for auxiliary head in
        training."""

        with torch.no_grad():
            # x_T = self.teacher.extract_feat(img)
            self.teacher.eval()
            Teacher_seg_logit = self.teacher._decode_head_forward_test(T_x, img_metas)

        losses = dict()
        if isinstance(self.auxiliary_head, nn.ModuleList):
            for idx, aux_head in enumerate(self.auxiliary_head):
                loss_aux = aux_head.forward_train_kd(x, img_metas,
                                                  gt_semantic_seg,
                                                  self.train_cfg,
                                                  Teacher_seg_logit)
                losses.update(add_prefix(loss_aux, f'aux_{idx}'))
        else:
            loss_aux = self.auxiliary_head.forward_train_kd(
                x, img_metas, gt_semantic_seg, self.train_cfg, Teacher_seg_logit)
            losses.update(add_prefix(loss_aux, 'aux'))

        return losses
コード例 #8
0
ファイル: kd_encoder_decoder.py プロジェクト: jiamingNo1/skd
    def _decode_head_forward_train_kd(self, x, T_x, img_metas, gt_semantic_seg):
        """Run forward function and calculate loss for decode head in
        training."""
        with torch.no_grad():
            # x_T = self.teacher.extract_feat(img)
            self.teacher.eval()
            Teacher_seg_logit = self.teacher._decode_head_forward_test(T_x, img_metas)
            # seg_logits = self.teacher.forward(inputs)
        losses = dict()
        loss_decode = self.decode_head.forward_train_kd(x, img_metas,
                                                     gt_semantic_seg,
                                                     self.train_cfg,
                                                     Teacher_seg_logit
                                                        )

        losses.update(add_prefix(loss_decode, 'decode'))
        return losses