Example #1
0
def MnasNet_block_first(network,last_name='',block_name='conv2_',block_n=3,num_out1=16,num_out2=16,downsampling=False,down_method='pooling',use_global_stats='False'):

    input_name=last_name

        
    for i in range(1,block_n+1,1):
        if i==1 and downsampling==True:
            first_stride=2
        else:
            first_stride=1  
            
        #ConvolutionDepthwise
        network,last_name=write_prototxt.ConvolutionDepthwise(network,name=block_name+'/dw',bottom_name=last_name,top_name=block_name+'/dw',num_output=num_out1,
                                                             bias_term=False,pad=1,kernel_size=3,stride=first_stride,weight_type='msra',bias_type='constant')
        network,last_name=write_prototxt.BatchNorm(network,name_bn=block_name+'/dw/bn',name_scale=block_name+"/dw/scale",
                          bottom_name=block_name+'/dw',top_name=block_name+'/dw',use_global_stats=use_global_stats)
        network,last_name=write_prototxt.ReLU(network,name=block_name+"/dw/ReLU",bottom_name=block_name+'/dw',top_name=block_name+'/dw')



        network,last_name=write_prototxt.Convolution(network,name=block_name+'/sep',bottom_name=last_name,top_name=block_name+'/sep',num_output=num_out2,
                            bias_term=False,pad=0,kernel_size=1,stride=1,weight_type='msra',bias_type='constant')
        network,last_name=write_prototxt.BatchNorm(network,name_bn=block_name+"/sep/bn",name_scale=block_name+"/sep/scale",
                          bottom_name=block_name+'/sep',top_name=block_name+'/sep',use_global_stats=use_global_stats)
        network,last_name=write_prototxt.ReLU(network,name=block_name+"/sep/ReLU",bottom_name=block_name+'/sep',top_name=block_name+'/sep')





        return network,last_name
def ResBlock(network,last_name='',block_name='conv2_',block_n=3,num_out=16,downsampling=False,down_method='pooling',use_global_stats='False'):
    
    input_name=last_name

        
    for i in range(1,block_n+1,1):
        if i==1 and downsampling==True:
            first_stride=2
        else:
            first_stride=1
        
        
        
        
        
        network,last_name=write_prototxt.Convolution(network,name=block_name+str(i)+'_0',bottom_name=last_name,top_name=block_name+str(i)+'_0',num_output=num_out,
                            bias_term=True,pad=1,kernel_size=3,stride=first_stride,weight_type='msra',bias_type='constant')
        network,last_name=write_prototxt.BatchNorm(network,name_bn=block_name+str(i)+"_bn"+'0',name_scale=block_name+str(i)+"_scale"+'0',
                          bottom_name=block_name+str(i)+'_0',top_name=block_name+str(i)+'_0',use_global_stats=use_global_stats)
        network,last_name=write_prototxt.ReLU(network,name=block_name+str(i)+"_ReLU"+'0',bottom_name=block_name+str(i)+'_0',top_name=block_name+str(i)+'_0')
    
        network,last_name=write_prototxt.Convolution(network,name=block_name+str(i)+'_1',bottom_name=last_name,top_name=block_name+str(i)+'_1',num_output=num_out,
                            bias_term=True,pad=1,kernel_size=3,stride=1,weight_type='msra',bias_type='constant')
        network,last_name=write_prototxt.BatchNorm(network,name_bn=block_name+str(i)+"bn"+'1',name_scale=block_name+str(i)+"_scale"+'1',
                          bottom_name=block_name+str(i)+'_1',top_name=block_name+str(i)+'_1',use_global_stats=use_global_stats)
        #network,last_name=ReLU(network,name=block_name+str(i)+"ReLU"+'_1',bottom_name=block_name+str(i)+'_1',top_name=block_name+str(i)+'_1')
    
    
    
        if i==1 and downsampling==True:
            
            input_name_temp=last_name
            
            if down_method=='conv':   
                network,last_name=write_prototxt.Convolution(network,name=block_name+str(i)+'_down',bottom_name=input_name,top_name=block_name+str(i)+'_down',num_output=num_out,
                                bias_term=True,pad=0,kernel_size=1,stride=first_stride,weight_type='msra',bias_type='constant')
                network,last_name=write_prototxt.BatchNorm(network,name_bn=block_name+str(i)+"_bn"+'_down',name_scale=block_name+str(i)+"_scale"+'_down',
                                  bottom_name=block_name+str(i)+'_down',top_name=block_name+str(i)+'_down',use_global_stats=use_global_stats)
                #network,last_name=ReLU(network,name=block_name+str(i)+"_ReLU"+'_down',bottom_name=block_name+str(i)+'_down',top_name=block_name+str(i)+'_down')
                
            else:
                network,last_name=write_prototxt.Pooling(network,name=block_name+str(i)+"_Pooling"+'_down',bottom_name=input_name,top_name=block_name+str(i)+"_Pooling"+'_down',pool='MAX',kernel_size=2,stride=2)
            
            input_name=input_name_temp
        
        network,last_name=write_prototxt.Eltwise(network,name=block_name+"Eltwise_"+str(i),bottom_name1=input_name,
                                  bottom_name2=last_name,top_name=block_name+"Eltwise_"+str(i),operation='SUM')
        network,last_name=write_prototxt.ReLU(network,name=block_name+str(i)+"ReLU"+'_1',bottom_name=last_name,top_name=last_name)
        
        input_name=last_name
    return network,last_name
Example #3
0
def MnasNet_block_bd11(network,last_name='',block_name='conv2_',block_n=3,num_out0=16,num_out1=16,num_out2=16,downsampling=False,down_method='pooling',use_global_stats='False'):
    
    input_name=last_name








        
    for i in range(1,block_n+1,1):
        if i==1 and downsampling==True:
            first_stride=2
        else:
            first_stride=1
        
        
        
        
        network,last_name=write_prototxt.Convolution(network,name=block_name+'0'+'/expand',bottom_name=last_name,top_name=block_name+'0'+'/expand',num_output=num_out1,
                            bias_term=False,pad=0,kernel_size=1,stride=1,weight_type='msra',bias_type='constant')
        network,last_name=write_prototxt.BatchNorm(network,name_bn=block_name+'0'+"/bn/expand",name_scale=block_name+'0'+"/scale/expand",
                          bottom_name=block_name+'0'+'/expand',top_name=block_name+'0'+'/expand',use_global_stats=use_global_stats)
        network,last_name=write_prototxt.ReLU(network,name=block_name+str(i)+'0'+"ReLU/dw",bottom_name=block_name+'0'+'/expand',top_name=block_name+'0'+'/expand')
        
        
        #ConvolutionDepthwise
        
        network,last_name=write_prototxt.ConvolutionDepthwise(network,name=block_name+'0'+'/dw',bottom_name=last_name,top_name=block_name+'0'+'/dw',num_output=num_out1,
                                                             bias_term=False,pad=1,kernel_size=3,stride=first_stride,weight_type='msra',bias_type='constant')
        network,last_name=write_prototxt.BatchNorm(network,name_bn=block_name+'0'+'/bn/dw',name_scale=block_name+'0'+"/scale/dw",
                          bottom_name=block_name+'0'+'/dw',top_name=block_name+'0'+'/dw',use_global_stats=use_global_stats)
        network,last_name=write_prototxt.ReLU(network,name=block_name+str(i)+'0'+"ReLU/dw",bottom_name=block_name+'0'+'/dw',top_name=block_name+'0'+'/dw')
    
    
        network,last_name=write_prototxt.Convolution(network,name=block_name+'0'+'/linear',bottom_name=last_name,top_name=block_name+'0'+'/linear',num_output=num_out2,
                            bias_term=False,pad=0,kernel_size=1,stride=1,weight_type='msra',bias_type='constant')
        network,last_name=write_prototxt.BatchNorm(network,name_bn=block_name+'0'+"/bn/linear",name_scale=block_name+'0'+"/scale/linear",
                          bottom_name=block_name+'0'+'/linear',top_name=block_name+'0'+'/linear',use_global_stats=use_global_stats)
        #network,last_name=write_prototxt.ReLU(network,name=block_name+str(i)+'0'+"ReLU/dw",bottom_name=block_name+'0'+'/sep',top_name=block_name+'0'+'/sep')
    
        if num_out0==num_out2:
            network,last_name = write_prototxt.Eltwise(network,name="block_"+block_name+'0',bottom_name1=input_name,bottom_name2=last_name,top_name="block_"+block_name,operation='')
 
    return network,last_name
Example #4
0
def MnasNet_block_bdce(network,last_name='',block_name='conv2_',block_n=3,kernel_size=3,num_out1=16,num_out2=16,downsampling=False,down_method='pooling',use_global_stats='False'):

    input_name=last_name

        
    for i in range(1,block_n+1,1):
        if i==1 and downsampling==True:
            first_stride=2
        else:
            first_stride=1  



        network,last_name=write_prototxt.Convolution(network,name=block_name+'/sep1',bottom_name=last_name,top_name=block_name+'/sep1',num_output=num_out1,
                            bias_term=False,pad=0,kernel_size=1,stride=1,weight_type='msra',bias_type='constant')
        network,last_name=write_prototxt.BatchNorm(network,name_bn=block_name+"/sep1/bn",name_scale=block_name+"/sep1/scale",
                          bottom_name=block_name+'/sep1',top_name=block_name+'/sep1',use_global_stats=use_global_stats)
        network,last_name=write_prototxt.ReLU(network,name=block_name+"/sep1/ReLU",bottom_name=block_name+'/sep1',top_name=block_name+'/sep1')

            
        #ConvolutionDepthwise
        network,last_name=write_prototxt.ConvolutionDepthwise(network,name=block_name+'/dw',bottom_name=last_name,top_name=block_name+'/dw',num_output=num_out1,
                                                             bias_term=False,pad=int(kernel_size/2),kernel_size=kernel_size,stride=first_stride,weight_type='msra',bias_type='constant')
        network,last_name=write_prototxt.BatchNorm(network,name_bn=block_name+'/dw/bn',name_scale=block_name+"/dw/scale",
                          bottom_name=block_name+'/dw',top_name=block_name+'/dw',use_global_stats=use_global_stats)
        network,last_name=write_prototxt.ReLU(network,name=block_name+"/dw/ReLU",bottom_name=block_name+'/dw',top_name=block_name+'/dw')



        network,last_name=write_prototxt.Convolution(network,name=block_name+'/sep2',bottom_name=last_name,top_name=block_name+'/sep2',num_output=num_out2,
                            bias_term=False,pad=0,kernel_size=1,stride=1,weight_type='msra',bias_type='constant')
        network,last_name=write_prototxt.BatchNorm(network,name_bn=block_name+"/sep2/bn",name_scale=block_name+"/sep2/scale",
                          bottom_name=block_name+'/sep2',top_name=block_name+'/sep2',use_global_stats=use_global_stats)
        if block_name.split("_")[-1]=="1":
            network,last_name=write_prototxt.ReLU(network,name=block_name+"/sep2/ReLU",bottom_name=block_name+'/sep2',top_name=block_name+'/sep2')


        if block_name.split("_")[-1]!="1":
            network,last_name=write_prototxt.Eltwise(network,name=block_name+"/Eltwise"+str(i),bottom_name1=input_name,
                                      bottom_name2=last_name,top_name=block_name+"/Eltwise"+str(i),operation='SUM')
            network,last_name=write_prototxt.ReLU(network,name=block_name+"/Eltwise/ReLU",bottom_name=last_name,top_name=last_name)
     
    
        return network,last_name
def ResNet_cifar( mode='train',root_path='',batch_size=32,block_n=3):


    network='name:"ResNet-20"'+'\n'
    
    if mode=='train':
        network,last_name=write_prototxt.data(network,name="Data1",mirror=True,scale=0.00390625,crop_size=32,batch_size=batch_size,backend="LMDB",shuffle=True,datasets_path=root_path)
        use_global_stats=False
    elif mode=='test':
        network,last_name=write_prototxt.data(network,name="Data1",mirror=False,scale=0.00390625,crop_size=32,batch_size=batch_size,backend="LMDB",shuffle=False,datasets_path=root_path)
        use_global_stats=True
    
    
    network,last_name=write_prototxt.Convolution(network,name="conv1",bottom_name=last_name,top_name='conv1',num_output=16,
                        bias_term=True,pad=1,kernel_size=3,stride=1,weight_type='msra',bias_type='constant')
    network,last_name=write_prototxt.BatchNorm(network,name_bn="conv1/bn",name_scale="conv1/scale",
                          bottom_name="conv1",top_name="conv1",use_global_stats=use_global_stats)
    network,last_name=write_prototxt.ReLU(network,name="conv1/ReLU",bottom_name='conv1',top_name='conv1')
        
    
    
    #network=Pooling(network,name="Pooling1",bottom_name='conv1',top_name='Pooling1',pool='MAX')
    
    network,last_name=ResBlock(network,last_name=last_name,block_name='conv2_',block_n=block_n,num_out=16,downsampling=False,use_global_stats=use_global_stats)
    network,last_name=ResBlock(network,last_name=last_name,block_name='conv3_',block_n=block_n,num_out=32,downsampling=True,down_method='conv',use_global_stats=use_global_stats)
    network,last_name=ResBlock(network,last_name=last_name,block_name='conv4_',block_n=block_n,num_out=64,downsampling=True,down_method='conv',use_global_stats=use_global_stats)
    #network,last_name=ResBlock(network,last_name=last_name,block_name='conv5_',block_n=block_n,num_out=64,downsampling=True,down_method='conv',use_global_stats=use_global_stats)
    
    network,last_name=write_prototxt.Pooling(network,name="Pooling1",bottom_name=last_name,top_name='Pooling1',pool='AVE',global_pooling=True)
   
    
  
    network,last_name=write_prototxt.InnerProduct(network,name="fc1",bottom_name=last_name,top_name='fc1',num_output=10,weight_type='msra',bias_type='constant')
    #if mode=='train':
    network,last_name=write_prototxt.SoftmaxWithLoss(network,name="Softmax1",bottom_name1='fc1',bottom_name2='label',top_name='Softmax1')
    if mode=='test':
        network,last_name=write_prototxt.Accuracy(network,name="prob",bottom_name1='fc1',bottom_name2='label',top_name='prob')
#    
#   
    print network
    
    
    return network
Example #6
0
def Net( mode='train',root_path='',batch_size=32):


    network='name:"MnasNet"'+'\n'
    
    if mode=='train':
        network,last_name=write_prototxt.data(network,name="data",mirror=True,scale=0.00390625,crop_size=224,batch_size=batch_size,backend="LMDB",shuffle=True,datasets_path=root_path)
        use_global_stats=False
    elif mode=='test':
        network,last_name=write_prototxt.data(network,name="data",mirror=False,scale=0.00390625,crop_size=224,batch_size=batch_size,backend="LMDB",shuffle=False,datasets_path=root_path)
        use_global_stats=True
    elif mode=='deploy':
        network='input: "data"'+'\n'
        network=network+'input_dim: 1'+'\n'
        network=network+'input_dim: 3'+'\n'
        network=network+'input_dim: 224'+'\n'
        network=network+'input_dim: 224'+'\n'
        
        last_name="data"
        use_global_stats=True
    
    network,last_name=write_prototxt.Convolution(network,name="conv1",bottom_name=last_name,top_name='conv1',num_output=32,
                        bias_term=False,pad=1,kernel_size=3,stride=2,weight_type='msra',bias_type='constant')
    network,last_name=write_prototxt.BatchNorm(network,name_bn="conv1/bn",name_scale="conv1/scale",
                          bottom_name="conv1",top_name="conv1",use_global_stats=use_global_stats)
    network,last_name=write_prototxt.ReLU(network,name="conv1/ReLU",bottom_name='conv1',top_name='conv1')
         

    network,last_name=MnasNet_block_first(network,last_name=last_name,block_name='conv2_1',
                                  block_n=1,num_out1=32,num_out2=16,downsampling=False,use_global_stats=use_global_stats)
    
    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv3_1',
                                  block_n=1,kernel_size=3,num_out1=48,num_out2=24,downsampling=True,use_global_stats=use_global_stats)
    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv3_2',
                                  block_n=1,kernel_size=3,num_out1=72,num_out2=24,downsampling=False,use_global_stats=use_global_stats)
    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv3_3',
                                  block_n=1,kernel_size=3,num_out1=72,num_out2=24,downsampling=False,use_global_stats=use_global_stats)


    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv4_1',
                                  block_n=1,kernel_size=5,num_out1=72,num_out2=40,downsampling=True,use_global_stats=use_global_stats)
    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv4_2',
                                  block_n=1,kernel_size=5,num_out1=120,num_out2=40,downsampling=False,use_global_stats=use_global_stats)
    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv4_3',
                                  block_n=1,kernel_size=5,num_out1=120,num_out2=40,downsampling=False,use_global_stats=use_global_stats)


    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv5_1',
                                  block_n=1,kernel_size=5,num_out1=240,num_out2=80,downsampling=True,use_global_stats=use_global_stats)
    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv5_2',
                                  block_n=1,kernel_size=5,num_out1=480,num_out2=80,downsampling=False,use_global_stats=use_global_stats)
    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv5_3',
                                  block_n=1,kernel_size=5,num_out1=480,num_out2=80,downsampling=False,use_global_stats=use_global_stats)


    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv6_1',
                                  block_n=1,kernel_size=3,num_out1=480,num_out2=96,downsampling=False,use_global_stats=use_global_stats)
    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv6_2',
                                  block_n=1,kernel_size=3,num_out1=576,num_out2=96,downsampling=False,use_global_stats=use_global_stats)




    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv7_1',
                                  block_n=1,kernel_size=5,num_out1=576,num_out2=192,downsampling=True,use_global_stats=use_global_stats)
    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv7_2',
                                  block_n=1,kernel_size=5,num_out1=1152,num_out2=192,downsampling=False,use_global_stats=use_global_stats)
    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv7_3',
                                  block_n=1,kernel_size=5,num_out1=1152,num_out2=192,downsampling=False,use_global_stats=use_global_stats)
    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv7_4',
                                  block_n=1,kernel_size=5,num_out1=1152,num_out2=192,downsampling=False,use_global_stats=use_global_stats)



    network,last_name=MnasNet_block_bdce(network,last_name=last_name,block_name='conv8_1',
                                  block_n=1,kernel_size=3,num_out1=1152,num_out2=320,downsampling=False,use_global_stats=use_global_stats)



    network,last_name=write_prototxt.Pooling(network,name="Pooling1",bottom_name=last_name,top_name='Pooling1',pool='AVE',global_pooling=True)


    network,last_name=write_prototxt.InnerProduct(network,name="fc1",bottom_name=last_name,top_name='fc1',num_output=2,weight_type='msra',bias_type='constant')
    #if mode=='train':
    network,last_name=write_prototxt.SoftmaxWithLoss(network,name="Softmax1",bottom_name1='fc1',bottom_name2='label',top_name='Softmax1')
    if mode=='test':
        network,last_name=write_prototxt.Accuracy(network,name="prob",bottom_name1='fc1',bottom_name2='label',top_name='prob')
    #if mode=='deploy':
        #network,last_name=write_prototxt.Softmax(network,name="prob",bottom_name='fc1',top_name='prob')

#    print network
    
    
    return network
def Net(mode='train', root_path='', batch_size=32):

    network = 'name:"MobileNetV1"' + '\n'

    if mode == 'train':
        network, last_name = write_prototxt.data(network,
                                                 name="Data1",
                                                 mirror=True,
                                                 crop_size=28,
                                                 batch_size=batch_size,
                                                 backend="LMDB",
                                                 shuffle=True,
                                                 datasets_path=root_path)
        use_global_stats = False
    elif mode == 'test':
        network, last_name = write_prototxt.data(network,
                                                 name="Data1",
                                                 mirror=False,
                                                 crop_size=28,
                                                 batch_size=batch_size,
                                                 backend="LMDB",
                                                 shuffle=False,
                                                 datasets_path=root_path)
        use_global_stats = True

    network, last_name = write_prototxt.Convolution(network,
                                                    name="conv1",
                                                    bottom_name=last_name,
                                                    top_name='conv1',
                                                    num_output=32,
                                                    bias_term=False,
                                                    pad=1,
                                                    kernel_size=3,
                                                    stride=1,
                                                    weight_type='msra',
                                                    bias_type='constant')
    network, last_name = write_prototxt.BatchNorm(
        network,
        name_bn="conv1_bn",
        name_scale="conv1__scale",
        bottom_name="conv1",
        top_name="conv1",
        use_global_stats=use_global_stats)
    network, last_name = write_prototxt.ReLU(network,
                                             name="conv1_ReLU",
                                             bottom_name='conv1',
                                             top_name='conv1')

    network, last_name = Mobile_Unit(network,
                                     last_name=last_name,
                                     block_name='conv2_',
                                     block_n=1,
                                     num_out1=32,
                                     num_out2=64,
                                     downsampling=False,
                                     use_global_stats=use_global_stats)

    network, last_name = Mobile_Unit(network,
                                     last_name=last_name,
                                     block_name='conv3_',
                                     block_n=1,
                                     num_out1=64,
                                     num_out2=128,
                                     downsampling=True,
                                     use_global_stats=use_global_stats)
    network, last_name = Mobile_Unit(network,
                                     last_name=last_name,
                                     block_name='conv4_',
                                     block_n=1,
                                     num_out1=128,
                                     num_out2=128,
                                     downsampling=False,
                                     use_global_stats=use_global_stats)
    network, last_name = Mobile_Unit(network,
                                     last_name=last_name,
                                     block_name='conv5_',
                                     block_n=1,
                                     num_out1=128,
                                     num_out2=256,
                                     downsampling=False,
                                     use_global_stats=use_global_stats)

    network, last_name = Mobile_Unit(network,
                                     last_name=last_name,
                                     block_name='conv6_',
                                     block_n=1,
                                     num_out1=256,
                                     num_out2=512,
                                     downsampling=True,
                                     use_global_stats=use_global_stats)
    network, last_name = Mobile_Unit(network,
                                     last_name=last_name,
                                     block_name='conv7_',
                                     block_n=1,
                                     num_out1=512,
                                     num_out2=512,
                                     downsampling=False,
                                     use_global_stats=use_global_stats)
    network, last_name = Mobile_Unit(network,
                                     last_name=last_name,
                                     block_name='conv8_',
                                     block_n=1,
                                     num_out1=512,
                                     num_out2=512,
                                     downsampling=False,
                                     use_global_stats=use_global_stats)

    network, last_name = write_prototxt.Pooling(network,
                                                name="Pooling1",
                                                bottom_name=last_name,
                                                top_name='Pooling1',
                                                pool='AVE',
                                                global_pooling=True)

    network, last_name = write_prototxt.InnerProduct(network,
                                                     name="fc1",
                                                     bottom_name=last_name,
                                                     top_name='fc1',
                                                     num_output=10,
                                                     weight_type='xavier',
                                                     bias_type='constant')
    if mode == 'train':
        network, last_name = write_prototxt.SoftmaxWithLoss(
            network,
            name="Softmax1",
            bottom_name1='fc1',
            bottom_name2='Data2',
            top_name='Softmax1')
    if mode == 'test':
        network, last_name = write_prototxt.Accuracy(network,
                                                     name="prob",
                                                     bottom_name1='fc1',
                                                     bottom_name2='Data2',
                                                     top_name='prob')


#
#
    print network

    return network
def ResBlock(network,
             last_name='',
             block_name='conv2_',
             block_n=3,
             num_out=16,
             downsampling=False,
             down_method='pooling',
             use_global_stats='False',
             dropout_ratio=0.5,
             first_layer=False):

    input_name = last_name

    for i in range(1, block_n + 1, 1):
        if i == 1 and downsampling == True:
            first_stride = 2
        else:
            first_stride = 1

        network, last_name = write_prototxt.Convolution(
            network,
            name=block_name + str(i) + '_0',
            bottom_name=last_name,
            top_name=block_name + str(i) + '_0',
            num_output=num_out,
            bias_term=True,
            pad=1,
            kernel_size=3,
            stride=first_stride,
            weight_type='msra',
            bias_type='constant')
        network, last_name = write_prototxt.BatchNorm(
            network,
            name_bn=block_name + str(i) + "_bn" + '0',
            name_scale=block_name + str(i) + "_scale" + '0',
            bottom_name=block_name + str(i) + '_0',
            top_name=block_name + str(i) + '_0',
            use_global_stats=use_global_stats)
        network, last_name = write_prototxt.ReLU(
            network,
            name=block_name + str(i) + "_ReLU" + '0',
            bottom_name=block_name + str(i) + '_0',
            top_name=block_name + str(i) + '_0')

        network, last_name = write_prototxt.Dropout(
            network,
            name=block_name + "Drop" + str(i),
            bottom_name1=block_name + str(i) + '_0',
            top_name=block_name + str(i) + '_0',
            dropout_ratio=dropout_ratio)

        network, last_name = write_prototxt.Convolution(
            network,
            name=block_name + str(i) + '_1',
            bottom_name=last_name,
            top_name=block_name + str(i) + '_1',
            num_output=num_out,
            bias_term=True,
            pad=1,
            kernel_size=3,
            stride=1,
            weight_type='msra',
            bias_type='constant')
        network, last_name = write_prototxt.BatchNorm(
            network,
            name_bn=block_name + str(i) + "bn" + '1',
            name_scale=block_name + str(i) + "_scale" + '1',
            bottom_name=block_name + str(i) + '_1',
            top_name=block_name + str(i) + '_1',
            use_global_stats=use_global_stats)
        #network,last_name=ReLU(network,name=block_name+str(i)+"ReLU"+'_1',bottom_name=block_name+str(i)+'_1',top_name=block_name+str(i)+'_1')
        origin_name = last_name

        network, last_name = write_prototxt.Pooling(
            network,
            name=block_name + str(i) + "_Pooling",
            bottom_name=last_name,
            top_name=block_name + str(i) + '_Pooling',
            pool='AVE',
            global_pooling=True)

        num_output_temp = num_out / 16
        if num_output_temp < 16: num_output_temp = 16
        num_output_temp *= 10  # 需修改
        network, last_name = write_prototxt.Convolution(
            network,
            name=block_name + str(i) + '_2',
            bottom_name=last_name,
            top_name=block_name + str(i) + '_2',
            num_output=num_output_temp,
            bias_term=True,
            pad=0,
            kernel_size=1,
            stride=1,
            weight_type='msra',
            bias_type='constant')
        network, last_name = write_prototxt.ReLU(
            network,
            name=block_name + str(i) + "_ReLU" + '2',
            bottom_name=block_name + str(i) + '_2',
            top_name=block_name + str(i) + '_2')

        network, last_name = write_prototxt.Convolution(
            network,
            name=block_name + str(i) + '_3',
            bottom_name=last_name,
            top_name=block_name + str(i) + '_3',
            num_output=num_out,
            bias_term=True,
            pad=0,
            kernel_size=1,
            stride=1,
            weight_type='msra',
            bias_type='constant')
        network, last_name = write_prototxt.Sigmoid(
            network,
            name=block_name + str(i) + "_Prob" + '3',
            bottom_name=block_name + str(i) + '_3',
            top_name=block_name + str(i) + '_3')

        SE_name = last_name

        if i == 1 and (downsampling == True or first_layer == True):

            input_name_temp = last_name

            if down_method == 'conv':
                network, last_name = write_prototxt.Convolution(
                    network,
                    name=block_name + str(i) + '_down',
                    bottom_name=input_name,
                    top_name=block_name + str(i) + '_down',
                    num_output=num_out,
                    bias_term=True,
                    pad=0,
                    kernel_size=1,
                    stride=first_stride,
                    weight_type='msra',
                    bias_type='constant')
                network, last_name = write_prototxt.BatchNorm(
                    network,
                    name_bn=block_name + str(i) + "_bn" + '_down',
                    name_scale=block_name + str(i) + "_scale" + '_down',
                    bottom_name=block_name + str(i) + '_down',
                    top_name=block_name + str(i) + '_down',
                    use_global_stats=use_global_stats)
                #network,last_name=ReLU(network,name=block_name+str(i)+"_ReLU"+'_down',bottom_name=block_name+str(i)+'_down',top_name=block_name+str(i)+'_down')

            else:
                network, last_name = write_prototxt.Pooling(
                    network,
                    name=block_name + str(i) + "_Pooling" + '_down',
                    bottom_name=input_name,
                    top_name=block_name + str(i) + "_Pooling" + '_down',
                    pool='MAX',
                    kernel_size=2,
                    stride=2)

            input_name = last_name


#        network,last_name=write_prototxt.Eltwise(network,name=block_name+"Eltwise_"+str(i),bottom_name1=input_name,
#                                  bottom_name2=last_name,top_name=block_name+"Eltwise_"+str(i),operation='SUM')
        network, last_name = write_prototxt.Axpy(
            network,
            name=block_name + "Axpy_" + str(i),
            bottom_name1=SE_name,
            bottom_name2=origin_name,
            bottom_name3=input_name,
            top_name=block_name + "Axpy_" + str(i))
        network, last_name = write_prototxt.ReLU(network,
                                                 name=block_name + str(i) +
                                                 "ReLU" + '_1',
                                                 bottom_name=last_name,
                                                 top_name=last_name)

        input_name = last_name
    return network, last_name
def Net(mode='train', root_path='', batch_size=32):

    network = 'name:"ShuffleNet"' + '\n'

    if mode == 'train':
        network, last_name = write_prototxt.data(network,
                                                 name="Data1",
                                                 mirror=True,
                                                 crop_size=28,
                                                 batch_size=batch_size,
                                                 backend="LMDB",
                                                 shuffle=True,
                                                 datasets_path=root_path)
        use_global_stats = False
    elif mode == 'test':
        network, last_name = write_prototxt.data(network,
                                                 name="Data1",
                                                 mirror=False,
                                                 crop_size=28,
                                                 batch_size=batch_size,
                                                 backend="LMDB",
                                                 shuffle=False,
                                                 datasets_path=root_path)
        use_global_stats = True

    network, last_name = write_prototxt.Convolution(network,
                                                    name="conv1",
                                                    bottom_name=last_name,
                                                    top_name='conv1',
                                                    num_output=15,
                                                    bias_term=False,
                                                    pad=1,
                                                    kernel_size=3,
                                                    stride=1,
                                                    weight_type='msra',
                                                    bias_type='constant')
    network, last_name = write_prototxt.BatchNorm(
        network,
        name_bn="conv1_bn",
        name_scale="conv1__scale",
        bottom_name="conv1",
        top_name="conv1",
        use_global_stats=use_global_stats)
    network, last_name = write_prototxt.ReLU(network,
                                             name="conv1_ReLU",
                                             bottom_name='conv1',
                                             top_name='conv1')

    network, last_name = ShuffleNet_Unit0(network,
                                          last_name=last_name,
                                          block_name='resx1_conv',
                                          block_n=1,
                                          num_out0=24,
                                          num_out1=45,
                                          downsampling=False,
                                          group=3,
                                          use_global_stats=use_global_stats)

    network, last_name = ShuffleNet_Unit(network,
                                         last_name=last_name,
                                         block_name='resx2_conv',
                                         block_n=1,
                                         num_out0=30,
                                         num_out1=60,
                                         downsampling=False,
                                         group=3,
                                         use_global_stats=use_global_stats)
    network, last_name = ShuffleNet_Unit(network,
                                         last_name=last_name,
                                         block_name='resx3_conv',
                                         block_n=1,
                                         num_out0=30,
                                         num_out1=60,
                                         downsampling=False,
                                         group=3,
                                         use_global_stats=use_global_stats)
    network, last_name = ShuffleNet_Unit(network,
                                         last_name=last_name,
                                         block_name='resx4_conv',
                                         block_n=1,
                                         num_out0=30,
                                         num_out1=60,
                                         downsampling=False,
                                         group=3,
                                         use_global_stats=use_global_stats)

    network, last_name = ShuffleNet_Unit(network,
                                         last_name=last_name,
                                         block_name='resx5_conv',
                                         block_n=1,
                                         num_out0=30,
                                         num_out1=60,
                                         downsampling=True,
                                         group=3,
                                         use_global_stats=use_global_stats)

    network, last_name = ShuffleNet_Unit(network,
                                         last_name=last_name,
                                         block_name='resx6_conv',
                                         block_n=1,
                                         num_out0=48,
                                         num_out1=120,
                                         downsampling=False,
                                         group=3,
                                         use_global_stats=use_global_stats)
    network, last_name = ShuffleNet_Unit(network,
                                         last_name=last_name,
                                         block_name='resx7_conv',
                                         block_n=1,
                                         num_out0=48,
                                         num_out1=120,
                                         downsampling=False,
                                         group=3,
                                         use_global_stats=use_global_stats)

    network, last_name = ShuffleNet_Unit(network,
                                         last_name=last_name,
                                         block_name='resx10_conv',
                                         block_n=1,
                                         num_out0=48,
                                         num_out1=120,
                                         downsampling=True,
                                         group=3,
                                         use_global_stats=use_global_stats)

    network, last_name = ShuffleNet_Unit(network,
                                         last_name=last_name,
                                         block_name='resx11_conv',
                                         block_n=1,
                                         num_out0=60,
                                         num_out1=240,
                                         downsampling=False,
                                         group=3,
                                         use_global_stats=use_global_stats)
    network, last_name = ShuffleNet_Unit(network,
                                         last_name=last_name,
                                         block_name='resx12_conv',
                                         block_n=1,
                                         num_out0=60,
                                         num_out1=240,
                                         downsampling=False,
                                         group=3,
                                         use_global_stats=use_global_stats)
    network, last_name = ShuffleNet_Unit(network,
                                         last_name=last_name,
                                         block_name='resx13_conv',
                                         block_n=1,
                                         num_out0=60,
                                         num_out1=240,
                                         downsampling=False,
                                         group=3,
                                         use_global_stats=use_global_stats)

    network, last_name = write_prototxt.Pooling(network,
                                                name="Pooling1",
                                                bottom_name=last_name,
                                                top_name='Pooling1',
                                                pool='AVE',
                                                global_pooling=True)

    network, last_name = write_prototxt.InnerProduct(network,
                                                     name="fc1",
                                                     bottom_name=last_name,
                                                     top_name='fc1',
                                                     num_output=10,
                                                     weight_type='xavier',
                                                     bias_type='constant')
    if mode == 'train':
        network, last_name = write_prototxt.SoftmaxWithLoss(
            network,
            name="Softmax1",
            bottom_name1='fc1',
            bottom_name2='Data2',
            top_name='Softmax1')
    if mode == 'test':
        network, last_name = write_prototxt.Accuracy(network,
                                                     name="prob",
                                                     bottom_name1='fc1',
                                                     bottom_name2='Data2',
                                                     top_name='prob')


#
#
    print network

    return network
def ShuffleNet_Unit(network,
                    last_name='',
                    block_name='conv2_',
                    block_n=3,
                    num_out0=16,
                    num_out1=16,
                    downsampling=False,
                    group=3,
                    down_method='pooling',
                    use_global_stats='False'):

    input_name = last_name

    for i in range(1, block_n + 1, 1):
        if i == 1 and downsampling == True:
            first_stride = 2
        else:
            first_stride = 1

        network, last_name = write_prototxt.Convolution(network,
                                                        name=block_name + '1',
                                                        bottom_name=last_name,
                                                        top_name=block_name +
                                                        '1',
                                                        num_output=num_out0,
                                                        bias_term=False,
                                                        pad=0,
                                                        kernel_size=1,
                                                        stride=1,
                                                        weight_type='msra',
                                                        bias_type='constant',
                                                        group=group)
        network, last_name = write_prototxt.BatchNorm(
            network,
            name_bn=block_name + '1' + "_bn",
            name_scale=block_name + '1' + "_scale",
            bottom_name=block_name + '1',
            top_name=block_name + '1',
            use_global_stats=use_global_stats)
        network, last_name = write_prototxt.ReLU(network,
                                                 name=block_name + '1' +
                                                 "_ReLU",
                                                 bottom_name=block_name + '1',
                                                 top_name=block_name + '1')

        network, last_name = write_prototxt.ShuffleChannel(
            network,
            name=block_name + "_shuffle",
            bottom_name=last_name,
            top_name=block_name + "_shuffle",
            group=group)

        #ConvolutionDepthwise

        network, last_name = write_prototxt.ConvolutionDepthwise(
            network,
            name=block_name + '2',
            bottom_name=last_name,
            top_name=block_name + '2',
            num_output=num_out0,
            bias_term=False,
            pad=1,
            kernel_size=3,
            stride=first_stride,
            weight_type='msra',
            bias_type='constant')
        network, last_name = write_prototxt.BatchNorm(
            network,
            name_bn=block_name + '2' + '_bn',
            name_scale=block_name + '2' + "_scale",
            bottom_name=block_name + '2',
            top_name=block_name + '2',
            use_global_stats=use_global_stats)
        #network,last_name=write_prototxt.ReLU(network,name=block_name+'2'+"_ReLU",bottom_name=block_name+'2',top_name=block_name+'2')

        network, last_name = write_prototxt.Convolution(network,
                                                        name=block_name + '3',
                                                        bottom_name=last_name,
                                                        top_name=block_name +
                                                        '3',
                                                        num_output=num_out1,
                                                        bias_term=False,
                                                        pad=0,
                                                        kernel_size=1,
                                                        stride=1,
                                                        weight_type='msra',
                                                        bias_type='constant',
                                                        group=group)
        network, last_name = write_prototxt.BatchNorm(
            network,
            name_bn=block_name + '3' + "_bn",
            name_scale=block_name + '3' + "_scale",
            bottom_name=block_name + '3',
            top_name=block_name + '3',
            use_global_stats=use_global_stats)
        #network,last_name=write_prototxt.ReLU(network,name=block_name+str(i)+'0'+"ReLU/dw",bottom_name=block_name+'0'+'/sep',top_name=block_name+'0'+'/sep')

        if (downsampling == True):
            network, input_name = write_prototxt.Pooling(
                network,
                name=block_name + "_match",
                bottom_name=input_name,
                top_name=block_name + "_match",
                pool='AVE',
                kernel_size=3,
                stride=2)
            network, last_name = write_prototxt.Concat(
                network,
                name=block_name + '_concat',
                bottom_name1=input_name,
                bottom_name2=last_name,
                top_name=block_name + '_concat')
            network, last_name = write_prototxt.ReLU(
                network,
                name=block_name + '_concat_ReLU',
                bottom_name=block_name + '_concat',
                top_name=block_name + '_concat')
        else:
            network, last_name = write_prototxt.Eltwise(
                network,
                name=block_name + '_elewise',
                bottom_name1=input_name,
                bottom_name2=last_name,
                top_name=block_name + '_elewise',
                operation='SUM')
            network, last_name = write_prototxt.ReLU(
                network,
                name=block_name + 'elewise_ReLU',
                bottom_name=block_name + '_elewise',
                top_name=block_name + '_elewise')
    return network, last_name