Exemplo n.º 1
0
    def __init__(self, cfg):
        super(RetinaNetModule, self).__init__()

        self.cfg = cfg.clone()

        anchor_generator = make_anchor_generator_retinanet(cfg)
        head = RetinaNetHead(cfg)
        box_coder = BoxCoder(weights=(10., 10., 5., 5.))

        box_selector_test = make_retinanet_postprocessor(cfg, 100, box_coder)
        box_selector_train = None
        if self.cfg.MODEL.MASK_ON:
            box_selector_train = make_retinanet_postprocessor(
                cfg, 100, box_coder)

        loss_evaluator = make_retinanet_loss_evaluator(cfg, box_coder)

        self.anchor_generator = anchor_generator
        self.head = head
        self.box_selector_test = box_selector_test
        self.box_selector_train = box_selector_train
        self.loss_evaluator = loss_evaluator
        self.freeze = cfg.MODEL.RETINANET.FREEZE

        if self.freeze:
            dfs_freeze(self, requires_grad=False)
Exemplo n.º 2
0
    def __init__(self, cfg, in_channels):
        super(RPNModule, self).__init__()

        self.cfg = cfg.clone()

        anchor_generator = make_anchor_generator(cfg)

        rpn_head = registry.RPN_HEADS[cfg.MODEL.RPN.RPN_HEAD]
        head = rpn_head(cfg, in_channels,
                        anchor_generator.num_anchors_per_location()[0])

        rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

        box_selector_train = make_rpn_postprocessor(cfg,
                                                    rpn_box_coder,
                                                    is_train=True)
        box_selector_test = make_rpn_postprocessor(cfg,
                                                   rpn_box_coder,
                                                   is_train=False)

        loss_evaluator = make_rpn_loss_evaluator(cfg, rpn_box_coder)

        self.anchor_generator = anchor_generator
        self.head = head
        self.box_selector_train = box_selector_train
        self.box_selector_test = box_selector_test
        self.loss_evaluator = loss_evaluator
        self.use_extended_features = cfg.MODEL.RPN.USE_EXTENDED_FEATURES
        self.freeze = cfg.MODEL.RPN.FREEZE

        if self.freeze:
            dfs_freeze(self, requires_grad=False)
Exemplo n.º 3
0
    def __init__(
        self,
        in_channels,
        refine_level=2,
        refine_type='none',
        use_gn=False,
        freeze=False,
    ):
        super(BFP, self).__init__()
        assert refine_type in ['none', 'conv', 'non_local']

        self.in_channels = in_channels
        self.refine_level = refine_level
        self.refine_type = refine_type
        assert 0 <= self.refine_level

        if self.refine_type == 'conv':
            self.refine = make_conv3x3(self.in_channels,
                                       self.in_channels,
                                       use_gn=use_gn,
                                       use_relu=True,
                                       kaiming_init=True)
        elif self.refine_type == 'non_local':
            self.refine = NonLocal2D(
                self.in_channels,
                reduction=1,
                use_scale=False,
                use_gn=use_gn,
            )
        else:
            self.refine = None

        self.freeze = freeze
        if self.freeze:
            dfs_freeze(self, requires_grad=False)
Exemplo n.º 4
0
def build_backbone(cfg):
    assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \
        "cfg.MODEL.BACKBONE.CONV_BODY: {} are not registered in registry".format(
            cfg.MODEL.BACKBONE.CONV_BODY
        )
    model = registry.BACKBONES[cfg.MODEL.BACKBONE.CONV_BODY](cfg)

    if cfg.MODEL.BACKBONE.FREEZE:
        dfs_freeze(model, requires_grad=False)
    return model
Exemplo n.º 5
0
    def __init__(self, cfg, in_channels):
        super(ROIBoxHead, self).__init__(cfg=cfg)
        self.use_extended_features = cfg.MODEL.ROI_BOX_HEAD.USE_EXTENDED_FEATURES
        self.feature_extractor = make_roi_box_feature_extractor(cfg, in_channels)
        self.predictor = make_roi_box_predictor(cfg, self.feature_extractor.out_channels)
        self.post_processor = make_roi_box_post_processor(cfg)
        self.loss_evaluator = make_roi_box_loss_evaluator(cfg)
        self.freeze = cfg.MODEL.ROI_BOX_HEAD.FREEZE

        if self.freeze:
            dfs_freeze(self, requires_grad=False)
Exemplo n.º 6
0
def build_roi_maskiou_head(cfg):
    model = ROIMaskIoUHead(cfg)

    if cfg.MODEL.ROI_MASK_HEAD.FREEZE:
        dfs_freeze(model, requires_grad=False)
    return model
Exemplo n.º 7
0
def build_roi_keypoint_head(cfg, in_channels):
    model = ROIKeypointHead(cfg, in_channels)

    if cfg.MODEL.ROI_KEYPOINT_HEAD.FREEZE:
        dfs_freeze(model, requires_grad=False)
    return model