def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg_0, gt_semantic_seg_1, gt_semantic_seg_2): """Run forward function and calculate loss for decode head in training.""" losses = dict() # class check #temp_gt = gt_semantic_seg #temp_gt[temp_gt==255]=-1 rank, world_size = get_dist_info() if rank == 100: print(gt_semantic_seg_2.max(), gt_semantic_seg_1.max(), gt_semantic_seg_0.max()) loss_decode_dir = self.lane_dir_head.forward_train( x, img_metas, gt_semantic_seg_2, self.train_cfg) loss_decode_sty = self.lane_sty_head.forward_train( x, img_metas, gt_semantic_seg_1, self.train_cfg) loss_decode_typ = self.lane_typ_head.forward_train( x, img_metas, gt_semantic_seg_0, self.train_cfg) rank, world_size = get_dist_info() if rank == 10: print('pin_size', gt_semantic_seg.size(), 'pin_size') print('\npin 0', gt_semantic_seg[:, :, :, 0].max(), '\npin 0') print('\npin 1', gt_semantic_seg[:, :, :, 1].max(), '\npin 0') print('\npin 2', gt_semantic_seg[:, :, :, 2].max(), '\npin 0') losses.update(add_prefix(loss_decode_dir, 'decode_lane_dir')) losses.update(add_prefix(loss_decode_sty, 'decode_lane_sty')) losses.update(add_prefix(loss_decode_typ, 'decode_lane_typ')) return losses
def _loss_regularization_forward_train(self): """Calculate regularization loss for model weight in training.""" losses = dict() if isinstance(self.loss_regularization, nn.ModuleList): for idx, regularize_loss in enumerate(self.loss_regularization): loss_regularize = dict( loss_regularize=regularize_loss(self.modules())) losses.update(add_prefix(loss_regularize, f'regularize_{idx}')) else: loss_regularize = dict( loss_regularize=self.loss_regularization(self.modules())) losses.update(add_prefix(loss_regularize, 'regularize')) return losses
def losses(self, seg_logit, seg_label): """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit loss = dict() loss.update( add_prefix( super(DAHead, self).losses(pam_cam_seg_logit, seg_label), 'pam_cam')) loss.update( add_prefix( super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam')) loss.update( add_prefix( super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam')) return loss
def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): """Run forward function and calculate loss for auxiliary head in training.""" losses = dict() if isinstance(self.auxiliary_head, nn.ModuleList): for idx, aux_head in enumerate(self.auxiliary_head): loss_aux = aux_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_aux, f'aux_{idx}')) else: loss_aux = self.auxiliary_head.forward_train( x, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_aux, 'aux')) return losses
def _decode_head_forward_train_r(self, x, img_metas, gt_semantic_seg): """Run forward function and calculate loss for decode head in training.""" losses = dict() loss_decode = self.decode_head_r.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_decode, 'decode_r')) return losses
def _decode_head_forward_train_r(self, x, img_metas, gt_semantic_seg): """Run forward function and calculate loss for decode head in training.""" losses = dict() loss_decode = self.decode_head_r[0].forward_train( x, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_decode, 'decode_r_0')) for i in range(1, self.num_stages): # forward test again, maybe unnecessary for most methods. prev_outputs = self.decode_head_r[i - 1].forward_test( x, img_metas, self.test_cfg) loss_decode = self.decode_head_r[i].forward_train( x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) losses.update(add_prefix(loss_decode, f'decode_r_{i}')) return losses
def _auxiliary_head_forward_train_kd(self, x, T_x, img_metas, gt_semantic_seg): """Run forward function and calculate loss for auxiliary head in training.""" with torch.no_grad(): # x_T = self.teacher.extract_feat(img) self.teacher.eval() Teacher_seg_logit = self.teacher._decode_head_forward_test(T_x, img_metas) losses = dict() if isinstance(self.auxiliary_head, nn.ModuleList): for idx, aux_head in enumerate(self.auxiliary_head): loss_aux = aux_head.forward_train_kd(x, img_metas, gt_semantic_seg, self.train_cfg, Teacher_seg_logit) losses.update(add_prefix(loss_aux, f'aux_{idx}')) else: loss_aux = self.auxiliary_head.forward_train_kd( x, img_metas, gt_semantic_seg, self.train_cfg, Teacher_seg_logit) losses.update(add_prefix(loss_aux, 'aux')) return losses
def _decode_head_forward_train_kd(self, x, T_x, img_metas, gt_semantic_seg): """Run forward function and calculate loss for decode head in training.""" with torch.no_grad(): # x_T = self.teacher.extract_feat(img) self.teacher.eval() Teacher_seg_logit = self.teacher._decode_head_forward_test(T_x, img_metas) # seg_logits = self.teacher.forward(inputs) losses = dict() loss_decode = self.decode_head.forward_train_kd(x, img_metas, gt_semantic_seg, self.train_cfg, Teacher_seg_logit ) losses.update(add_prefix(loss_decode, 'decode')) return losses