def forward(self, outputs, batch): cfg = self.cfg hm_loss, wh_loss, off_loss, seg_loss, seg_feat_loss = 0, 0, 0, 0, 0 hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0 hm, wh, hps, reg, hm_hp, hp_offset, seg_feat, seg = outputs for s in range(cfg.MODEL.NUM_STACKS): hm = _sigmoid(hm) if cfg.LOSS.HM_HP and not cfg.LOSS.MSE_LOSS: hm_hp = _sigmoid(hm_hp) if cfg.TEST.EVAL_ORACLE_HMHP: hm_hp = batch['hm_hp'] if cfg.TEST.EVAL_ORACLE_HM: hm = batch['hm'] if cfg.TEST.EVAL_ORACLE_KPS: if cfg.LOSS.DENSE_HP: hps = batch['dense_hps'] else: hps = torch.from_numpy( gen_oracle_map( batch['hps'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), cfg.MODEL.OUTPUT_RES, cfg.MODEL.OUTPUT_RES)).to( torch.device('cuda:%d' % self.local_rank)) if cfg.TEST.EVAL_ORACLE_HP_OFFSET: hp_offset = torch.from_numpy( gen_oracle_map( hp_offset.detach().cpu().numpy(), batch['hp_ind'].detach().cpu().numpy(), cfg.MODEL.OUTPUT_RES, cfg.MODEL.OUTPUT_RES)).to( torch.device('cuda:%d' % self.local_rank)) hm_loss += self.crit(hm, batch['hm']) / cfg.MODEL.NUM_STACKS if cfg.LOSS.DENSE_HP: mask_weight = batch['dense_hps_mask'].sum() + 1e-4 hp_loss += (self.crit_kp( hps * batch['dense_hps_mask'], batch['dense_hps'] * batch['dense_hps_mask']) / mask_weight) / cfg.MODEL.NUM_STACKS else: hp_loss += self.crit_kp(hps, batch['hps_mask'], batch['ind'], batch['hps']) / cfg.MODEL.NUM_STACKS if cfg.LOSS.WH_WEIGHT > 0: wh_loss += self.crit_reg(wh, batch['reg_mask'], batch['ind'], batch['wh']) / cfg.MODEL.NUM_STACKS if cfg.LOSS.REG_OFFSET and cfg.LOSS.OFF_WEIGHT > 0: off_loss += self.crit_reg(reg, batch['reg_mask'], batch['ind'], batch['reg']) / cfg.MODEL.NUM_STACKS if cfg.LOSS.REG_HP_OFFSET and cfg.LOSS.OFF_WEIGHT > 0: hp_offset_loss += self.crit_reg( hp_offset, batch['hp_mask'], batch['hp_ind'], batch['hp_offset']) / cfg.MODEL.NUM_STACKS if cfg.LOSS.HM_HP and cfg.LOSS.HM_HP_WEIGHT > 0: hm_hp_loss += self.crit_hm_hp( hm_hp, batch['hm_hp']) / cfg.MODEL.NUM_STACKS seg_loss += self.crit_seg(seg, seg_feat, batch['ind'], batch['seg']) loss = cfg.LOSS.HM_WEIGHT * hm_loss + cfg.LOSS.WH_WEIGHT * wh_loss + \ cfg.LOSS.OFF_WEIGHT * off_loss + cfg.LOSS.HP_WEIGHT * hp_loss + \ cfg.LOSS.HM_HP_WEIGHT * hm_hp_loss + cfg.LOSS.OFF_WEIGHT * hp_offset_loss+\ seg_loss loss_stats = { 'loss': loss, 'hm_loss': hm_loss, 'hp_loss': hp_loss, 'hm_hp_loss': hm_hp_loss, 'hp_offset_loss': hp_offset_loss, 'wh_loss': wh_loss, 'off_loss': off_loss, 'seg_loss': seg_loss } return loss, loss_stats
def forward(self, outputs, batch): opt = self.opt hm_loss, wh_loss, off_loss = 0, 0, 0 hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0 for s in range(opt.num_stacks): output = outputs[s] output['hm'] = _sigmoid(output['hm']) if opt.hm_hp and not opt.mse_loss: output['hm_hp'] = _sigmoid(output['hm_hp']) if opt.eval_oracle_hmhp: output['hm_hp'] = batch['hm_hp'] if opt.eval_oracle_hm: output['hm'] = batch['hm'] if opt.eval_oracle_kps: if opt.dense_hp: output['hps'] = batch['dense_hps'] else: output['hps'] = torch.from_numpy( gen_oracle_map(batch['hps'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), opt.output_res, opt.output_res)).to(opt.device) if opt.eval_oracle_hp_offset: output['hp_offset'] = torch.from_numpy( gen_oracle_map(batch['hp_offset'].detach().cpu().numpy(), batch['hp_ind'].detach().cpu().numpy(), opt.output_res, opt.output_res)).to(opt.device) hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks if opt.dense_hp: mask_weight = batch['dense_hps_mask'].sum() + 1e-4 hp_loss += (self.crit_kp( output['hps'] * batch['dense_hps_mask'], batch['dense_hps'] * batch['dense_hps_mask']) / mask_weight) / opt.num_stacks else: hp_loss += self.crit_kp(output['hps'], batch['hps_mask'], batch['ind'], batch['hps']) / opt.num_stacks if opt.wh_weight > 0: wh_loss += self.crit_reg(output['wh'], batch['reg_mask'], batch['ind'], batch['wh']) / opt.num_stacks if opt.reg_offset and opt.off_weight > 0: off_loss += self.crit_reg(output['reg'], batch['reg_mask'], batch['ind'], batch['reg']) / opt.num_stacks if opt.reg_hp_offset and opt.off_weight > 0: hp_offset_loss += self.crit_reg( output['hp_offset'], batch['hp_mask'], batch['hp_ind'], batch['hp_offset']) / opt.num_stacks if opt.hm_hp and opt.hm_hp_weight > 0: hm_hp_loss += self.crit_hm_hp(output['hm_hp'], batch['hm_hp']) / opt.num_stacks loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \ opt.off_weight * off_loss + opt.hp_weight * hp_loss + \ opt.hm_hp_weight * hm_hp_loss + opt.off_weight * hp_offset_loss loss_stats = { 'loss': loss, 'hm_loss': opt.hm_weight * hm_loss, 'hp_loss': opt.hp_weight * hp_loss, 'hm_hp_loss': opt.hm_hp_weight * hm_hp_loss, 'hp_offset_loss': opt.off_weight * hp_offset_loss, 'wh_loss': opt.wh_weight * wh_loss, 'off_loss': opt.off_weight * off_loss } return loss, loss_stats
def forward(self, outputs, batch): opt = self.opt hm_loss, wh_loss, off_loss, proposal_loss, proposal_scale_loss = 0, 0, 0, 0, 0 for s in range(opt.num_stacks): output = outputs[s] if not opt.mse_loss: output['hm'] = _sigmoid(output['hm']) if opt.eval_oracle_hm: output['hm'] = batch['hm'] if opt.eval_oracle_wh: output['wh'] = torch.from_numpy( gen_oracle_map(batch['wh'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['wh'].shape[3], output['wh'].shape[2])).to(opt.device) if opt.eval_oracle_offset: output['reg'] = torch.from_numpy( gen_oracle_map(batch['reg'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['reg'].shape[3], output['reg'].shape[2])).to(opt.device) hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks if opt.wh_weight > 0: if opt.dense_wh: mask_weight = batch['dense_wh_mask'].sum() + 1e-4 wh_loss += (self.crit_wh( output['wh'] * batch['dense_wh_mask'], batch['dense_wh'] * batch['dense_wh_mask']) / mask_weight) / opt.num_stacks elif opt.cat_spec_wh: wh_loss += self.crit_wh( output['wh'], batch['cat_spec_mask'], batch['ind'], batch['cat_spec_wh']) / opt.num_stacks else: if opt.reg_proposal: out_scale = output['scale'].detach() wh_loss += self.crit_reg( output['wh'] + out_scale, batch['reg_mask'], batch['ind'], batch['wh']) / opt.num_stacks else: wh_loss += self.crit_reg( output['wh'], batch['reg_mask'], batch['ind'], batch['wh']) / opt.num_stacks if opt.reg_offset and opt.off_weight > 0: off_loss += self.crit_reg(output['reg'], batch['reg_mask'], batch['ind'], batch['reg']) / opt.num_stacks if opt.reg_proposal and opt.proposal_weight > 0: output['proposal'] = _sigmoid( output['proposal']) # for focal loss ignore_mask = batch['proposal'].gt(-1).float() proposal_loss += self.crit_centerness( output['proposal'], batch['proposal'], ignore_mask) / opt.num_stacks ignore_scale_mask = batch['scale'].gt(0).float() valid_num = ignore_scale_mask.sum() proposal_scale_loss += self.crit_scale( output['scale'] * ignore_scale_mask, batch['scale']) / valid_num / opt.num_stacks loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \ opt.off_weight * off_loss + opt.proposal_weight * proposal_loss + opt.scale_weight * proposal_scale_loss if opt.reg_proposal: loss_stats = { 'loss': loss, 'proposal_loss': proposal_loss, 'scale_loss': proposal_scale_loss, 'hm_loss': hm_loss, 'wh_loss': wh_loss, 'off_loss': off_loss } else: loss_stats = { 'loss': loss, 'hm_loss': hm_loss, 'wh_loss': wh_loss, 'off_loss': off_loss } return loss, loss_stats
def forward(self, outputs, batch): opt = self.opt hm_loss, dep_loss, rot_loss, dim_loss = 0, 0, 0, 0 wh_loss, off_loss = 0, 0 tilt_loss = 0 for s in range(opt.num_stacks): output = outputs[s] output['hm'] = _sigmoid(output['hm']) output['dep'] = 1. / (output['dep'].sigmoid() + 1e-6) - 1. if opt.eval_oracle_dep: output['dep'] = torch.from_numpy( gen_oracle_map(batch['dep'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), opt.output_w, opt.output_h)).to(opt.device) hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks if opt.dep_weight > 0: #depth : L1 Loss dep_loss += self.crit_reg(output['dep'], batch['reg_mask'], batch['ind'], batch['dep']) / opt.num_stacks if opt.dim_weight > 0: #3D dimension : L1 Loss dim_loss += self.crit_reg(output['dim'], batch['reg_mask'], batch['ind'], batch['dim']) / opt.num_stacks if opt.rot_weight > 0: #orientation rot_loss += self.crit_rot(output['rot'], batch['rot_mask'], batch['ind'], batch['rotbin'], batch['rotres']) / opt.num_stacks # print('rot_mask') # print(batch['rot_mask']) #tilt # if opt.tilt_weight > 0: # tilt_loss += self.crit_reg(output['tilt'], batch['rot_mask'], # batch['ind'], batch['tilt']) / opt.num_stacks # print('reg_mask') # print(batch['reg_mask']) if opt.reg_bbox and opt.wh_weight > 0: wh_loss += self.crit_reg(output['wh'], batch['rot_mask'], batch['ind'], batch['wh']) / opt.num_stacks if opt.reg_offset and opt.off_weight > 0: off_loss += self.crit_reg(output['reg'], batch['rot_mask'], batch['ind'], batch['reg']) / opt.num_stacks if opt.reg_bbox: loss = opt.hm_weight * hm_loss + opt.dep_weight * dep_loss + \ opt.dim_weight * dim_loss + opt.rot_weight * rot_loss + \ opt.wh_weight * wh_loss + opt.off_weight * off_loss loss_stats = { 'loss': loss, 'hm_loss': hm_loss, 'dep_loss': dep_loss, 'dim_loss': dim_loss, 'rot_loss': rot_loss, 'wh_loss': wh_loss, 'off_loss': off_loss, 'tilt_loss': tilt_loss } else: loss = opt.hm_weight * hm_loss + opt.dep_weight * dep_loss + \ opt.dim_weight * dim_loss + opt.rot_weight * rot_loss + \ opt.off_weight * off_loss loss_stats = { 'loss': loss, 'hm_loss': hm_loss, 'dep_loss': dep_loss, 'dim_loss': dim_loss, 'rot_loss': rot_loss, 'off_loss': off_loss } return loss, loss_stats
def forward(self, outputs, batch): opt = self.opt focal_loss, pull_loss, push_loss, reg_loss = 0, 0, 0, 0 lm_focal_loss, rm_focal_loss, ct_focal_loss = 0, 0, 0 lm_reg_loss, rm_reg_loss, ct_reg_loss = 0, 0, 0 for s in range(opt.num_stacks): output = outputs[s] if not opt.mse_loss: output['lm'] = _sigmoid(output['lm']) output['rm'] = _sigmoid(output['rm']) output['ct'] = _sigmoid(output['ct']) if opt.eval_oracle_lm: output['lm'] = batch['lm'] if opt.eval_oracle_rm: output['rm'] = batch['rm'] if opt.eval_oracle_ct: output['ct'] = batch['ct'] if opt.eval_oracle_ae: output['lm_tag'] = torch.from_numpy( gen_oracle_map(batch['lm_tag'].detach().cpu().numpy(), batch['lm_tag'].detach().cpu().numpy(), output['lm_tag'].shape[3], output['lm_tag'].shape[2])).to(opt.device) output['rm_tag'] = torch.from_numpy( gen_oracle_map(batch['rm_tag'].detach().cpu().numpy(), batch['rm_tag'].detach().cpu().numpy(), output['rm_tag'].shape[3], output['rm_tag'].shape[2])).to(opt.device) if opt.eval_oracle_offset: output['lm_reg'] = torch.from_numpy( gen_oracle_map(batch['lm_reg'].detach().cpu().numpy(), batch['lm_tag'].detach().cpu().numpy(), output['lm_reg'].shape[3], output['lm_reg'].shape[2])).to(opt.device) output['rm_reg'] = torch.from_numpy( gen_oracle_map(batch['rm_reg'].detach().cpu().numpy(), batch['rm_tag'].detach().cpu().numpy(), output['rm_reg'].shape[3], output['rm_reg'].shape[2])).to(opt.device) output['ct_reg'] = torch.from_numpy( gen_oracle_map(batch['ct_reg'].detach().cpu().numpy(), batch['ct_tag'].detach().cpu().numpy(), output['ct_reg'].shape[3], output['ct_reg'].shape[2])).to(opt.device) # focal loss lm_focal_loss = self.crit(output['lm'], batch['lm']) / opt.num_stacks rm_focal_loss = self.crit(output['rm'], batch['rm']) / opt.num_stacks ct_focal_loss = self.crit(output['ct'], batch['ct']) / opt.num_stacks focal_loss += lm_focal_loss focal_loss += rm_focal_loss focal_loss += ct_focal_loss # tag loss pull, push = self.crit_tag(output['rm_tag'], output['lm_tag'], batch['rm_tag'], batch['lm_tag'], batch['reg_mask']) pull_loss += opt.pull_weight * pull / opt.num_stacks push_loss += opt.push_weight * push / opt.num_stacks # reg loss lm_reg_loss = opt.regr_weight * self.crit_reg( output['lm_reg'], batch['reg_mask'], batch['lm_tag'], batch['lm_reg']) / opt.num_stacks rm_reg_loss = opt.regr_weight * self.crit_reg( output['rm_reg'], batch['reg_mask'], batch['rm_tag'], batch['rm_reg']) / opt.num_stacks ct_reg_loss = opt.regr_weight * self.crit_reg( output['ct_reg'], batch['reg_mask'], batch['ct_tag'], batch['ct_reg']) / opt.num_stacks reg_loss += lm_reg_loss reg_loss += rm_reg_loss reg_loss += ct_reg_loss loss = focal_loss + pull_loss + push_loss + reg_loss loss_stats = { 'loss': loss, 'focal_loss': focal_loss, 'pull_loss': pull_loss, 'push_loss': push_loss, 'reg_loss': reg_loss, 'lm_focal_loss': lm_focal_loss, 'rm_focal_loss': rm_focal_loss, 'ct_focal_loss': ct_focal_loss, 'lm_reg_loss': lm_reg_loss, 'rm_reg_loss': rm_reg_loss, 'ct_reg_loss': ct_reg_loss } return loss, loss_stats
def forward(self, outputs, batch): opt = self.opt hm_loss, dep_loss, rot_loss, dim_loss = 0, 0, 0, 0 wh_loss, off_loss = 0, 0 # opt.num_stacks = 2 if opt.arch == 'hourglass' else 1 for s in range(opt.num_stacks): output = outputs[s] output['hm'] = _sigmoid(output['hm']) output['dep'] = 1. / (output['dep'].sigmoid() + 1e-6) - 1. if opt.eval_oracle_dep: output['dep'] = torch.from_numpy( gen_oracle_map(batch['dep'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), opt.output_w, opt.output_h)).to(opt.device) import pdb pdb.set_trace() # opt.num_stacks:1 # L2损失 hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks if opt.dep_weight > 0: # output['dep']:torch.Size([1, 1, 96, 320]) # batch['reg_mask']:torch.Size([1, 50]) # batch['ind']:torch.Size([1, 50]) # batch['dep']:torch.Size([1, 50, 1]) dep_loss += self.crit_reg(output['dep'], batch['reg_mask'], batch['ind'], batch['dep']) / opt.num_stacks if opt.dim_weight > 0: # 定义的L1损失函数 dim_loss += self.crit_reg(output['dim'], batch['reg_mask'], batch['ind'], batch['dim']) / opt.num_stacks if opt.rot_weight > 0: rot_loss += self.crit_rot(output['rot'], batch['rot_mask'], batch['ind'], batch['rotbin'], batch['rotres']) / opt.num_stacks if opt.reg_bbox and opt.wh_weight > 0: wh_loss += self.crit_reg(output['wh'], batch['rot_mask'], batch['ind'], batch['wh']) / opt.num_stacks if opt.reg_offset and opt.off_weight > 0: off_loss += self.crit_reg(output['reg'], batch['rot_mask'], batch['ind'], batch['reg']) / opt.num_stacks loss = opt.hm_weight * hm_loss + opt.dep_weight * dep_loss + \ opt.dim_weight * dim_loss + opt.rot_weight * rot_loss + \ opt.wh_weight * wh_loss + opt.off_weight * off_loss loss_stats = { 'loss': loss, 'hm_loss': hm_loss, 'dep_loss': dep_loss, 'dim_loss': dim_loss, 'rot_loss': rot_loss, 'wh_loss': wh_loss, 'off_loss': off_loss } return loss, loss_stats
def forward(self, outputs, batch,global_step,tb_writer): opt = self.opt hm_loss, wh_loss, off_loss = 0, 0, 0 hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0 loss_stats = {} for s in range(opt.num_stacks): output = outputs[s] output['hm'] = _sigmoid(output['hm']) if opt.hm_hp: output['hm_hp'] = _sigmoid(output['hm_hp']) if opt.eval_oracle_hmhp: output['hm_hp'] = batch['hm_hp'] if opt.eval_oracle_hm: output['hm'] = batch['hm'] if opt.eval_oracle_kps: if opt.dense_hp: output['hps'] = batch['dense_hps'] else: output['hps'] = torch.from_numpy(gen_oracle_map( batch['hps'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), opt.output_res, opt.output_res)).to(opt.device) if opt.eval_oracle_hp_offset: output['hp_offset'] = torch.from_numpy(gen_oracle_map( batch['hp_offset'].detach().cpu().numpy(), batch['hp_ind'].detach().cpu().numpy(), opt.output_res, opt.output_res)).to(opt.device) if opt.mdn: V=torch.Tensor((np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0).astype(np.float32)).float().cuda() hm_loss += self.crit(output['hm'], batch['hm'])[0] / opt.num_stacks if opt.mdn: mdn_logits = output['mdn_logits'] #mdn_logits.shape: torch.Size([2, 3, 128, 128]) if opt.mdn_dropout > 0 and opt.epoch<opt.mdn_dropout_stop: M=opt.mdn_n_comps ridx= torch.randperm(M)[:torch.randint(1,1+opt.mdn_dropout,(1,))] drop_mask = torch.ones((M,)) drop_mask[ridx]=0 drop_mask = torch.reshape(drop_mask,(1,-1,1,1)).float().cuda() mdn_logits = mdn_logits*drop_mask if tb_writer is not None: tb_writer.add_histogram('drop_out_idx',ridx+1,global_step=global_step) else: mdn_pi = torch.clamp(torch.nn.Softmax(dim=1)(mdn_logits), 1e-4, 1.-1e-4) mdn_sigma= torch.clamp(torch.nn.ELU()(output['mdn_sigma'])+opt.mdn_min_sigma,1e-4,1e5) mdn_mu = output['hps'] if tb_writer is not None: for i in range(mdn_pi.shape[1]): tb_writer.add_histogram('mdn_pi/{}'.format(i),mdn_pi[:,i],global_step=global_step) tb_writer.add_histogram('mdn_sigma/{}'.format(i),mdn_sigma[:,i*2:i*2+2],global_step=global_step) if opt.dense_hp: gt = batch['dense_hps'] mask = batch['dense_hps_mask'][:,0::2,:,:] _,max_pi_ind = torch.max(mdn_pi,1) else: gt = batch['hps'] mask = batch['hps_mask'][:,:,0::2] mdn_mu= _tranpose_and_gather_feat(mdn_mu, batch['ind']) mdn_pi= _tranpose_and_gather_feat(mdn_pi, batch['ind']) mdn_sigma= _tranpose_and_gather_feat(mdn_sigma, batch['ind']) _,max_pi_ind = torch.max(mdn_pi,-1) if tb_writer is not None: tb_writer.add_histogram('mdn_pi_max_comp',max_pi_ind+1,global_step=global_step) ''' mdn_n_comps=3 batch['hps'].shape: torch.Size([2, 32, 34]) batch['hps_mask'].shape: torch.Size([2, 32, 34]) batch['ind'].shape: torch.Size([2, 32]) gt.shape: torch.Size([2, 32, 34]) mask.shape: torch.Size([2, 32, 17]) before gather, after gather mdn_mu.shape: torch.Size([2, 102, 128, 128]), torch.Size([2, 32, 102]) mdn_pi.shape: torch.Size([2, 3, 128, 128]), torch.Size([2, 32, 3]) mdn_sigma.shape: torch.Size([2, 6, 128, 128]), torch.Size([2, 32, 6]) ''' if opt.mdn_inter: hp_loss += self.crit_kp(gt,mdn_mu,mdn_sigma,mdn_pi,mask,V,debug=opt.debug==6)[0] / opt.num_stacks else: hp_loss = self.crit_kp(gt,mdn_mu,mdn_sigma,mdn_pi,mask,V,debug=opt.debug==6)[0] / opt.num_stacks else: if opt.dense_hp: mask_weight = batch['dense_hps_mask'].sum() + 1e-4 hp_loss += (self.crit_kp(output['hps'] * batch['dense_hps_mask'], batch['dense_hps'] * batch['dense_hps_mask']) / mask_weight) / opt.num_stacks else: hp_loss += self.crit_kp(output['hps'], batch['hps_mask'], batch['ind'], batch['hps']) / opt.num_stacks if opt.wh_weight > 0: wh_loss += self.crit_reg(output['wh'], batch['reg_mask'], batch['ind'], batch['wh'])[0] / opt.num_stacks if opt.reg_offset and opt.off_weight > 0: off_loss += self.crit_reg(output['reg'], batch['reg_mask'], batch['ind'], batch['reg'])[0] / opt.num_stacks if opt.reg_hp_offset and opt.off_weight > 0: hp_offset_loss += self.crit_reg( output['hp_offset'], batch['hp_mask'], batch['hp_ind'], batch['hp_offset'])[0] / opt.num_stacks if opt.hm_hp and opt.hm_hp_weight > 0: hm_hp_loss += self.crit_hm_hp( output['hm_hp'], batch['hm_hp'])[0] / opt.num_stacks loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \ opt.off_weight * off_loss + opt.hp_weight * hp_loss + \ opt.hm_hp_weight * hm_hp_loss + opt.off_weight * hp_offset_loss loss_stats.update({'loss': loss, 'hm_loss': hm_loss, 'hp_loss': hp_loss, 'hm_hp_loss': hm_hp_loss, 'hp_offset_loss': hp_offset_loss, 'wh_loss': wh_loss, 'off_loss': off_loss}) return loss, loss_stats
def forward(self, outputs, batch): """ Arguments: locations (list[BoxList]) box_cls (list[Tensor]) box_regression (list[Tensor]) centerness (list[Tensor]) targets (list[BoxList]) Returns: cls_loss (Tensor) reg_loss (Tensor) centerness_loss (Tensor) """ box_cls, box_regression, centerness = outputs[-1] box_cls = [_sigmoid(cls) for cls in box_cls] if len(centerness): centerness = [_sigmoid(ness) for ness in centerness] locations = self.compute_locations(box_regression) num_classes = box_cls[0].size(1) labels, reg_targets = self.prepare_targets(locations, batch) box_cls_flatten = [] box_regression_flatten = [] labels_flatten = [] reg_targets_flatten = [] centerness_flatten = [] for l in range(len(labels)): box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes)) box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4)) labels_flatten.append(labels[l].reshape(-1)) reg_targets_flatten.append(reg_targets[l].reshape(-1, 4)) if len(centerness): centerness_flatten.append(centerness[l].reshape(-1)) box_cls_flatten = torch.cat(box_cls_flatten, dim=0) box_regression_flatten = torch.cat(box_regression_flatten, dim=0) labels_flatten = torch.cat(labels_flatten, dim=0) reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0) centerness_targets = self.compute_centerness_targets(reg_targets_flatten) pos_inds = torch.nonzero(labels_flatten >= 0) # pytorch version incompatible if pos_inds.numel() > 0: pos_inds = pos_inds.squeeze(1) box_regression_flatten = box_regression_flatten[pos_inds] reg_targets_flatten = reg_targets_flatten[pos_inds] if len(centerness): centerness_flatten = torch.cat(centerness_flatten, dim=0) centerness_flatten = centerness_flatten[pos_inds] cls_loss = self.cls_loss_func(box_cls_flatten, labels_flatten.int()) / max(pos_inds.numel(), 1) else: cls_loss = self.cls_loss_func(box_cls_flatten, labels_flatten.int(), centerness_targets[:, None]) cls_loss = cls_loss / torch.clamp(centerness_targets[pos_inds].sum(), min=1) centerness_loss = cls_loss.new_zeros(1) if pos_inds.numel() > 0: centerness_targets = centerness_targets[pos_inds] reg_loss = self.box_reg_loss_func( box_regression_flatten, reg_targets_flatten, centerness_targets ) / centerness_targets.sum() if len(centerness): centerness_loss = self.centerness_loss_func( centerness_flatten, centerness_targets ) / pos_inds.numel() else: reg_loss = box_regression_flatten.sum() loss = cls_loss + reg_loss + centerness_loss loss_stats = {'loss': loss, 'cls': cls_loss, 'reg': reg_loss, 'centerness': centerness_loss} return loss, loss_stats
def forward(self, outputs_1, outputs_2, batch): opt = self.opt hm_loss_1, wh_loss_1, off_loss_1 = 0, 0, 0 hm_loss_2, wh_loss_2, off_loss_2 = 0, 0, 0 for s in range(opt.num_stacks): output_1 = outputs_1[s] output_2 = outputs_2[s] if not opt.mse_loss: output_1['hm_1'] = _sigmoid(output_1['hm_1']) output_2['hm_2'] = _sigmoid(output_2['hm_2']) if opt.eval_oracle_hm: output_1['hm_1'] = batch['hm_1'] output_2['hm_2'] = batch['hm_2'] if opt.eval_oracle_wh: output_1['wh_1'] = torch.from_numpy(gen_oracle_map( batch['wh_1'].detach().cpu().numpy(), batch['ind_1'].detach().cpu().numpy(), output_1['wh_1'].shape[3], output_1['wh_1'].shape[2])).to(opt.device) output_2['wh_2'] = torch.from_numpy(gen_oracle_map( batch['wh_2'].detach().cpu().numpy(), batch['ind_2'].detach().cpu().numpy(), output_2['wh_2'].shape[3], output_2['wh_2'].shape[2])).to(opt.device) if opt.eval_oracle_offset: output_1['reg_1'] = torch.from_numpy(gen_oracle_map( batch['reg_1'].detach().cpu().numpy(), batch['ind_1'].detach().cpu().numpy(), output_1['reg_1'].shape[3], output_1['reg_1'].shape[2])).to(opt.device) output_2['reg_2'] = torch.from_numpy(gen_oracle_map( batch['reg_2'].detach().cpu().numpy(), batch['ind_2'].detach().cpu().numpy(), output_2['reg_2'].shape[3], output_2['reg_2'].shape[2])).to(opt.device) hm_loss_1 += self.crit(output_1['hm_1'], batch['hm_1']) / opt.num_stacks hm_loss_2 += self.crit(output_2['hm_2'], batch['hm_2']) / opt.num_stacks if opt.wh_weight > 0: if opt.dense_wh: mask_weight = batch['dense_wh_mask'].sum() + 1e-4 # ''' wh_loss += ( self.crit_wh(output['wh'] * batch['dense_wh_mask'], # 不可用,沒更新成multitask版本 batch['dense_wh'] * batch['dense_wh_mask']) / mask_weight) / opt.num_stacks # ''' elif opt.cat_spec_wh: wh_loss_1 += self.crit_wh( output_1['wh_1'], batch['cat_spec_mask_1'], batch['ind_1'], batch['cat_spec_wh_1']) / opt.num_stacks wh_loss_2 += self.crit_wh( output_2['wh_2'], batch['cat_spec_mask_2'], batch['ind_2'], batch['cat_spec_wh_2']) / opt.num_stacks else: wh_loss_1 += self.crit_reg( output_1['wh_1'], batch['reg_mask_1'], batch['ind_1'], batch['wh_1'], output_1['hm_1'], batch['hm_1']) / opt.num_stacks wh_loss_2 += self.crit_reg( output_2['wh_2'], batch['reg_mask_2'], batch['ind_2'], batch['wh_2'], output_1['hm_2'], batch['hm_2']) / opt.num_stacks if opt.reg_offset and opt.off_weight > 0: off_loss_1 += self.crit_reg(output_1['reg_1'], batch['reg_mask_1'], batch['ind_1'], batch['reg_1'], output_1['hm_1'], batch['hm_1']) / opt.num_stacks off_loss_2 += self.crit_reg(output_2['reg_2'], batch['reg_mask_2'], batch['ind_2'], batch['reg_2'], output_1['hm_2'], batch['hm_2']) / opt.num_stacks loss_1 = opt.hm_weight * hm_loss_1 + opt.wh_weight * wh_loss_1 + \ opt.off_weight * off_loss_1 loss_2 = opt.hm_weight * hm_loss_2 + opt.wh_weight * wh_loss_2 + \ opt.off_weight * off_loss_2 loss_type = 'normal' if loss_type == 'weighted': loss_weight_1, loss_weight_2 = task_weight[0], task_weight[1] loss = (loss_weight_1 * loss_1) + (loss_weight_2 * loss_2) elif loss_type == 'geometric': n = task loss = (loss_1 * loss_2)**(1/n) else: loss_weight_1, loss_weight_2 = 1.0, 1.0 loss = (loss_weight_1 * loss_1) + (loss_weight_2 * loss_2) loss_stats = {'loss': loss, 'loss_1': loss_1, 'hm_loss_1': hm_loss_1, 'wh_loss_1': wh_loss_1, 'off_loss_1': off_loss_1, 'loss_2': loss_2, 'hm_loss_2': hm_loss_2, 'wh_loss_2': wh_loss_2, 'off_loss_2': off_loss_2} return loss, loss_stats
def forward(self, outputs, batch): opt = self.opt hm_loss, wh_loss, off_loss, ang_loss = 0, 0, 0, 0 for s in range(opt.num_stacks): # num_stacks = 1 output = outputs[s] if not opt.mse_loss: #这里使用sigmoid函数自己理解人为这是对heatmap做归一化, 在进行对heatmap loss计算加快收敛。 output['hm'] = _sigmoid(output['hm']) if opt.eval_oracle_hm: output['hm'] = batch['hm'] if opt.eval_oracle_wh: output['wh'] = torch.from_numpy( gen_oracle_map(batch['wh'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['wh'].shape[3], output['wh'].shape[2])).to(opt.device) if opt.eval_oracle_offset: output['reg'] = torch.from_numpy( gen_oracle_map(batch['reg'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['reg'].shape[3], output['reg'].shape[2])).to(opt.device) hm_loss += self.crit(output['hm'], batch['hm'])['loss'] / opt.num_stacks if opt.wh_weight > 0: if opt.dense_wh: mask_weight = batch['dense_wh_mask'].sum() + 1e-4 wh_loss += (self.crit_wh( output['wh'] * batch['dense_wh_mask'], batch['dense_wh'] * batch['dense_wh_mask']) / mask_weight) / opt.num_stacks elif opt.cat_spec_wh: wh_loss += self.crit_wh( output['wh'], batch['cat_spec_mask'], batch['ind'], batch['cat_spec_wh']) / opt.num_stacks else: wh_loss += self.crit_reg(output['wh'], batch['reg_mask'], batch['ind'], batch['wh']) / opt.num_stacks if opt.reg_offset and opt.off_weight > 0: off_loss += self.crit_reg(output['reg'], batch['reg_mask'], batch['ind'], batch['reg']) / opt.num_stacks if opt.ang_weight > 0: ang_loss += self.crit_reg(output['ang'], batch['reg_mask'], batch['ind'], batch['ang']) / opt.num_stacks loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \ opt.off_weight * off_loss+ opt.ang_weight*ang_loss loss_stats = { 'loss': loss, 'hm_loss': hm_loss, 'wh_loss': wh_loss, 'off_loss': off_loss, 'ang_loss': ang_loss } return loss, loss_stats
def forward(self, outputs, batch): opt = self.opt hm_loss, wh_loss, angle_loss, off_loss = 0, 0, 0, 0 for s in range(len(outputs)): output = outputs[s] if not opt.mse_loss: output['hm'] = _sigmoid(output['hm']) if opt.eval_oracle_hm: output['hm'] = batch['hm'] if opt.eval_oracle_wh: output['wh'] = torch.from_numpy( gen_oracle_map(batch['wh'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['wh'].shape[3], output['wh'].shape[2])).to(opt.device) if opt.eval_oracle_angle: output['angle'] = torch.from_numpy( gen_oracle_map(batch['angle'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['angle'].shape[3], output['angle'].shape[2])).to(opt.device) if opt.eval_oracle_offset: output['reg'] = torch.from_numpy( gen_oracle_map(batch['reg'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['reg'].shape[3], output['reg'].shape[2])).to(opt.device) hm_loss += self.crit(output['hm'], batch['hm']) / len(outputs) if opt.wh_weight > 0: if opt.dense_wh: mask_weight = batch['dense_wh_mask'].sum() + 1e-4 wh_loss += (self.crit_wh( output['wh'] * batch['dense_wh_mask'], batch['dense_wh'] * batch['dense_wh_mask']) / mask_weight) / len(outputs) elif opt.cat_spec_wh: wh_loss += self.crit_wh( output['wh'], batch['cat_spec_mask'], batch['ind'], batch['cat_spec_wh']) / len(outputs) else: wh_loss += self.crit_reg(output['wh'], batch['reg_mask'], batch['ind'], batch['wh']) / len(outputs) if opt.angle_weight > 0: # output['angle'] = _sigmoid(output['angle']) if opt.dense_angle: # mask_weight = batch['dense_angle_mask'].sum() + 1e-4 # angle_loss += (self.crit_angle(output['angle'] * batch['dense_angle_mask'], # batch['dense_angle'] * batch['dense_angle_mask']) /mask_weight) / len(outputs) angle_loss += self.crit_dense_angle( output['angle'], batch['dense_angle_mask'], batch['dense_angle']) / len(outputs) elif opt.cat_spec_angle: angle_loss += self.crit_angle( output['angle'], batch['cat_spec_angle_mask'], batch['ind'], batch['cat_spec_angle']) / len(outputs) else: angle_loss += self.crit_reg( output['angle'], batch['reg_mask'], batch['ind'], batch['angle']) / len(outputs) if opt.reg_offset and opt.off_weight > 0: off_loss += self.crit_reg(output['reg'], batch['reg_mask'], batch['ind'], batch['reg']) / len(outputs) loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \ opt.off_weight * off_loss + opt.angle_weight * angle_loss loss_stats = { 'loss': loss, 'hm_loss': hm_loss, 'wh_loss': wh_loss, 'angle_loss': angle_loss, 'off_loss': off_loss } return loss, loss_stats
def forward(self, outputs, batch): opt = self.opt hm_loss, wh_loss, off_loss, lincomb_mask_loss, segm_loss = 0, 0, 0, 0, 0 # hm_loss, wh_loss, off_loss, segm_loss = 0, 0, 0, 0 for s in range(opt.num_stacks): output = outputs[s] if not opt.mse_loss: output['hm'] = _sigmoid(output['hm']) if opt.eval_oracle_hm: output['hm'] = batch['hm'] if opt.eval_oracle_wh: output['wh'] = torch.from_numpy( gen_oracle_map(batch['wh'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['wh'].shape[3], output['wh'].shape[2])).to(opt.device) if opt.eval_oracle_offset: output['reg'] = torch.from_numpy( gen_oracle_map(batch['reg'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['reg'].shape[3], output['reg'].shape[2])).to(opt.device) hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks if opt.wh_weight > 0: if opt.dense_wh: mask_weight = batch['dense_wh_mask'].sum() + 1e-4 wh_loss += (self.crit_wh( output['wh'] * batch['dense_wh_mask'], batch['dense_wh'] * batch['dense_wh_mask']) / mask_weight) / opt.num_stacks elif opt.cat_spec_wh: wh_loss += self.crit_wh( output['wh'], batch['cat_spec_mask'], batch['ind'], batch['cat_spec_wh']) / opt.num_stacks else: wh_loss += self.crit_reg(output['wh'], batch['reg_mask'], batch['ind'], batch['wh']) / opt.num_stacks if opt.reg_offset and opt.off_weight > 0: off_loss += self.crit_reg(output['reg'], batch['reg_mask'], batch['ind'], batch['reg']) / opt.num_stacks if opt.use_semantic_segmentation_loss: segm_loss += self.semantic_segmentation_loss( output['segm'], batch['masks'], batch['gt_bbox_lbl']) / opt.num_stacks lincomb_mask_loss += self.lincomb_mask_loss( output['masks'], output['proto'], batch['reg_mask'], batch['masks'], batch['gt_bbox_lbl']) / opt.num_stacks loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \ opt.off_weight * off_loss + \ opt.lincomb_mask_weight * lincomb_mask_loss + opt.segm_weight * segm_loss loss_stats = { 'loss': loss, 'hm_loss': hm_loss, 'wh_loss': wh_loss, 'off_loss': off_loss, 'lincomb_mask_loss': lincomb_mask_loss, 'segm_loss': segm_loss } # loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \ # opt.off_weight * off_loss + opt.segm_weight * segm_loss # loss_stats = {'loss': loss, 'hm_loss': hm_loss, # 'wh_loss': wh_loss, 'off_loss': off_loss, 'segm_loss': segm_loss} return loss, loss_stats
def forward(self, outputs, batch): opt = self.opt hm_loss, wh_loss, off_loss = 0, 0, 0 hm_loss_bike, hm_loss_car, hm_loss_color_cone, hm_loss_person = 0, 0, 0, 0 for s in range(opt.num_stacks): output = outputs[s] if not opt.mse_loss: output['hm'] = _sigmoid(output['hm']) if opt.eval_oracle_hm: output['hm'] = batch['hm'] if opt.eval_oracle_wh: output['wh'] = torch.from_numpy( gen_oracle_map(batch['wh'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['wh'].shape[3], output['wh'].shape[2])).to(opt.device) if opt.eval_oracle_offset: output['reg'] = torch.from_numpy( gen_oracle_map(batch['reg'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['reg'].shape[3], output['reg'].shape[2])).to(opt.device) hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks hm_loss_categorie = {} for i, categorie in enumerate(cf.categories): hm_loss_categorie[categorie] = self.crit( output['hm'][:, i, :, :], batch['hm'][:, i, :, :]) / opt.num_stacks # hm_loss_bike += self.crit(output['hm'][:,0,:,:], batch['hm'][:,0,:,:]) / opt.num_stacks # hm_loss_car += self.crit(output['hm'][:,1,:,:], batch['hm'][:,1,:,:]) / opt.num_stacks # hm_loss_color_cone += self.crit(output['hm'][:,2,:,:], batch['hm'][:,2,:,:]) / opt.num_stacks # hm_loss_person += self.crit(output['hm'][:,3,:,:], batch['hm'][:,3,:,:]) / opt.num_stacks if opt.wh_weight > 0: if opt.dense_wh: mask_weight = batch['dense_wh_mask'].sum() + 1e-4 wh_loss += (self.crit_wh( output['wh'] * batch['dense_wh_mask'], batch['dense_wh'] * batch['dense_wh_mask']) / mask_weight) / opt.num_stacks elif opt.cat_spec_wh: wh_loss += self.crit_wh( output['wh'], batch['cat_spec_mask'], batch['ind'], batch['cat_spec_wh']) / opt.num_stacks else: wh_loss += self.crit_reg(output['wh'], batch['reg_mask'], batch['ind'], batch['wh']) / opt.num_stacks if opt.reg_offset and opt.off_weight > 0: off_loss += self.crit_reg(output['reg'], batch['reg_mask'], batch['ind'], batch['reg']) / opt.num_stacks loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \ opt.off_weight * off_loss loss_stats = { 'loss': loss, 'hm_loss': hm_loss, 'wh_loss': wh_loss, 'off_loss': off_loss } for cat_name, cat_loss in hm_loss_categorie.items(): loss_stats.update({'hm_loss_' + cat_name: cat_loss}) return loss, loss_stats
def forward(self, outputs, batch): opt = self.opt hm_loss, wh_loss, off_loss = 0, 0, 0 for s in range(opt.num_stacks): # extract outputs for loss calculation output = outputs[s] # ! pc: why sigmoid? if not opt.mse_loss: output['hm'] = _sigmoid(output['hm']) # ? evaluate groundtruth if opt.eval_oracle_hm: output['hm'] = batch['hm'] if opt.eval_oracle_wh: output['wh'] = torch.from_numpy( gen_oracle_map(batch['wh'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['wh'].shape[3], output['wh'].shape[2])).to(opt.device) if opt.eval_oracle_offset: output['reg'] = torch.from_numpy( gen_oracle_map(batch['reg'].detach().cpu().numpy(), batch['ind'].detach().cpu().numpy(), output['reg'].shape[3], output['reg'].shape[2])).to(opt.device) # calculate hm loss hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks # calculate size loss if opt.wh_weight > 0: if opt.dense_wh: mask_weight = batch['dense_wh_mask'].sum() + 1e-4 wh_loss += (self.crit_wh( output['wh'] * batch['dense_wh_mask'], batch['dense_wh'] * batch['dense_wh_mask']) / mask_weight) / opt.num_stacks elif opt.cat_spec_wh: wh_loss += self.crit_wh( output['wh'], batch['cat_spec_mask'], batch['ind'], batch['cat_spec_wh']) / opt.num_stacks else: wh_loss += self.crit_reg(output['wh'], batch['reg_mask'], batch['ind'], batch['wh']) / opt.num_stacks # calculate off-center loss if opt.reg_offset and opt.off_weight > 0: off_loss += self.crit_reg(output['reg'], batch['reg_mask'], batch['ind'], batch['reg']) / opt.num_stacks # weighted overall loss loss = opt.hm_weight * hm_loss + \ opt.wh_weight * wh_loss + \ opt.off_weight * off_loss # construct loss dictionary loss_stats = { 'loss': loss, 'hm_loss': hm_loss, 'wh_loss': wh_loss, 'off_loss': off_loss } return loss, loss_stats