def call(self, inputs, genres): """ Executes the discriminator model on a batch of input images and outputs whether it is real or fake. :param inputs: a batch of images, shape=[batch_size, height, width, channels] :param genre: a genre string, shape=[1,] :return: a batch of values indicating whether the image is real or fake, shape=[batch_size, 1] """ genres = tf.convert_to_tensor(genres) # set up genre embedding embed = self.embedding(genres) embed = self.embed_reshape(self.embed_dense(embed)) # merge information out = self.merge([embed, inputs]) # proceed with normal discriminator processes out = self.conv1(out) # out = self.activate(self.norm(self.conv2(out))) out = self.conv2(out) (mean, variance) = tf.nn.moments(out, (0, 1, 2)) out = batch_normalization(out, mean, variance, offset=0, scale=1, variance_epsilon=0.00001) self.activate(out) # out = self.activate(self.norm(self.conv3(out))) out = self.conv3(out) (mean, variance) = tf.nn.moments(out, (0, 1, 2)) out = batch_normalization(out, mean, variance, offset=0, scale=1, variance_epsilon=0.00001) self.activate(out) # out = self.activate(self.norm(self.conv4(out))) out = self.conv4(out) (mean, variance) = tf.nn.moments(out, (0, 1, 2)) out = batch_normalization(out, mean, variance, offset=0, scale=1, variance_epsilon=0.00001) self.activate(out) flat = self.flat(out) return self.decision(flat)
def call(self, inputs): mean, variance = nn.moments( inputs, self.moments_axes, keepdims=True) outputs = nn.batch_normalization( inputs, mean, variance, None, None, self.epsilon, name='LayerInstanceNorm') return outputs
def call(self, inputs, gamma, beta): mean, variance = nn.moments( inputs, self.moments_axes, keepdims=True) outputs = nn.batch_normalization( inputs, mean, variance, gamma, beta, self.epsilon, name='AdaptiveInstanceNorm') return outputs
def batch_norm_layer(x, is_train, decay=0.9, name_or_scope=None): """ x: [b, emb_dim] """ with tf.variable_scope(name_or_scope=name_or_scope, default_name="batch_norm_layer"): params_shape = [1, x.shape[-1]] beta = tf.get_variable("beta", params_shape, tf.float32, initializer=tf.constant_initializer( 0.0, tf.float32)) gamma = tf.get_variable("gamma", params_shape, tf.float32, initializer=tf.constant_initializer( 1.0, tf.float32)) if is_train: mean, variance = tfnn.moments(x, axes=[0], keep_dims=True) moving_mean = tf.get_variable('moving_mean', shape=params_shape, dtype=tf.float32, initializer=tf.constant_initializer( 0.0, tf.float32), trainable=False) moving_variance = tf.get_variable( 'moving_variance', shape=params_shape, dtype=tf.float32, initializer=tf.constant_initializer(1.0, tf.float32), trainable=False) tf.add_to_collection( tf.GraphKeys.TRAIN_OP, tf.assign(moving_mean, decay * moving_mean + (1 - decay) * mean)) tf.add_to_collection( tf.GraphKeys.TRAIN_OP, tf.assign(moving_variance, decay * moving_variance + (1 - decay) * variance)) else: mean = tf.get_variable('moving_mean', shape=params_shape, dtype=tf.float32, initializer=tf.constant_initializer( 0.0, tf.float32), trainable=False) variance = tf.get_variable('moving_variance', shape=params_shape, dtype=tf.float32, initializer=tf.constant_initializer( 1.0, tf.float32), trainable=False) x = tfnn.batch_normalization(x, mean, variance, beta, gamma, 1e-6) return x
def resnetBlock(X, weights1, bias1, weights2, bias2, mean1, variance1, offset1, scale1, mean2, variance2, offset2, scale2): conv1 = nn.conv2d(X, weights1, strides=[1, 1, 1, 1], padding="VALID", data_format="NCHW") conv1_bias = nn.bias_add(conv1, bias1, data_format="NCHW") bn1 = nn.batch_normalization(conv1_bias, mean1, variance1, offset1, scale1, EPSILON) relu1 = nn.relu(bn1) conv2 = nn.conv2d(relu1, weights2, strides=[1, 1, 1, 1], padding="SAME", data_format="NCHW") conv2_bias = nn.bias_add(conv2, bias2, data_format="NCHW") bn2 = nn.batch_normalization(conv2_bias, mean2, variance2, offset2, scale2, EPSILON) return bn2
def trans_conv_layer(previous_layer, filters, kernel_size=(3,3),stride=2,padding='SAME',relu=True, norm='instance'): batch,height,width,channels = previous_layer.get_shape().as_list() shape = [kernel_size[0], kernel_size[1], filters, channels] new_shape = tf.convert_to_tensor([batch,(height*stride),(width*stride),filters]) conv = conv2d_transpose(previous_layer, create_weights(shape),new_shape,[1,stride,stride,1],padding=padding) if norm == 'instance': normalised = tf.contrib.layers.instance_norm(conv) else: mean,variance = tf.nn.moments(conv,[0,1,2]) normalised = batch_normalization(conv,mean,variance,None,None,0.0001) if relu: relu_layer = tf.nn.relu(normalised) return relu_layer return normalised
def conv_layer(previous_layer, filters, kernel_size=(3,3),stride=2,relu=True,padding='SAME', norm='instance'): #Get the output channel size from previous layer channels = previous_layer.get_shape().as_list()[3] #The shape of the output tensor for this convolutional layer shape = [kernel_size[0], kernel_size[1],channels, filters] #Create convolution layer conv = conv2d(previous_layer, create_weights(shape), [1,stride,stride,1],padding=padding) #If using instance norm.. if norm == 'instance': normalised = tf.contrib.layers.instance_norm(conv) #Otherwise, we use batch norm.. else: mean,variance = tf.nn.moments(conv,[0,1,2]) normalised = batch_normalization(conv,mean,variance,None,None,0.0001) #If we use ReLU, add ReLU layer. if relu: relu_layer = tf.nn.relu(normalised) return relu_layer return normalised
def call(self, inputs): # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) # Broadcasting only necessary for norm where the axis is not just # the last dimension broadcast_shape = [1] * ndims for dim in self.axis: broadcast_shape[dim] = input_shape.dims[dim].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]): return array_ops.reshape(v, broadcast_shape) return v if not self._fused: input_dtype = inputs.dtype if input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32': # If mixed precision is used, cast inputs to float32 so that this is at # least as numerically stable as the fused version. inputs = math_ops.cast(inputs, 'float32') # Calculate the moments on the last axis (layer activations). mean, variance = nn.moments(inputs, self.axis, keep_dims=True) scale, offset = _broadcast(self.gamma), _broadcast(self.beta) # Compute layer normalization using the batch_normalization function. outputs = nn.batch_normalization(inputs, mean, variance, offset=offset, scale=scale, variance_epsilon=self.epsilon) outputs = tf.cast(outputs, input_dtype) else: # Collapse dims before self.axis, and dims in self.axis pre_dim, in_dim = (1, 1) axis = sorted(self.axis) tensor_shape = array_ops.shape(inputs) for dim in range(0, ndims): dim_tensor = tensor_shape[dim] if dim < axis[0]: pre_dim = pre_dim * dim_tensor else: assert dim in axis in_dim = in_dim * dim_tensor squeezed_shape = [1, pre_dim, in_dim, 1] # This fused operation requires reshaped inputs to be NCHW. data_format = 'NCHW' inputs = array_ops.reshape(inputs, squeezed_shape) def _set_const_tensor(val, dtype, shape): return array_ops.fill(shape, constant_op.constant(val, dtype=dtype)) # self.gamma and self.beta have the wrong shape for fused_batch_norm, so # we cannot pass them as the scale and offset parameters. Therefore, we # create two constant tensors in correct shapes for fused_batch_norm and # later construct a separate calculation on the scale and offset. scale = _set_const_tensor(1.0, self.dtype, [pre_dim]) offset = _set_const_tensor(0.0, self.dtype, [pre_dim]) # Compute layer normalization using the fused_batch_norm function. outputs, _, _ = nn.fused_batch_norm(inputs, scale=scale, offset=offset, epsilon=self.epsilon, data_format=data_format) outputs = array_ops.reshape(outputs, tensor_shape) scale, offset = _broadcast(self.gamma), _broadcast(self.beta) if scale is not None: outputs = outputs * math_ops.cast(scale, outputs.dtype) if offset is not None: outputs = outputs + math_ops.cast(offset, outputs.dtype) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) return outputs