Example #1
0
 def __init__(self, opt):
     super(CtdetLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.crit_angle = RegL1Loss()
     self.opt = opt
Example #2
0
    def __init__(self, opt):
        super(CtdetLoss, self).__init__()
        self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
                  RegLoss() if opt.reg_loss == 'sl1' else None

        self.opt = opt
Example #3
0
 def __init__(self, opt):
     super(MotLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = (
         RegL1Loss()
         if opt.reg_loss == "l1"
         else RegLoss()
         if opt.reg_loss == "sl1"
         else None
     )
     self.crit_wh = (
         torch.nn.L1Loss(reduction="sum")
         if opt.dense_wh
         else NormRegL1Loss()
         if opt.norm_wh
         else RegWeightedL1Loss()
         if opt.cat_spec_wh
         else self.crit_reg
     )
     self.opt = opt
     self.emb_dim = opt.reid_dim
     self.nID = opt.nID
     self.classifier = nn.Linear(self.emb_dim, self.nID)
     self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)
     # self.TriLoss = TripletLoss()
     self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)
     self.s_det = nn.Parameter(-1.85 * torch.ones(1))
     self.s_id = nn.Parameter(-1.05 * torch.ones(1))
Example #4
0
    def __init__(self, opt):
        super(MotLoss, self).__init__()
        self.crit = paddle.nn.MSELoss() if opt.mse_loss else FocalLoss()
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
            RegLoss() if opt.reg_loss == 'sl1' else None
        self.crit_wh = paddle.nn.L1Loss(reduction='sum') if opt.dense_wh else \
            NormRegL1Loss() if opt.norm_wh else \
                RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
        self.opt = opt
        self.emb_dim = opt.reid_dim
        self.nID = opt.nID

        # param_attr = paddle.ParamAttr(initializer=KaimingUniform())
        # bound = 1 / math.sqrt(self.emb_dim)
        # bias_attr = paddle.ParamAttr(initializer=Uniform(-bound, bound))
        # self.classifier = nn.Linear(self.emb_dim, self.nID, weight_attr=param_attr, bias_attr=bias_attr)
        self.classifier = nn.Linear(self.emb_dim, self.nID, bias_attr=True)
        if opt.id_loss == 'focal': # 一般用不到
            # torch.nn.init.normal_(self.classifier.weight, std=0.01)
            prior_prob = 0.01
            bias_value = -math.log((1 - prior_prob) / prior_prob)
            # torch.nn.init.constant_(self.classifier.bias, bias_value)

            weight_attr = paddle.framework.ParamAttr(initializer=nn.initializer.Normal(std=0.01))
            bias_attr = paddle.framework.ParamAttr(initializer=nn.initializer.Constant(bias_value))
            self.classifier = nn.Linear(self.emb_dim, self.nID, weight_attr=weight_attr, bias_attr=bias_attr)
        self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)
        self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)
        # self.s_det = nn.Parameter(-1.85 * torch.ones(1))
        # self.s_id = nn.Parameter(-1.05 * torch.ones(1))
        self.s_det = paddle.create_parameter([1], dtype='float32', default_initializer = nn.initializer.Constant(value=-1.85))
        self.s_id = paddle.create_parameter([1], dtype='float32', default_initializer = nn.initializer.Constant(value=-1.05))
Example #5
0
 def __init__(self, opt):
     super(PoseLoss, self).__init__()
     self.crit_hm_hp = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_kp = RegWeightedL1Loss() if not opt.dense_hp else \
         torch.nn.L1Loss(reduction='sum')
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.opt = opt
 def __init__(self, opt):
     super(CenterLandmarkLoss, self).__init__()
     self.crit = FocalLoss()
     self.crit_hm_hp = nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_kp = RegWeightedL1Loss() if not opt.dense_hp else nn.L1Loss(
         reduction='sum')  #####Why sum????
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else RegLoss(
     ) if opt.reg_loss == 'sl1' else None
     self.opt = opt
Example #7
0
 def __init__(self, cfg, local_rank):
     super(MultiPoseLoss, self).__init__()
     self.crit = FocalLoss()
     self.crit_hm_hp = torch.nn.MSELoss() if cfg.LOSS.MSE_LOSS else FocalLoss()
     self.crit_kp = RegWeightedL1Loss() if not cfg.LOSS.DENSE_HP else \
                    torch.nn.L1Loss(reduction='sum')
     self.crit_reg = RegL1Loss() if cfg.LOSS.REG_LOSS == 'l1' else \
                     RegLoss() if cfg.LOSS.REG_LOSS == 'sl1' else None                       
     self.cfg = cfg
     self.local_rank = local_rank
Example #8
0
 def __init__(self, opt):
     super(DetLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.opt = opt
     #self.TriLoss = TripletLoss()
     self.s_det = nn.Parameter(-1.85 * torch.ones(1))
Example #9
0
 def __init__(self, opt):
     super(MSPCtdetLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_centerness = FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
               RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
               NormRegL1Loss() if opt.norm_wh else \
               RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.crit_scale = torch.nn.SmoothL1Loss(size_average=False)
     self.opt = opt
Example #10
0
 def __init__(self, opt):
     super(CtdetLoss_doublehm, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.crit_reconstruct_img = torch.nn.L1Loss()
     self.opt = opt
     self.test_reconstruct_conflict_with_class = True
Example #11
0
 def __init__(self, opt):
     super(FvdetLoss, self).__init__()
     self.crit = FocalLoss()
     #self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
     #RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_reg = RegL1Loss()
     #self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
     #NormRegL1Loss() if opt.norm_wh else \
     #RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.crit_dep_iou = RegL1LossDepIOU()
     self.opt = opt
Example #12
0
 def __init__(self, opt):
     super(CtdetLossSpotNetVid, self).__init__()
     self.crit_seg = torch.nn.BCEWithLogitsLoss(
     )  # torch.nn.MSELoss()  # hughes
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
               RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
               NormRegL1Loss() if opt.norm_wh else \
               RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.opt = opt
Example #13
0
    def __init__(self, opt):
        super(CtdetLoss, self).__init__()
        self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
                  RegLoss() if opt.reg_loss == 'sl1' else None
        self.crit_wh = self.crit_reg

        if opt.direct_loss == 'cls':
            self.crit_direct = CrossEntropyLoss()
        else:
            self.crit_direct = RegL1Loss()
        self.opt = opt

        if opt.loss_hm_magnitude:
            if opt.loss_hm_magnitude_pos_only:
                self.crit_grad_magnitude = FocalLossMagnitudePosOnly()
            elif opt.loss_hm_magnitude_neg_only:
                self.crit_grad_magnitude = FocalLossMagnitudeNegOnly()
            else:
                self.crit_grad_magnitude = FocalLossMagnitude()
Example #14
0
 def __init__(self, opt):
     super(CtsegLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     ##########
     self.crit_mask = FastDiceLoss(opt.seg_feat_channel)
     self.opt = opt
Example #15
0
    def __init__(self, opt):
        super(CircleLoss, self).__init__()
        # if opt.mask_focal_loss:
        self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
        # else:
        #   self.crit = FocalLoss_mask()

        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
                  RegLoss() if opt.reg_loss == 'sl1' else None
        self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
                  NormRegL1Loss() if opt.norm_wh else \
                  RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
        self.opt = opt
Example #16
0
 def __init__(self, opt):
   super(MultiPoseLoss, self).__init__()
   self.crit = FocalLoss()
   self.crit_hm_hp = FocalLoss()
   if opt.mdn:
     self.crit_kp = th_mdn_loss_dense if opt.dense_hp else \
                     th_mdn_loss_ind
   else:
     self.crit_kp = torch.nn.L1Loss(reduction='sum')  if opt.dense_hp else \
                     RegWeightedL1Loss()
   self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
                     RegLoss() if opt.reg_loss == 'sl1' else None
   self.opt = opt
Example #17
0
    def __init__(self, opt):
        super(MultiPoseLoss, self).__init__()
        self.crit = FocalLoss()
        self.crit_hm_hp = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
        self.crit_kp = RegWeightedL1Loss() if not opt.dense_hp else \
                       torch.nn.L1Loss(reduction='sum')
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
                        RegLoss() if opt.reg_loss == 'sl1' else None

        ##changed
        self.crit_view_side = CrossEntropyLossWMask()
        self.crit_view_front_rear = CrossEntropyLossWMask()

        self.opt = opt
Example #18
0
    def __init__(self, opt):
        super(MultiKPSLoss, self).__init__()
        self.crit = FocalLoss()  #中心点hp
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
            RegLoss() if opt.reg_loss == 'sl1' else None  #中心点回归
        self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
            NormRegL1Loss() if opt.norm_wh else \
            RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg  #宽高回归
        self.crit_hm_hp = torch.nn.MSELoss() if opt.mse_loss else FocalLoss(
        )  #关键点hp
        self.crit_kp = RegWeightedL1Loss() if not opt.dense_hp else \
            torch.nn.L1Loss(reduction='sum')                                   #关键点回归

        self.opt = opt
Example #19
0
 def __init__(self, opt):
     super(HoidetLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     if opt.hard_negative:
         self.crit_h = FocalLossHardNeg(7)
     if opt.hm_rel_dcn25_i_casc_match or opt.hm_rel_dcn25_i_match:
         self.crit_reg_offset = RegL1LossOffset()
         self.bce = torch.nn.BCELoss()
     self.opt = opt
Example #20
0
 def __init__(self, opt):
     super(MultiPoseLoss, self).__init__()
     self.crit = FocalLoss()
     self.crit_hm_hp = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_kp = RegWeightedL1Loss() if not opt.dense_hp else \
                    torch.nn.L1Loss(reduction='sum')
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
                     RegLoss() if opt.reg_loss == 'sl1' else None
     self.opt = opt
     self.loss_dict = {}
     self.loss_dict['count'] = 0
     self.loss_dict['pos_loss'] = 0
     self.loss_dict['neg_loss'] = 0
     self.loss_dict['neg_loss1'] = 0
     self.loss_dict['num_pos'] = 0
Example #21
0
    def __init__(self, opt):
        super(MultiPoseLoss, self).__init__()
        self.crit = FocalLoss()
        self.crit_hm_hp = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
        self.crit_kp = RegWeightedL1Loss() if not opt.dense_hp else \
            torch.nn.L1Loss(reduction='sum')
        # self.t_crit_kp = t_RegWeightedL1Loss()
        #
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
            RegLoss() if opt.reg_loss == 'sl1' else None
        # self.t_crit_reg = t_RegL1Loss() if opt.reg_loss == 'l1' else \
        #     t_RegLoss() if opt.reg_loss == 'sl1' else None

        self.crit_teacher = torch.nn.MSELoss(size_average=True)
        self.opt = opt
Example #22
0
 def __init__(self, opt):
     super(CtdetLoss, self).__init__()
     # 对于2D目标检测,分类损失使用的是Focal Loss. 这里opt.mse_loss默认为False
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     # 对于2D目标检测,回归损失使用的是L1 Loss. 这里opt.reg_loss默认为l1
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
               RegLoss() if opt.reg_loss == 'sl1' else None
     # 对于2D目标检测,偏移损失使用的L1 Loss. 这里opt.dense_wh默认为False
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
               NormRegL1Loss() if opt.norm_wh else \
               RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     # self.crit = torch.nn.MSELoss()
     # self.crit_reg = RegL1Loss()
     # self.crit_wh = torch.nn.L1Loss(reduction='sum')
     self.opt = opt
Example #23
0
    def __init__(self, opt):
        super(MultiPoseLoss, self).__init__()
        self.crit = FocalLoss()
        self.crit_hm_hp = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
        self.crit_kp = RegWeightedL1Loss() if not opt.dense_hp else \
                       torch.nn.L1Loss(reduction='sum')
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
                        RegLoss() if opt.reg_loss == 'sl1' else None

        # 3d compatibility
        self.crit = CritAuto(self.crit)
        self.crit_hm_hp = CritAuto(self.crit_hm_hp)
        self.crit_kp = CritAuto(self.crit_kp)
        self.crit_reg = CritAuto(self.crit_reg)

        self.opt = opt
Example #24
0
 def __init__(self, opt):
     super(MotLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()             # 分类loss   中心点x,y offset loss   w,h回归loss
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.opt = opt
     self.emb_dim = opt.reid_dim                                                 # reid特征长度
     self.nID = opt.nID                                                          # 所有目标的ID数
     self.classifier = nn.Linear(self.emb_dim, self.nID)
     self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)                          # Re-ID loss
     #self.TriLoss = TripletLoss()
     self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)
     self.s_det = nn.Parameter(-1.85 * torch.ones(1))
     self.s_id = nn.Parameter(-1.05 * torch.ones(1))
Example #25
0
    def __init__(self, opt):
        super(CtdetLoss, self).__init__()

        # define loss here

        # hm: modified focal loss
        self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
        # smoothl1
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' \
            else RegLoss() if opt.reg_loss == 'sl1' \
            else None
        # l1
        self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh \
            else NormRegL1Loss() if opt.norm_wh \
            else RegWeightedL1Loss() if opt.cat_spec_wh \
            else self.crit_reg
        self.opt = opt
Example #26
0
 def __init__(self, opt):
   super(CtdetLoss_NFS, self).__init__()
   self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
   # self.centerloss = CenterLoss_gt()
   if opt.eq1:
     print("eq1 is True")
     self.centerloss = CenterLoss_gt_eq1_cuda()
   else:
     print("eq1 is False")
     self.centerloss = CenterLoss_gt_cuda()
   # self.crit = torch.nn.MSELoss()
   self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
             RegLoss() if opt.reg_loss == 'sl1' else None
   self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
             NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
   self.opt = opt
Example #27
0
 def __init__(self, opt):
     super(MotLossWithEdgeRegression, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.crit_edge = torch.nn.BCEWithLogitsLoss()
     self.opt = opt
     self.emb_dim = opt.reid_dim
     self.nID = opt.nID
     self.classifier = nn.Linear(self.emb_dim, self.nID)
     self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)
     #self.TriLoss = TripletLoss()
     self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)
     self.s_det = nn.Parameter(-1.85 * torch.ones(1))
     self.s_id = nn.Parameter(-1.05 * torch.ones(1))
Example #28
0
 def __init__(self, opt):
     super(MotLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.crit_density_focal = FocalLoss()
     self.crit_density_ssim = NORMMSSSIM()
     self.crit_count = torch.nn.MSELoss()
     self.opt = opt
     self.emb_dim = opt.reid_dim
     self.nID = opt.nID
     self.classifier = nn.Linear(self.emb_dim, self.nID)
     self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)
     self.emb_scale = math.sqrt(2) * math.log(self.nID -
                                              1) if self.nID != 1 else 0
     self.s_det = nn.Parameter(-1.85 * torch.ones(1))
     self.s_id = nn.Parameter(-1.05 * torch.ones(1))
Example #29
0
 def __init__(self, opt):
     super(MotLoss, self).__init__()
     self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
     self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
         RegLoss() if opt.reg_loss == 'sl1' else None
     self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
         NormRegL1Loss() if opt.norm_wh else \
             RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
     self.opt = opt
     self.emb_dim = opt.reid_dim
     self.nID = opt.nID
     self.classifier = nn.Linear(self.emb_dim, self.nID)
     if opt.id_loss == 'focal':
         torch.nn.init.normal_(self.classifier.weight, std=0.01)
         prior_prob = 0.01
         bias_value = -math.log((1 - prior_prob) / prior_prob)
         torch.nn.init.constant_(self.classifier.bias, bias_value)
     self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)
     self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)
     self.s_det = nn.Parameter(-1.85 * torch.ones(1))
     self.s_id = nn.Parameter(-1.05 * torch.ones(1))
Example #30
0
    def __init__(self, opt, loss_states):
        super(MotLoss, self).__init__()
        self.opt = opt
        self.loss_states = loss_states
        self.emb_dim = opt.reid_dim
        self.nID = opt.nID

        # Loss for heatmap
        self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()

        # Loss for offsets
        self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
            RegLoss() if opt.reg_loss == 'sl1' else None

        # Loss for object sizes
        self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
            NormRegL1Loss() if opt.norm_wh else \
                RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg

        # Supervised loss for object IDs
        self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)

        # FC layer for supervised object ID prediction
        self.classifier = nn.Linear(self.emb_dim, self.nID)

        # Self supervised loss for object embeddings
        self.SelfSupLoss = NTXentLoss(opt.device, 0.5) if opt.unsup_loss == 'nt_xent' else \
            TripletLoss(opt.device, 'batch_all', 0.5) if opt.unsup_loss == 'triplet_all' else \
                TripletLoss(opt.device, 'batch_hard', 0.5) if opt.unsup_loss == 'triplet_hard' else None

        if opt.unsup and self.SelfSupLoss is None:
            raise ValueError('{} is not a supported self-supervised loss. '.format(opt.unsup_loss) + \
                             'Choose nt_xent, triplet_all, or triplet_hard')

        self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)
        self.s_det = nn.Parameter(-1.85 * torch.ones(1))
        self.s_id = nn.Parameter(-1.05 * torch.ones(1))