Esempio n. 1
0
def build_psr_fpn_backbone(cfg, input_shape: ShapeSpec, num_classes=None):
    """
    Args:
        cfg: a detectron2 CfgNode

    Returns:
        backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
    """
    bottom_up = build_resnet_backbone(cfg, input_shape)
    in_features = cfg.MODEL.FPN.IN_FEATURES
    out_channels = cfg.MODEL.FPN.OUT_CHANNELS
    convf_name = cfg.MODEL.CUSTOM.FPN.CONVF_NAME
    noise_var = cfg.MODEL.CUSTOM.FPN.NOISE_VAR
    num_branch = cfg.MODEL.CUSTOM.BRANCH.NUM_BRANCH

    backbone = PSRFPN(bottom_up=bottom_up,
                      in_features=in_features,
                      out_channels=out_channels,
                      convf_name=convf_name,
                      num_branch=num_branch,
                      noise_var=noise_var,
                      norm=cfg.MODEL.FPN.NORM,
                      top_block=LastLevelMaxPool(),
                      fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
                      num_classes=num_classes)
    return backbone
Esempio n. 2
0
    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        self.config = config
        pretrained = config.get("pretrained", False)
        pretrained_path = config.get("pretrained_path", None)

        self.resnet = build_resnet_backbone(config, ShapeSpec(channels=3))

        if pretrained:
            state_dict = torch.hub.load_state_dict_from_url(pretrained_path,
                                                            progress=False)
            new_state_dict = OrderedDict()
            replace_layer = {"backbone.": ""}

            for key, value in state_dict["model"].items():
                new_key = re.sub(r"(backbone\.)",
                                 lambda x: replace_layer[x.groups()[0]], key)
                new_state_dict[new_key] = value
            self.resnet.load_state_dict(new_state_dict, strict=False)

        self.out_dim = 2048
Esempio n. 3
0
def build_normal_fpn_pretrain_backbone(cfg,
                                       input_shape: ShapeSpec,
                                       num_classes=None):
    """
    Args:
        cfg: a detectron2 CfgNode

    Returns:
        backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
    """
    bottom_up = build_resnet_backbone(cfg, input_shape)
    in_features = cfg.MODEL.FPN.IN_FEATURES
    out_channels = cfg.MODEL.FPN.OUT_CHANNELS
    noise_var = cfg.MODEL.CUSTOM.FPN.NOISE_VAR

    backbone = FPN(bottom_up=bottom_up,
                   in_features=in_features,
                   out_channels=out_channels,
                   noise_var=noise_var,
                   norm=cfg.MODEL.FPN.NORM,
                   top_block=LastLevelMaxPool(),
                   fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
                   num_classes=num_classes)
    return backbone