def weight_init(name='undefined', method='msra', height=None, width=None, input_channel=None, output_channel=None, collection='TRAINABLED'): assert method in ['msra'], print('mothod not support') if method == 'msra': initer = tf.contrib.layer.variance_scaling_initializer(factor=2.0, model='FAN_IN', uniform=False) # elif: # sys.exit() else: sys.exit() w = tf.get_variable(name=name + '_weight', shape=[height, width, input_channel, output_channel], dtype=tf.float32, initializer=initer) Util.AddToCollectionInfo(tf.GraphKeys.GLOBAL_VARIABLES, w) tf.add_to_collection(collection, w) Util.AddToCollectionInfo(collection, w) return w
def Conv_Layer( name='undefined', input=None, height=None, width=None, output_channel=None ): conv = tf.nn.conv2d( input=input, filter=init.weight_init( name=name, height=height, width=width, input_channel=input.get_shape().as_list()[-1], output_channel=output_channel, collection='ConvWeight' ), strides=[1, 1, 1, 1], padding='SAME', name=name ) tf.add_to_collection(name='ConvOut', value=conv) Util.AddToCollectionInfo('ConvOut', conv) Util.CLayerInfo(name, input, conv) return conv pass
def Activate_Layer( name='undefined', input=None, method='LeakReLU' ): """ 默认格式'NHWC' :param name: :param input: :param method: :return: """ assert method in ['LeakReLU', 'ReLU'], Util.CError('method is not supported') if method == 'LeakReLU': activate = keras.layers.LeakyReLU( alpha=0.1, name=name + 'LeakReLU' )(input) elif method == 'ReLU': activate = keras.layers.ReL else: Util.CError('method is not supported!') sys.exit() tf.add_to_collection(name='ActiOut', value=activate) Util.CLayerInfo(name, input, activate) Util.AddToCollectionInfo('ActiOut', activate) return activate pass
def Bias_Layer( name='undefined', input=None, ): """ 默认格式'NHWC' :param name: :param input: :return: """ bias = tf.nn.bias_add( value=input, bias=init.bias_init( name=name, output_channel=input.get_shape().as_list()[-1], collection='BiasBias' ), data_format='NHWC', name=name ) tf.add_to_collection(name='BiasOut', value=bias) Util.AddToCollectionInfo('BiasOut', bias) Util.CLayerInfo(name, input, bias) return bias pass
def bias_init(name='undefined', method='zero', output_channel=None, collection='TRAINABLED'): assert method in ['zero', 'one'], print('method not support') if method == 'zero': b = tf.get_variable(name=name + '_bias', shape=[output_channel], dtype=tf.float32, initializer=tf.zeros_initializer()) elif method == 'one': b = tf.get_variable(name=name + '_bias', shape=[output_channel], dtype=tf.float32, initializer=tf.ones_initializer()) else: sys.exit() Util.AddToCollectionInfo(tf.GraphKeys.GLOBAL_VARIABLES, b) tf.add_to_collection(collection, b) Util.AddToCollectionInfo(collection, b) return b pass
def BatchNormal_Layer( name='undefined', input=None, train=tf.bool(True), ): """ 默认格式'NHWC' :param input: :return: """ global MOVING_DECAY global BNEPS assert Util.CGlobalExit('MOVING_DECAY') assert Util.CGlobalExit('BNEPS') train_mean = tf.reduce_mean( input_tensor=input, axis=3, name=name + '_t_mean', ) train_var = tf.reduce_mean( tf.square( x=tf.subtract( x=input, y=train_mean ) ), axis=[0, 1, 2] ) beta = init.bias_init( name=name + '_beta', method='zero', output_channel=input.get_shape().as_list()[-1], collection='BnBeta' ) gama= init.bias_init( name=name + '_gama', method='one', output_channel=input.get_shape().as_list()[-1], collection='BnGama' ) ema = tf.train.ExponentialMovingAverage(MOVING_DECAY) predict_mean, predict_var = ema.apply([train_mean, train_var]) def depend_in_train(): with tf.control_dependencies([predict_mean, predict_var]): return tf.identity(train_mean), tf.identity(train_var) pass mean, var = tf.cond( train, lambda: depend_in_train(), lambda: (predict_mean, predict_var) ) bn = tf.nn.batch_normalization( x=input, mean=mean, variance=var, offset=beta, scale=gama, variance_epsilon=BNEPS ) tf.add_to_collection(name='BnOut', value=bn) Util.AddToCollectionInfo('BnOut', bn) Util.CLayerInfo(name, input, bn) # bn = tf.add( # x=beta, # y=tf.div( # tf.multiply( # x=gama, # y=tf.subtract( # x=input, # y=mean # ) # ), # y=var # ), # name=name + 'bn_output' # ) return bn pass