Exemplo n.º 1
0
    def resnet_unit(cls, data, name, filter, stride, dilate, proj, norm_type,
                    norm_mom, ndev):
        """
        One resnet unit is comprised of 2 or 3 convolutions and a shortcut.
        :param data:
        :param name:
        :param filter:
        :param stride:
        :param dilate:
        :param proj:
        :param norm_type:
        :param norm_mom:
        :param ndev:
        :return:
        """
        norm = normalizer_factory(type=norm_type, ndev=ndev, mom=norm_mom)

        conv1 = conv(data,
                     name=name + "_conv1",
                     filter=filter // 4,
                     stride=stride)
        bn1 = norm(data=conv1, name=name + "_bn1")
        relu1 = relu(bn1, name=name + "_relu1")

        conv2 = conv(relu1,
                     name=name + "_conv2",
                     filter=filter // 4,
                     kernel=3,
                     dilate=dilate)
        bn2 = norm(data=conv2, name=name + "_bn2")
        relu2 = relu(bn2, name=name + "_relu2")

        conv3 = conv(relu2, name=name + "_conv3", filter=filter)
        bn3 = norm(data=conv3, name=name + "_bn3")

        if proj:
            shortcut = conv(data,
                            name=name + "_sc",
                            filter=filter,
                            stride=stride)
            shortcut = norm(data=shortcut, name=name + "_sc_bn")
        else:
            shortcut = data

        eltwise = add(bn3, shortcut, name=name + "_plus")

        return relu(eltwise, name=name + "_relu")
Exemplo n.º 2
0
    def resnet_c1(cls, data, use_3x3_conv0, use_bn_preprocess, norm_type,
                  norm_mom, ndev):
        """
        Resnet C1 is comprised of irregular initial layers.
        :param data: image symbol
        :param use_3x3_conv0: use three 3x3 convs to replace one 7x7 conv
        :param use_bn_preprocess: use batchnorm as the whitening layer, introduced by tornadomeet
        :param norm_type: normalization method of activation, could be local, fix, sync, gn, in, ibn
        :param norm_mom: normalization momentum, specific to batchnorm
        :param ndev: num of gpus for sync batchnorm
        :return: C1 symbol
        """
        # preprocess
        if use_bn_preprocess:
            data = whiten(data, name="bn_data")

        norm = normalizer_factory(type=norm_type, ndev=ndev, mom=norm_mom)

        # C1
        if use_3x3_conv0:
            data = conv(data, filter=64, kernel=3, stride=2, name="conv0_0")
            data = norm(data, name='bn0_0')
            data = relu(data, name='relu0_0')

            data = conv(data, filter=64, kernel=3, name="conv0_1")
            data = norm(data, name='bn0_1')
            data = relu(data, name='relu0_1')

            data = conv(data, filter=64, kernel=3, name="conv0_2")
            data = norm(data, name='bn0_2')
            data = relu(data, name='relu0_2')
        else:
            data = conv(data, filter=64, kernel=7, stride=2, name="conv0")
            data = norm(data, name='bn0')
            data = relu(data, name='relu0')

        data = pool(data, name="pool0", kernel=3, stride=2, pool_type='max')

        return data
Exemplo n.º 3
0
 class NormalizeParam:
     # normalizer = normalizer_factory(type="syncbn", ndev=8, wd_mult=1.0)
     normalizer = normalizer_factory(type="local")
Exemplo n.º 4
0
    def __init__(self, pBackbone):
        self.p = pBackbone
        self.b = Builder()

    def get_feature(self, data):
        self.symbol = self.b.get_backbone(data, "msra", 50, "c5x2",
                                          self.p.normalizer, self.p.fp16)
        return self.symbol


# TODO: hook import with ResNetV1Builder
# import sys
# sys.modules[__name__] = ResNetV1Builder()

if __name__ == "__main__":
    #############################################################
    # python -m poi.models.common.backbone.resnet_v1
    #############################################################

    h = Builder()
    data = var("data")
    sym = h.get_backbone(data,
                         "msra",
                         50,
                         "fpn",
                         normalizer_factory(type="fixbn"),
                         fp16=True)
    import mxnet as mx
    sym = mx.sym.Group(sym)
    mx.viz.print_summary(sym)