예제 #1
0
파일: hoidet.py 프로젝트: samshin7/PPDM
 def __init__(self, opt):
     super(HoidetLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.opt = opt
예제 #2
0
 def __init__(self, opt):
     super(CtsegLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     ##########
     self.crit_mask = FastDiceLoss(opt.seg_feat_channel)
     self.opt = opt
예제 #3
0
 def __init__(self, opt):
     super(CtdetLoss_doublehm, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.crit_reconstruct_img = torch.nn.L1Loss()
     self.opt = opt
     self.test_reconstruct_conflict_with_class = True
예제 #4
0
 def __init__(self, opt):
     super(CtdetLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_centerness = FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
       RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
       NormRegL1Loss() if opt.norm_wh else \
         RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.crit_scale = torch.nn.SmoothL1Loss(size_average=False)
     self.opt = opt
예제 #5
0
 def __init__(self, opt):
     super(CtdetLossSpotNetVid, self).__init__()
     self.crit_seg = torch.nn.BCEWithLogitsLoss(
     )  # torch.nn.MSELoss()  # hughes
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
               RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
               NormRegL1Loss() if opt.norm_wh else \
               RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.opt = opt
예제 #6
0
 def __init__(self, opt):
     super(DetLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.opt = opt
     #self.TriLoss = TripletLoss()
     self.s_det = nn.Parameter(-1.85 * torch.ones(1))
예제 #7
0
 def __init__(self, opt):
     super(CtdetLoss, self).__init__()
     self.crit = FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
               RegLoss() if opt.reg_loss == 'sl1' else None
     if opt.mdn:
         self.crit_wh = None if opt.dense_wh else None if opt.norm_wh else th_mdn_loss_ind
     else:
         self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
                 NormRegL1Loss() if opt.norm_wh else \
                 RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.opt = opt
예제 #8
0
 def __init__(self, opt):
   super(CtdetLoss, self).__init__()
   ###均方误差,l2 loss
   self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
   ### 定义了几种回归损失的形式
   self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
             RegLoss() if opt.reg_loss == 'sl1' else None
   ### 定义了几种wh损失函数的形式
   self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
             NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
   self.opt = opt
예제 #9
0
 def __init__(self, opt):
     super(HoidetLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     if opt.hard_negative:
         self.crit_h = FocalLossHardNeg(7)
     if opt.hm_rel_dcn25_i_casc_match or opt.hm_rel_dcn25_i_match:
         self.crit_reg_offset = RegL1LossOffset()
         self.bce = torch.nn.BCELoss()
     self.opt = opt
예제 #10
0
    def __init__(self, opt):
        super(MultiKPSLoss, self).__init__()
        self.crit = FocalLoss()  #中心点hp
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
            RegLoss() if opt.reg_loss == 'sl1' else None  #中心点回归
        self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
            NormRegL1Loss() if opt.norm_wh else \
            RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg  #宽高回归
        self.crit_hm_hp = torch.nn.MSELoss() if opt.mse_loss else FocalLoss(
        )  #关键点hp
        self.crit_kp = RegWeightedL1Loss() if not opt.dense_hp else \
            torch.nn.L1Loss(reduction='sum')                                   #关键点回归

        self.opt = opt
예제 #11
0
 def __init__(self, opt):
     super(CtdetLoss, self).__init__()
     # 对于2D目标检测,分类损失使用的是Focal Loss. 这里opt.mse_loss默认为False
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     # 对于2D目标检测,回归损失使用的是L1 Loss. 这里opt.reg_loss默认为l1
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
               RegLoss() if opt.reg_loss == 'sl1' else None
     # 对于2D目标检测,偏移损失使用的L1 Loss. 这里opt.dense_wh默认为False
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
               NormRegL1Loss() if opt.norm_wh else \
               RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     # self.crit = torch.nn.MSELoss()
     # self.crit_reg = RegL1Loss()
     # self.crit_wh = torch.nn.L1Loss(reduction='sum')
     self.opt = opt
예제 #12
0
 def __init__(self, opt):
     super(MotLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.opt = opt
     self.emb_dim = opt.reid_dim
     self.nID = opt.nID
     self.classifier = nn.Linear(self.emb_dim, self.nID)
     self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)
     self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)
     self.s_det = nn.Parameter(-1.85 * torch.ones(1))
     self.s_id = nn.Parameter(-1.05 * torch.ones(1))
예제 #13
0
 def __init__(self, opt):
   super(CtdetLoss_NFS, self).__init__()
   self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
   # self.centerloss = CenterLoss_gt()
   if opt.eq1:
     print("eq1 is True")
     self.centerloss = CenterLoss_gt_eq1_cuda()
   else:
     print("eq1 is False")
     self.centerloss = CenterLoss_gt_cuda()
   # self.crit = torch.nn.MSELoss()
   self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
             RegLoss() if opt.reg_loss == 'sl1' else None
   self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
             NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
   self.opt = opt
예제 #14
0
    def __init__(self, opt):
        super(CtdetLoss, self).__init__()

        # define loss here

        # hm: modified focal loss
        self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
        # smoothl1
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' \
            else RegLoss() if opt.reg_loss == 'sl1' \
            else None
        # l1
        self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh \
            else NormRegL1Loss() if opt.norm_wh \
            else RegWeightedL1Loss() if opt.cat_spec_wh \
            else self.crit_reg
        self.opt = opt
예제 #15
0
 def __init__(self, opt):
     super(MotLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.opt = opt
     self.emb_dim = opt.reid_dim
     self.nID = opt.nID
     self.classifier = nn.Linear(self.emb_dim, self.nID)
     if opt.id_loss == 'focal':
         torch.nn.init.normal_(self.classifier.weight, std=0.01)
         prior_prob = 0.01
         bias_value = -math.log((1 - prior_prob) / prior_prob)
         torch.nn.init.constant_(self.classifier.bias, bias_value)
     self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)
     self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)
     self.s_det = nn.Parameter(-1.85 * torch.ones(1))
     self.s_id = nn.Parameter(-1.05 * torch.ones(1))
예제 #16
0
    def __init__(self, opt, loss_states):
        super(MotLoss, self).__init__()
        self.opt = opt
        self.loss_states = loss_states
        self.emb_dim = opt.reid_dim
        self.nID = opt.nID

        # Loss for heatmap
        self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()

        # Loss for offsets
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
            RegLoss() if opt.reg_loss == 'sl1' else None

        # Loss for object sizes
        self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
            NormRegL1Loss() if opt.norm_wh else \
                RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg

        # Supervised loss for object IDs
        self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)

        # FC layer for supervised object ID prediction
        self.classifier = nn.Linear(self.emb_dim, self.nID)

        # Self supervised loss for object embeddings
        self.SelfSupLoss = NTXentLoss(opt.device, 0.5) if opt.unsup_loss == 'nt_xent' else \
            TripletLoss(opt.device, 'batch_all', 0.5) if opt.unsup_loss == 'triplet_all' else \
                TripletLoss(opt.device, 'batch_hard', 0.5) if opt.unsup_loss == 'triplet_hard' else None

        if opt.unsup and self.SelfSupLoss is None:
            raise ValueError('{} is not a supported self-supervised loss. '.format(opt.unsup_loss) + \
                             'Choose nt_xent, triplet_all, or triplet_hard')

        self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)
        self.s_det = nn.Parameter(-1.85 * torch.ones(1))
        self.s_id = nn.Parameter(-1.05 * torch.ones(1))