def Layernorm(name, norm_axes, inputs): mean, var = tf.nn.moments(inputs, norm_axes, keep_dims=True) # Assume the 'neurons' axis is the first of norm_axes. This is the case for fully-connected and BCHW conv layers. n_neurons = inputs.get_shape().as_list()[norm_axes[0]] offset = lib.param(name + '.offset', np.zeros(n_neurons, dtype='float32')) scale = lib.param(name + '.scale', np.ones(n_neurons, dtype='float32')) # Add broadcasting dims to offset and scale (e.g. BCHW conv data) offset = tf.reshape(offset, [-1] + [1 for i in range(len(norm_axes) - 1)]) scale = tf.reshape(scale, [-1] + [1 for i in range(len(norm_axes) - 1)]) result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5) return result
def Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, update_moving_stats=True, fused=True, labels=None, n_labels=None): """conditional batchnorm (dumoulin et al 2016) for BCHW conv filtermaps""" if axes != [0, 2, 3]: raise Exception('unsupported') mean, var = tf.nn.moments(inputs, axes, keep_dims=True) shape = mean.get_shape().as_list() # shape is [1,n,1,1] offset_m = lib.param(name + '.offset', np.zeros([n_labels, shape[1]], dtype='float32')) scale_m = lib.param(name + '.scale', np.ones([n_labels, shape[1]], dtype='float32')) offset = tf.nn.embedding_lookup(offset_m, labels) scale = tf.nn.embedding_lookup(scale_m, labels) result = tf.nn.batch_normalization(inputs, mean, var, offset[:, :, None, None], scale[:, :, None, None], 1e-5) return result
def Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, update_moving_stats=True, fused=True): if ((axes == [0,2,3]) or (axes == [0,2])) and fused==True: if axes==[0,2]: inputs = tf.expand_dims(inputs, 3) # Old (working but pretty slow) implementation: ########## # inputs = tf.transpose(inputs, [0,2,3,1]) # mean, var = tf.nn.moments(inputs, [0,1,2], keep_dims=False) # offset = lib.param(name+'.offset', np.zeros(mean.get_shape()[-1], dtype='float32')) # scale = lib.param(name+'.scale', np.ones(var.get_shape()[-1], dtype='float32')) # result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-4) # return tf.transpose(result, [0,3,1,2]) # New (super fast but untested) implementation: offset = lib.param(name+'.offset', np.zeros(inputs.get_shape()[1], dtype='float32')) scale = lib.param(name+'.scale', np.ones(inputs.get_shape()[1], dtype='float32')) moving_mean = lib.param(name+'.moving_mean', np.zeros(inputs.get_shape()[1], dtype='float32'), trainable=False) moving_variance = lib.param(name+'.moving_variance', np.ones(inputs.get_shape()[1], dtype='float32'), trainable=False) def _fused_batch_norm_training(): return tf.nn.fused_batch_norm(inputs, scale, offset, epsilon=1e-5, data_format='NCHW') def _fused_batch_norm_inference(): # Version which blends in the current item's statistics batch_size = tf.cast(tf.shape(inputs)[0], 'float32') mean, var = tf.nn.moments(inputs, [2,3], keep_dims=True) mean = ((1./batch_size)*mean) + (((batch_size-1.)/batch_size)*moving_mean)[None,:,None,None] var = ((1./batch_size)*var) + (((batch_size-1.)/batch_size)*moving_variance)[None,:,None,None] return tf.nn.batch_normalization(inputs, mean, var, offset[None,:,None,None], scale[None,:,None,None], 1e-5), mean, var # Standard version # return tf.nn.fused_batch_norm( # inputs, # scale, # offset, # epsilon=1e-2, # mean=moving_mean, # variance=moving_variance, # is_training=False, # data_format='NCHW' # ) if is_training is None: outputs, batch_mean, batch_var = _fused_batch_norm_training() else: outputs, batch_mean, batch_var = tf.cond(is_training, _fused_batch_norm_training, _fused_batch_norm_inference) if update_moving_stats: no_updates = lambda: outputs def _force_updates(): """Internal function forces updates moving_vars if is_training.""" float_stats_iter = tf.cast(stats_iter, tf.float32) update_moving_mean = tf.assign(moving_mean, ((float_stats_iter/(float_stats_iter+1))*moving_mean) + ((1/(float_stats_iter+1))*batch_mean)) update_moving_variance = tf.assign(moving_variance, ((float_stats_iter/(float_stats_iter+1))*moving_variance) + ((1/(float_stats_iter+1))*batch_var)) with tf.control_dependencies([update_moving_mean, update_moving_variance]): return tf.identity(outputs) outputs = tf.cond(is_training, _force_updates, no_updates) if axes == [0,2]: return outputs[:,:,:,0] # collapse last dim else: return outputs else: # raise Exception('old BN') # TODO we can probably use nn.fused_batch_norm here too for speedup mean, var = tf.nn.moments(inputs, axes, keep_dims=True) shape = mean.get_shape().as_list() if 0 not in axes: print("WARNING ({}): didn't find 0 in axes, but not using separate BN params for each item in batch".format(name)) shape[0] = 1 offset = lib.param(name+'.offset', np.zeros(shape, dtype='float32')) scale = lib.param(name+'.scale', np.ones(shape, dtype='float32')) result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5) return result
def Linear(name, input_dim, output_dim, inputs, biases=True, initialization=None, weightnorm=None, gain=1.): """ initialization: None, `lecun`, 'glorot', `he`, 'glorot_he', `orthogonal`, `("uniform", range)` """ with tf.name_scope(name) as scope: def uniform(stdev, size): if _weights_stdev is not None: stdev = _weights_stdev return np.random.uniform(low=-stdev * np.sqrt(3), high=stdev * np.sqrt(3), size=size).astype('float32') if initialization == 'lecun': # and input_dim != output_dim): # disabling orth. init for now because it's too slow weight_values = uniform(np.sqrt(1. / input_dim), (input_dim, output_dim)) elif initialization == 'glorot' or (initialization == None): weight_values = uniform(np.sqrt(2. / (input_dim + output_dim)), (input_dim, output_dim)) elif initialization == 'he': weight_values = uniform(np.sqrt(2. / input_dim), (input_dim, output_dim)) elif initialization == 'glorot_he': weight_values = uniform(np.sqrt(4. / (input_dim + output_dim)), (input_dim, output_dim)) elif initialization == 'orthogonal' or \ (initialization == None and input_dim == output_dim): # From lasagne def sample(shape): if len(shape) < 2: raise RuntimeError("Only shapes of length 2 or more are " "supported.") flat_shape = (shape[0], np.prod(shape[1:])) # TODO: why normal and not uniform? a = np.random.normal(0.0, 1.0, flat_shape) u, _, v = np.linalg.svd(a, full_matrices=False) # pick the one with the correct shape q = u if u.shape == flat_shape else v q = q.reshape(shape) return q.astype('float32') weight_values = sample((input_dim, output_dim)) elif initialization[0] == 'uniform': weight_values = np.random.uniform( low=-initialization[1], high=initialization[1], size=(input_dim, output_dim)).astype('float32') else: raise Exception('Invalid initialization!') weight_values *= gain weight = lib.param(name + '.W', weight_values) if weightnorm == None: weightnorm = _default_weightnorm if weightnorm: norm_values = np.sqrt(np.sum(np.square(weight_values), axis=0)) # norm_values = np.linalg.norm(weight_values, axis=0) target_norms = lib.param(name + '.g', norm_values) with tf.name_scope('weightnorm') as scope: norms = tf.sqrt( tf.reduce_sum(tf.square(weight), reduction_indices=[0])) weight = weight * (target_norms / norms) # if 'Discriminator' in name: # print "WARNING weight constraint on {}".format(name) # weight = tf.nn.softsign(10.*weight)*.1 if inputs.get_shape().ndims == 2: result = tf.matmul(inputs, weight) else: reshaped_inputs = tf.reshape(inputs, [-1, input_dim]) result = tf.matmul(reshaped_inputs, weight) result = tf.reshape( result, tf.pack(tf.unpack(tf.shape(inputs))[:-1] + [output_dim])) if biases: result = tf.nn.bias_add( result, lib.param(name + '.b', np.zeros((output_dim, ), dtype='float32'))) return result
def Conv2D(name, input_dim, output_dim, filter_size, inputs, he_init=True, mask_type=None, stride=1, weightnorm=None, biases=True, gain=1.): """ inputs: tensor of shape (batch size, num channels, height, width) mask_type: one of None, 'a', 'b' returns: tensor of shape (batch size, num channels, height, width) """ with tf.name_scope(name) as scope: if mask_type is not None: mask_type, mask_n_channels = mask_type mask = np.ones((filter_size, filter_size, input_dim, output_dim), dtype='float32') center = filter_size // 2 # Mask out future locations # filter shape is (height, width, input channels, output channels) mask[center + 1:, :, :, :] = 0. mask[center, center + 1:, :, :] = 0. # Mask out future channels for i in xrange(mask_n_channels): for j in xrange(mask_n_channels): if (mask_type == 'a' and i >= j) or (mask_type == 'b' and i > j): mask[center, center, i::mask_n_channels, j::mask_n_channels] = 0. def uniform(stdev, size): return np.random.uniform(low=-stdev * np.sqrt(3), high=stdev * np.sqrt(3), size=size).astype('float32') fan_in = input_dim * filter_size**2 fan_out = output_dim * filter_size**2 / (stride**2) if mask_type is not None: # only approximately correct fan_in /= 2. fan_out /= 2. if he_init: filters_stdev = np.sqrt(4. / (fan_in + fan_out)) else: # Normalized init (Glorot & Bengio) filters_stdev = np.sqrt(2. / (fan_in + fan_out)) if _weights_stdev is not None: filter_values = uniform( _weights_stdev, (filter_size, filter_size, input_dim, output_dim)) else: filter_values = uniform( filters_stdev, (filter_size, filter_size, input_dim, output_dim)) # print "WARNING IGNORING GAIN" filter_values *= gain filters = lib.param(name + '.Filters', filter_values) if weightnorm == None: weightnorm = _default_weightnorm if weightnorm: norm_values = np.sqrt( np.sum(np.square(filter_values), axis=(0, 1, 2))) target_norms = lib.param(name + '.g', norm_values) with tf.name_scope('weightnorm') as scope: norms = tf.sqrt( tf.reduce_sum(tf.square(filters), reduction_indices=[0, 1, 2])) filters = filters * (target_norms / norms) if mask_type is not None: with tf.name_scope('filter_mask'): filters = filters * mask result = tf.nn.conv2d(input=inputs, filter=filters, strides=[1, 1, stride, stride], padding='SAME', data_format='NCHW') if biases: _biases = lib.param(name + '.Biases', np.zeros(output_dim, dtype='float32')) result = tf.nn.bias_add(result, _biases, data_format='NCHW') return result