Exemplo n.º 1
0
    def call(self, inputs, training=True):
        inputs_shape = array_ops.shape(inputs)
        batch_size = inputs_shape[0]
        if self.data_format == 'channels_first':
            c_axis, h_axis, w_axis = 1, 2, 3
        else:
            c_axis, h_axis, w_axis = 3, 1, 2

        height, width = inputs_shape[h_axis], inputs_shape[w_axis]
        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides

        # Infer the dynamic output shape:
        out_height = utils.deconv_output_length(height,
                                                kernel_h,
                                                self.padding,
                                                stride_h)
        out_width = utils.deconv_output_length(width,
                                               kernel_w,
                                               self.padding,
                                               stride_w)
        if self.data_format == 'channels_first':
            output_shape = (batch_size, self.filters, out_height, out_width)
            strides = (1, 1, stride_h, stride_w)
        else:
            output_shape = (batch_size, out_height, out_width, self.filters)
            strides = (1, stride_h, stride_w, 1)

        output_shape_tensor = array_ops.stack(output_shape)
        outputs = nn.conv2d_transpose(
            inputs,
            self.compute_spectral_normal(training=training),
            output_shape_tensor,
            strides,
            padding=self.padding.upper(),
            data_format=utils.convert_data_format(self.data_format, ndim=4))

        if not context.executing_eagerly():
            # Infer the static output shape:
            out_shape = inputs.get_shape().as_list()
            out_shape[c_axis] = self.filters
            out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis],
                                                           kernel_h,
                                                           self.padding,
                                                           stride_h)
            out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis],
                                                           kernel_w,
                                                           self.padding,
                                                           stride_w)
            outputs.set_shape(out_shape)

        if self.use_bias:
            outputs = nn.bias_add(
                outputs,
                self.bias,
                data_format=utils.convert_data_format(self.data_format, ndim=4))

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
Exemplo n.º 2
0
def style_swap(content, style, patch_size, stride):
    '''Efficiently swap content feature patches with nearest-neighbor style patches
       Original paper: https://arxiv.org/abs/1612.04337
       Adapted from: https://github.com/rtqichen/style-swap/blob/master/lib/NonparametricPatchAutoencoderFactory.lua
    '''
    nC = tf.shape(style)[-1]  # Num channels of input content feature and style-swapped output

    ### Extract patches from style image that will be used for conv/deconv layers
    style_patches = tf.extract_image_patches(style, [1,patch_size,patch_size,1], [1,stride,stride,1], [1,1,1,1], 'VALID')
    before_reshape = tf.shape(style_patches)  # NxRowsxColsxPatch_size*Patch_size*nC
    style_patches = tf.reshape(style_patches, [before_reshape[1]*before_reshape[2],patch_size,patch_size,nC])
    style_patches = tf.transpose(style_patches, [1,2,3,0])  # Patch_sizexPatch_sizexIn_CxOut_c

    # Normalize each style patch
    style_patches_norm = tf.nn.l2_normalize(style_patches, dim=3)

    # Compute cross-correlation/nearest neighbors of patches by using style patches as conv filters
    ss_enc = tf.nn.conv2d(content,
                          style_patches_norm,
                          [1,stride,stride,1],
                          'VALID')

    # For each spatial position find index of max along channel/patch dim  
    ss_argmax = tf.argmax(ss_enc, axis=3)
    encC = tf.shape(ss_enc)[-1]  # Num channels in intermediate conv output, same as # of patches
    
    # One-hot encode argmax with same size as ss_enc, with 1's in max channel idx for each spatial pos
    ss_oh = tf.one_hot(ss_argmax, encC, 1., 0., 3)

    # Calc size of transposed conv out
    deconv_out_H = utils.deconv_output_length(tf.shape(ss_oh)[1], patch_size, 'valid', stride)
    deconv_out_W = utils.deconv_output_length(tf.shape(ss_oh)[2], patch_size, 'valid', stride)
    deconv_out_shape = tf.stack([1,deconv_out_H,deconv_out_W,nC])

    # Deconv back to original content size with highest matching (unnormalized) style patch swapped in for each content patch
    ss_dec = tf.nn.conv2d_transpose(ss_oh,
                                    style_patches,
                                    deconv_out_shape,
                                    [1,stride,stride,1],
                                    'VALID')

    ### Interpolate to average overlapping patch locations
    ss_oh_sum = tf.reduce_sum(ss_oh, axis=3, keep_dims=True)

    filter_ones = tf.ones([patch_size,patch_size,1,1], dtype=tf.float32)
    
    deconv_out_shape = tf.stack([1,deconv_out_H,deconv_out_W,1])  # Same spatial size as ss_dec with 1 channel

    counting = tf.nn.conv2d_transpose(ss_oh_sum,
                                         filter_ones,
                                         deconv_out_shape,
                                         [1,stride,stride,1],
                                         'VALID')

    counting = tf.tile(counting, [1,1,1,nC])  # Repeat along channel dim to make same size as ss_dec

    interpolated_dec = tf.divide(ss_dec, counting)

    return interpolated_dec
Exemplo n.º 3
0
def _conv_sn(conv,
             inputs,
             filters,
             kernel_size,
             name,
             strides=1,
             padding='valid',
             activation=None,
             use_bias=True,
             kernel_initializer=tf.glorot_uniform_initializer(),
             bias_initializer=tf.zeros_initializer(),
             use_gamma=False,
             factor=None,
             transposed=False):
    input_shape = inputs.get_shape().as_list()
    c_axis, h_axis, w_axis = 3, 1, 2  # channels last
    input_dim = input_shape[c_axis]
    with tf.variable_scope(name):
        if transposed is True:
            kernel_shape = kernel_size + (filters, input_dim)
            kernel = tf.get_variable('kernel',
                                     shape=kernel_shape,
                                     initializer=kernel_initializer)
            height, width = input_shape[h_axis], input_shape[w_axis]
            kernel_h, kernel_w = kernel_size
            stride_h, stride_w = strides
            out_height = deconv_output_length(height, kernel_h, padding,
                                              stride_h)
            out_width = deconv_output_length(width, kernel_w, padding,
                                             stride_w)
            output_shape = (input_shape[0], out_height, out_width, filters)
            outputs = conv(inputs,
                           spectral_norm(kernel,
                                         use_gamma=use_gamma,
                                         factor=factor),
                           tf.stack(output_shape),
                           strides=(1, *strides, 1),
                           padding=padding.upper())
        else:
            kernel_shape = kernel_size + (input_dim, filters)
            kernel = tf.get_variable('kernel',
                                     shape=kernel_shape,
                                     initializer=kernel_initializer)
            outputs = conv(inputs,
                           spectral_norm(kernel,
                                         use_gamma=use_gamma,
                                         factor=factor),
                           strides=(1, *strides, 1),
                           padding=padding.upper())
        if use_bias is True:
            bias = tf.get_variable('bias',
                                   shape=(filters, ),
                                   initializer=bias_initializer)
            outputs = tf.nn.bias_add(outputs, bias)
        if activation is not None:
            outputs = activation(outputs)

    return outputs
Exemplo n.º 4
0
    def call(self, inputs):
        inputs_shape = array_ops.shape(inputs)
        batch_size = inputs_shape[0]
        if self.data_format == 'channels_first':
            c_axis, h_axis, w_axis = 1, 2, 3
        else:
            c_axis, h_axis, w_axis = 3, 1, 2

        height, width = inputs_shape[h_axis], inputs_shape[w_axis]
        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides

        # Infer the dynamic output shape:
        out_height = utils.deconv_output_length(height, kernel_h, self.padding,
                                                stride_h)
        out_width = utils.deconv_output_length(width, kernel_w, self.padding,
                                               stride_w)
        if self.data_format == 'channels_first':
            output_shape = (batch_size, self.filters, out_height, out_width)
            strides = (1, 1, stride_h, stride_w)
        else:
            output_shape = (batch_size, out_height, out_width, self.filters)
            strides = (1, stride_h, stride_w, 1)

        output_shape_tensor = array_ops.stack(output_shape)

        kernel_norm = nn.l2_normalize(self.kernel, [0, 1, 3])
        if self.use_scale:
            kernel_norm = tf.reshape(self.scale,
                                     [1, 1, self.filters, 1]) * kernel_norm

        outputs = nn.conv2d_transpose(inputs,
                                      kernel_norm,
                                      output_shape_tensor,
                                      strides,
                                      padding=self.padding.upper(),
                                      data_format=utils.convert_data_format(
                                          self.data_format, ndim=4))

        if context.in_graph_mode():
            # Infer the static output shape:
            out_shape = inputs.get_shape().as_list()
            out_shape[c_axis] = self.filters
            out_shape[h_axis] = utils.deconv_output_length(
                out_shape[h_axis], kernel_h, self.padding, stride_h)
            out_shape[w_axis] = utils.deconv_output_length(
                out_shape[w_axis], kernel_w, self.padding, stride_w)
            outputs.set_shape(out_shape)

        if self.use_bias:
            outputs = nn.bias_add(outputs,
                                  self.bias,
                                  data_format=utils.convert_data_format(
                                      self.data_format, ndim=4))

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
Exemplo n.º 5
0
def deconv2d_real(name,
                  input,
                  oc,
                  f_h=3,
                  f_w=3,
                  s_h=1,
                  s_w=1,
                  bn=True,
                  is_training=True,
                  print_shape=False,
                  act='lrelu'):
    with tf.variable_scope(name) as scope:
        padding = "SAME"
        input_shape = input.get_shape().as_list()
        out_h = utils.deconv_output_length(input_shape[1], f_h, padding, s_h)
        out_w = utils.deconv_output_length(input_shape[2], f_w, padding, s_w)
        output_shape = (input_shape[0], out_h, out_w, oc)
        strides = [1, s_h, s_w, 1]

        W = tf.get_variable("W",
                            shape=[f_h, f_w, oc, input_shape[-1]],
                            initializer=tf.contrib.layers.xavier_initializer())

        deconv = tf.nn.conv2d_transpose(input,
                                        filter=W,
                                        output_shape=output_shape,
                                        strides=strides,
                                        padding=padding)

        if bn:
            bn = tf.contrib.layers.batch_norm(deconv,
                                              center=True,
                                              scale=True,
                                              renorm=True,
                                              scope='bn')
        else:
            bn = deconv

        if act == 'relu':
            # Just use the regular relu
            act = tf.nn.relu(bn)
        elif act == 'lrelu':
            # This is leaky relu
            act = tf.nn.leaky_relu(bn)
        elif act == 'abs':
            act = tf.abs(bn)
        else:
            act = bn

        if print_shape:
            print("{} shape : {}".format(act.name, act.get_shape()))

        return act
Exemplo n.º 6
0
 def deconv(self,
          input,
          k_h,
          k_w,
          c_o,
          s_h,
          s_w,
          name,
          relu=True,
          padding=DEFAULT_PADDING,
          group=1,
          biased=True):
     # Verify that the padding is acceptable
     self.validate_padding(padding)
     # Get the number of channels in the input
     c_i = input.get_shape()[-1]
     # Verify that the grouping parameter is valid
     assert c_i % group == 0
     assert c_o % group == 0
     # Convolution for a given input and kernel
     stride_shape = [1, s_h, s_w, 1]
     input_shape = input.get_shape()
     out_h = utils.deconv_output_length(input_shape[1].value, k_h, 1, s_h)
     out_w = utils.deconv_output_length(input_shape[2].value, k_w, 1, s_w)
     out_shape = tf.stack( [ tf.shape(input)[0], out_h, out_w, c_o/group  ] )
     #print "deconv: c_i=%d -> c_o=%d"%(c_i,c_o)," padding=",padding
     #print "out(h,w): ",out_h,out_w
     deconv = lambda i, k: tf.nn.conv2d_transpose(i, k, output_shape=out_shape, strides=stride_shape, padding=padding)
     with tf.variable_scope(name) as scope:
         kernel = self.make_var('weights', shape=[k_h, k_w, c_o/group, c_i])
         if group == 1:
             # This is the common-case. Convolve the input without any further complications.
             output = deconv(input, kernel)
         else:
             # Split the input into groups and then convolve each of them independently
             #print "input: ",input.get_shape()
             #print "kernel: ",kernel.get_shape()
             input_groups  = tf.split( input,  num_or_size_splits=group, axis=3 )
             kernel_groups = tf.split( kernel, num_or_size_splits=group, axis=3 )
             #print "grouped input: ",input_groups[0].get_shape()
             #print "grouped kernel: ",kernel_groups[0].get_shape()
             output_groups = [deconv(i, k) for i, k in zip(input_groups, kernel_groups)]
             # Concatenate the groups
             output = tf.concat(output_groups,axis=3)
         # Add the biases
         if biased:
             biases = self.make_var('biases', [c_o])
             output = tf.nn.bias_add(output, biases)
         if relu:
             # ReLU non-linearity
             output = tf.nn.relu(output, name=scope.name)
         return output
Exemplo n.º 7
0
def conv2d_transpose(x, W, strides=(1, 1), padding="SAME"):
  strides = list(strides)

  # Compute output size, cf. tf.layers.Conv2DTranspose
  W_shape = tf.shape(W)
  inputs_shape = tf.shape(x)

  # Infer the dynamic output shape:
  out_height = deconv_output_length(inputs_shape[1], W_shape[0], padding, strides[0])
  out_width = deconv_output_length(inputs_shape[2], W_shape[1], padding, strides[1])
  output_shape = (inputs_shape[0], out_height, out_width, W_shape[2])

  return tf.nn.conv2d_transpose(x, W, tf.stack(output_shape), strides=[1] + strides + [1], padding=padding)
Exemplo n.º 8
0
    def _compute_output_shape(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape).as_list()
        output_shape = list(input_shape)
        if self.data_format == 'channels_first':
            c_axis, h_axis, w_axis = 1, 2, 3
        else:
            c_axis, h_axis, w_axis = 3, 1, 2

        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides

        output_shape[c_axis] = self.filters
        output_shape[h_axis] = utils.deconv_output_length(
            output_shape[h_axis], kernel_h, self.padding, stride_h)
        output_shape[w_axis] = utils.deconv_output_length(
            output_shape[w_axis], kernel_w, self.padding, stride_w)
        return tensor_shape.TensorShape(output_shape)
Exemplo n.º 9
0
 def testDeconvOutputLength(self):
   self.assertEqual(4, utils.deconv_output_length(4, 2, 'same', 1))
   self.assertEqual(8, utils.deconv_output_length(4, 2, 'same', 2))
   self.assertEqual(5, utils.deconv_output_length(4, 2, 'valid', 1))
   self.assertEqual(8, utils.deconv_output_length(4, 2, 'valid', 2))
   self.assertEqual(3, utils.deconv_output_length(4, 2, 'full', 1))
   self.assertEqual(6, utils.deconv_output_length(4, 2, 'full', 2))
Exemplo n.º 10
0
 def testDeconvOutputLength(self):
     self.assertEqual(4, utils.deconv_output_length(4, 2, 'same', 1))
     self.assertEqual(8, utils.deconv_output_length(4, 2, 'same', 2))
     self.assertEqual(5, utils.deconv_output_length(4, 2, 'valid', 1))
     self.assertEqual(8, utils.deconv_output_length(4, 2, 'valid', 2))
     self.assertEqual(3, utils.deconv_output_length(4, 2, 'full', 1))
     self.assertEqual(6, utils.deconv_output_length(4, 2, 'full', 2))
Exemplo n.º 11
0
 def call(self, inputs, **kwargs):
     out_width = utils.deconv_output_length(
         tf.shape(inputs)[1], self.kernel_size, "valid", self.stride)
     output_shape = (tf.shape(inputs)[0], out_width, self.out_channels)
     conv1d_output = conv_transpose_1d(inputs,
                                       self.kernel,
                                       output_shape,
                                       self.stride,
                                       padding="VALID")
     ha = self.activation(conv1d_output +
                          self.bias) if self.activation is not None else (
                              conv1d_output + self.bias)
     return ha
Exemplo n.º 12
0
def style_swap(content, style, patch_size, stride):
    '''Efficiently swap content feature patches with nearest-neighbor style patches
       Original paper: https://arxiv.org/abs/1612.04337
       Adapted from: https://github.com/rtqichen/style-swap/blob/master/lib/NonparametricPatchAutoencoderFactory.lua
    '''
    nC = tf.shape(style)[
        -1]  # Num channels of input content feature and style-swapped output

    ### Extract patches from style image that will be used for conv/deconv layers
    style_patches = tf.extract_image_patches(style,
                                             [1, patch_size, patch_size, 1],
                                             [1, stride, stride, 1],
                                             [1, 1, 1, 1], 'VALID')
    before_reshape = tf.shape(
        style_patches)  # NxRowsxColsxPatch_size*Patch_size*nC
    style_patches = tf.reshape(
        style_patches,
        [before_reshape[1] * before_reshape[2], patch_size, patch_size, nC])
    style_patches = tf.transpose(
        style_patches, [1, 2, 3, 0])  # Patch_sizexPatch_sizexIn_CxOut_c

    # Normalize each style patch
    style_patches_norm = tf.nn.l2_normalize(style_patches, dim=3)

    # Compute cross-correlation/nearest neighbors of patches by using style patches as conv filters
    ss_enc = tf.nn.conv2d(content, style_patches_norm, [1, stride, stride, 1],
                          'VALID')

    # For each spatial position find index of max along channel/patch dim
    ss_argmax = tf.argmax(ss_enc, axis=3)
    encC = tf.shape(ss_enc)[
        -1]  # Num channels in intermediate conv output, same as # of patches

    # One-hot encode argmax with same size as ss_enc, with 1's in max channel idx for each spatial pos
    ss_oh = tf.one_hot(ss_argmax, encC, 1., 0., 3)

    # Calc size of transposed conv out
    deconv_out_H = utils.deconv_output_length(
        tf.shape(ss_oh)[1], patch_size, 'valid', stride)
    deconv_out_W = utils.deconv_output_length(
        tf.shape(ss_oh)[2], patch_size, 'valid', stride)
    deconv_out_shape = tf.stack([1, deconv_out_H, deconv_out_W, nC])

    # Deconv back to original content size with highest matching (unnormalized) style patch swapped in for each content patch
    ss_dec = tf.nn.conv2d_transpose(ss_oh, style_patches, deconv_out_shape,
                                    [1, stride, stride, 1], 'VALID')

    ### Interpolate to average overlapping patch locations
    ss_oh_sum = tf.reduce_sum(ss_oh, axis=3, keep_dims=True)

    filter_ones = tf.ones([patch_size, patch_size, 1, 1], dtype=tf.float32)

    deconv_out_shape = tf.stack(
        [1, deconv_out_H, deconv_out_W,
         1])  # Same spatial size as ss_dec with 1 channel

    counting = tf.nn.conv2d_transpose(ss_oh_sum, filter_ones, deconv_out_shape,
                                      [1, stride, stride, 1], 'VALID')

    counting = tf.tile(
        counting,
        [1, 1, 1, nC])  # Repeat along channel dim to make same size as ss_dec

    interpolated_dec = tf.divide(ss_dec, counting)

    return interpolated_dec
Exemplo n.º 13
0
def style_swap(content, style, patch_size, stride):

    nC = tf.shape(style)[-1]  # 特征图像通道
    #print(patch_size)
    #print(content.shape)

    ### 从style feature中提取一块图像
    style_patches = tf.extract_image_patches(style,
                                             [1, patch_size, patch_size, 1],
                                             [1, stride, stride, 1],
                                             [1, 1, 1, 1], 'VALID')
    #print(style_patches.shape)
    before_reshape = tf.shape(
        style_patches)  # NxRowsxColsxPatch_size*Patch_size*nC
    style_patches = tf.reshape(
        style_patches,
        [before_reshape[1] * before_reshape[2], patch_size, patch_size, nC])
    #print(style_patches.shape)
    style_patches = tf.transpose(
        style_patches, [1, 2, 3, 0])  # Patch_sizexPatch_sizexIn_CxOut_c
    #print(style_patches.shape)

    # l2泛化
    style_patches_norm = tf.nn.l2_normalize(style_patches, dim=3)

    #每一个style_path与原图像相乘,得到一个相乘的结果为一个通道
    ss_enc = tf.nn.conv2d(content, style_patches_norm, [1, stride, stride, 1],
                          'VALID')

    # 在每个通道内找到最大的值,即为style提取patch和content最接近的区域
    #print(ss_enc.shape)
    ss_argmax = tf.argmax(ss_enc, axis=3)
    #print(ss_argmax.shape)
    encC = tf.shape(ss_enc)[-1]

    # 每一个patch, 标记最大那个区域为1,其他区域为0
    ss_oh = tf.one_hot(ss_argmax, encC, 1., 0., 3)
    #print(ss_oh.shape)

    # 输出图像的大小
    deconv_out_H = utils.deconv_output_length(
        tf.shape(ss_oh)[1], patch_size, 'valid', stride)
    deconv_out_W = utils.deconv_output_length(
        tf.shape(ss_oh)[2], patch_size, 'valid', stride)
    deconv_out_shape = tf.stack([1, deconv_out_H, deconv_out_W, nC])

    # 反卷积,还原出来的大小为原图大小,但是只有最相近的patch的信息
    ss_dec = tf.nn.conv2d_transpose(ss_oh, style_patches, deconv_out_shape,
                                    [1, stride, stride, 1], 'VALID')

    ### 重叠部分求平均指
    ss_oh_sum = tf.reduce_sum(ss_oh, axis=3, keep_dims=True)

    filter_ones = tf.ones([patch_size, patch_size, 1, 1], dtype=tf.float32)
    #print(filter_ones.shape)

    deconv_out_shape = tf.stack([1, deconv_out_H, deconv_out_W, 1])

    counting = tf.nn.conv2d_transpose(ss_oh_sum, filter_ones, deconv_out_shape,
                                      [1, stride, stride, 1], 'VALID')

    counting = tf.tile(counting, [1, 1, 1, nC])

    interpolated_dec = tf.divide(ss_dec, counting)

    return interpolated_dec