def call(self, inputs, training=True): inputs_shape = array_ops.shape(inputs) batch_size = inputs_shape[0] if self.data_format == 'channels_first': c_axis, h_axis, w_axis = 1, 2, 3 else: c_axis, h_axis, w_axis = 3, 1, 2 height, width = inputs_shape[h_axis], inputs_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides # Infer the dynamic output shape: out_height = utils.deconv_output_length(height, kernel_h, self.padding, stride_h) out_width = utils.deconv_output_length(width, kernel_w, self.padding, stride_w) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) strides = (1, 1, stride_h, stride_w) else: output_shape = (batch_size, out_height, out_width, self.filters) strides = (1, stride_h, stride_w, 1) output_shape_tensor = array_ops.stack(output_shape) outputs = nn.conv2d_transpose( inputs, self.compute_spectral_normal(training=training), output_shape_tensor, strides, padding=self.padding.upper(), data_format=utils.convert_data_format(self.data_format, ndim=4)) if not context.executing_eagerly(): # Infer the static output shape: out_shape = inputs.get_shape().as_list() out_shape[c_axis] = self.filters out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis], kernel_h, self.padding, stride_h) out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis], kernel_w, self.padding, stride_w) outputs.set_shape(out_shape) if self.use_bias: outputs = nn.bias_add( outputs, self.bias, data_format=utils.convert_data_format(self.data_format, ndim=4)) if self.activation is not None: return self.activation(outputs) return outputs
def style_swap(content, style, patch_size, stride): '''Efficiently swap content feature patches with nearest-neighbor style patches Original paper: https://arxiv.org/abs/1612.04337 Adapted from: https://github.com/rtqichen/style-swap/blob/master/lib/NonparametricPatchAutoencoderFactory.lua ''' nC = tf.shape(style)[-1] # Num channels of input content feature and style-swapped output ### Extract patches from style image that will be used for conv/deconv layers style_patches = tf.extract_image_patches(style, [1,patch_size,patch_size,1], [1,stride,stride,1], [1,1,1,1], 'VALID') before_reshape = tf.shape(style_patches) # NxRowsxColsxPatch_size*Patch_size*nC style_patches = tf.reshape(style_patches, [before_reshape[1]*before_reshape[2],patch_size,patch_size,nC]) style_patches = tf.transpose(style_patches, [1,2,3,0]) # Patch_sizexPatch_sizexIn_CxOut_c # Normalize each style patch style_patches_norm = tf.nn.l2_normalize(style_patches, dim=3) # Compute cross-correlation/nearest neighbors of patches by using style patches as conv filters ss_enc = tf.nn.conv2d(content, style_patches_norm, [1,stride,stride,1], 'VALID') # For each spatial position find index of max along channel/patch dim ss_argmax = tf.argmax(ss_enc, axis=3) encC = tf.shape(ss_enc)[-1] # Num channels in intermediate conv output, same as # of patches # One-hot encode argmax with same size as ss_enc, with 1's in max channel idx for each spatial pos ss_oh = tf.one_hot(ss_argmax, encC, 1., 0., 3) # Calc size of transposed conv out deconv_out_H = utils.deconv_output_length(tf.shape(ss_oh)[1], patch_size, 'valid', stride) deconv_out_W = utils.deconv_output_length(tf.shape(ss_oh)[2], patch_size, 'valid', stride) deconv_out_shape = tf.stack([1,deconv_out_H,deconv_out_W,nC]) # Deconv back to original content size with highest matching (unnormalized) style patch swapped in for each content patch ss_dec = tf.nn.conv2d_transpose(ss_oh, style_patches, deconv_out_shape, [1,stride,stride,1], 'VALID') ### Interpolate to average overlapping patch locations ss_oh_sum = tf.reduce_sum(ss_oh, axis=3, keep_dims=True) filter_ones = tf.ones([patch_size,patch_size,1,1], dtype=tf.float32) deconv_out_shape = tf.stack([1,deconv_out_H,deconv_out_W,1]) # Same spatial size as ss_dec with 1 channel counting = tf.nn.conv2d_transpose(ss_oh_sum, filter_ones, deconv_out_shape, [1,stride,stride,1], 'VALID') counting = tf.tile(counting, [1,1,1,nC]) # Repeat along channel dim to make same size as ss_dec interpolated_dec = tf.divide(ss_dec, counting) return interpolated_dec
def _conv_sn(conv, inputs, filters, kernel_size, name, strides=1, padding='valid', activation=None, use_bias=True, kernel_initializer=tf.glorot_uniform_initializer(), bias_initializer=tf.zeros_initializer(), use_gamma=False, factor=None, transposed=False): input_shape = inputs.get_shape().as_list() c_axis, h_axis, w_axis = 3, 1, 2 # channels last input_dim = input_shape[c_axis] with tf.variable_scope(name): if transposed is True: kernel_shape = kernel_size + (filters, input_dim) kernel = tf.get_variable('kernel', shape=kernel_shape, initializer=kernel_initializer) height, width = input_shape[h_axis], input_shape[w_axis] kernel_h, kernel_w = kernel_size stride_h, stride_w = strides out_height = deconv_output_length(height, kernel_h, padding, stride_h) out_width = deconv_output_length(width, kernel_w, padding, stride_w) output_shape = (input_shape[0], out_height, out_width, filters) outputs = conv(inputs, spectral_norm(kernel, use_gamma=use_gamma, factor=factor), tf.stack(output_shape), strides=(1, *strides, 1), padding=padding.upper()) else: kernel_shape = kernel_size + (input_dim, filters) kernel = tf.get_variable('kernel', shape=kernel_shape, initializer=kernel_initializer) outputs = conv(inputs, spectral_norm(kernel, use_gamma=use_gamma, factor=factor), strides=(1, *strides, 1), padding=padding.upper()) if use_bias is True: bias = tf.get_variable('bias', shape=(filters, ), initializer=bias_initializer) outputs = tf.nn.bias_add(outputs, bias) if activation is not None: outputs = activation(outputs) return outputs
def call(self, inputs): inputs_shape = array_ops.shape(inputs) batch_size = inputs_shape[0] if self.data_format == 'channels_first': c_axis, h_axis, w_axis = 1, 2, 3 else: c_axis, h_axis, w_axis = 3, 1, 2 height, width = inputs_shape[h_axis], inputs_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides # Infer the dynamic output shape: out_height = utils.deconv_output_length(height, kernel_h, self.padding, stride_h) out_width = utils.deconv_output_length(width, kernel_w, self.padding, stride_w) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) strides = (1, 1, stride_h, stride_w) else: output_shape = (batch_size, out_height, out_width, self.filters) strides = (1, stride_h, stride_w, 1) output_shape_tensor = array_ops.stack(output_shape) kernel_norm = nn.l2_normalize(self.kernel, [0, 1, 3]) if self.use_scale: kernel_norm = tf.reshape(self.scale, [1, 1, self.filters, 1]) * kernel_norm outputs = nn.conv2d_transpose(inputs, kernel_norm, output_shape_tensor, strides, padding=self.padding.upper(), data_format=utils.convert_data_format( self.data_format, ndim=4)) if context.in_graph_mode(): # Infer the static output shape: out_shape = inputs.get_shape().as_list() out_shape[c_axis] = self.filters out_shape[h_axis] = utils.deconv_output_length( out_shape[h_axis], kernel_h, self.padding, stride_h) out_shape[w_axis] = utils.deconv_output_length( out_shape[w_axis], kernel_w, self.padding, stride_w) outputs.set_shape(out_shape) if self.use_bias: outputs = nn.bias_add(outputs, self.bias, data_format=utils.convert_data_format( self.data_format, ndim=4)) if self.activation is not None: return self.activation(outputs) return outputs
def deconv2d_real(name, input, oc, f_h=3, f_w=3, s_h=1, s_w=1, bn=True, is_training=True, print_shape=False, act='lrelu'): with tf.variable_scope(name) as scope: padding = "SAME" input_shape = input.get_shape().as_list() out_h = utils.deconv_output_length(input_shape[1], f_h, padding, s_h) out_w = utils.deconv_output_length(input_shape[2], f_w, padding, s_w) output_shape = (input_shape[0], out_h, out_w, oc) strides = [1, s_h, s_w, 1] W = tf.get_variable("W", shape=[f_h, f_w, oc, input_shape[-1]], initializer=tf.contrib.layers.xavier_initializer()) deconv = tf.nn.conv2d_transpose(input, filter=W, output_shape=output_shape, strides=strides, padding=padding) if bn: bn = tf.contrib.layers.batch_norm(deconv, center=True, scale=True, renorm=True, scope='bn') else: bn = deconv if act == 'relu': # Just use the regular relu act = tf.nn.relu(bn) elif act == 'lrelu': # This is leaky relu act = tf.nn.leaky_relu(bn) elif act == 'abs': act = tf.abs(bn) else: act = bn if print_shape: print("{} shape : {}".format(act.name, act.get_shape())) return act
def deconv(self, input, k_h, k_w, c_o, s_h, s_w, name, relu=True, padding=DEFAULT_PADDING, group=1, biased=True): # Verify that the padding is acceptable self.validate_padding(padding) # Get the number of channels in the input c_i = input.get_shape()[-1] # Verify that the grouping parameter is valid assert c_i % group == 0 assert c_o % group == 0 # Convolution for a given input and kernel stride_shape = [1, s_h, s_w, 1] input_shape = input.get_shape() out_h = utils.deconv_output_length(input_shape[1].value, k_h, 1, s_h) out_w = utils.deconv_output_length(input_shape[2].value, k_w, 1, s_w) out_shape = tf.stack( [ tf.shape(input)[0], out_h, out_w, c_o/group ] ) #print "deconv: c_i=%d -> c_o=%d"%(c_i,c_o)," padding=",padding #print "out(h,w): ",out_h,out_w deconv = lambda i, k: tf.nn.conv2d_transpose(i, k, output_shape=out_shape, strides=stride_shape, padding=padding) with tf.variable_scope(name) as scope: kernel = self.make_var('weights', shape=[k_h, k_w, c_o/group, c_i]) if group == 1: # This is the common-case. Convolve the input without any further complications. output = deconv(input, kernel) else: # Split the input into groups and then convolve each of them independently #print "input: ",input.get_shape() #print "kernel: ",kernel.get_shape() input_groups = tf.split( input, num_or_size_splits=group, axis=3 ) kernel_groups = tf.split( kernel, num_or_size_splits=group, axis=3 ) #print "grouped input: ",input_groups[0].get_shape() #print "grouped kernel: ",kernel_groups[0].get_shape() output_groups = [deconv(i, k) for i, k in zip(input_groups, kernel_groups)] # Concatenate the groups output = tf.concat(output_groups,axis=3) # Add the biases if biased: biases = self.make_var('biases', [c_o]) output = tf.nn.bias_add(output, biases) if relu: # ReLU non-linearity output = tf.nn.relu(output, name=scope.name) return output
def conv2d_transpose(x, W, strides=(1, 1), padding="SAME"): strides = list(strides) # Compute output size, cf. tf.layers.Conv2DTranspose W_shape = tf.shape(W) inputs_shape = tf.shape(x) # Infer the dynamic output shape: out_height = deconv_output_length(inputs_shape[1], W_shape[0], padding, strides[0]) out_width = deconv_output_length(inputs_shape[2], W_shape[1], padding, strides[1]) output_shape = (inputs_shape[0], out_height, out_width, W_shape[2]) return tf.nn.conv2d_transpose(x, W, tf.stack(output_shape), strides=[1] + strides + [1], padding=padding)
def _compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() output_shape = list(input_shape) if self.data_format == 'channels_first': c_axis, h_axis, w_axis = 1, 2, 3 else: c_axis, h_axis, w_axis = 3, 1, 2 kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides output_shape[c_axis] = self.filters output_shape[h_axis] = utils.deconv_output_length( output_shape[h_axis], kernel_h, self.padding, stride_h) output_shape[w_axis] = utils.deconv_output_length( output_shape[w_axis], kernel_w, self.padding, stride_w) return tensor_shape.TensorShape(output_shape)
def testDeconvOutputLength(self): self.assertEqual(4, utils.deconv_output_length(4, 2, 'same', 1)) self.assertEqual(8, utils.deconv_output_length(4, 2, 'same', 2)) self.assertEqual(5, utils.deconv_output_length(4, 2, 'valid', 1)) self.assertEqual(8, utils.deconv_output_length(4, 2, 'valid', 2)) self.assertEqual(3, utils.deconv_output_length(4, 2, 'full', 1)) self.assertEqual(6, utils.deconv_output_length(4, 2, 'full', 2))
def call(self, inputs, **kwargs): out_width = utils.deconv_output_length( tf.shape(inputs)[1], self.kernel_size, "valid", self.stride) output_shape = (tf.shape(inputs)[0], out_width, self.out_channels) conv1d_output = conv_transpose_1d(inputs, self.kernel, output_shape, self.stride, padding="VALID") ha = self.activation(conv1d_output + self.bias) if self.activation is not None else ( conv1d_output + self.bias) return ha
def style_swap(content, style, patch_size, stride): '''Efficiently swap content feature patches with nearest-neighbor style patches Original paper: https://arxiv.org/abs/1612.04337 Adapted from: https://github.com/rtqichen/style-swap/blob/master/lib/NonparametricPatchAutoencoderFactory.lua ''' nC = tf.shape(style)[ -1] # Num channels of input content feature and style-swapped output ### Extract patches from style image that will be used for conv/deconv layers style_patches = tf.extract_image_patches(style, [1, patch_size, patch_size, 1], [1, stride, stride, 1], [1, 1, 1, 1], 'VALID') before_reshape = tf.shape( style_patches) # NxRowsxColsxPatch_size*Patch_size*nC style_patches = tf.reshape( style_patches, [before_reshape[1] * before_reshape[2], patch_size, patch_size, nC]) style_patches = tf.transpose( style_patches, [1, 2, 3, 0]) # Patch_sizexPatch_sizexIn_CxOut_c # Normalize each style patch style_patches_norm = tf.nn.l2_normalize(style_patches, dim=3) # Compute cross-correlation/nearest neighbors of patches by using style patches as conv filters ss_enc = tf.nn.conv2d(content, style_patches_norm, [1, stride, stride, 1], 'VALID') # For each spatial position find index of max along channel/patch dim ss_argmax = tf.argmax(ss_enc, axis=3) encC = tf.shape(ss_enc)[ -1] # Num channels in intermediate conv output, same as # of patches # One-hot encode argmax with same size as ss_enc, with 1's in max channel idx for each spatial pos ss_oh = tf.one_hot(ss_argmax, encC, 1., 0., 3) # Calc size of transposed conv out deconv_out_H = utils.deconv_output_length( tf.shape(ss_oh)[1], patch_size, 'valid', stride) deconv_out_W = utils.deconv_output_length( tf.shape(ss_oh)[2], patch_size, 'valid', stride) deconv_out_shape = tf.stack([1, deconv_out_H, deconv_out_W, nC]) # Deconv back to original content size with highest matching (unnormalized) style patch swapped in for each content patch ss_dec = tf.nn.conv2d_transpose(ss_oh, style_patches, deconv_out_shape, [1, stride, stride, 1], 'VALID') ### Interpolate to average overlapping patch locations ss_oh_sum = tf.reduce_sum(ss_oh, axis=3, keep_dims=True) filter_ones = tf.ones([patch_size, patch_size, 1, 1], dtype=tf.float32) deconv_out_shape = tf.stack( [1, deconv_out_H, deconv_out_W, 1]) # Same spatial size as ss_dec with 1 channel counting = tf.nn.conv2d_transpose(ss_oh_sum, filter_ones, deconv_out_shape, [1, stride, stride, 1], 'VALID') counting = tf.tile( counting, [1, 1, 1, nC]) # Repeat along channel dim to make same size as ss_dec interpolated_dec = tf.divide(ss_dec, counting) return interpolated_dec
def style_swap(content, style, patch_size, stride): nC = tf.shape(style)[-1] # 特征图像通道 #print(patch_size) #print(content.shape) ### 从style feature中提取一块图像 style_patches = tf.extract_image_patches(style, [1, patch_size, patch_size, 1], [1, stride, stride, 1], [1, 1, 1, 1], 'VALID') #print(style_patches.shape) before_reshape = tf.shape( style_patches) # NxRowsxColsxPatch_size*Patch_size*nC style_patches = tf.reshape( style_patches, [before_reshape[1] * before_reshape[2], patch_size, patch_size, nC]) #print(style_patches.shape) style_patches = tf.transpose( style_patches, [1, 2, 3, 0]) # Patch_sizexPatch_sizexIn_CxOut_c #print(style_patches.shape) # l2泛化 style_patches_norm = tf.nn.l2_normalize(style_patches, dim=3) #每一个style_path与原图像相乘,得到一个相乘的结果为一个通道 ss_enc = tf.nn.conv2d(content, style_patches_norm, [1, stride, stride, 1], 'VALID') # 在每个通道内找到最大的值,即为style提取patch和content最接近的区域 #print(ss_enc.shape) ss_argmax = tf.argmax(ss_enc, axis=3) #print(ss_argmax.shape) encC = tf.shape(ss_enc)[-1] # 每一个patch, 标记最大那个区域为1,其他区域为0 ss_oh = tf.one_hot(ss_argmax, encC, 1., 0., 3) #print(ss_oh.shape) # 输出图像的大小 deconv_out_H = utils.deconv_output_length( tf.shape(ss_oh)[1], patch_size, 'valid', stride) deconv_out_W = utils.deconv_output_length( tf.shape(ss_oh)[2], patch_size, 'valid', stride) deconv_out_shape = tf.stack([1, deconv_out_H, deconv_out_W, nC]) # 反卷积,还原出来的大小为原图大小,但是只有最相近的patch的信息 ss_dec = tf.nn.conv2d_transpose(ss_oh, style_patches, deconv_out_shape, [1, stride, stride, 1], 'VALID') ### 重叠部分求平均指 ss_oh_sum = tf.reduce_sum(ss_oh, axis=3, keep_dims=True) filter_ones = tf.ones([patch_size, patch_size, 1, 1], dtype=tf.float32) #print(filter_ones.shape) deconv_out_shape = tf.stack([1, deconv_out_H, deconv_out_W, 1]) counting = tf.nn.conv2d_transpose(ss_oh_sum, filter_ones, deconv_out_shape, [1, stride, stride, 1], 'VALID') counting = tf.tile(counting, [1, 1, 1, nC]) interpolated_dec = tf.divide(ss_dec, counting) return interpolated_dec