コード例 #1
0
    def __init__(
        self,
        backbone: Union[nn.Module, Mapping],
        neck: Union[nn.Module, Mapping, None] = None,
        *,
        pretrained_backbone: Optional[str] = None,
        output_shapes: List[ShapeSpec],
        output_names: Optional[List[str]] = None,
    ):
        """
        Args:
            backbone: either a backbone module or a mmdet config dict that defines a
                backbone. The backbone takes a 4D image tensor and returns a
                sequence of tensors.
            neck: either a backbone module or a mmdet config dict that defines a
                neck. The neck takes outputs of backbone and returns a
                sequence of tensors. If None, no neck is used.
            pretrained_backbone: defines the backbone weights that can be loaded by
                mmdet, such as "torchvision://resnet50".
            output_shapes: shape for every output of the backbone (or neck, if given).
                stride and channels are often needed.
            output_names: names for every output of the backbone (or neck, if given).
                By default, will use "out0", "out1", ...
        """
        super().__init__()
        if isinstance(backbone, Mapping):
            from mmdet.models import build_backbone

            backbone = build_backbone(_to_container(backbone))
        self.backbone = backbone

        if isinstance(neck, Mapping):
            from mmdet.models import build_neck

            neck = build_neck(_to_container(neck))
        self.neck = neck

        # It's confusing that backbone weights are given as a separate argument,
        # but "neck" weights, if any, are part of neck itself. This is the interface
        # of mmdet so we follow it. Reference:
        # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/two_stage.py
        logger.info(f"Initializing mmdet backbone weights: {pretrained_backbone} ...")
        self.backbone.init_weights(pretrained_backbone)
        # train() in mmdet modules is non-trivial, and has to be explicitly
        # called. Reference:
        # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py
        self.backbone.train()
        if self.neck is not None:
            logger.info("Initializing mmdet neck weights ...")
            if isinstance(self.neck, nn.Sequential):
                for m in self.neck:
                    m.init_weights()
            else:
                self.neck.init_weights()
            self.neck.train()

        self._output_shapes = output_shapes
        if not output_names:
            output_names = [f"out{i}" for i in range(len(output_shapes))]
        self._output_names = output_names
コード例 #2
0
 def __init__(self,
              backbone,
              neck,
              neck_3d,
              bbox_head,
              n_voxels,
              anchor_generator,
              train_cfg=None,
              test_cfg=None,
              pretrained=None,
              init_cfg=None):
     super().__init__(init_cfg=init_cfg)
     self.backbone = build_backbone(backbone)
     self.neck = build_neck(neck)
     self.neck_3d = build_neck(neck_3d)
     bbox_head.update(train_cfg=train_cfg)
     bbox_head.update(test_cfg=test_cfg)
     self.bbox_head = build_head(bbox_head)
     self.n_voxels = n_voxels
     self.anchor_generator = build_anchor_generator(anchor_generator)
     self.train_cfg = train_cfg
     self.test_cfg = test_cfg
コード例 #3
0
 def __init__(self,
              backbone,
              neck=None,
              bbox_head=None,
              train_cfg=None,
              test_cfg=None,
              init_cfg=None,
              pretrained=None):
     super(SingleStage3DDetector, self).__init__(init_cfg)
     self.backbone = build_backbone(backbone)
     if neck is not None:
         self.neck = build_neck(neck)
     bbox_head.update(train_cfg=train_cfg)
     bbox_head.update(test_cfg=test_cfg)
     self.bbox_head = build_head(bbox_head)
     self.train_cfg = train_cfg
     self.test_cfg = test_cfg
コード例 #4
0
def main():
    args = parse_args()

    config = mmcv.Config.fromfile(os.path.join(root, args.config))
    data = torch.randn(1, 3, 800, 800)

    with torch.no_grad():
        backbone = build_backbone(config.model.backbone)
        neck = build_neck(config.model.neck)
        rpn_head = build_head(config.model.rpn_head)

    backbone.eval()
    neck.eval()
    rpn_head.eval()

    #torch.jit.save(rpn_head,'./rpn_head.pt')
    exit()
コード例 #5
0
ファイル: two_stage.py プロジェクト: anorthman/custom
    def __init__(self, cfg, train_cfg, test_cfg):
        super(BaseDetector, self).__init__()

        if 'neck' in cfg:
            self.neck = build_neck(cfg['neck'])

        if 'rpn_head' in cfg:
            self.rpn_head = build_head(cfg['rpn_head'])

        if 'bbox_head' in cfg:
            self.bbox_roi_extractor = build_roi_extractor(
                cfg['bbox_roi_extractor'])
            self.bbox_head = build_head(cfg['bbox_head'])

        if 'mask_head' in cfg:
            self.mask_roi_extractor = build_roi_extractor(
                cfg['mask_roi_extractor'])
            self.mask_head = build_head(cfg['mask_head'])

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg