예제 #1
0
def make_model(config, num_thing, num_stuff):
    body_config = config["body"]
    fpn_config = config["fpn"]
    rpn_config = config["rpn"]
    roi_config = config["roi"]

    classes = {
        "total": num_thing + num_stuff,
        "stuff": num_stuff,
        "thing": num_thing
    }

    # BN + activation
    norm_act_static, norm_act_dynamic = norm_act_from_config(body_config)

    # Create backbone (Bottom-up pathway) RESNET TODO: RESNEXT

    log_debug("Creating backbone model %s", body_config["body"])

    body_fn = models.__dict__["net_" + body_config["body"]]

    body_params = body_config.getstruct("body_params") if body_config.get(
        "body_params") else {}

    body = body_fn(norm_act=norm_act_static, **body_params)

    # Load parameters
    if body_config.get("weights"):
        body.load_state_dict(
            torch.load(body_config["weights"], map_location="cpu"))

    # Freeze parameters
    for n, m in body.named_modules():
        for mod_id in range(1, body_config.getint("num_frozen") + 1):
            if ("mod%d" % mod_id) in n:
                freeze_params(m)

    body_channels = body_config.getstruct("out_channels")

    # Create FPN (Top-down pathway)
    fpn_inputs = fpn_config.getstruct("inputs")

    fpn = FPN([body_channels[inp] for inp in fpn_inputs],
              fpn_config.getint("out_channels"),
              fpn_config.getint("extra_scales"), norm_act_static,
              fpn_config["interpolation"])

    # body Bottom-up + Top-down forward pass
    body = FPNBody(body, fpn, fpn_inputs)
    # Create RPN

    # init ProposalGenerator
    proposal_generator = ProposalGenerator(
        rpn_config.getfloat("nms_threshold"),
        rpn_config.getint("num_pre_nms_train"),
        rpn_config.getint("num_post_nms_train"),
        rpn_config.getint("num_pre_nms_val"),
        rpn_config.getint("num_post_nms_val"), rpn_config.getint("min_size"))

    # init AnchorMatcher
    anchor_matcher = AnchorMatcher(
        rpn_config.getint("num_samples"),  # 256
        rpn_config.getfloat("pos_ratio"),  # 0.5
        rpn_config.getfloat("pos_threshold"),  # 0.7
        rpn_config.getfloat("neg_threshold"),  # 0.3
        rpn_config.getfloat("void_threshold"))  # 0.7

    # init Loss TODO check sigma choice
    rpn_loss = RPNLoss(rpn_config.getfloat("sigma"))

    # FPN-based region proposal networks
    rpn_algo = RPNAlgoFPN(proposal_generator, anchor_matcher, rpn_loss,
                          rpn_config.getint("anchor_scale"),
                          rpn_config.getstruct("anchor_ratios"),
                          fpn_config.getstruct("out_strides"),
                          rpn_config.getint("fpn_min_level"),
                          rpn_config.getint("fpn_levels"))

    # RPNHead for two sibling layers Cls Regs
    rpn_head = RPNHead(
        fpn_config.getint("out_channels"),
        len(rpn_config.getstruct("anchor_ratios")),
        #TODO not always true original K=S*R already corrected
        1,  #TODO: add it to config.ini following detectron2
        rpn_config.getint("hidden_channels"),
        norm_act_dynamic)

    # Create instance segmentation network

    # init BbxPredictionGenerator
    bbx_prediction_generator = BbxPredictionGenerator(
        roi_config.getfloat("nms_threshold"),
        roi_config.getfloat("score_threshold"),
        roi_config.getint("max_predictions"))

    # init MskPredictionGenerator
    msk_prediction_generator = MskPredictionGenerator()

    roi_size = roi_config.getstruct("roi_size")

    # init ProposalMatcher
    proposal_matcher = ProposalMatcher(classes,
                                       roi_config.getint("num_samples"),
                                       roi_config.getfloat("pos_ratio"),
                                       roi_config.getfloat("pos_threshold"),
                                       roi_config.getfloat("neg_threshold_hi"),
                                       roi_config.getfloat("neg_threshold_lo"),
                                       roi_config.getfloat("void_threshold"))
    # init DetectionLoss
    bbx_loss = DetectionLoss(roi_config.getfloat("sigma"))

    msk_loss = InstanceSegLoss()

    lbl_roi_size = tuple(s * 2 for s in roi_size)

    roi_algo = InstanceSegAlgoFPN(
        bbx_prediction_generator,
        msk_prediction_generator,
        proposal_matcher,
        bbx_loss,
        msk_loss,
        classes,
        roi_config.getstruct(
            "bbx_reg_weights"),  #bbx_reg_weights = (10., 10., 5., 5.)
        roi_config.getint("fpn_canonical_scale"),
        roi_config.getint("fpn_canonical_level"),
        roi_size,
        roi_config.getint("fpn_min_level"),
        roi_config.getint("fpn_levels"),
        lbl_roi_size,
        roi_config.getboolean("void_is_background"))

    roi_head = FPNMaskHead(fpn_config.getint("out_channels"),
                           classes,
                           roi_size,
                           norm_act=norm_act_dynamic)

    # Create final network
    out = InstanceSegNet(body, rpn_head, roi_head, rpn_algo, roi_algo, classes)
    return out
예제 #2
0
def make_model(config, num_thing, num_stuff):
    body_config = config["body"]
    fpn_config = config["fpn"]
    rpn_config = config["rpn"]
    roi_config = config["roi"]
    sem_config = config["sem"]
    classes = {
        "total": num_thing + num_stuff,
        "stuff": num_stuff,
        "thing": num_thing
    }

    # BN + activation
    norm_act_static, norm_act_dynamic = norm_act_from_config(body_config)

    # Create backbone
    log_debug("Creating backbone model %s", body_config["body"])
    body_fn = models.__dict__["net_" + body_config["body"]]
    body_params = body_config.getstruct("body_params") if body_config.get(
        "body_params") else {}
    body = body_fn(norm_act=norm_act_static, **body_params)

    body_channels = body_config.getstruct("out_channels")

    # Create FPN
    fpn_inputs = fpn_config.getstruct("inputs")
    fpn = FPN([body_channels[inp] for inp in fpn_inputs],
              fpn_config.getint("out_channels"),
              fpn_config.getint("extra_scales"), norm_act_static,
              fpn_config["interpolation"])
    body = FPNBody(body, fpn, fpn_inputs)

    # Create RPN
    proposal_generator = ProposalGenerator(
        rpn_config.getfloat("nms_threshold"),
        rpn_config.getint("num_pre_nms_train"),
        rpn_config.getint("num_post_nms_train"),
        rpn_config.getint("num_pre_nms_val"),
        rpn_config.getint("num_post_nms_val"), rpn_config.getint("min_size"))
    anchor_matcher = AnchorMatcher(rpn_config.getint("num_samples"),
                                   rpn_config.getfloat("pos_ratio"),
                                   rpn_config.getfloat("pos_threshold"),
                                   rpn_config.getfloat("neg_threshold"),
                                   rpn_config.getfloat("void_threshold"))
    rpn_loss = RPNLoss(rpn_config.getfloat("sigma"))
    rpn_algo = RPNAlgoFPN(proposal_generator, anchor_matcher, rpn_loss,
                          rpn_config.getint("anchor_scale"),
                          rpn_config.getstruct("anchor_ratios"),
                          fpn_config.getstruct("out_strides"),
                          rpn_config.getint("fpn_min_level"),
                          rpn_config.getint("fpn_levels"))
    rpn_head = RPNHead(fpn_config.getint("out_channels"),
                       len(rpn_config.getstruct("anchor_ratios")), 1,
                       rpn_config.getint("hidden_channels"), norm_act_dynamic)

    # Create instance segmentation network
    bbx_prediction_generator = BbxPredictionGenerator(
        roi_config.getfloat("nms_threshold"),
        roi_config.getfloat("score_threshold"),
        roi_config.getint("max_predictions"))
    msk_prediction_generator = MskPredictionGenerator()
    roi_size = roi_config.getstruct("roi_size")
    proposal_matcher = ProposalMatcher(classes,
                                       roi_config.getint("num_samples"),
                                       roi_config.getfloat("pos_ratio"),
                                       roi_config.getfloat("pos_threshold"),
                                       roi_config.getfloat("neg_threshold_hi"),
                                       roi_config.getfloat("neg_threshold_lo"),
                                       roi_config.getfloat("void_threshold"))
    bbx_loss = DetectionLoss(roi_config.getfloat("sigma"))
    msk_loss = InstanceSegLoss()
    lbl_roi_size = tuple(s * 2 for s in roi_size)
    roi_algo = InstanceSegAlgoFPN(bbx_prediction_generator,
                                  msk_prediction_generator, proposal_matcher,
                                  bbx_loss, msk_loss, classes,
                                  roi_config.getstruct("bbx_reg_weights"),
                                  roi_config.getint("fpn_canonical_scale"),
                                  roi_config.getint("fpn_canonical_level"),
                                  roi_size, roi_config.getint("fpn_min_level"),
                                  roi_config.getint("fpn_levels"),
                                  lbl_roi_size,
                                  roi_config.getboolean("void_is_background"))
    roi_head = FPNMaskHead(fpn_config.getint("out_channels"),
                           classes,
                           roi_size,
                           norm_act=norm_act_dynamic)

    # Create semantic segmentation network
    sem_loss = SemanticSegLoss(ohem=sem_config.getfloat("ohem"))
    sem_algo = SemanticSegAlgo(sem_loss, classes["total"])
    sem_head = FPNSemanticHeadDeeplab(
        fpn_config.getint("out_channels"),
        sem_config.getint("fpn_min_level"),
        sem_config.getint("fpn_levels"),
        classes["total"],
        pooling_size=sem_config.getstruct("pooling_size"),
        norm_act=norm_act_static)

    # Create final network
    return PanopticNet(body, rpn_head, roi_head, sem_head, rpn_algo, roi_algo,
                       sem_algo, classes)
예제 #3
0
def make_model(config, num_thing, num_stuff):
    body_config = config["body"]
    fpn_config = config["fpn"]
    rpn_config = config["rpn"]
    roi_config = config["roi"]
    classes = {
        "total": num_thing + num_stuff,
        "stuff": num_stuff,
        "thing": num_thing
    }

    # BN + activation
    norm_act_static, norm_act_dynamic = norm_act_from_config(body_config)

    # Create backbone
    log_debug("Creating backbone model %s", body_config["body"])
    body_fn = models.__dict__["net_" + body_config["body"]]
    body_params = body_config.getstruct("body_params") if body_config.get(
        "body_params") else {}
    body = body_fn(norm_act=norm_act_static, **body_params)
    if body_config.get("weights"):
        body.load_state_dict(
            torch.load(body_config["weights"], map_location="cpu"))

    # Freeze parameters
    for n, m in body.named_modules():
        for mod_id in range(1, body_config.getint("num_frozen") + 1):
            if ("mod%d" % mod_id) in n:
                freeze_params(m)

    body_channels = body_config.getstruct("out_channels")

    # Create FPN
    fpn_inputs = fpn_config.getstruct("inputs")
    fpn = FPN([body_channels[inp] for inp in fpn_inputs],
              fpn_config.getint("out_channels"),
              fpn_config.getint("extra_scales"), norm_act_static,
              fpn_config["interpolation"])
    body = FPNBody(body, fpn, fpn_inputs)

    # Create RPN
    proposal_generator = ProposalGenerator(
        rpn_config.getfloat("nms_threshold"),
        rpn_config.getint("num_pre_nms_train"),
        rpn_config.getint("num_post_nms_train"),
        rpn_config.getint("num_pre_nms_val"),
        rpn_config.getint("num_post_nms_val"), rpn_config.getint("min_size"))
    anchor_matcher = AnchorMatcher(rpn_config.getint("num_samples"),
                                   rpn_config.getfloat("pos_ratio"),
                                   rpn_config.getfloat("pos_threshold"),
                                   rpn_config.getfloat("neg_threshold"),
                                   rpn_config.getfloat("void_threshold"))
    rpn_loss = RPNLoss(rpn_config.getfloat("sigma"))
    rpn_algo = RPNAlgoFPN(proposal_generator, anchor_matcher, rpn_loss,
                          rpn_config.getint("anchor_scale"),
                          rpn_config.getstruct("anchor_ratios"),
                          fpn_config.getstruct("out_strides"),
                          rpn_config.getint("fpn_min_level"),
                          rpn_config.getint("fpn_levels"))
    rpn_head = RPNHead(fpn_config.getint("out_channels"),
                       len(rpn_config.getstruct("anchor_ratios")), 1,
                       rpn_config.getint("hidden_channels"), norm_act_dynamic)

    # Create detection network
    prediction_generator = PredictionGenerator(
        roi_config.getfloat("nms_threshold"),
        roi_config.getfloat("score_threshold"),
        roi_config.getint("max_predictions"))
    proposal_matcher = ProposalMatcher(classes,
                                       roi_config.getint("num_samples"),
                                       roi_config.getfloat("pos_ratio"),
                                       roi_config.getfloat("pos_threshold"),
                                       roi_config.getfloat("neg_threshold_hi"),
                                       roi_config.getfloat("neg_threshold_lo"),
                                       roi_config.getfloat("void_threshold"))
    roi_loss = DetectionLoss(roi_config.getfloat("sigma"))
    roi_size = roi_config.getstruct("roi_size")
    roi_algo = DetectionAlgoFPN(prediction_generator, proposal_matcher,
                                roi_loss, classes,
                                roi_config.getstruct("bbx_reg_weights"),
                                roi_config.getint("fpn_canonical_scale"),
                                roi_config.getint("fpn_canonical_level"),
                                roi_size, roi_config.getint("fpn_min_level"),
                                roi_config.getint("fpn_levels"))
    roi_head = FPNROIHead(fpn_config.getint("out_channels"),
                          classes,
                          roi_size,
                          norm_act=norm_act_dynamic)

    # Create final network
    return DetectionNet(body, rpn_head, roi_head, rpn_algo, roi_algo, classes)