Example #1
0
    def forward(self, outputs, batch):
        cfg = self.cfg
        hm_loss, wh_loss, off_loss, seg_loss, seg_feat_loss = 0, 0, 0, 0, 0
        hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0
        hm, wh, hps, reg, hm_hp, hp_offset, seg_feat, seg = outputs

        for s in range(cfg.MODEL.NUM_STACKS):
            hm = _sigmoid(hm)
            if cfg.LOSS.HM_HP and not cfg.LOSS.MSE_LOSS:
                hm_hp = _sigmoid(hm_hp)

            if cfg.TEST.EVAL_ORACLE_HMHP:
                hm_hp = batch['hm_hp']
            if cfg.TEST.EVAL_ORACLE_HM:
                hm = batch['hm']
            if cfg.TEST.EVAL_ORACLE_KPS:
                if cfg.LOSS.DENSE_HP:
                    hps = batch['dense_hps']
                else:
                    hps = torch.from_numpy(
                        gen_oracle_map(
                            batch['hps'].detach().cpu().numpy(),
                            batch['ind'].detach().cpu().numpy(),
                            cfg.MODEL.OUTPUT_RES, cfg.MODEL.OUTPUT_RES)).to(
                                torch.device('cuda:%d' % self.local_rank))
            if cfg.TEST.EVAL_ORACLE_HP_OFFSET:
                hp_offset = torch.from_numpy(
                    gen_oracle_map(
                        hp_offset.detach().cpu().numpy(),
                        batch['hp_ind'].detach().cpu().numpy(),
                        cfg.MODEL.OUTPUT_RES, cfg.MODEL.OUTPUT_RES)).to(
                            torch.device('cuda:%d' % self.local_rank))

            hm_loss += self.crit(hm, batch['hm']) / cfg.MODEL.NUM_STACKS
            if cfg.LOSS.DENSE_HP:
                mask_weight = batch['dense_hps_mask'].sum() + 1e-4
                hp_loss += (self.crit_kp(
                    hps * batch['dense_hps_mask'],
                    batch['dense_hps'] * batch['dense_hps_mask']) /
                            mask_weight) / cfg.MODEL.NUM_STACKS
            else:
                hp_loss += self.crit_kp(hps, batch['hps_mask'], batch['ind'],
                                        batch['hps']) / cfg.MODEL.NUM_STACKS
            if cfg.LOSS.WH_WEIGHT > 0:
                wh_loss += self.crit_reg(wh, batch['reg_mask'], batch['ind'],
                                         batch['wh']) / cfg.MODEL.NUM_STACKS
            if cfg.LOSS.REG_OFFSET and cfg.LOSS.OFF_WEIGHT > 0:
                off_loss += self.crit_reg(reg, batch['reg_mask'], batch['ind'],
                                          batch['reg']) / cfg.MODEL.NUM_STACKS
            if cfg.LOSS.REG_HP_OFFSET and cfg.LOSS.OFF_WEIGHT > 0:
                hp_offset_loss += self.crit_reg(
                    hp_offset, batch['hp_mask'], batch['hp_ind'],
                    batch['hp_offset']) / cfg.MODEL.NUM_STACKS
            if cfg.LOSS.HM_HP and cfg.LOSS.HM_HP_WEIGHT > 0:
                hm_hp_loss += self.crit_hm_hp(
                    hm_hp, batch['hm_hp']) / cfg.MODEL.NUM_STACKS

            seg_loss += self.crit_seg(seg, seg_feat, batch['ind'],
                                      batch['seg'])

        loss = cfg.LOSS.HM_WEIGHT * hm_loss + cfg.LOSS.WH_WEIGHT * wh_loss + \
               cfg.LOSS.OFF_WEIGHT * off_loss + cfg.LOSS.HP_WEIGHT * hp_loss + \
               cfg.LOSS.HM_HP_WEIGHT * hm_hp_loss + cfg.LOSS.OFF_WEIGHT * hp_offset_loss+\
               seg_loss

        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'hp_loss': hp_loss,
            'hm_hp_loss': hm_hp_loss,
            'hp_offset_loss': hp_offset_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss,
            'seg_loss': seg_loss
        }
        return loss, loss_stats
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss = 0, 0, 0
        hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            output['hm'] = _sigmoid(output['hm'])
            if opt.hm_hp and not opt.mse_loss:
                output['hm_hp'] = _sigmoid(output['hm_hp'])

            if opt.eval_oracle_hmhp:
                output['hm_hp'] = batch['hm_hp']
            if opt.eval_oracle_hm:
                output['hm'] = batch['hm']
            if opt.eval_oracle_kps:
                if opt.dense_hp:
                    output['hps'] = batch['dense_hps']
                else:
                    output['hps'] = torch.from_numpy(
                        gen_oracle_map(batch['hps'].detach().cpu().numpy(),
                                       batch['ind'].detach().cpu().numpy(),
                                       opt.output_res,
                                       opt.output_res)).to(opt.device)
            if opt.eval_oracle_hp_offset:
                output['hp_offset'] = torch.from_numpy(
                    gen_oracle_map(batch['hp_offset'].detach().cpu().numpy(),
                                   batch['hp_ind'].detach().cpu().numpy(),
                                   opt.output_res,
                                   opt.output_res)).to(opt.device)

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.dense_hp:
                mask_weight = batch['dense_hps_mask'].sum() + 1e-4
                hp_loss += (self.crit_kp(
                    output['hps'] * batch['dense_hps_mask'], batch['dense_hps']
                    * batch['dense_hps_mask']) / mask_weight) / opt.num_stacks
            else:
                hp_loss += self.crit_kp(output['hps'], batch['hps_mask'],
                                        batch['ind'],
                                        batch['hps']) / opt.num_stacks
            if opt.wh_weight > 0:
                wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                         batch['ind'],
                                         batch['wh']) / opt.num_stacks
            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks
            if opt.reg_hp_offset and opt.off_weight > 0:
                hp_offset_loss += self.crit_reg(
                    output['hp_offset'], batch['hp_mask'], batch['hp_ind'],
                    batch['hp_offset']) / opt.num_stacks
            if opt.hm_hp and opt.hm_hp_weight > 0:
                hm_hp_loss += self.crit_hm_hp(output['hm_hp'],
                                              batch['hm_hp']) / opt.num_stacks
        loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
               opt.off_weight * off_loss + opt.hp_weight * hp_loss + \
               opt.hm_hp_weight * hm_hp_loss + opt.off_weight * hp_offset_loss

        loss_stats = {
            'loss': loss,
            'hm_loss': opt.hm_weight * hm_loss,
            'hp_loss': opt.hp_weight * hp_loss,
            'hm_hp_loss': opt.hm_hp_weight * hm_hp_loss,
            'hp_offset_loss': opt.off_weight * hp_offset_loss,
            'wh_loss': opt.wh_weight * wh_loss,
            'off_loss': opt.off_weight * off_loss
        }
        return loss, loss_stats
Example #3
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss, proposal_loss, proposal_scale_loss = 0, 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            if opt.eval_oracle_hm:
                output['hm'] = batch['hm']
            if opt.eval_oracle_wh:
                output['wh'] = torch.from_numpy(
                    gen_oracle_map(batch['wh'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['wh'].shape[3],
                                   output['wh'].shape[2])).to(opt.device)
            if opt.eval_oracle_offset:
                output['reg'] = torch.from_numpy(
                    gen_oracle_map(batch['reg'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['reg'].shape[3],
                                   output['reg'].shape[2])).to(opt.device)

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(
                        output['wh'] * batch['dense_wh_mask'],
                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / opt.num_stacks
                elif opt.cat_spec_wh:
                    wh_loss += self.crit_wh(
                        output['wh'], batch['cat_spec_mask'], batch['ind'],
                        batch['cat_spec_wh']) / opt.num_stacks
                else:
                    if opt.reg_proposal:
                        out_scale = output['scale'].detach()
                        wh_loss += self.crit_reg(
                            output['wh'] + out_scale, batch['reg_mask'],
                            batch['ind'], batch['wh']) / opt.num_stacks
                    else:
                        wh_loss += self.crit_reg(
                            output['wh'], batch['reg_mask'], batch['ind'],
                            batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

            if opt.reg_proposal and opt.proposal_weight > 0:
                output['proposal'] = _sigmoid(
                    output['proposal'])  # for focal loss
                ignore_mask = batch['proposal'].gt(-1).float()
                proposal_loss += self.crit_centerness(
                    output['proposal'], batch['proposal'],
                    ignore_mask) / opt.num_stacks
                ignore_scale_mask = batch['scale'].gt(0).float()
                valid_num = ignore_scale_mask.sum()
                proposal_scale_loss += self.crit_scale(
                    output['scale'] * ignore_scale_mask,
                    batch['scale']) / valid_num / opt.num_stacks

        loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
               opt.off_weight * off_loss + opt.proposal_weight * proposal_loss + opt.scale_weight * proposal_scale_loss

        if opt.reg_proposal:
            loss_stats = {
                'loss': loss,
                'proposal_loss': proposal_loss,
                'scale_loss': proposal_scale_loss,
                'hm_loss': hm_loss,
                'wh_loss': wh_loss,
                'off_loss': off_loss
            }
        else:
            loss_stats = {
                'loss': loss,
                'hm_loss': hm_loss,
                'wh_loss': wh_loss,
                'off_loss': off_loss
            }
        return loss, loss_stats
Example #4
0
    def forward(self, outputs, batch):
        opt = self.opt

        hm_loss, dep_loss, rot_loss, dim_loss = 0, 0, 0, 0
        wh_loss, off_loss = 0, 0
        tilt_loss = 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            output['hm'] = _sigmoid(output['hm'])
            output['dep'] = 1. / (output['dep'].sigmoid() + 1e-6) - 1.

            if opt.eval_oracle_dep:
                output['dep'] = torch.from_numpy(
                    gen_oracle_map(batch['dep'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   opt.output_w, opt.output_h)).to(opt.device)

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks

            if opt.dep_weight > 0:
                #depth : L1 Loss
                dep_loss += self.crit_reg(output['dep'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['dep']) / opt.num_stacks
            if opt.dim_weight > 0:
                #3D dimension : L1 Loss
                dim_loss += self.crit_reg(output['dim'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['dim']) / opt.num_stacks
            if opt.rot_weight > 0:
                #orientation
                rot_loss += self.crit_rot(output['rot'], batch['rot_mask'],
                                          batch['ind'], batch['rotbin'],
                                          batch['rotres']) / opt.num_stacks
                # print('rot_mask')
                # print(batch['rot_mask'])
            #tilt
            # if opt.tilt_weight > 0:
            #   tilt_loss += self.crit_reg(output['tilt'], batch['rot_mask'],
            #                             batch['ind'], batch['tilt']) / opt.num_stacks
            # print('reg_mask')
            # print(batch['reg_mask'])
            if opt.reg_bbox and opt.wh_weight > 0:
                wh_loss += self.crit_reg(output['wh'], batch['rot_mask'],
                                         batch['ind'],
                                         batch['wh']) / opt.num_stacks
            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['rot_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

        if opt.reg_bbox:
            loss = opt.hm_weight * hm_loss + opt.dep_weight * dep_loss + \
                   opt.dim_weight * dim_loss + opt.rot_weight * rot_loss + \
                   opt.wh_weight * wh_loss + opt.off_weight * off_loss

            loss_stats = {
                'loss': loss,
                'hm_loss': hm_loss,
                'dep_loss': dep_loss,
                'dim_loss': dim_loss,
                'rot_loss': rot_loss,
                'wh_loss': wh_loss,
                'off_loss': off_loss,
                'tilt_loss': tilt_loss
            }
        else:
            loss = opt.hm_weight * hm_loss + opt.dep_weight * dep_loss + \
                   opt.dim_weight * dim_loss + opt.rot_weight * rot_loss + \
                   opt.off_weight * off_loss
            loss_stats = {
                'loss': loss,
                'hm_loss': hm_loss,
                'dep_loss': dep_loss,
                'dim_loss': dim_loss,
                'rot_loss': rot_loss,
                'off_loss': off_loss
            }

        return loss, loss_stats
Example #5
0
    def forward(self, outputs, batch):
        opt = self.opt
        focal_loss, pull_loss, push_loss, reg_loss = 0, 0, 0, 0
        lm_focal_loss, rm_focal_loss, ct_focal_loss = 0, 0, 0
        lm_reg_loss, rm_reg_loss, ct_reg_loss = 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['lm'] = _sigmoid(output['lm'])
                output['rm'] = _sigmoid(output['rm'])
                output['ct'] = _sigmoid(output['ct'])

            if opt.eval_oracle_lm:
                output['lm'] = batch['lm']
            if opt.eval_oracle_rm:
                output['rm'] = batch['rm']
            if opt.eval_oracle_ct:
                output['ct'] = batch['ct']
            if opt.eval_oracle_ae:
                output['lm_tag'] = torch.from_numpy(
                    gen_oracle_map(batch['lm_tag'].detach().cpu().numpy(),
                                   batch['lm_tag'].detach().cpu().numpy(),
                                   output['lm_tag'].shape[3],
                                   output['lm_tag'].shape[2])).to(opt.device)
                output['rm_tag'] = torch.from_numpy(
                    gen_oracle_map(batch['rm_tag'].detach().cpu().numpy(),
                                   batch['rm_tag'].detach().cpu().numpy(),
                                   output['rm_tag'].shape[3],
                                   output['rm_tag'].shape[2])).to(opt.device)
            if opt.eval_oracle_offset:
                output['lm_reg'] = torch.from_numpy(
                    gen_oracle_map(batch['lm_reg'].detach().cpu().numpy(),
                                   batch['lm_tag'].detach().cpu().numpy(),
                                   output['lm_reg'].shape[3],
                                   output['lm_reg'].shape[2])).to(opt.device)
                output['rm_reg'] = torch.from_numpy(
                    gen_oracle_map(batch['rm_reg'].detach().cpu().numpy(),
                                   batch['rm_tag'].detach().cpu().numpy(),
                                   output['rm_reg'].shape[3],
                                   output['rm_reg'].shape[2])).to(opt.device)
                output['ct_reg'] = torch.from_numpy(
                    gen_oracle_map(batch['ct_reg'].detach().cpu().numpy(),
                                   batch['ct_tag'].detach().cpu().numpy(),
                                   output['ct_reg'].shape[3],
                                   output['ct_reg'].shape[2])).to(opt.device)

            # focal loss
            lm_focal_loss = self.crit(output['lm'],
                                      batch['lm']) / opt.num_stacks
            rm_focal_loss = self.crit(output['rm'],
                                      batch['rm']) / opt.num_stacks
            ct_focal_loss = self.crit(output['ct'],
                                      batch['ct']) / opt.num_stacks
            focal_loss += lm_focal_loss
            focal_loss += rm_focal_loss
            focal_loss += ct_focal_loss

            # tag loss
            pull, push = self.crit_tag(output['rm_tag'], output['lm_tag'],
                                       batch['rm_tag'], batch['lm_tag'],
                                       batch['reg_mask'])
            pull_loss += opt.pull_weight * pull / opt.num_stacks
            push_loss += opt.push_weight * push / opt.num_stacks

            # reg loss
            lm_reg_loss = opt.regr_weight * self.crit_reg(
                output['lm_reg'], batch['reg_mask'], batch['lm_tag'],
                batch['lm_reg']) / opt.num_stacks
            rm_reg_loss = opt.regr_weight * self.crit_reg(
                output['rm_reg'], batch['reg_mask'], batch['rm_tag'],
                batch['rm_reg']) / opt.num_stacks
            ct_reg_loss = opt.regr_weight * self.crit_reg(
                output['ct_reg'], batch['reg_mask'], batch['ct_tag'],
                batch['ct_reg']) / opt.num_stacks
            reg_loss += lm_reg_loss
            reg_loss += rm_reg_loss
            reg_loss += ct_reg_loss

        loss = focal_loss + pull_loss + push_loss + reg_loss
        loss_stats = {
            'loss': loss,
            'focal_loss': focal_loss,
            'pull_loss': pull_loss,
            'push_loss': push_loss,
            'reg_loss': reg_loss,
            'lm_focal_loss': lm_focal_loss,
            'rm_focal_loss': rm_focal_loss,
            'ct_focal_loss': ct_focal_loss,
            'lm_reg_loss': lm_reg_loss,
            'rm_reg_loss': rm_reg_loss,
            'ct_reg_loss': ct_reg_loss
        }
        return loss, loss_stats
Example #6
0
    def forward(self, outputs, batch):
        opt = self.opt

        hm_loss, dep_loss, rot_loss, dim_loss = 0, 0, 0, 0
        wh_loss, off_loss = 0, 0

        # opt.num_stacks = 2 if opt.arch == 'hourglass' else 1
        for s in range(opt.num_stacks):
            output = outputs[s]
            output['hm'] = _sigmoid(output['hm'])
            output['dep'] = 1. / (output['dep'].sigmoid() + 1e-6) - 1.

            if opt.eval_oracle_dep:
                output['dep'] = torch.from_numpy(
                    gen_oracle_map(batch['dep'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   opt.output_w, opt.output_h)).to(opt.device)

            import pdb
            pdb.set_trace()
            # opt.num_stacks:1
            # L2损失
            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.dep_weight > 0:

                # output['dep']:torch.Size([1, 1, 96, 320])
                # batch['reg_mask']:torch.Size([1, 50])
                # batch['ind']:torch.Size([1, 50])
                # batch['dep']:torch.Size([1, 50, 1])
                dep_loss += self.crit_reg(output['dep'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['dep']) / opt.num_stacks
            if opt.dim_weight > 0:
                # 定义的L1损失函数
                dim_loss += self.crit_reg(output['dim'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['dim']) / opt.num_stacks
            if opt.rot_weight > 0:
                rot_loss += self.crit_rot(output['rot'], batch['rot_mask'],
                                          batch['ind'], batch['rotbin'],
                                          batch['rotres']) / opt.num_stacks
            if opt.reg_bbox and opt.wh_weight > 0:
                wh_loss += self.crit_reg(output['wh'], batch['rot_mask'],
                                         batch['ind'],
                                         batch['wh']) / opt.num_stacks
            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['rot_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks
        loss = opt.hm_weight * hm_loss + opt.dep_weight * dep_loss + \
               opt.dim_weight * dim_loss + opt.rot_weight * rot_loss + \
               opt.wh_weight * wh_loss + opt.off_weight * off_loss

        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'dep_loss': dep_loss,
            'dim_loss': dim_loss,
            'rot_loss': rot_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss
        }
        return loss, loss_stats
Example #7
0
  def forward(self, outputs, batch,global_step,tb_writer):
    opt = self.opt
    hm_loss, wh_loss, off_loss = 0, 0, 0
    hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0

    loss_stats = {}
    for s in range(opt.num_stacks):
      output = outputs[s]
      output['hm'] = _sigmoid(output['hm'])
      if opt.hm_hp:
        output['hm_hp'] = _sigmoid(output['hm_hp'])
      
      if opt.eval_oracle_hmhp:
        output['hm_hp'] = batch['hm_hp']
      if opt.eval_oracle_hm:
        output['hm'] = batch['hm']
      if opt.eval_oracle_kps:
        if opt.dense_hp:
          output['hps'] = batch['dense_hps']
        else:
          output['hps'] = torch.from_numpy(gen_oracle_map(
            batch['hps'].detach().cpu().numpy(), 
            batch['ind'].detach().cpu().numpy(), 
            opt.output_res, opt.output_res)).to(opt.device)
      if opt.eval_oracle_hp_offset:
        output['hp_offset'] = torch.from_numpy(gen_oracle_map(
          batch['hp_offset'].detach().cpu().numpy(), 
          batch['hp_ind'].detach().cpu().numpy(), 
          opt.output_res, opt.output_res)).to(opt.device)

      if opt.mdn:
        V=torch.Tensor((np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0).astype(np.float32)).float().cuda()

      hm_loss += self.crit(output['hm'], batch['hm'])[0] / opt.num_stacks

      if opt.mdn:
        mdn_logits = output['mdn_logits']
        #mdn_logits.shape: torch.Size([2, 3, 128, 128])
        if opt.mdn_dropout > 0 and opt.epoch<opt.mdn_dropout_stop:
          M=opt.mdn_n_comps

          ridx= torch.randperm(M)[:torch.randint(1,1+opt.mdn_dropout,(1,))]
          drop_mask = torch.ones((M,))
          drop_mask[ridx]=0
          drop_mask = torch.reshape(drop_mask,(1,-1,1,1)).float().cuda()
          mdn_logits = mdn_logits*drop_mask
          if tb_writer is not None:
            tb_writer.add_histogram('drop_out_idx',ridx+1,global_step=global_step)
        else:
          mdn_pi = torch.clamp(torch.nn.Softmax(dim=1)(mdn_logits), 1e-4, 1.-1e-4)  
        
        mdn_sigma= torch.clamp(torch.nn.ELU()(output['mdn_sigma'])+opt.mdn_min_sigma,1e-4,1e5)
        mdn_mu = output['hps']

        if tb_writer  is not None:
          for i in range(mdn_pi.shape[1]):
            tb_writer.add_histogram('mdn_pi/{}'.format(i),mdn_pi[:,i],global_step=global_step)
            tb_writer.add_histogram('mdn_sigma/{}'.format(i),mdn_sigma[:,i*2:i*2+2],global_step=global_step)

        if opt.dense_hp:
          gt = batch['dense_hps']
          mask = batch['dense_hps_mask'][:,0::2,:,:]
          _,max_pi_ind = torch.max(mdn_pi,1)
        else:
          gt = batch['hps']
          mask = batch['hps_mask'][:,:,0::2]
          mdn_mu= _tranpose_and_gather_feat(mdn_mu, batch['ind'])
          mdn_pi= _tranpose_and_gather_feat(mdn_pi, batch['ind'])
          mdn_sigma= _tranpose_and_gather_feat(mdn_sigma, batch['ind'])
          _,max_pi_ind = torch.max(mdn_pi,-1)

        if tb_writer is not None:
          tb_writer.add_histogram('mdn_pi_max_comp',max_pi_ind+1,global_step=global_step)
        '''
          mdn_n_comps=3
          batch['hps'].shape: torch.Size([2, 32, 34])
          batch['hps_mask'].shape: torch.Size([2, 32, 34])
          batch['ind'].shape: torch.Size([2, 32])
          gt.shape: torch.Size([2, 32, 34])
          mask.shape: torch.Size([2, 32, 17])
          before gather, after gather
          mdn_mu.shape: torch.Size([2, 102, 128, 128]), torch.Size([2, 32, 102])
          mdn_pi.shape: torch.Size([2, 3, 128, 128]), torch.Size([2, 32, 3])
          mdn_sigma.shape: torch.Size([2, 6, 128, 128]), torch.Size([2, 32, 6])
        '''
        if opt.mdn_inter:
          hp_loss += self.crit_kp(gt,mdn_mu,mdn_sigma,mdn_pi,mask,V,debug=opt.debug==6)[0] / opt.num_stacks
        else:
          hp_loss = self.crit_kp(gt,mdn_mu,mdn_sigma,mdn_pi,mask,V,debug=opt.debug==6)[0] / opt.num_stacks
      else:
        if opt.dense_hp:
          mask_weight = batch['dense_hps_mask'].sum() + 1e-4
          hp_loss += (self.crit_kp(output['hps'] * batch['dense_hps_mask'], 
                                  batch['dense_hps'] * batch['dense_hps_mask']) / 
                                  mask_weight) / opt.num_stacks
        else:
          hp_loss += self.crit_kp(output['hps'], batch['hps_mask'], 
                                batch['ind'], batch['hps']) / opt.num_stacks
      
      if opt.wh_weight > 0:
        wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                 batch['ind'], batch['wh'])[0] / opt.num_stacks
      if opt.reg_offset and opt.off_weight > 0:
        off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                  batch['ind'], batch['reg'])[0] / opt.num_stacks
      if opt.reg_hp_offset and opt.off_weight > 0:
        hp_offset_loss += self.crit_reg(
          output['hp_offset'], batch['hp_mask'],
          batch['hp_ind'], batch['hp_offset'])[0] / opt.num_stacks
      if opt.hm_hp and opt.hm_hp_weight > 0:
        hm_hp_loss += self.crit_hm_hp(
          output['hm_hp'], batch['hm_hp'])[0] / opt.num_stacks

    loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
           opt.off_weight * off_loss + opt.hp_weight * hp_loss + \
           opt.hm_hp_weight * hm_hp_loss + opt.off_weight * hp_offset_loss
    
    loss_stats.update({'loss': loss, 'hm_loss': hm_loss, 'hp_loss': hp_loss, 
                  'hm_hp_loss': hm_hp_loss, 'hp_offset_loss': hp_offset_loss,
                  'wh_loss': wh_loss, 'off_loss': off_loss})
    return loss, loss_stats
Example #8
0
    def forward(self, outputs, batch):
        """
        Arguments:
            locations (list[BoxList])
            box_cls (list[Tensor])
            box_regression (list[Tensor])
            centerness (list[Tensor])
            targets (list[BoxList])

        Returns:
            cls_loss (Tensor)
            reg_loss (Tensor)
            centerness_loss (Tensor)
        """
        box_cls, box_regression, centerness = outputs[-1]
        box_cls = [_sigmoid(cls) for cls in box_cls]
        if len(centerness): centerness = [_sigmoid(ness) for ness in centerness]
        locations = self.compute_locations(box_regression)
        num_classes = box_cls[0].size(1)
        labels, reg_targets = self.prepare_targets(locations, batch)

        box_cls_flatten = []
        box_regression_flatten = []
        labels_flatten = []
        reg_targets_flatten = []
        centerness_flatten = []
        for l in range(len(labels)):
            box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes))
            box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4))
            labels_flatten.append(labels[l].reshape(-1))
            reg_targets_flatten.append(reg_targets[l].reshape(-1, 4))
            if len(centerness): centerness_flatten.append(centerness[l].reshape(-1))

        box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
        box_regression_flatten = torch.cat(box_regression_flatten, dim=0)
        labels_flatten = torch.cat(labels_flatten, dim=0)
        reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)
        centerness_targets = self.compute_centerness_targets(reg_targets_flatten)

        pos_inds = torch.nonzero(labels_flatten >= 0)  # pytorch version incompatible
        if pos_inds.numel() > 0:
            pos_inds = pos_inds.squeeze(1)

        box_regression_flatten = box_regression_flatten[pos_inds]
        reg_targets_flatten = reg_targets_flatten[pos_inds]

        if len(centerness):
            centerness_flatten = torch.cat(centerness_flatten, dim=0)
            centerness_flatten = centerness_flatten[pos_inds]
            cls_loss = self.cls_loss_func(box_cls_flatten, labels_flatten.int()) / max(pos_inds.numel(), 1)
        else:
            cls_loss = self.cls_loss_func(box_cls_flatten, labels_flatten.int(), centerness_targets[:, None])
            cls_loss = cls_loss / torch.clamp(centerness_targets[pos_inds].sum(), min=1)

        centerness_loss = cls_loss.new_zeros(1)
        if pos_inds.numel() > 0:
            centerness_targets = centerness_targets[pos_inds]
            reg_loss = self.box_reg_loss_func(
                box_regression_flatten,
                reg_targets_flatten,
                centerness_targets
            ) / centerness_targets.sum()
            if len(centerness):
                centerness_loss = self.centerness_loss_func(
                    centerness_flatten,
                    centerness_targets
                ) / pos_inds.numel()
        else:
            reg_loss = box_regression_flatten.sum()

        loss = cls_loss + reg_loss + centerness_loss
        loss_stats = {'loss': loss, 'cls': cls_loss, 'reg': reg_loss, 'centerness': centerness_loss}
        return loss, loss_stats
Example #9
0
  def forward(self, outputs_1, outputs_2, batch):
    opt = self.opt
    hm_loss_1, wh_loss_1, off_loss_1 = 0, 0, 0
    hm_loss_2, wh_loss_2, off_loss_2 = 0, 0, 0
    for s in range(opt.num_stacks):
      output_1 = outputs_1[s]
      output_2 = outputs_2[s]
      if not opt.mse_loss:
        output_1['hm_1'] = _sigmoid(output_1['hm_1'])
        output_2['hm_2'] = _sigmoid(output_2['hm_2'])

      if opt.eval_oracle_hm:
        output_1['hm_1'] = batch['hm_1']
        output_2['hm_2'] = batch['hm_2']

      if opt.eval_oracle_wh:
        output_1['wh_1'] = torch.from_numpy(gen_oracle_map(
          batch['wh_1'].detach().cpu().numpy(), 
          batch['ind_1'].detach().cpu().numpy(), 
          output_1['wh_1'].shape[3], output_1['wh_1'].shape[2])).to(opt.device)
        output_2['wh_2'] = torch.from_numpy(gen_oracle_map(
          batch['wh_2'].detach().cpu().numpy(), 
          batch['ind_2'].detach().cpu().numpy(), 
          output_2['wh_2'].shape[3], output_2['wh_2'].shape[2])).to(opt.device)

      if opt.eval_oracle_offset:
        output_1['reg_1'] = torch.from_numpy(gen_oracle_map(
          batch['reg_1'].detach().cpu().numpy(), 
          batch['ind_1'].detach().cpu().numpy(), 
          output_1['reg_1'].shape[3], output_1['reg_1'].shape[2])).to(opt.device)
        output_2['reg_2'] = torch.from_numpy(gen_oracle_map(
          batch['reg_2'].detach().cpu().numpy(), 
          batch['ind_2'].detach().cpu().numpy(), 
          output_2['reg_2'].shape[3], output_2['reg_2'].shape[2])).to(opt.device)

      hm_loss_1 += self.crit(output_1['hm_1'], batch['hm_1']) / opt.num_stacks
      hm_loss_2 += self.crit(output_2['hm_2'], batch['hm_2']) / opt.num_stacks

      if opt.wh_weight > 0:
        if opt.dense_wh:
          mask_weight = batch['dense_wh_mask'].sum() + 1e-4      # '''
          wh_loss += (
            self.crit_wh(output['wh'] * batch['dense_wh_mask'],  # 不可用,沒更新成multitask版本
            batch['dense_wh'] * batch['dense_wh_mask']) / 
            mask_weight) / opt.num_stacks                        # '''

        elif opt.cat_spec_wh:
          wh_loss_1 += self.crit_wh(
            output_1['wh_1'], batch['cat_spec_mask_1'],
            batch['ind_1'], batch['cat_spec_wh_1']) / opt.num_stacks
          wh_loss_2 += self.crit_wh(
            output_2['wh_2'], batch['cat_spec_mask_2'],
            batch['ind_2'], batch['cat_spec_wh_2']) / opt.num_stacks

        else:
          wh_loss_1 += self.crit_reg(
            output_1['wh_1'], batch['reg_mask_1'],
            batch['ind_1'], batch['wh_1'], output_1['hm_1'], batch['hm_1']) / opt.num_stacks
          wh_loss_2 += self.crit_reg(
            output_2['wh_2'], batch['reg_mask_2'],
            batch['ind_2'], batch['wh_2'], output_1['hm_2'], batch['hm_2']) / opt.num_stacks
      
      if opt.reg_offset and opt.off_weight > 0:
        off_loss_1 += self.crit_reg(output_1['reg_1'], batch['reg_mask_1'],
                             batch['ind_1'], batch['reg_1'],
                             output_1['hm_1'], batch['hm_1']) / opt.num_stacks
        off_loss_2 += self.crit_reg(output_2['reg_2'], batch['reg_mask_2'],
                             batch['ind_2'], batch['reg_2'],
                             output_1['hm_2'], batch['hm_2']) / opt.num_stacks
        
    loss_1 = opt.hm_weight * hm_loss_1 + opt.wh_weight * wh_loss_1 + \
           opt.off_weight * off_loss_1
    loss_2 = opt.hm_weight * hm_loss_2 + opt.wh_weight * wh_loss_2 + \
           opt.off_weight * off_loss_2

    loss_type = 'normal'
    if loss_type == 'weighted':
      loss_weight_1, loss_weight_2 = task_weight[0], task_weight[1]
      loss = (loss_weight_1 * loss_1) + (loss_weight_2 * loss_2)
    elif loss_type == 'geometric':
      n = task
      loss = (loss_1 * loss_2)**(1/n)
    else:
      loss_weight_1, loss_weight_2 = 1.0, 1.0
      loss = (loss_weight_1 * loss_1) + (loss_weight_2 * loss_2)
      
    loss_stats = {'loss': loss, 'loss_1': loss_1, 'hm_loss_1': hm_loss_1,
                  'wh_loss_1': wh_loss_1, 'off_loss_1': off_loss_1, 
                  'loss_2': loss_2, 'hm_loss_2': hm_loss_2,
                  'wh_loss_2': wh_loss_2, 'off_loss_2': off_loss_2}
    return loss, loss_stats
Example #10
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss, ang_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):  # num_stacks = 1
            output = outputs[s]
            if not opt.mse_loss:
                #这里使用sigmoid函数自己理解人为这是对heatmap做归一化, 在进行对heatmap loss计算加快收敛。
                output['hm'] = _sigmoid(output['hm'])

            if opt.eval_oracle_hm:
                output['hm'] = batch['hm']
            if opt.eval_oracle_wh:
                output['wh'] = torch.from_numpy(
                    gen_oracle_map(batch['wh'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['wh'].shape[3],
                                   output['wh'].shape[2])).to(opt.device)
            if opt.eval_oracle_offset:
                output['reg'] = torch.from_numpy(
                    gen_oracle_map(batch['reg'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['reg'].shape[3],
                                   output['reg'].shape[2])).to(opt.device)

            hm_loss += self.crit(output['hm'],
                                 batch['hm'])['loss'] / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(
                        output['wh'] * batch['dense_wh_mask'],
                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / opt.num_stacks
                elif opt.cat_spec_wh:
                    wh_loss += self.crit_wh(
                        output['wh'], batch['cat_spec_mask'], batch['ind'],
                        batch['cat_spec_wh']) / opt.num_stacks
                else:
                    wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                             batch['ind'],
                                             batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

            if opt.ang_weight > 0:
                ang_loss += self.crit_reg(output['ang'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['ang']) / opt.num_stacks

        loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
               opt.off_weight * off_loss+ opt.ang_weight*ang_loss
        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss,
            'ang_loss': ang_loss
        }
        return loss, loss_stats
Example #11
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, angle_loss, off_loss = 0, 0, 0, 0
        for s in range(len(outputs)):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            if opt.eval_oracle_hm:
                output['hm'] = batch['hm']

            if opt.eval_oracle_wh:
                output['wh'] = torch.from_numpy(
                    gen_oracle_map(batch['wh'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['wh'].shape[3],
                                   output['wh'].shape[2])).to(opt.device)

            if opt.eval_oracle_angle:
                output['angle'] = torch.from_numpy(
                    gen_oracle_map(batch['angle'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['angle'].shape[3],
                                   output['angle'].shape[2])).to(opt.device)

            if opt.eval_oracle_offset:
                output['reg'] = torch.from_numpy(
                    gen_oracle_map(batch['reg'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['reg'].shape[3],
                                   output['reg'].shape[2])).to(opt.device)

            hm_loss += self.crit(output['hm'], batch['hm']) / len(outputs)

            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(
                        output['wh'] * batch['dense_wh_mask'],
                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / len(outputs)
                elif opt.cat_spec_wh:
                    wh_loss += self.crit_wh(
                        output['wh'], batch['cat_spec_mask'], batch['ind'],
                        batch['cat_spec_wh']) / len(outputs)
                else:
                    wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                             batch['ind'],
                                             batch['wh']) / len(outputs)

            if opt.angle_weight > 0:
                # output['angle'] = _sigmoid(output['angle'])
                if opt.dense_angle:
                    # mask_weight = batch['dense_angle_mask'].sum() + 1e-4
                    # angle_loss += (self.crit_angle(output['angle'] * batch['dense_angle_mask'],
                    #                 batch['dense_angle'] * batch['dense_angle_mask']) /mask_weight) / len(outputs)
                    angle_loss += self.crit_dense_angle(
                        output['angle'], batch['dense_angle_mask'],
                        batch['dense_angle']) / len(outputs)
                elif opt.cat_spec_angle:
                    angle_loss += self.crit_angle(
                        output['angle'], batch['cat_spec_angle_mask'],
                        batch['ind'], batch['cat_spec_angle']) / len(outputs)
                else:
                    angle_loss += self.crit_reg(
                        output['angle'], batch['reg_mask'], batch['ind'],
                        batch['angle']) / len(outputs)

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / len(outputs)

        loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
               opt.off_weight * off_loss + opt.angle_weight * angle_loss
        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'angle_loss': angle_loss,
            'off_loss': off_loss
        }

        return loss, loss_stats
Example #12
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss, lincomb_mask_loss, segm_loss = 0, 0, 0, 0, 0
        # hm_loss, wh_loss, off_loss, segm_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]

            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            if opt.eval_oracle_hm:
                output['hm'] = batch['hm']
            if opt.eval_oracle_wh:
                output['wh'] = torch.from_numpy(
                    gen_oracle_map(batch['wh'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['wh'].shape[3],
                                   output['wh'].shape[2])).to(opt.device)
            if opt.eval_oracle_offset:
                output['reg'] = torch.from_numpy(
                    gen_oracle_map(batch['reg'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['reg'].shape[3],
                                   output['reg'].shape[2])).to(opt.device)

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(
                        output['wh'] * batch['dense_wh_mask'],
                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / opt.num_stacks
                elif opt.cat_spec_wh:
                    wh_loss += self.crit_wh(
                        output['wh'], batch['cat_spec_mask'], batch['ind'],
                        batch['cat_spec_wh']) / opt.num_stacks
                else:
                    wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                             batch['ind'],
                                             batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

            if opt.use_semantic_segmentation_loss:
                segm_loss += self.semantic_segmentation_loss(
                    output['segm'], batch['masks'],
                    batch['gt_bbox_lbl']) / opt.num_stacks
            lincomb_mask_loss += self.lincomb_mask_loss(
                output['masks'], output['proto'], batch['reg_mask'],
                batch['masks'], batch['gt_bbox_lbl']) / opt.num_stacks

        loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
               opt.off_weight * off_loss + \
               opt.lincomb_mask_weight * lincomb_mask_loss + opt.segm_weight * segm_loss
        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss,
            'lincomb_mask_loss': lincomb_mask_loss,
            'segm_loss': segm_loss
        }
        # loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
        #        opt.off_weight * off_loss +  opt.segm_weight * segm_loss
        # loss_stats = {'loss': loss, 'hm_loss': hm_loss,
        #               'wh_loss': wh_loss, 'off_loss': off_loss, 'segm_loss': segm_loss}
        return loss, loss_stats
Example #13
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss = 0, 0, 0

        hm_loss_bike, hm_loss_car, hm_loss_color_cone, hm_loss_person = 0, 0, 0, 0

        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            if opt.eval_oracle_hm:
                output['hm'] = batch['hm']
            if opt.eval_oracle_wh:
                output['wh'] = torch.from_numpy(
                    gen_oracle_map(batch['wh'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['wh'].shape[3],
                                   output['wh'].shape[2])).to(opt.device)
            if opt.eval_oracle_offset:
                output['reg'] = torch.from_numpy(
                    gen_oracle_map(batch['reg'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['reg'].shape[3],
                                   output['reg'].shape[2])).to(opt.device)

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks

            hm_loss_categorie = {}
            for i, categorie in enumerate(cf.categories):
                hm_loss_categorie[categorie] = self.crit(
                    output['hm'][:, i, :, :],
                    batch['hm'][:, i, :, :]) / opt.num_stacks

            # hm_loss_bike += self.crit(output['hm'][:,0,:,:], batch['hm'][:,0,:,:]) / opt.num_stacks
            # hm_loss_car += self.crit(output['hm'][:,1,:,:], batch['hm'][:,1,:,:]) / opt.num_stacks
            # hm_loss_color_cone += self.crit(output['hm'][:,2,:,:], batch['hm'][:,2,:,:]) / opt.num_stacks
            # hm_loss_person += self.crit(output['hm'][:,3,:,:], batch['hm'][:,3,:,:]) / opt.num_stacks

            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(
                        output['wh'] * batch['dense_wh_mask'],
                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / opt.num_stacks
                elif opt.cat_spec_wh:
                    wh_loss += self.crit_wh(
                        output['wh'], batch['cat_spec_mask'], batch['ind'],
                        batch['cat_spec_wh']) / opt.num_stacks
                else:
                    wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                             batch['ind'],
                                             batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

        loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
               opt.off_weight * off_loss
        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss
        }

        for cat_name, cat_loss in hm_loss_categorie.items():
            loss_stats.update({'hm_loss_' + cat_name: cat_loss})

        return loss, loss_stats
Example #14
0
    def forward(self, outputs, batch):

        opt = self.opt
        hm_loss, wh_loss, off_loss = 0, 0, 0

        for s in range(opt.num_stacks):
            # extract outputs for loss calculation
            output = outputs[s]

            # ! pc: why sigmoid?
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            # ? evaluate groundtruth
            if opt.eval_oracle_hm:
                output['hm'] = batch['hm']

            if opt.eval_oracle_wh:
                output['wh'] = torch.from_numpy(
                    gen_oracle_map(batch['wh'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['wh'].shape[3],
                                   output['wh'].shape[2])).to(opt.device)

            if opt.eval_oracle_offset:
                output['reg'] = torch.from_numpy(
                    gen_oracle_map(batch['reg'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['reg'].shape[3],
                                   output['reg'].shape[2])).to(opt.device)

            # calculate hm loss
            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks

            # calculate size loss
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(
                        output['wh'] * batch['dense_wh_mask'],
                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / opt.num_stacks
                elif opt.cat_spec_wh:
                    wh_loss += self.crit_wh(
                        output['wh'], batch['cat_spec_mask'], batch['ind'],
                        batch['cat_spec_wh']) / opt.num_stacks
                else:
                    wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                             batch['ind'],
                                             batch['wh']) / opt.num_stacks

            # calculate off-center loss
            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

        # weighted overall loss
        loss = opt.hm_weight * hm_loss + \
            opt.wh_weight * wh_loss + \
            opt.off_weight * off_loss

        # construct loss dictionary
        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss
        }

        return loss, loss_stats