예제 #1
0
 def output_shape(self):
     return {
         name: ShapeSpec(
             channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
         )
         for name in self._out_features
     }
예제 #2
0
def build_backbone(cfg, input_shape=None):
    if input_shape is None:
        input_shape = ShapeSpec(channels=len(cfg.INPUT.PIXEL_MEAN))

    if 'PLAN' in cfg.BACKBONE:
        pyramidal = cfg.BACKBONE.NAME
        pyramidal = BACKBONE_REGISTRY.get(pyramidal)(cfg, input_shape)
        backbone = BACKBONE_REGISTRY.get('PLAN')(cfg, input_shape, pyramidal)

    else:
        backbone = cfg.BACKBONE.NAME
        backbone = BACKBONE_REGISTRY.get(backbone)(cfg, input_shape)

    assert isinstance(backbone, Backbone)
    return backbone
예제 #3
0
파일: head.py 프로젝트: major196512/vistem
    def __init__(self, cfg, input_shape):
        super().__init__(cfg)

        self.in_features = cfg.META_ARCH.ROI.IN_FEATURES
        self.num_classes = cfg.META_ARCH.NUM_CLASSES
        in_channels = [input_shape[f].channels for f in self.in_features]

        assert len(set(in_channels)) == 1, in_channels
        for feat in self.in_features:
            assert feat in input_shape.keys(
            ), f"'{feat}' is not in backbone({input_shape.keys()})"

        # Matcher
        iou_thres = cfg.META_ARCH.ROI.MATCHER.IOU_THRESHOLDS
        iou_labels = cfg.META_ARCH.ROI.MATCHER.IOU_LABELS
        allow_low_quality_matches = cfg.META_ARCH.ROI.MATCHER.LOW_QUALITY_MATCHES
        self.proposal_matcher = Matcher(
            iou_thres,
            iou_labels,
            allow_low_quality_matches=allow_low_quality_matches)

        # Sampling
        self.proposal_append_gt = cfg.META_ARCH.ROI.SAMPLING.PROPOSAL_APPEND_GT
        self.batch_size_per_image = cfg.META_ARCH.ROI.SAMPLING.BATCH_SIZE_PER_IMAGE
        self.positive_fraction = cfg.META_ARCH.ROI.SAMPLING.POSITIVE_FRACTION

        # Pooling Parameters and Module
        box_pooler_type = cfg.META_ARCH.ROI.BOX_POOLING.TYPE
        box_pooler_resolution = cfg.META_ARCH.ROI.BOX_POOLING.RESOLUTION
        box_pooler_sampling_ratio = cfg.META_ARCH.ROI.BOX_POOLING.SAMPLING_RATIO
        self.box_pooler = ROIPooler(
            output_size=box_pooler_resolution,
            scales=tuple(1.0 / input_shape[k].stride
                         for k in self.in_features),
            sampling_ratio=box_pooler_sampling_ratio,
            pooler_type=box_pooler_type,
        )

        # Loss parameters
        self.loss_weight = cfg.META_ARCH.ROI.BOX_LOSS.LOSS_WEIGHT
        self.smooth_l1_beta = cfg.META_ARCH.ROI.BOX_LOSS.SMOOTH_L1_BETA

        if isinstance(self.loss_weight, float):
            self.loss_weight = {
                "loss_cls": self.loss_weight,
                "loss_loc": self.loss_weight
            }
        assert 'loss_cls' in self.loss_weight
        assert 'loss_loc' in self.loss_weight

        # Inference parameters
        bbox_reg_weights = cfg.META_ARCH.ROI.TEST.BBOX_REG_WEIGHTS
        self.box2box_transform = Box2BoxTransform(weights=bbox_reg_weights)

        self.test_nms_thresh = cfg.META_ARCH.ROI.TEST.NMS_THRESH
        self.score_threshhold = cfg.TEST.SCORE_THRESH
        self.max_detections_per_image = cfg.TEST.DETECTIONS_PER_IMAGE

        # ROI Head
        self.box_head = BoxHead(
            cfg,
            ShapeSpec(channels=in_channels[0],
                      height=box_pooler_resolution,
                      width=box_pooler_resolution))