Exemplo n.º 1
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss, seg_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks

            wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                     batch['ind'],
                                     batch['wh']) / opt.num_stacks

            off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                      batch['ind'],
                                      batch['reg']) / opt.num_stacks

            seg_loss += self.crit_seg(output['seg'], output['seg_feat'],
                                      batch['ind'], batch['seg'])

        loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
               opt.off_weight * off_loss + opt.seg_weight * seg_loss

        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss,
            'seg_loss': seg_loss
        }

        return loss, loss_stats
Exemplo n.º 2
0
    def forward(self, outputs, batch):
        """
        :param outputs:
        :param batch:
        :return:
        """
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss = 0.0, 0.0, 0.0, 0.0  # 初始化4个loss为0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            # 计算heatmap loss
            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(output['wh'] * batch['dense_wh_mask'],
                                             batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / opt.num_stacks
                else:  # 计算box尺寸的L1/Smooth L1 loss
                    wh_loss += self.crit_reg(
                        output['wh'], batch['reg_mask'],
                        batch['ind'], batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:  # 计算box中心坐标偏移的L1 loss
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'], batch['reg']) / opt.num_stacks

            # 检测目标id分类的交叉熵损失
            if opt.id_weight > 0:
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                id_head = id_head[batch['reg_mask'] > 0].contiguous()  # 只有有目标的像素才计算id loss
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch['ids'][batch['reg_mask'] > 0]  # 有目标的track id
                id_output = self.classifier.forward(id_head).contiguous()  # 用于检测目标分类的最后一层是FC?
                id_loss += self.IDLoss(id_output, id_target)
                # id_loss += self.IDLoss(id_output, id_target) + self.TriLoss(id_head, id_target)

        # loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss + opt.id_weight * id_loss

        det_loss = opt.hm_weight * hm_loss \
                   + opt.wh_weight * wh_loss \
                   + opt.off_weight * off_loss

        loss = torch.exp(-self.s_det) * det_loss \
               + torch.exp(-self.s_id) * id_loss \
               + (self.s_det + self.s_id)
        loss *= 0.5
        # print(loss, hm_loss, wh_loss, off_loss, id_loss)

        loss_stats = {'loss': loss,
                      'hm_loss': hm_loss,
                      'wh_loss': wh_loss,
                      'off_loss': off_loss,
                      'id_loss': id_loss}
        return loss, loss_stats
Exemplo n.º 3
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss = 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            if opt.eval_oracle_hm:
                output['hm'] = batch['hm']
            if opt.eval_oracle_wh:
                output['wh'] = torch.from_numpy(
                    gen_oracle_map(batch['wh'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['wh'].shape[3],
                                   output['wh'].shape[2])).to(opt.device)
            if opt.eval_oracle_offset:
                output['reg'] = torch.from_numpy(
                    gen_oracle_map(batch['reg'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   output['reg'].shape[3],
                                   output['reg'].shape[2])).to(opt.device)

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(
                        output['wh'] * batch['dense_wh_mask'],
                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / opt.num_stacks
                elif opt.cat_spec_wh:
                    wh_loss += self.crit_wh(
                        output['wh'], batch['cat_spec_mask'], batch['ind'],
                        batch['cat_spec_wh']) / opt.num_stacks
                else:
                    wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                             batch['ind'],
                                             batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

        loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
               opt.off_weight * off_loss
        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss
        }
        return loss, loss_stats
    def forward(self, outputs, batch):
        opt = self.opt

        hm_loss, dep_loss, rot_loss, dim_loss = 0, 0, 0, 0
        wh_loss, off_loss = 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            output['hm'] = _sigmoid(output['hm'])
            output['dep'] = 1. / (output['dep'].sigmoid() + 1e-6) - 1.

            if opt.eval_oracle_dep:
                output['dep'] = torch.from_numpy(
                    gen_oracle_map(batch['dep'].detach().cpu().numpy(),
                                   batch['ind'].detach().cpu().numpy(),
                                   opt.output_w, opt.output_h)).to(opt.device)

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.dep_weight > 0:
                dep_loss += self.crit_reg(output['dep'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['dep']) / opt.num_stacks
            if opt.dim_weight > 0:
                dim_loss += self.crit_reg(output['dim'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['dim']) / opt.num_stacks
            if opt.rot_weight > 0:
                rot_loss += self.crit_rot(output['rot'], batch['rot_mask'],
                                          batch['ind'], batch['rotbin'],
                                          batch['rotres']) / opt.num_stacks
            if opt.reg_bbox and opt.wh_weight > 0:
                wh_loss += self.crit_reg(output['wh'], batch['rot_mask'],
                                         batch['ind'],
                                         batch['wh']) / opt.num_stacks
            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['rot_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks
        loss = opt.hm_weight * hm_loss + opt.dep_weight * dep_loss + \
               opt.dim_weight * dim_loss + opt.rot_weight * rot_loss + \
               opt.wh_weight * wh_loss + opt.off_weight * off_loss

        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'dep_loss': dep_loss,
            'dim_loss': dim_loss,
            'rot_loss': rot_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss
        }
        return loss, loss_stats
Exemplo n.º 5
0
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss, id_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (
                                       self.crit_wh(output['wh'] * batch['dense_wh_mask'],
                                                    batch['dense_wh'] * batch['dense_wh_mask']) /
                                       mask_weight) / opt.num_stacks
                else:
                    wh_loss += self.crit_reg(
                        output['wh'], batch['reg_mask'],
                        batch['ind'], batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'], batch['reg']) / opt.num_stacks

            if opt.id_weight > 0:
                id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
                id_head = id_head[batch['reg_mask'] > 0].contiguous()
                id_head = self.emb_scale * F.normalize(id_head)
                id_target = batch['ids'][batch['reg_mask'] > 0]
                id_output = self.classifier(id_head).contiguous()
                id_loss += self.IDLoss(id_output, id_target)
                # id_loss += self.IDLoss(id_output, id_target) + self.TriLoss(id_head, id_target)

        # loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss + opt.id_weight * id_loss

        det_loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss

        loss = torch.exp(-self.s_det) * det_loss + torch.exp(-self.s_id) * id_loss + (self.s_det + self.s_id)
        loss *= 0.5

        # print(loss, hm_loss, wh_loss, off_loss, id_loss)

        loss_stats = {'loss': loss, 'hm_loss': hm_loss,
                      'wh_loss': wh_loss, 'off_loss': off_loss, 'id_loss': id_loss}
        return loss, loss_stats
Exemplo n.º 6
0
def _neg_seg_loss(pred, gt, mask):
    ''' Modified focal loss. Exactly the same as CornerNet.
      Runs faster and costs a little bit more memory
    Arguments:
      pred (batch x c x h x w)
      gt_regr (batch x c x h x w)
  '''
    pred = _sigmoid(pred)
    pos_inds = gt.eq(1).float()
    neg_inds = gt.lt(1).float()
    mask = mask.unsqueeze(-1).expand(pred.size()).float()
    pos_loss = torch.log(pred) * pos_inds
    neg_loss = torch.log(1 - pred) * neg_inds * mask

    num_sample = mask.sum()
    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()
    loss = 0 - (pos_loss * 1.5 + neg_loss) / num_sample
    return loss
Exemplo n.º 7
0
    def forward(self, outputs, batch):
        """
        :param outputs:
        :param batch:
        :return:
        """
        opt = self.opt

        # 初始化4个loss为0
        hm_loss, wh_loss, off_loss, reid_loss = 0.0, 0.0, 0.0, 0.0
        for s in range(opt.num_stacks):
            # ----- Detection loss
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            # --- heat-map loss
            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks

            # --- box width and height loss
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(output['wh'] * batch['dense_wh_mask'],
                                             batch['dense_wh'] * batch['dense_wh_mask']) / mask_weight) \
                               / opt.num_stacks
                else:  # box width and height using L1/Smooth L1 loss
                    wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                             batch['ind'],
                                             batch['wh']) / opt.num_stacks

            # --- bbox center offset loss
            if opt.reg_offset and opt.off_weight > 0:  # offset using L1 loss
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

            # ----- ReID loss: only process the class requiring ReID
            if opt.id_weight > 0:  # if ReID is needed
                cls_id_map = batch['cls_id_map']

                # 遍历每一个需要ReID的检测类别, 计算ReID损失
                for cls_id, id_num in self.nID_dict.items():
                    inds = torch.where(cls_id_map == cls_id)
                    if inds[0].shape[0] == 0:
                        # print('skip class id', cls_id)
                        continue

                    # --- 取cls_id对应索引处的特征向量
                    cls_id_head = output['id'][inds[0], :, inds[2], inds[3]]
                    cls_id_head = self.emb_scale_dict[cls_id] * F.normalize(
                        cls_id_head)  # n × emb_dim

                    # --- 获取target类别
                    cls_id_target = batch['cls_tr_ids'][inds[0], cls_id,
                                                        inds[2], inds[3]]

                    # ---分类结果
                    # 使用普通的全连接层
                    cls_id_output = self.classifiers[str(cls_id)].forward(
                        cls_id_head).contiguous()

                    # 使用Arc margin全连接层
                    # cls_id_output = self.classifiers[str(cls_id)].forward(cls_id_head, cls_id_target).contiguous()

                    # --- 累加每一个检测类别的ReID loss
                    # 选择一: 使用交叉熵优化ReID
                    reid_loss += self.IDLoss(cls_id_output, cls_id_target)

                    # 选择二: 使用Circle loss优化ReID
                    # reid_loss += self.circle_loss(*convert_label_to_similarity(cls_id_output, cls_id_target))

                    # 选择三: 使用triplet loss优化ReID
                    # reid_loss += self.IDLoss(cls_id_output, cls_id_target) + self.TriLoss(cls_id_head, cls_id_target)

        # loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss + opt.id_weight * id_loss

        det_loss = opt.hm_weight * hm_loss \
                   + opt.wh_weight * wh_loss \
                   + opt.off_weight * off_loss

        if opt.id_weight > 0:
            loss = torch.exp(-self.s_det) * det_loss \
                   + torch.exp(-self.s_id) * reid_loss \
                   + (self.s_det + self.s_id)
        else:
            loss = torch.exp(-self.s_det) * det_loss \
                   + self.s_det

        loss *= 0.5
        # print(loss, hm_loss, wh_loss, off_loss, id_loss)

        if opt.id_weight > 0:
            loss_stats = {
                'loss': loss,
                'hm_loss': hm_loss,
                'wh_loss': wh_loss,
                'off_loss': off_loss,
                'id_loss': reid_loss
            }
        else:
            loss_stats = {
                'loss': loss,
                'hm_loss': hm_loss,
                'wh_loss': wh_loss,
                'off_loss': off_loss
            }  # only exists det loss

        return loss, loss_stats
Exemplo n.º 8
0
    def forward(self, outputs, batch):
        """
        :param outputs:
        :param batch:
        :return:
        """
        opt = self.opt

        # 初始化4个loss为0
        hm_loss, wh_loss, off_loss, reid_loss = 0.0, 0.0, 0.0, 0.0
        for s in range(opt.num_stacks):
            output = outputs[s]
            if not opt.mse_loss:
                output['hm'] = _sigmoid(output['hm'])

            # 计算heat-map loss
            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.wh_weight > 0:
                if opt.dense_wh:
                    mask_weight = batch['dense_wh_mask'].sum() + 1e-4
                    wh_loss += (self.crit_wh(
                        output['wh'] * batch['dense_wh_mask'],
                        batch['dense_wh'] * batch['dense_wh_mask']) /
                                mask_weight) / opt.num_stacks
                else:  # 计算box尺寸的L1/Smooth L1 loss
                    wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                             batch['ind'],
                                             batch['wh']) / opt.num_stacks

            if opt.reg_offset and opt.off_weight > 0:  # 计算box中心坐标偏移的L1 loss
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks

            cls_id_map = batch['cls_id_map']

            # ----- ReID损失: 仅仅处理需要ReID的类别
            if opt.id_weight > 0:  # 如果需要训练ReID
                # 遍历每一个需要ReID的检测类别, 计算ReID损失
                for cls_id, id_num in self.nID_dict.items():
                    inds = torch.where(cls_id_map == cls_id)
                    if inds[0].shape[0] == 0:
                        # print('skip class id', cls_id)
                        continue

                    # --- 取cls_id对应索引处的特征向量
                    cls_id_head = output['id'][inds[0], :, inds[2], inds[3]]
                    cls_id_head = self.emb_scale_dict[cls_id] * F.normalize(
                        cls_id_head)  # n × emb_dim

                    # --- 获取target类别
                    cls_id_target = batch['cls_tr_ids'][inds[0], cls_id,
                                                        inds[2], inds[3]]

                    # ---分类结果
                    cls_id_output = self.classifiers[str(cls_id)].forward(
                        cls_id_head).contiguous()  # FC layer
                    # 使用Arc margin FC layer
                    # cls_id_output = self.classifiers[str(cls_id)].forward(cls_id_head, cls_id_target).contiguous()

                    # --- 累加每一个检测类别的ReID loss
                    # 使用交叉熵优化ReID
                    # reid_loss += self.IDLoss(cls_id_output, cls_id_target)

                    # 使用Circle loss优化ReID
                    reid_loss += self.circle_loss(*convert_label_to_similarity(
                        cls_id_output, cls_id_target))

                    # 使用triplet loss优化ReID
                    # reid_loss += self.IDLoss(id_output, id_target) + self.TriLoss(id_head, id_target)

        # loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss + opt.id_weight * id_loss

        det_loss = opt.hm_weight * hm_loss \
                   + opt.wh_weight * wh_loss \
                   + opt.off_weight * off_loss

        if opt.id_weight > 0:
            loss = torch.exp(-self.s_det) * det_loss \
                   + torch.exp(-self.s_id) * reid_loss \
                   + (self.s_det + self.s_id)
        else:
            loss = torch.exp(-self.s_det) * det_loss \
                   + (self.s_det + self.s_id)

        loss *= 0.5
        # print(loss, hm_loss, wh_loss, off_loss, id_loss)

        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss,
            'id_loss': reid_loss
        }
        return loss, loss_stats
Exemplo n.º 9
0
    def loss(self, stuff, cur_labels, ref_labels):
        affs, sims, qlts, sfeats, cur_feats = stuff
        # print(cur_labels.shape, ref_labels.shape)
        losses = {}
        # print(cur_labels[:10])
        # print(ref_labels[:20])
        sim_target = cur_labels.view(-1, 1) == ref_labels.view(1, -1)
        # print(sim_target.shape)
        # print(sims[0].shape)
        weight_1 = sim_target.new_zeros(sim_target.shape).int()
        weight_2 = sim_target.new_zeros(sim_target.shape).int()
        pos_cur = torch.nonzero(cur_labels > 0)
        pos_ref = torch.nonzero(ref_labels > 0)
        weight_1[pos_cur, :] = 1
        weight_2[:, pos_ref] = 1
        weight = (weight_1 + weight_2).flatten()
        pos_pos_inds = torch.nonzero(weight == 2)
        pos_neg_inds = torch.nonzero(weight == 1)
        neg_neg_inds = torch.nonzero(weight == 0)
        pos_pos_cnt = pos_pos_inds.size(0)
        pos_neg_cnt = pos_pos_cnt * 2
        neg_neg_cnt = pos_pos_cnt * 1
        rand_inds = torch.randperm(pos_neg_inds.size(0))
        rand_inds = pos_neg_inds[rand_inds[:pos_neg_cnt]]
        weight[rand_inds] = 2
        rand_inds = torch.randperm(neg_neg_inds.size(0))
        rand_inds = neg_neg_inds[rand_inds[:neg_neg_cnt]]
        weight[rand_inds] = 2

        # weight[:10, :10] = 1
        # weight[pos_cur, :10] = 1
        # weight[:10, pos_ref] = 1
        sim_inds = torch.nonzero(weight == 2)
        # print(sim_inds.shape)
        # print(sim_inds.shape, pos_cur.numel(), pos_ref.numel())
        sim_target = sim_target.reshape(-1).float()
        # print(sim_target.shape)
        # print(sim_inds.shape, weight.shape, sim_target.shape)
        sim_target = sim_target[sim_inds]
        # print(sim_inds[:10], sim_target.shape)
        # print(sims[0].flatten()[:10], sim_target[:10])
        # print(torch.nonzero(sim_target > 0).size(0), torch.nonzero(sim_target == 0).size(0), pos_pos_cnt, weight.numel())
        # print(self.rel.stages)
        for i in range(self.rel.stages):
            sim = sims[i].sum(dim=1).reshape(-1)[sim_inds]
            sim = _sigmoid(sim)
            qlt = _sigmoid(qlts[i])
            sfeat = sfeats[i]
            # print(sfeat.device)
            # print(self.pre_cls[0].weight.device)
            pre_cls_pred = _sigmoid(self.pre_cls(sfeat))
            # print(pre_cls_pred[:10])
            # print(cls_target[:10])
            losses['pre_cls.%d.loss' % i] = self.binary_loss(
                pre_cls_pred.flatten(),
                cur_labels.flatten().float()) / self.rel.stages * 0.5
            qlt_target = (1 - torch.abs(cur_labels.flatten().float() -
                                        pre_cls_pred.flatten())).detach()
            # print(qlt_target[:10])
            losses['qlt.%d.loss' % i] = self.binary_loss(
                qlt.flatten(), qlt_target.flatten()) / self.rel.stages * 0.5
            losses['sim.%d.loss' % i] = self.binary_loss(
                sim.flatten(), sim_target.flatten()) / self.rel.stages
        # print(losses['sim.%d.loss' % 0])
        # input()
        loss = 0
        for v in losses.values():
            loss += v
        loss = loss * 0.5
        # losses['loss_rel'] = loss
        # for k in list(losses.keys()):
        #     losses[k] = float(losses[k])
        return loss, losses
    def forward(self, outputs, batch):
        opt = self.opt
        hm_loss, wh_loss, off_loss = 0, 0, 0
        hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0
        for s in range(opt.num_stacks):
            output = outputs[s]
            output['hm'] = _sigmoid(output['hm'])
            if opt.hm_hp and not opt.mse_loss:
                output['hm_hp'] = _sigmoid(output['hm_hp'])

            if opt.eval_oracle_hmhp:
                output['hm_hp'] = batch['hm_hp']
            if opt.eval_oracle_hm:
                output['hm'] = batch['hm']
            if opt.eval_oracle_kps:
                if opt.dense_hp:
                    output['hps'] = batch['dense_hps']
                else:
                    output['hps'] = torch.from_numpy(
                        gen_oracle_map(batch['hps'].detach().cpu().numpy(),
                                       batch['ind'].detach().cpu().numpy(),
                                       opt.output_res,
                                       opt.output_res)).to(opt.device)
            if opt.eval_oracle_hp_offset:
                output['hp_offset'] = torch.from_numpy(
                    gen_oracle_map(batch['hp_offset'].detach().cpu().numpy(),
                                   batch['hp_ind'].detach().cpu().numpy(),
                                   opt.output_res,
                                   opt.output_res)).to(opt.device)

            hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
            if opt.dense_hp:
                mask_weight = batch['dense_hps_mask'].sum() + 1e-4
                hp_loss += (self.crit_kp(
                    output['hps'] * batch['dense_hps_mask'], batch['dense_hps']
                    * batch['dense_hps_mask']) / mask_weight) / opt.num_stacks
            else:
                hp_loss += self.crit_kp(output['hps'], batch['hps_mask'],
                                        batch['ind'],
                                        batch['hps']) / opt.num_stacks
            if opt.wh_weight > 0:
                wh_loss += self.crit_reg(output['wh'], batch['reg_mask'],
                                         batch['ind'],
                                         batch['wh']) / opt.num_stacks
            if opt.reg_offset and opt.off_weight > 0:
                off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                                          batch['ind'],
                                          batch['reg']) / opt.num_stacks
            if opt.reg_hp_offset and opt.off_weight > 0:
                hp_offset_loss += self.crit_reg(
                    output['hp_offset'], batch['hp_mask'], batch['hp_ind'],
                    batch['hp_offset']) / opt.num_stacks
            if opt.hm_hp and opt.hm_hp_weight > 0:
                hm_hp_loss += self.crit_hm_hp(output['hm_hp'],
                                              batch['hm_hp']) / opt.num_stacks
        loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + \
               opt.off_weight * off_loss + opt.hp_weight * hp_loss + \
               opt.hm_hp_weight * hm_hp_loss + opt.off_weight * hp_offset_loss

        loss_stats = {
            'loss': loss,
            'hm_loss': hm_loss,
            'hp_loss': hp_loss,
            'hm_hp_loss': hm_hp_loss,
            'hp_offset_loss': hp_offset_loss,
            'wh_loss': wh_loss,
            'off_loss': off_loss
        }
        return loss, loss_stats