Ejemplo n.º 1
0
    def losses(self, outputs):
        logits, feats, targets = outputs
        loss_dict = {}
        loss_dict.update(
            reid_losses(self._cfg, logits[0], feats[0], targets, 'b1_'))
        loss_dict.update(
            reid_losses(self._cfg, logits[1], feats[1], targets, 'b2_'))
        loss_dict.update(
            reid_losses(self._cfg, logits[2], feats[2], targets, 'b3_'))
        loss_dict.update(
            reid_losses(self._cfg, logits[3], feats[3], targets, 'b21_'))
        loss_dict.update(
            reid_losses(self._cfg, logits[5], feats[4], targets, 'b31_'))

        part_ce_loss = [(CrossEntropyLoss(self._cfg)(logits[4], None,
                                                     targets), 'b22_'),
                        (CrossEntropyLoss(self._cfg)(logits[6], None,
                                                     targets), 'b32_'),
                        (CrossEntropyLoss(self._cfg)(logits[7], None,
                                                     targets), 'b33_')]
        named_ce_loss = {}
        for item in part_ce_loss:
            named_ce_loss[item[1] + [*item[0]][0]] = [*item[0].values()][0]
        loss_dict.update(named_ce_loss)
        return loss_dict
Ejemplo n.º 2
0
    def losses(self, outputs, gt_labels):
        b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits, \
            b1_pool_feat, b2_pool_feat, b3_pool_feat, b22_pool_feat, b33_pool_feat, pred_class_logits = outputs

        loss_dict = {}
        loss_names = self._cfg.MODEL.LOSSES.NAME

        # Log prediction accuracy
        if "CrossEntropyLoss" in loss_names:
            loss_dict['loss_cls_b1'] = CrossEntropyLoss(self._cfg)(b1_logits,
                                                                   gt_labels)
            loss_dict['loss_cls_b2'] = CrossEntropyLoss(self._cfg)(b2_logits,
                                                                   gt_labels)
            loss_dict['loss_cls_b21'] = CrossEntropyLoss(self._cfg)(b21_logits,
                                                                    gt_labels)
            loss_dict['loss_cls_b22'] = CrossEntropyLoss(self._cfg)(b22_logits,
                                                                    gt_labels)
            loss_dict['loss_cls_b3'] = CrossEntropyLoss(self._cfg)(b3_logits,
                                                                   gt_labels)
            loss_dict['loss_cls_b31'] = CrossEntropyLoss(self._cfg)(b31_logits,
                                                                    gt_labels)
            loss_dict['loss_cls_b32'] = CrossEntropyLoss(self._cfg)(b32_logits,
                                                                    gt_labels)
            loss_dict['loss_cls_b33'] = CrossEntropyLoss(self._cfg)(b33_logits,
                                                                    gt_labels)
            CrossEntropyLoss.log_accuracy(pred_class_logits.detach(),
                                          gt_labels)

        if "TripletLoss" in loss_names:
            loss_dict['loss_triplet_b1'] = TripletLoss(self._cfg)(b1_pool_feat,
                                                                  gt_labels)
            loss_dict['loss_triplet_b2'] = TripletLoss(self._cfg)(b2_pool_feat,
                                                                  gt_labels)
            loss_dict['loss_triplet_b3'] = TripletLoss(self._cfg)(b3_pool_feat,
                                                                  gt_labels)
            loss_dict['loss_triplet_b22'] = TripletLoss(self._cfg)(
                b22_pool_feat, gt_labels)
            loss_dict['loss_triplet_b33'] = TripletLoss(self._cfg)(
                b33_pool_feat, gt_labels)

        if "NpairLoss" in loss_names:
            loss_dict['loss_npair_b1'] = NpairLoss(self._cfg)(b1_pool_feat,
                                                              gt_labels)
            loss_dict['loss_npair_b2'] = NpairLoss(self._cfg)(b2_pool_feat,
                                                              gt_labels)
            loss_dict['loss_npair_b3'] = NpairLoss(self._cfg)(b3_pool_feat,
                                                              gt_labels)
            loss_dict['loss_npair_b22'] = NpairLoss(self._cfg)(b22_pool_feat,
                                                               gt_labels)
            loss_dict['loss_npair_b33'] = NpairLoss(self._cfg)(b33_pool_feat,
                                                               gt_labels)
        return loss_dict
Ejemplo n.º 3
0
    def losses(self, b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits,
               b1_pool_feat, b2_pool_feat, b3_pool_feat, b22_pool_feat, b33_pool_feat, gt_labels):
        loss_dict = {}
        loss_names = self._cfg.MODEL.LOSSES.NAME

        if "CrossEntropyLoss" in loss_names:
            loss_dict['loss_cls_b1'] = CrossEntropyLoss(self._cfg)(b1_logits, gt_labels)
            loss_dict['loss_cls_b2'] = CrossEntropyLoss(self._cfg)(b2_logits, gt_labels)
            loss_dict['loss_cls_b21'] = CrossEntropyLoss(self._cfg)(b21_logits, gt_labels)
            loss_dict['loss_cls_b22'] = CrossEntropyLoss(self._cfg)(b22_logits, gt_labels)
            loss_dict['loss_cls_b3'] = CrossEntropyLoss(self._cfg)(b3_logits, gt_labels)
            loss_dict['loss_cls_b31'] = CrossEntropyLoss(self._cfg)(b31_logits, gt_labels)
            loss_dict['loss_cls_b32'] = CrossEntropyLoss(self._cfg)(b32_logits, gt_labels)
            loss_dict['loss_cls_b33'] = CrossEntropyLoss(self._cfg)(b33_logits, gt_labels)

        if "TripletLoss" in loss_names:
            loss_dict['loss_triplet_b1'] = TripletLoss(self._cfg)(b1_pool_feat, gt_labels)
            loss_dict['loss_triplet_b2'] = TripletLoss(self._cfg)(b2_pool_feat, gt_labels)
            loss_dict['loss_triplet_b3'] = TripletLoss(self._cfg)(b3_pool_feat, gt_labels)
            loss_dict['loss_triplet_b22'] = TripletLoss(self._cfg)(b22_pool_feat, gt_labels)
            loss_dict['loss_triplet_b33'] = TripletLoss(self._cfg)(b33_pool_feat, gt_labels)

        return loss_dict