Exemple #1
0
    def _get_bbox_head_logit(self, conv_feat):
        if self._head_feat is not None:
            return self._head_feat

        p = self.p

        if p.normalizer.__name__ == "fix_bn":
            conv_feat = X.convrelu(conv_feat, filter=256, kernel=3, name="bbox_conv1")
            conv_feat = X.convrelu(conv_feat, filter=256, kernel=3, name="bbox_conv2")
            conv_feat = X.convrelu(conv_feat, filter=256, kernel=3, name="bbox_conv3")
            conv_feat = X.convrelu(conv_feat, filter=256, kernel=3, name="bbox_conv4")
        elif p.normalizer.__name__ in ["sync_bn", "gn"]:
            conv_feat = X.convnormrelu(p.normalizer, conv_feat, filter=256, kernel=3, name="bbox_conv1")
            conv_feat = X.convnormrelu(p.normalizer, conv_feat, filter=256, kernel=3, name="bbox_conv2")
            conv_feat = X.convnormrelu(p.normalizer, conv_feat, filter=256, kernel=3, name="bbox_conv3")
            conv_feat = X.convnormrelu(p.normalizer, conv_feat, filter=256, kernel=3, name="bbox_conv4")
        else:
            raise NotImplementedError("Unsupported normalizer: {}".format(p.normalizer.__name__))

        flatten = X.flatten(conv_feat, name="bbox_feat_flatten")
        reshape = X.reshape(flatten, (0, 0, 1, 1), name="bbox_feat_reshape")

        if p.normalizer.__name__ == "fix_bn":
            fc1 = X.convrelu(reshape, filter=1024, name="bbox_fc1")
        elif p.normalizer.__name__ in ["sync_bn", "gn"]:
            fc1 = X.convnormrelu(p.normalizer, reshape, filter=1024, name="bbox_fc1")
        else:
            raise NotImplementedError("Unsupported normalizer: {}".format(p.normalizer.__name__))

        self._head_feat = fc1

        return self._head_feat
Exemple #2
0
    def _get_bbox_head_logit(self, conv_feat):
        # comment this for re-infer in test stage
        # if self._head_feat is not None:
        #     return self._head_feat

        p = self.p
        stage = self.stage

        flatten = X.flatten(conv_feat, name="bbox_feat_flatten_" + stage)
        reshape = X.reshape(flatten, (0, 0, 1, 1),
                            name="bbox_feat_reshape_" + stage)

        if p.normalizer.__name__ == "fix_bn":
            fc1 = X.convrelu(reshape,
                             filter=1024,
                             weight=self.fc1_weight,
                             bias=self.fc1_bias,
                             no_bias=False,
                             name="bbox_fc1_" + stage)
            fc2 = X.convrelu(fc1,
                             filter=1024,
                             weight=self.fc2_weight,
                             bias=self.fc2_bias,
                             no_bias=False,
                             name="bbox_fc2_" + stage)
        elif p.normalizer.__name__ in ["sync_bn", "gn"]:
            fc1 = X.convnormrelu(p.normalizer,
                                 reshape,
                                 filter=1024,
                                 weight=self.fc1_weight,
                                 bias=self.fc1_bias,
                                 no_bias=False,
                                 name="bbox_fc1_" + stage)
            fc2 = X.convnormrelu(p.normalizer,
                                 fc1,
                                 filter=1024,
                                 weight=self.fc2_weight,
                                 bias=self.fc2_bias,
                                 no_bias=False,
                                 name="bbox_fc2_" + stage)
        else:
            raise NotImplementedError("Unsupported normalizer: {}".format(
                p.normalizer.__name__))

        self._head_feat = fc2

        return self._head_feat
Exemple #3
0
    def get_output(self, conv_feat):
        if self._cls_logit is not None and self._bbox_delta is not None:
            return self._cls_logit, self._bbox_delta

        p = self.p
        num_base_anchor = len(p.anchor_generate.ratio) * len(p.anchor_generate.scale)
        conv_channel = p.head.conv_channel

        if p.normalizer.__name__ == "fix_bn":
            conv = X.convrelu(
                conv_feat,
                kernel=3,
                filter=conv_channel,
                name="rpn_conv_3x3",
                no_bias=False,
                init=X.gauss(0.01)
            )
        elif p.normalizer.__name__ in ["sync_bn", "gn"]:
            conv = X.convnormrelu(
                p.normalizer,
                conv_feat,
                kernel=3,
                filter=conv_channel,
                name="rpn_conv_3x3",
                no_bias=False,
                init=X.gauss(0.01)
            )
        else:
            raise NotImplementedError("Unsupported normalizer: {}".format(p.normalizer.__name__))

        if p.fp16:
            conv = X.to_fp32(conv, name="rpn_conv_3x3_fp32")

        cls_logit = X.conv(
            conv,
            filter=2 * num_base_anchor,
            name="rpn_cls_logit",
            no_bias=False,
            init=X.gauss(0.01)
        )

        bbox_delta = X.conv(
            conv,
            filter=4 * num_base_anchor,
            name="rpn_bbox_delta",
            no_bias=False,
            init=X.gauss(0.01)
        )

        self._cls_logit = cls_logit
        self._bbox_delta = bbox_delta

        return self._cls_logit, self._bbox_delta
Exemple #4
0
    def get_rcnn_feature(self, rcnn_feat):
        p = self.p

        if p.normalizer.__name__ == "fix_bn":
            rcnn_feat = X.convrelu(rcnn_feat,
                                   filter=p.reduce.channel,
                                   kernel=3,
                                   name="backbone_reduce")
        elif p.normalizer.__name__ in ["sync_bn", "gn"]:
            rcnn_feat = X.convnormrelu(p.normalizer,
                                       rcnn_feat,
                                       filter=p.reduce.channel,
                                       kernel=3,
                                       name="backbone_reduce")
        else:
            raise NotImplementedError("Unsupported normalizer: {}".format(
                p.normalizer.__name__))

        return rcnn_feat