Ejemplo n.º 1
0
def identity_block(input_tensor, kernel_size, filters, stage, block,model_name='resnet50_v20',dilate=(1,1)):
    """The identity block is the block that has no conv layer at shortcut.

    # Arguments
        input_tensor: input tensor
        kernel_size: default 3, the kernel size of middle conv layer at main path
        filters: list of integers, the filters of 3 conv layer at main path
        stage: integer, current stage label, used for generating layer names
        block: 'a','b'..., current block label, used for generating layer names

    # Returns
        Output tensor for the block.
    """
    filters1, filters2, filters3 = filters
    conv_name_base = model_name+'_stage' + str(stage) 
    bn_name_base = conv_name_base
    
    x = Norm(input_tensor,fix_gamma=False,use_global_stats=False,eps=1e-5,name=bn_name_base + '_batchnorm'+str(block+0))
    x = Activation(x,name=bn_name_base + '_activation'+str(block+0),act_type='relu')
    x = Convolution(x,kernel=(1,1),num_filter=filters1,no_bias=True,name=conv_name_base + '_conv'+str(block+1))
    
    x = Norm(x,fix_gamma=False,use_global_stats=False,eps=1e-5,name=bn_name_base + '_batchnorm'+str(block+1))
    x = Activation(x,name=bn_name_base + '_activation'+str(block+1),act_type='relu')
    x = Convolution(x,kernel=(3,3),pad=dilate,no_bias=True,num_filter=filters2,name=conv_name_base + '_conv'+str(block+2),dilate=dilate)
    
    x = Norm(x,fix_gamma=False,use_global_stats=False,eps=1e-5,name=bn_name_base + '_batchnorm'+str(block+2))
    x = Activation(x,name=bn_name_base + '_activation'+str(block+2),act_type='relu')
    x = Convolution(x,kernel=(1,1),num_filter=filters3,no_bias=True,name=conv_name_base + '_conv'+str(block+3))
    
    x = x + input_tensor
    return x
Ejemplo n.º 2
0
def residual_unit(data, num_filter, stride, dilate, dim_match, name):
    s = stride
    d = dilate

    bn1 = BN(data=data,
             fix_gamma=fix_gamma,
             use_global_stats=use_global_stats,
             eps=eps,
             momentum=bn_mom,
             name=name + '_bn1')
    act1 = Relu(data=bn1, act_type='relu', name=name + '_relu1')
    conv1 = Conv(data=act1,
                 num_filter=int(num_filter * 0.25),
                 kernel=(1, 1),
                 no_bias=True,
                 name=name + '_conv1')

    bn2 = BN(data=conv1,
             fix_gamma=fix_gamma,
             use_global_stats=use_global_stats,
             eps=eps,
             momentum=bn_mom,
             name=name + '_bn2')
    act2 = Relu(data=bn2, act_type='relu', name=name + '_relu2')
    conv2 = Conv(data=act2,
                 num_filter=int(num_filter * 0.25),
                 kernel=(3, 3),
                 pad=(d, d),
                 stride=(s, s),
                 dilate=(d, d),
                 no_bias=True,
                 name=name + '_conv2')

    bn3 = BN(data=conv2,
             fix_gamma=fix_gamma,
             use_global_stats=use_global_stats,
             eps=eps,
             momentum=bn_mom,
             name=name + '_bn3')
    act3 = Relu(data=bn3, act_type='relu', name=name + '_relu3')
    conv3 = Conv(data=act3,
                 num_filter=num_filter,
                 kernel=(1, 1),
                 no_bias=True,
                 name=name + '_conv3')
    if dim_match:
        shortcut = data
    else:
        shortcut = Conv(data=act1,
                        num_filter=num_filter,
                        kernel=(1, 1),
                        stride=(s, s),
                        no_bias=True,
                        name=name + '_sc')

    shortcut._set_attr(mirror_stage='True')
    return conv3 + shortcut
def atrous_spatial_pyramid_pooling(feat, rate, aspp_with_separable_conv, oc_context=False):
    conv_1x1 = Conv(feat, num_filter=256, kernel=(1, 1), name="aspp_1x1")
    conv_1x1 = BN(conv_1x1, use_global_stats=use_global_stats, fix_gamma=fix_gamma,
                  momentum=bn_mom, name="aspp_1x1_bn", eps=eps, **args)
    conv_1x1 = Relu(conv_1x1, act_type='relu', name='aspp_1x1_relu')

    if aspp_with_separable_conv:
        conv_3x3_d6 = Sepconv(data=feat, in_channel=2048, num_filter=256, stride=1,
                              dilate=6 * rate, name="aspp_3x3_d6")
        conv_3x3_d6 = BN(conv_3x3_d6, use_global_stats=use_global_stats, fix_gamma=fix_gamma,
                         momentum=bn_mom, name="aspp_3x3_d6_bn", eps=eps, **args)
        conv_3x3_d6 = Relu(conv_3x3_d6, act_type='relu', name='aspp_3x3_d6_relu')
        conv_3x3_d12 = Sepconv(data=feat, in_channel=2048, num_filter=256, stride=1,
                               dilate=12 * rate, name="aspp_3x3_d12")
        conv_3x3_d12 = BN(conv_3x3_d12, use_global_stats=use_global_stats, fix_gamma=fix_gamma,
                          momentum=bn_mom, name="aspp_3x3_d12_bn", eps=eps, **args)
        conv_3x3_d12 = Relu(conv_3x3_d12, act_type='relu', name='aspp_3x3_d12_relu')
        conv_3x3_d18 = Sepconv(data=feat, in_channel=2048, num_filter=256, stride=1,
                               dilate=18 * rate, name="aspp_3x3_d18")
        conv_3x3_d18 = BN(conv_3x3_d18, use_global_stats=use_global_stats, fix_gamma=fix_gamma,
                          momentum=bn_mom, name="aspp_3x3_d18_bn", eps=eps, **args)
        conv_3x3_d18 = Relu(conv_3x3_d18, act_type='relu', name='aspp_3x3_d18_relu')
    else:
        conv_3x3_d6 = Conv(feat, num_filter=256, kernel=(3, 3), dilate=(6 * rate, 6 * rate),
                           pad=(6 * rate, 6 * rate), name="aspp_3x3_d6")
        conv_3x3_d6 = BN(conv_3x3_d6, use_global_stats=use_global_stats, fix_gamma=fix_gamma,
                         momentum=bn_mom, name="aspp_3x3_d6_bn", eps=eps)
        conv_3x3_d6 = Relu(conv_3x3_d6, act_type='relu', name='aspp_3x3_d6_relu')
        conv_3x3_d12 = Conv(feat, num_filter=256, kernel=(3, 3), dilate=(12 * rate, 12 * rate),
                            pad=(12 * rate, 12 * rate), name="aspp_3x3_d12")
        conv_3x3_d12 = BN(conv_3x3_d12, use_global_stats=use_global_stats, fix_gamma=fix_gamma,
                          momentum=bn_mom, name="aspp_3x3_d12_bn", eps=eps)
        conv_3x3_d12 = Relu(conv_3x3_d12, act_type='relu', name='aspp_3x3_d12_relu')
        conv_3x3_d18 = Conv(feat, num_filter=256, kernel=(3, 3), dilate=(18 * rate, 18 * rate),
                            pad=(18 * rate, 18 * rate), name="aspp_3x3_d18")
        conv_3x3_d18 = BN(conv_3x3_d18, use_global_stats=use_global_stats, fix_gamma=fix_gamma,
                          momentum=bn_mom, name="aspp_3x3_d18_bn", eps=eps)
        conv_3x3_d18 = Relu(conv_3x3_d18, act_type='relu', name='aspp_3x3_d18_relu')

    if oc_context:
        gap = oc_context_block(feat, 128, 256, 256, resample_rate=2)
    else:
        gap = Pool(feat, kernel=(1, 1), global_pool=True, pool_type="avg", name="aspp_gap")
    gap = Conv(gap, num_filter=256, kernel=(1, 1), name="aspp_gap_1x1")
    gap = BN(gap, use_global_stats=use_global_stats, fix_gamma=fix_gamma, momentum=bn_mom,
             name="aspp_gap_1x1_bn", eps=eps, **args)
    if not oc_context:
        gap = Relu(gap, act_type='relu', name='aspp_gap_1x1_relu')
        gap = broadcast_like(gap, conv_1x1, name="aspp_gap_broadcast")
    aspp = concat(conv_1x1, conv_3x3_d6, conv_3x3_d12, conv_3x3_d18, gap, dim=1, name="aspp_concat")
    aspp_1x1 = Conv(aspp, num_filter=256, kernel=(1, 1), name="aspp_concat_1x1")
    aspp_1x1 = BN(aspp_1x1, use_global_stats=use_global_stats, fix_gamma=fix_gamma, momentum=bn_mom,
                  name="aspp_concat_1x1_bn", eps=eps, **args)
    aspp_1x1._set_attr(mirror_stage='True')
    aspp_1x1 = Relu(aspp_1x1, act_type='relu', name='aspp_concat_1x1_relu')
    return aspp_1x1
Ejemplo n.º 4
0
def detai_net(inputs,classes=1000,use_att=False,model_name='resnetv20'):
    """Instantiates the ResNet50 architecture.
    """
    if use_att:
        x = inputs[0]
    else:
        x = inputs
        
    x = Norm(x,fix_gamma=True,use_global_stats=False,eps=1e-5,name=model_name+'_batchnorm0')
    x = Convolution(x,
        num_filter=64, kernel=(7, 7), stride=(2, 2), pad=(3,3), no_bias=True, name=model_name+'_conv0')
    x = Norm(x,fix_gamma=False,use_global_stats=False,eps=1e-5,name=model_name+'_batchnorm1')
    x = Activation(x,name=model_name+'_relu0',act_type='relu')
    x = Pooling(x,kernel=(3,3),stride=(2,2),pad=(1,1),pool_type='max',name=model_name+'_pool0')
     
    x = conv_block(x, 3, [64, 64, 256], stage=1, block=0, strides=(1, 1),model_name=model_name)
    x = identity_block(x, 3, [64, 64, 256], stage=1, block=3,model_name=model_name)
    x = identity_block(x, 3, [64, 64, 256], stage=1, block=6,model_name=model_name)
    
    x = conv_block(x, 3, [128, 128, 512], stage=2, block=0,model_name=model_name)
    x = identity_block(x, 3, [128, 128, 512], stage=2, block=3,model_name=model_name)
    x = identity_block(x, 3, [128, 128, 512], stage=2, block=6,model_name=model_name)
    x = identity_block(x, 3, [128, 128, 512], stage=2, block=9,model_name=model_name)
    
    x = conv_block(x, 3, [256, 256, 1024], stage=3, block=0,model_name=model_name)
    x = identity_block(x, 3, [256, 256, 1024], stage=3, block=3,model_name=model_name)
    x = identity_block(x, 3, [256, 256, 1024], stage=3, block=6,model_name=model_name)
    x = identity_block(x, 3, [256, 256, 1024], stage=3, block=9,model_name=model_name)
    x = identity_block(x, 3, [256, 256, 1024], stage=3, block=12,model_name=model_name)
    x = identity_block(x, 3, [256, 256, 1024], stage=3, block=15,model_name=model_name)
    
    x = conv_block(x, 3, [512, 512, 2048], stage=4, block=0,model_name=model_name)
    x = identity_block(x, 3, [512, 512, 2048], stage=4, block=3,model_name=model_name)
    x = identity_block(x, 3, [512, 512, 2048], stage=4, block=6,model_name=model_name)  
    
    x = Norm(x,fix_gamma=False,use_global_stats=False,eps=1e-5,name=model_name+'_batchnorm2')
    x = Activation(x,name=model_name+'_relu1',act_type='relu')
    
    x = Pooling(x,kernel=(3,3),stride=(2, 2),pool_type='avg',global_pool=True,name=model_name+'_pool1')
    x = Flatten(x,name=model_name+'_flatten0')
    
    if use_att:
        x = mx.sym.Concat(x,inputs[1],dim=-1)
        weight = mx.sym.Variable(model_name+'_concat_dense_weight',shape=(classes,4096)) 
        x = FullyConnected(x,num_hidden=classes,weight=weight,no_bias=True,name=model_name+'_dense1')
    else:
        x = FullyConnected(x,num_hidden=classes,name=model_name+'_dense1')
    return x



    
Ejemplo n.º 5
0
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2),model_name='resnet50_v20',dilate=(1,1)):
    """A block that has a conv layer at shortcut.

    # Arguments
        input_tensor: input tensor
        kernel_size: default 3, the kernel size of middle conv layer at main path
        filters: list of integers, the filters of 3 conv layer at main path
        stage: integer, current stage label, used for generating layer names
        block: 'a','b'..., current block label, used for generating layer names

    # Returns
        Output tensor for the block.

    Note that from stage 3, the first conv layer at main path is with strides=(2,2)
    And the shortcut should have strides=(2,2) as well
    """
    filters1, filters2, filters3 = filters
  
    conv_name_base = model_name+'_stage' + str(stage) 
    bn_name_base = conv_name_base
    
    x = Norm(input_tensor,fix_gamma=False,use_global_stats=False,eps=1e-5,name=bn_name_base + '_batchnorm'+str(block+0))
    x_sc = Activation(x,name=bn_name_base + '_activation'+str(block+0),act_type='relu')
    x = Convolution(x_sc,kernel=(1,1),num_filter=filters1,no_bias=True,name=conv_name_base + '_conv'+str(block+0))
    
    x = Norm(x,fix_gamma=False,use_global_stats=False,eps=1e-5,name=bn_name_base + '_batchnorm'+str(block+1))
    x = Activation(x,name=bn_name_base + '_activation'+str(block+1),act_type='relu')
    x = Convolution(x,kernel=(kernel_size,kernel_size),stride=strides,pad=dilate,num_filter=filters2,
                    no_bias=True, dilate=dilate,name=conv_name_base + '_conv'+str(block+1))
    
    x = Norm(x,fix_gamma=False,use_global_stats=False,eps=1e-5,name=bn_name_base + '_batchnorm'+str(block+2))
    x = Activation(x,name=bn_name_base + '_activation'+str(block+2),act_type='relu')
    x = Convolution(x,kernel=(1,1),num_filter=filters3,no_bias=True,name=conv_name_base + '_conv'+str(block+2))
    
    shortcut = Convolution(x_sc,kernel=(1,1),stride=strides,num_filter=filters3,
                           no_bias=True,name=conv_name_base + '_conv'+str(block+3))

    x = x + shortcut
    return x
Ejemplo n.º 6
0
def ResNet18_V2(inputs,classes=1000,model_name='resnetv20'):
    """Instantiates the ResNet18 architecture.
    """
    x = inputs
    x = Norm(x,fix_gamma=True,use_global_stats=False,eps=1e-5,name=model_name+'_batchnorm0')
    x = Convolution(x,
        num_filter=64, kernel=(7, 7), stride=(2, 2), pad=(3,3), no_bias=True, name=model_name+'_conv0')
    x = Norm(x,fix_gamma=False,use_global_stats=False,eps=1e-5,name=model_name+'_batchnorm1')
    x = Activation(x,name=model_name+'_relu0',act_type='relu')
    x = Pooling(x,kernel=(3,3),stride=(2,2),pad=(1,1),pool_type='max',name=model_name+'_pool0')
     
    x = conv_block(x, 3, [64, 64], stage=1, block=0, strides=(1, 1),short_connect=False,model_name=model_name)
    x = identity_block(x, 3, [64, 64], stage=1, block=2,model_name=model_name)

    x = conv_block(x, 3, [128, 128], stage=2, block=0,model_name=model_name)
    x = identity_block(x, 3, [128, 128], stage=2, block=3,model_name=model_name)

    x = conv_block(x, 3, [256, 256], stage=3, block=0,model_name=model_name)
    x = identity_block(x, 3, [256, 256], stage=3, block=3,model_name=model_name)

    x = conv_block(x, 3, [512, 512], stage=4, block=0, model_name=model_name)
    x = identity_block(x, 3, [512, 512], stage=4, block=3, model_name=model_name)
    

    x = Norm(x,fix_gamma=False,use_global_stats=False,eps=1e-5,name=model_name+'_batchnorm2')
    x = Activation(x,name=model_name+'_relu1',act_type='relu')
    
    x = Pooling(x,kernel=(3,3),stride=(2, 2),pool_type='avg',global_pool=True,name=model_name+'_pool1')
    x = Flatten(x,name=model_name+'_flatten0')
    
    xl2 = mx.sym.norm(x,axis=1,keepdims=True)
    xl2 = mx.sym.clip(xl2,0,8)
    x1 = mx.sym.L2Normalization(x)
    x1 = mx.sym.broadcast_mul(x1,xl2)
    
    weight = mx.sym.Variable(model_name+'_dense1_weight')
    bias = mx.sym.Variable(model_name+'_dense1_bias')
    
    output = FullyConnected(x,num_hidden=classes,weight=weight,bias=bias,name=model_name+'_dense1')
    
#    weight1 =  mx.sym.L2Normalization(weight)
    output1 = FullyConnected(x1,num_hidden=classes,weight=weight,bias=bias,name=model_name+'_dense2')

    return output,output1


#


    
Ejemplo n.º 7
0
def atten_net(inputs, classes=200, model_name='resnet50_v10'):
    """Instantiates the ResNet50 architecture.
    """
    x = inputs

    x = Convolution(x,
                    num_filter=64,
                    kernel=(7, 7),
                    stride=(2, 2),
                    pad=(3, 3),
                    no_bias=True,
                    name=model_name + '_conv0')
    x = Norm(x,
             fix_gamma=False,
             use_global_stats=False,
             eps=1e-5,
             name=model_name + '_batchnorm0')
    x = Activation(x, name=model_name + '_relu0', act_type='relu')
    x = Pooling(x,
                kernel=(3, 3),
                stride=(2, 2),
                pad=(1, 1),
                pool_type='max',
                name=model_name + '_pool0')

    x = conv_block(x,
                   3, [64, 64, 256],
                   stage=1,
                   block=0,
                   strides=(1, 1),
                   model_name=model_name)
    x = identity_block(x,
                       3, [64, 64, 256],
                       stage=1,
                       block=4,
                       model_name=model_name)
    x = identity_block(x,
                       3, [64, 64, 256],
                       stage=1,
                       block=7,
                       model_name=model_name)

    x = conv_block(x,
                   3, [128, 128, 512],
                   stage=2,
                   block=0,
                   model_name=model_name)
    x = identity_block(x,
                       3, [128, 128, 512],
                       stage=2,
                       block=4,
                       model_name=model_name)
    x = identity_block(x,
                       3, [128, 128, 512],
                       stage=2,
                       block=7,
                       model_name=model_name)
    x = identity_block(x,
                       3, [128, 128, 512],
                       stage=2,
                       block=10,
                       model_name=model_name)

    x = conv_block(x,
                   3, [256, 256, 1024],
                   stage=3,
                   block=0,
                   strides=(1, 1),
                   dilate=(2, 2),
                   model_name=model_name)
    x = identity_block(x,
                       3, [256, 256, 1024],
                       stage=3,
                       block=4,
                       dilate=(2, 2),
                       model_name=model_name)
    x = identity_block(x,
                       3, [256, 256, 1024],
                       stage=3,
                       block=7,
                       dilate=(2, 2),
                       model_name=model_name)
    x = identity_block(x,
                       3, [256, 256, 1024],
                       stage=3,
                       block=10,
                       dilate=(2, 2),
                       model_name=model_name)
    x = identity_block(x,
                       3, [256, 256, 1024],
                       stage=3,
                       block=13,
                       dilate=(2, 2),
                       model_name=model_name)
    x = identity_block(x,
                       3, [256, 256, 1024],
                       stage=3,
                       block=16,
                       dilate=(2, 2),
                       model_name=model_name)

    x = conv_block(x,
                   3, [512, 512, 2048],
                   stage=4,
                   block=0,
                   strides=(1, 1),
                   dilate=(2, 2),
                   model_name=model_name)
    x = identity_block(x,
                       3, [512, 512, 2048],
                       stage=4,
                       block=4,
                       dilate=(1, 1),
                       model_name=model_name)
    x = identity_block(x,
                       3, [512, 512, 2048],
                       stage=4,
                       block=7,
                       dilate=(1, 1),
                       model_name=model_name)

    att_fea = x
    x = Pooling(x,
                kernel=(3, 3),
                stride=(2, 2),
                pool_type='avg',
                global_pool=True,
                name=model_name + '_pool1')
    x = flatten(x)

    weight = mx.sym.Variable(model_name + '_dense1_weight')
    x = FullyConnected(x,
                       num_hidden=classes,
                       weight=weight,
                       name=model_name + '_dense1')

    #    x = SoftmaxOutput(x, name = 'softmax')

    return x, att_fea, weight
def Conv_BN_ACT(data, num_filter, kernel, pad=(0, 0), stride=(1,1), name=None, no_bias=True, num_group=1, act_type='relu'):
    conv = Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, num_group=num_group, no_bias=no_bias, name=('%s_conv' % name))
    b_act = BN_ACT(data=conv, act_type=act_type, name=name)
    return b_act
def Conv_BN(data, num_filter, kernel, pad=(0, 0), stride=(1,1), name=None, no_bias=True, num_group=1, zero_init_gamma=False):
    conv = Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, num_group=num_group, no_bias=no_bias, name=('%s_conv' % name))
    bn = BatchNorm(data=conv, zero_init_gamma=zero_init_gamma, name=('%s_bn' % name))
    return bn
Ejemplo n.º 10
0
def ResNet50_V2(inputs, classes=1000, batch_size=144, model_name='resnetv20'):
    """Instantiates the ResNet50 architecture.
    """

    x = inputs[0]
    lam = inputs[1]

    x = Norm(x,
             fix_gamma=True,
             use_global_stats=False,
             eps=1e-5,
             name=model_name + '_batchnorm0')
    x = Convolution(x,
                    num_filter=64,
                    kernel=(7, 7),
                    stride=(2, 2),
                    pad=(3, 3),
                    no_bias=True,
                    name=model_name + '_conv0')
    x = Norm(x,
             fix_gamma=False,
             use_global_stats=False,
             eps=1e-5,
             name=model_name + '_batchnorm1')
    x = Activation(x, name=model_name + '_relu0', act_type='relu')
    x = Pooling(x,
                kernel=(3, 3),
                stride=(2, 2),
                pad=(1, 1),
                pool_type='max',
                name=model_name + '_pool0')

    x = conv_block(x,
                   3, [64, 64, 256],
                   stage=1,
                   block=0,
                   strides=(1, 1),
                   model_name=model_name)
    x = identity_block(x,
                       3, [64, 64, 256],
                       stage=1,
                       block=3,
                       model_name=model_name)
    x = identity_block(x,
                       3, [64, 64, 256],
                       stage=1,
                       block=6,
                       model_name=model_name)

    x_f = mx.sym.slice(x,
                       begin=(None, None, None, None),
                       end=(None, None, None, None),
                       step=(-1, 1, 1, 1))
    lam1 = mx.sym.slice(lam, begin=(0), end=(1))
    lam2 = mx.sym.slice(lam, begin=(1), end=(2))
    x_mix = mx.sym.broadcast_mul(x, lam1) + broadcast_mul(x_f, lam2)
    x = mx.sym.concat(x, x_mix, dim=0)

    x = conv_block(x,
                   3, [128, 128, 512],
                   stage=2,
                   block=0,
                   model_name=model_name)
    x = identity_block(x,
                       3, [128, 128, 512],
                       stage=2,
                       block=3,
                       model_name=model_name)
    x = identity_block(x,
                       3, [128, 128, 512],
                       stage=2,
                       block=6,
                       model_name=model_name)
    x = identity_block(x,
                       3, [128, 128, 512],
                       stage=2,
                       block=9,
                       model_name=model_name)

    x = conv_block(x,
                   3, [256, 256, 1024],
                   stage=3,
                   block=0,
                   model_name=model_name)
    x = identity_block(x,
                       3, [256, 256, 1024],
                       stage=3,
                       block=3,
                       model_name=model_name)
    x = identity_block(x,
                       3, [256, 256, 1024],
                       stage=3,
                       block=6,
                       model_name=model_name)
    x = identity_block(x,
                       3, [256, 256, 1024],
                       stage=3,
                       block=9,
                       model_name=model_name)
    x = identity_block(x,
                       3, [256, 256, 1024],
                       stage=3,
                       block=12,
                       model_name=model_name)
    x = identity_block(x,
                       3, [256, 256, 1024],
                       stage=3,
                       block=15,
                       model_name=model_name)

    x = conv_block(x,
                   3, [512, 512, 2048],
                   stage=4,
                   block=0,
                   model_name=model_name)
    x = identity_block(x,
                       3, [512, 512, 2048],
                       stage=4,
                       block=3,
                       model_name=model_name)
    x = identity_block(x,
                       3, [512, 512, 2048],
                       stage=4,
                       block=6,
                       model_name=model_name)

    x = Norm(x,
             fix_gamma=False,
             use_global_stats=False,
             eps=1e-5,
             name=model_name + '_batchnorm2')
    x = Activation(x, name=model_name + '_relu1', act_type='relu')

    x = Pooling(x,
                kernel=(3, 3),
                stride=(2, 2),
                pool_type='avg',
                global_pool=True,
                name=model_name + '_pool1')
    x = Flatten(x, name=model_name + '_flatten0')

    #    x = mx.sym.L2Normalization(x)

    weight = mx.sym.Variable(model_name + '_dense_weight')
    #    weight = mx.sym.L2Normalization(weight)
    output = FullyConnected(x,
                            num_hidden=classes,
                            weight=weight,
                            no_bias=True,
                            name=model_name + '_dense1')
    #

    output1 = mx.sym.slice_axis(output, axis=0, begin=0, end=batch_size)
    mix_output = mx.sym.slice_axis(output,
                                   axis=0,
                                   begin=batch_size,
                                   end=2 * batch_size)

    output1_f = mx.sym.slice(output1,
                             begin=(None, None),
                             end=(None, None),
                             step=(-1, 1))
    output_mix = mx.sym.broadcast_mul(output1, lam1) + broadcast_mul(
        output1_f, lam2)

    #    x = mx.sym.L2Normalization(x)
    #    weight = mx.sym.Variable(model_name+'_dense1_weight',shape=(classes,2048))
    #    weight = mx.sym.L2Normalization(weight)

    #    alpha_w = mx.sym.Variable(model_name+'_alpha',shape=(1,1),init=mx.init.Constant(3))

    #    alpha_w = mx.sym.clip(alpha_w,2,6)

    #    weight = mx.sym.broadcast_mul(weight, alpha_w)
    #    x = mx.sym.broadcast_mul(x, alpha_w+1)

    #    weight = weight * alpha
    #    x = x * (alpha+2)

    #    weight = mx.sym.expand_dims(weight,0)
    #    x = mx.sym.expand_dims(x,1)
    #    x = mx.sym.tile(x,(1,classes,1))
    #    dis = mx.sym.broadcast_minus(x,weight)
    #    x = -mx.sym.sum(dis*dis,axis=2)
    #    x = x / alpha

    return output1, output_mix, mix_output


#inputs = mx.sym.Variable('data')
#lam = mx.sym.Variable('lam')
#outputs = ResNet50_V2([inputs,lam],200,batch_size=144)
#net = mx.gluon.SymbolBlock(outputs,[inputs,lam])
##import gluoncv
##gluoncv.utils.viz.plot_network(net,shape=)
#
#a = mx.viz.plot_network(outputs[1], shape={'data':(144,3,224,224),'lam':(2,)},save_format='png')
#a.view('12')