def NAL_stage_3(input, filters):
    '''
    The second attention module in the naive attention learning model using mixed attention without skip connection in mask branch.
    
    inputs:
    parameter input: Input data.
    parameter filters: a vector of length 3, representing the number of output filters used in the funtion of residual unit.
    
    outputs:
    X: Output data.
    '''

    F1, F2, F3 = filters

    #p = 1
    X = residual_unit(input, filters, s=1)

    #t = 2 trunk branch
    trunk = residual_unit(X, filters, s=1)
    trunk = residual_unit(trunk, filters, s=1)

    #soft mask branch   ### r = 1
    ###maxpooling
    X = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(X)
    X = residual_unit(X, filters, s=1)
    X = residual_unit(X, filters, s=1)
    X = UpSampling2D()(X)

    X = BatchNormalization()(X)
    X = Activation('relu')(X)
    X = Conv2D(filters=F3,
               kernel_size=(1, 1),
               strides=(1, 1),
               padding='valid',
               bias_initializer='zeros',
               kernel_initializer=tf.keras.initializers.RandomNormal(
                   mean=0.0, stddev=tf.sqrt(2 / (F3 * 1 * 1))))(X)
    X = BatchNormalization()(X)
    X = Activation('relu')(X)
    X = Conv2D(filters=F3,
               kernel_size=(1, 1),
               strides=(1, 1),
               padding='valid',
               bias_initializer='zeros',
               kernel_initializer=tf.keras.initializers.RandomNormal(
                   mean=0.0, stddev=tf.sqrt(2 / (F3 * 1 * 1))))(X)

    X = Activation('sigmoid')(X)

    #
    X = Multiply()([X, trunk])

    # p = 1
    X = residual_unit(X, filters, s=1)

    return X
Beispiel #2
0
def ResNet92_NAL(X):
    '''
    attention92_NAL model using mixed attention module.
    
    inputs:
    parameter X: Input data with shape (batch_size, 32, 32, 3).
    
    outputs:
    X: Output data with shape (batch_size, 10), each 10 values for one batch represents the probability for the image to be recognized as the ten digits.
    '''

    X = Conv2D(filters=16,
               kernel_size=3,
               padding='same',
               bias_initializer='zeros',
               kernel_initializer=tf.keras.initializers.RandomNormal(
                   mean=0.0, stddev=tf.sqrt(2 / (16 * 3 * 3))))(X)
    X = BatchNormalization()(X)
    X = Activation('relu')(X)
    X = MaxPooling2D(pool_size=(3, 3), strides=1, padding='same')(X)

    X = residual_unit(X, [4, 4, 16], s=1)
    X = NAL_stage_1(X, [4, 4, 16])  # 32*32*16

    X = residual_unit(X, [8, 8, 32], s=2)  # 16*16*32
    X = NAL_stage_2(X, [8, 8, 32])
    X = NAL_stage_2(X, [8, 8, 32])

    X = residual_unit(X, [16, 16, 64], s=2)  # 8*8*64
    X = NAL_stage_3(X, [16, 16, 64])
    X = NAL_stage_3(X, [16, 16, 64])
    X = NAL_stage_3(X, [16, 16, 64])

    X = residual_unit(X, [16, 16, 64], s=1)
    X = residual_unit(X, [16, 16, 64], s=1)
    X = residual_unit(X, [16, 16, 64], s=1)

    X = BatchNormalization()(X)
    X = Activation('relu')(X)
    X = GlobalAveragePooling2D()(X)
    X = Flatten()(X)

    output = Dense(10,
                   activation='softmax',
                   kernel_initializer=tf.keras.initializers.RandomNormal(
                       mean=0.0, stddev=tf.sqrt(2 / 10)))(X)

    return output
    
    inputs:
    parameter X: Input data with shape (batch_size, 32, 32, 3).
    
    outputs:
    X: Output data with shape (batch_size, 100), each 100 values for one batch represents the probability for the image to be recognized as the digits.
    '''
    
  X = Conv2D(filters = 16, kernel_size = 3, padding = 'same',
             bias_initializer = 'zeros',
             kernel_initializer = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = tf.sqrt(2/(16*3*3))))(X)
  X = BatchNormalization()(X)
  X = Activation('relu')(X)
  X = MaxPooling2D(pool_size = (3, 3), strides=1, padding = 'same')(X)

  X = residual_unit(X, [4, 4, 16], s = 1)
  X = attention_stage_1(X, [4, 4, 16])  # 32*32*16

  X = residual_unit(X, [8, 8, 32], s = 2) # 16*16*32
  X = attention_stage_2(X, [8, 8, 32])

  X = residual_unit(X, [16, 16, 64], s = 2) # 8*8*64
  X = attention_stage_3(X, [16, 16, 64])

  X = residual_unit(X, [16, 16, 64], s = 1)
  X = residual_unit(X, [16, 16, 64], s = 1)
  X = residual_unit(X, [16, 16, 64], s = 1)

  X = BatchNormalization()(X)
  X = Activation('relu')(X)
  X = GlobalAveragePooling2D()(X)