Example #1
0
def make_rpn_loss_evaluator(cfg, box_coder):
    matcher = Matcher(
        cfg.MODEL.RPN.FG_IOU_THRESHOLD,
        cfg.MODEL.RPN.BG_IOU_THRESHOLD,
        allow_low_quality_matches=True,
    )

    fg_bg_sampler = BalancedPositiveNegativeSampler(
        cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, cfg.MODEL.RPN.POSITIVE_FRACTION)

    if cfg.MODEL.RPN.USE_WING_LOSS:
        wing_loss = WingLoss(
            width=cfg.MODEL.RPN.WING_LOSS.WIDTH,
            curvature=cfg.MODEL.RPN.WING_LOSS.SIGMA,
        )
    else:
        wing_loss = None

    if cfg.MODEL.RPN.USE_FOCAL_LOSS:
        focal_loss = SigmoidFocalLoss(
            cfg.MODEL.RPN.FOCAL_LOSS.GAMMA,
            cfg.MODEL.RPN.FOCAL_LOSS.ALPHA,
        )
    else:
        focal_loss = None

    if cfg.MODEL.RPN.USE_SELF_ADJUST_SMOOTH_L1_LOSS:
        adjust_smooth_l1_loss = AdjustSmoothL1Loss(
            4, beta=cfg.MODEL.RPN.SELF_ADJUST_SMOOTH_L1_LOSS.BBOX_REG_BETA)
    else:
        adjust_smooth_l1_loss = None

    if cfg.MODEL.RPN.USE_COMBINATION_LOSS:
        combination_weight = cfg.MODEL.RPN.COMBINATION_LOSS.WEIGHT
    else:
        combination_weight = 0

    if cfg.MODEL.RPN.USE_BALANCE_L1_LOSS:
        balance_l1_loss = BalancedL1Loss(
            alpha=cfg.MODEL.RPN.BALANCE_L1_LOSS.ALPHA,
            beta=cfg.MODEL.RPN.BALANCE_L1_LOSS.BETA,
            gamma=cfg.MODEL.RPN.BALANCE_L1_LOSS.GAMMA)
    else:
        balance_l1_loss = None

    loss_evaluator = RPNLossComputation(
        matcher,
        fg_bg_sampler,
        box_coder,
        generate_rpn_labels,
        wing_loss=wing_loss,
        adjust_smooth_l1_loss=adjust_smooth_l1_loss,
        balance_l1_loss=balance_l1_loss,
        focal_loss=focal_loss,
        combination_weight=combination_weight)
    return loss_evaluator
Example #2
0
 def __init__(self, cfg, proposal_matcher, box_coder):
     """
     Arguments:
         proposal_matcher (Matcher)
         box_coder (BoxCoder)
     """
     # self.target_preparator = target_preparator
     self.proposal_matcher = proposal_matcher
     self.box_coder = box_coder
     self.num_classes = cfg.RETINANET.NUM_CLASSES - 1
     self.box_cls_loss_func = SigmoidFocalLoss(self.num_classes,
                                               cfg.RETINANET.LOSS_GAMMA,
                                               cfg.RETINANET.LOSS_ALPHA)
     if cfg.RETINANET.SELFADJUST_SMOOTH_L1:
         self.regression_loss = AdjustSmoothL1Loss(
             4, beta=cfg.RETINANET.BBOX_REG_BETA)
     else:
         self.regression_loss = SmoothL1Loss(
             beta=cfg.RETINANET.BBOX_REG_BETA)
Example #3
0
def make_retinanet_loss_evaluator(cfg, box_coder):
    matcher = Matcher(
        cfg.MODEL.RETINANET.FG_IOU_THRESHOLD,
        cfg.MODEL.RETINANET.BG_IOU_THRESHOLD,
        allow_low_quality_matches=cfg.RETINANET.LOW_QUALITY_MATCHES,
        low_quality_threshold=cfg.RETINANET.LOW_QUALITY_THRESHOLD)
    sigmoid_focal_loss = SigmoidFocalLoss(cfg.MODEL.RETINANET.LOSS_GAMMA,
                                          cfg.MODEL.RETINANET.LOSS_ALPHA)

    if cfg.RETINANET.USE_SELF_ADJUST_SMOOTH_L1_LOSS:
        adjust_smooth_l1_loss = AdjustSmoothL1Loss(
            4, beta=cfg.RETINANET.BBOX_REG_BETA)

    loss_evaluator = RetinaNetLossComputation(
        matcher,
        box_coder,
        generate_retinanet_labels,
        sigmoid_focal_loss=sigmoid_focal_loss,
        adjust_smooth_l1_loss=adjust_smooth_l1_loss,
        bbox_reg_beta=cfg.MODEL.RETINANET.BBOX_REG_BETA,
        regress_norm=cfg.MODEL.RETINANET.BBOX_REG_WEIGHT,
    )
    return loss_evaluator
Example #4
0
def make_roi_box_loss_evaluator(cfg):
    matcher = Matcher(
        cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD,
        cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD,
        allow_low_quality_matches=False,
    )

    bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS
    box_coder = BoxCoder(weights=bbox_reg_weights)

    fg_bg_sampler = BalancedPositiveNegativeSampler(
        cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE,
        cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION)

    cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG

    if cfg.MODEL.ROI_BOX_HEAD.USE_FOCAL_LOSS:
        focal_loss = SigmoidFocalLoss(cfg.MODEL.ROI_BOX_HEAD.FOCAL_LOSS.GAMMA,
                                      cfg.MODEL.ROI_BOX_HEAD.FOCAL_LOSS.ALPHA)
        # focal_loss = SoftmaxFocalLoss(
        #     class_num = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES-1,
        #     gamma=cfg.MODEL.RPN.FOCAL_LOSS.GAMMA,
        #     alpha=cfg.MODEL.RPN.FOCAL_LOSS.ALPHA,
        # )
    else:
        focal_loss = None

    if cfg.MODEL.ROI_BOX_HEAD.USE_CLASS_BALANCE_LOSS and \
        os.path.isfile(cfg.MODEL.ROI_BOX_HEAD.CLASS_BALANCE_LOSS.WEIGHT_FILE):
        num_class_list = ClassBalanceLoss.load_class_samples(
            filename=cfg.MODEL.ROI_BOX_HEAD.CLASS_BALANCE_LOSS.WEIGHT_FILE,
            category_type='category')
        class_balance_weight = ClassBalanceLoss(
            device=torch.device(cfg.MODEL.DEVICE),
            num_class_list=num_class_list,
            alpha=cfg.MODEL.ROI_BOX_HEAD.CLASS_BALANCE_LOSS.ALPHA,
            beta=cfg.MODEL.ROI_BOX_HEAD.CLASS_BALANCE_LOSS.BETA)
    else:
        class_balance_weight = None

    if cfg.MODEL.ROI_BOX_HEAD.USE_WING_LOSS:
        wing_loss = WingLoss(
            width=cfg.MODEL.ROI_BOX_HEAD.WING_LOSS.WIDTH,
            curvature=cfg.MODEL.ROI_BOX_HEAD.WING_LOSS.SIGMA,
        )
    else:
        wing_loss = None

    if cfg.MODEL.ROI_BOX_HEAD.USE_SELF_ADJUST_SMOOTH_L1_LOSS:
        adjust_smooth_l1_loss = AdjustSmoothL1Loss(
            4,
            beta=cfg.MODEL.ROI_BOX_HEAD.SELF_ADJUST_SMOOTH_L1_LOSS.
            BBOX_REG_BETA)
    else:
        adjust_smooth_l1_loss = None

    if cfg.MODEL.ROI_BOX_HEAD.USE_BALANCE_L1_LOSS:
        balance_l1_loss = BalancedL1Loss(
            alpha=cfg.MODEL.ROI_BOX_HEAD.BALANCE_L1_LOSS.ALPHA,
            beta=cfg.MODEL.ROI_BOX_HEAD.BALANCE_L1_LOSS.BETA,
            gamma=cfg.MODEL.ROI_BOX_HEAD.BALANCE_L1_LOSS.GAMMA)
    else:
        balance_l1_loss = None

    loss_evaluator = FastRCNNLossComputation(
        matcher,
        fg_bg_sampler,
        box_coder,
        cls_agnostic_bbox_reg,
        focal_loss=focal_loss,
        class_balance_weight=class_balance_weight,
        wing_loss=wing_loss,
        adjust_smooth_l1_loss=adjust_smooth_l1_loss,
        balance_l1_loss=balance_l1_loss,
    )

    return loss_evaluator