示例#1
0
def dcn_resnet_unit(input, name, filter, stride, dilate, proj, norm, **kwargs):
    conv1 = conv(input, name=name + "_conv1", filter=filter // 4)
    bn1 = norm(conv1, name=name + "_bn1")
    relu1 = relu(bn1, name=name + "_relu1")

    # conv2 filter router
    conv2_offset = conv(relu1, name=name + "_conv2_offset", filter=72, kernel=3, stride=stride, dilate=dilate)
    conv2 = mx.sym.contrib.DeformableConvolution(relu1, conv2_offset, kernel=(3, 3),
        stride=(stride, stride), dilate=(dilate, dilate), pad=(1, 1), num_filter=filter // 4,
        num_deformable_group=4, no_bias=True, name=name + "_conv2")
    bn2 = norm(conv2, name=name + "_bn2")
    relu2 = relu(bn2, name=name + "_relu2")

    conv3 = conv(relu2, name=name + "_conv3", filter=filter)
    bn3 = norm(conv3, name=name + "_bn3")

    if proj:
        shortcut = conv(input, name=name + "_sc", filter=filter, stride=stride)
        shortcut = norm(shortcut, name=name + "_sc_bn")
    else:
        shortcut = input

    eltwise = add(bn3, shortcut, name=name + "_plus")

    return relu(eltwise, name=name + "_relu")
示例#2
0
def se_v2_resnet_v1b_unit(input, name, filter, stride, dilate, proj, norm,
                          **kwargs):
    """
    diff with v1: move the SE module to 3x3 conv
    """
    conv1 = conv(input, name=name + "_conv1", filter=filter // 4)
    bn1 = norm(conv1, name=name + "_bn1")
    relu1 = relu(bn1, name=name + "_relu1")

    conv2 = conv(relu1,
                 name=name + "_conv2",
                 stride=stride,
                 filter=filter // 4,
                 kernel=3)
    bn2 = norm(conv2, name=name + "_bn2")
    relu2 = relu(bn2, name=name + "_relu2")
    relu2 = se(relu2,
               prefix=name + "_se2",
               f_down=filter // 16,
               f_up=filter // 4)

    conv3 = conv(relu2, name=name + "_conv3", filter=filter)
    bn3 = norm(conv3, name=name + "_bn3")

    if proj:
        shortcut = conv(input, name=name + "_sc", filter=filter, stride=stride)
        shortcut = norm(shortcut, name=name + "_sc_bn")
    else:
        shortcut = input

    eltwise = add(bn3, shortcut, name=name + "_plus")

    return relu(eltwise, name=name + "_relu")
示例#3
0
def trident_resnet_v1b_unit(input, name, id, filter, stride, dilate, proj, **kwargs):
    """
    Compared with v1, v1b moves stride=2 to the 3x3 conv instead of the 1x1 conv and use std in pre-processing
    This is also known as the facebook re-implementation of ResNet(a.k.a. the torch ResNet)
    """
    p = kwargs["params"]
    share_bn = p.branch_bn_shared
    share_conv = p.branch_conv_shared
    norm = p.normalizer

    ######################### prepare names #########################
    if id is not None:
        conv_postfix = ("_shared%s" if share_conv else "_branch%s") % id
        bn_postfix = ("_shared%s" if share_bn else "_branch%s") % id
        other_postfix = "_branch%s" % id
    else:
        conv_postfix = ""
        bn_postfix = ""
        other_postfix = ""

    ######################### prepare parameters #########################
    conv_params = lambda x: dict(
        weight=X.shared_var(name + "_%s_weight" % x) if share_conv else None,
        name=name + "_%s" % x + conv_postfix
    )

    def bn_params(x):
        ret = dict(
            gamma=X.shared_var(name + "_%s_gamma" % x) if share_bn else None,
            beta=X.shared_var(name + "_%s_beta" % x) if share_bn else None,
            moving_mean=X.shared_var(name + "_%s_moving_mean" % x) if share_bn else None,
            moving_var=X.shared_var(name + "_%s_moving_var" % x) if share_bn else None,
            name=name + "_%s" % x + bn_postfix
        )
        if norm.__name__ == "gn":
            del ret["moving_mean"], ret["moving_var"]
        return ret

    ######################### construct graph #########################
    conv1 = conv(input, filter=filter // 4, **conv_params("conv1"))
    bn1 = norm(conv1, **bn_params("bn1"))
    relu1 = relu(bn1, name=name + other_postfix)

    conv2 = conv(relu1, filter=filter // 4, kernel=3, stride=stride, dilate=dilate, **conv_params("conv2"))
    bn2 = norm(conv2, **bn_params("bn2"))
    relu2 = relu(bn2, name=name + other_postfix)

    conv3 = conv(relu2, filter=filter, **conv_params("conv3"))
    bn3 = norm(conv3, **bn_params("bn3"))

    if proj:
        shortcut = conv(input, filter=filter, stride=stride, **conv_params("sc"))
        shortcut = norm(shortcut, **bn_params("sc_bn"))
    else:
        shortcut = input

    eltwise = add(bn3, shortcut, name=name + "_plus" + other_postfix)

    return relu(eltwise, name=name + "_relu" + other_postfix)
示例#4
0
def trident_resnet_v1_unit(input, name, id, filter, stride, dilate, proj, **kwargs):
    p = kwargs["params"]
    share_bn = p.branch_bn_shared
    share_conv = p.branch_conv_shared
    norm = p.normalizer

    ######################### prepare names #########################
    if id is not None:
        conv_postfix = ("_shared%s" if share_conv else "_branch%s") % id
        bn_postfix = ("_shared%s" if share_bn else "_branch%s") % id
        other_postfix = "_branch%s" % id
    else:
        conv_postfix = ""
        bn_postfix = ""
        other_postfix = ""

    ######################### prepare parameters #########################
    conv_params = lambda x: dict(
        weight=X.shared_var(name + "_%s_weight" % x) if share_conv else None,
        name=name + "_%s" % x + conv_postfix
    )

    bn_params = lambda x: dict(
        gamma=X.shared_var(name + "_%s_gamma" % x) if share_bn else None,
        beta=X.shared_var(name + "_%s_beta" % x) if share_bn else None,
        moving_mean=X.shared_var(name + "_%s_moving_mean" % x) if share_bn else None,
        moving_var=X.shared_var(name + "_%s_moving_var" % x) if share_bn else None,
        name=name + "_%s" % x + bn_postfix
    )

    ######################### construct graph #########################
    conv1 = conv(input, filter=filter // 4, stride=stride, **conv_params("conv1"))
    bn1 = norm(conv1, **bn_params("bn1"))
    relu1 = relu(bn1, name=name + other_postfix)

    conv2 = conv(relu1, filter=filter // 4, kernel=3, dilate=dilate, **conv_params("conv2"))
    bn2 = norm(conv2, **bn_params("bn2"))
    relu2 = relu(bn2, name=name + other_postfix)

    conv3 = conv(relu2, filter=filter, **conv_params("conv3"))
    bn3 = norm(conv3, **bn_params("bn3"))

    if proj:
        shortcut = conv(input, filter=filter, stride=stride, **conv_params("sc"))
        shortcut = norm(shortcut, **bn_params("sc_bn"))
    else:
        shortcut = input

    eltwise = add(bn3, shortcut, name=name + "_plus" + other_postfix)

    return relu(eltwise, name=name + "_relu" + other_postfix)
示例#5
0
    def resnet_trident_unit(cls,
                            data,
                            name,
                            filter,
                            stride,
                            dilate,
                            proj,
                            norm_type,
                            norm_mom,
                            ndev,
                            branch_ids,
                            branch_bn_shared,
                            branch_conv_shared,
                            branch_deform=False):
        """
        One resnet unit is comprised of 2 or 3 convolutions and a shortcut.
        :param data:
        :param name:
        :param filter:
        :param stride:
        :param dilate:
        :param proj:
        :param norm_type:
        :param norm_mom:
        :param ndev:
        :param branch_ids:
        :param branch_bn_shared:
        :param branch_conv_shared:
        :param branch_deform:
        :return:
        """
        if branch_ids is None:
            branch_ids = range(len(data))

        norm = X.normalizer_factory(type=norm_type, ndev=ndev, mom=norm_mom)

        bn1 = cls.bn_shared(data,
                            name=name + "_bn1",
                            normalizer=norm,
                            branch_ids=branch_ids,
                            share_weight=branch_bn_shared)
        relu1 = [X.relu(bn) for bn in bn1]
        conv1 = cls.conv_shared(relu1,
                                name=name + "_conv1",
                                num_filter=filter // 4,
                                kernel=(1, 1),
                                branch_ids=branch_ids,
                                share_weight=branch_conv_shared)

        bn2 = cls.bn_shared(conv1,
                            name=name + "_bn2",
                            normalizer=norm,
                            branch_ids=branch_ids,
                            share_weight=branch_bn_shared)
        relu2 = [X.relu(bn) for bn in bn2]
        if not branch_deform:
            conv2 = cls.conv_shared(relu2,
                                    name=name + "_conv2",
                                    num_filter=filter // 4,
                                    kernel=(3, 3),
                                    pad=dilate,
                                    stride=stride,
                                    dilate=dilate,
                                    branch_ids=branch_ids,
                                    share_weight=branch_conv_shared)
        else:
            conv2_offset = cls.conv_shared(relu2,
                                           name=name + "_conv2_offset",
                                           num_filter=72,
                                           kernel=(3, 3),
                                           pad=(1, 1),
                                           stride=(1, 1),
                                           dilate=(1, 1),
                                           no_bias=False,
                                           branch_ids=branch_ids,
                                           share_weight=branch_conv_shared)
            conv2 = cls.deform_conv_shared(relu2,
                                           name=name + "_conv2",
                                           conv_offset=conv2_offset,
                                           num_filter=filter // 4,
                                           kernel=(3, 3),
                                           pad=dilate,
                                           stride=stride,
                                           dilate=dilate,
                                           num_deformable_group=4,
                                           branch_ids=branch_ids,
                                           share_weight=branch_conv_shared)

        bn3 = cls.bn_shared(conv2,
                            name=name + "_bn3",
                            normalizer=norm,
                            branch_ids=branch_ids,
                            share_weight=branch_bn_shared)
        relu3 = [X.relu(bn) for bn in bn3]
        conv3 = cls.conv_shared(relu3,
                                name=name + "_conv3",
                                num_filter=filter,
                                kernel=(1, 1),
                                branch_ids=branch_ids,
                                share_weight=branch_conv_shared)

        if proj:
            shortcut = cls.conv_shared(relu1,
                                       name=name + "_sc",
                                       num_filter=filter,
                                       kernel=(1, 1),
                                       branch_ids=branch_ids,
                                       share_weight=branch_conv_shared)
        else:
            shortcut = data

        return [X.add(conv3_i, shortcut_i, name=name + "_plus_branch{}".format(i)) \
                for i, conv3_i, shortcut_i in zip(branch_ids, conv3, shortcut)]