Ejemplo n.º 1
0
def parsing_loss_evaluator():
    matcher = Matcher(
        cfg.FAST_RCNN.FG_IOU_THRESHOLD,
        cfg.FAST_RCNN.BG_IOU_THRESHOLD,
        allow_low_quality_matches=False,
    )

    loss_evaluator = ParsingRCNNLossComputation(matcher, cfg.PRCNN.RESOLUTION)
    return loss_evaluator
Ejemplo n.º 2
0
def hier_loss_evaluator():
    matcher = Matcher(
        cfg.HRCNN.FG_IOU_THRESHOLD,
        cfg.HRCNN.BG_IOU_THRESHOLD,
        allow_low_quality_matches=False,
    )
    resolution = cfg.HRCNN.ROI_XFORM_RESOLUTION
    loss_evaluator = HierRCNNLossComputation(matcher, resolution)
    return loss_evaluator
Ejemplo n.º 3
0
def uv_loss_evaluator():
    matcher = Matcher(
        cfg.UVRCNN.FG_IOU_THRESHOLD,
        cfg.UVRCNN.BG_IOU_THRESHOLD,
        allow_low_quality_matches=False,
    )
    resolution = cfg.UVRCNN.RESOLUTION
    loss_evaluator = UVRCNNLossComputation(matcher, resolution)
    return loss_evaluator
Ejemplo n.º 4
0
def keypoint_loss_evaluator():
    matcher = Matcher(
        cfg.FAST_RCNN.FG_IOU_THRESHOLD,
        cfg.FAST_RCNN.BG_IOU_THRESHOLD,
        allow_low_quality_matches=False,
    )
    resolution = cfg.KRCNN.RESOLUTION

    loss_evaluator = KeypointRCNNLossComputation(matcher, resolution)
    return loss_evaluator
Ejemplo n.º 5
0
def mask_loss_evaluator():
    matcher = Matcher(
        cfg.FAST_RCNN.FG_IOU_THRESHOLD,
        cfg.FAST_RCNN.BG_IOU_THRESHOLD,
        allow_low_quality_matches=False,
    )

    loss_evaluator = MaskRCNNLossComputation(matcher, cfg.MRCNN.RESOLUTION,
                                             cfg.MRCNN.MASKIOU_ON)

    return loss_evaluator
Ejemplo n.º 6
0
def make_rpn_loss_evaluator(box_coder):
    matcher = Matcher(
        cfg.RPN.FG_IOU_THRESHOLD,
        cfg.RPN.BG_IOU_THRESHOLD,
        allow_low_quality_matches=True,
    )
    fg_bg_sampler = BalancedPositiveNegativeSampler(
        cfg.RPN.BATCH_SIZE_PER_IMAGE, cfg.RPN.POSITIVE_FRACTION)

    loss_evaluator = RPNLossComputation(matcher, fg_bg_sampler, box_coder,
                                        generate_rpn_labels)
    return loss_evaluator
Ejemplo n.º 7
0
def loss_evaluator():
    matcher = Matcher(
        cfg.FAST_RCNN.FG_IOU_THRESHOLD,
        cfg.FAST_RCNN.BG_IOU_THRESHOLD,
        allow_low_quality_matches=False,
    )
    loss_weight = cfg.OPLD.LOSS_WEIGHT
    pos_radius = cfg.OPLD.POS_RADIUS
    num_points = cfg.OPLD.NUM_POINTS
    roi_feat_size = cfg.OPLD.ROI_FEAT_SIZE

    evaluator = HGridLossComputation(
        loss_weight,
        matcher,
        pos_radius,
        num_points,
        roi_feat_size,
    )
    return evaluator
Ejemplo n.º 8
0
def box_loss_evaluator():
    matcher = Matcher(
        cfg.FAST_RCNN.FG_IOU_THRESHOLD,
        cfg.FAST_RCNN.BG_IOU_THRESHOLD,
        allow_low_quality_matches=False,
    )

    bbox_reg_weights = cfg.FAST_RCNN.BBOX_REG_WEIGHTS
    box_coder = BoxCoder(weights=bbox_reg_weights)

    fg_bg_sampler = BalancedPositiveNegativeSampler(
        cfg.FAST_RCNN.BATCH_SIZE_PER_IMAGE, cfg.FAST_RCNN.POSITIVE_FRACTION)

    cls_agnostic_bbox_reg = cfg.FAST_RCNN.CLS_AGNOSTIC_BBOX_REG
    cls_on = cfg.FAST_RCNN.CLS_ON
    reg_on = cfg.FAST_RCNN.REG_ON

    loss_evaluator = FastRCNNLossComputation(matcher, fg_bg_sampler, box_coder,
                                             cls_agnostic_bbox_reg, cls_on,
                                             reg_on)

    return loss_evaluator
Ejemplo n.º 9
0
def box_loss_evaluator(idx):
    matcher = Matcher(
        cfg.CASCADE_RCNN.FG_IOU_THRESHOLD[idx],
        cfg.CASCADE_RCNN.BG_IOU_THRESHOLD[idx],
        allow_low_quality_matches=False,
    )
    bbox_reg_weights = cfg.CASCADE_RCNN.BBOX_REG_WEIGHTS[idx]
    box_coder = BoxCoder(weights=bbox_reg_weights)

    fg_bg_sampler = BalancedPositiveNegativeSampler(
        cfg.FAST_RCNN.BATCH_SIZE_PER_IMAGE, cfg.FAST_RCNN.POSITIVE_FRACTION
    )

    cls_agnostic_bbox_reg = cfg.FAST_RCNN.CLS_AGNOSTIC_BBOX_REG

    loss_evaluator = CascadeRCNNLossComputation(
        matcher,
        fg_bg_sampler,
        box_coder,
        cls_agnostic_bbox_reg
    )
    return loss_evaluator