def build(self, args): self.preBuild() self.ys = kld.plchf([None, None, 3], 'style') self.yc = kld.plchf([None, None, None, 3], 'content') with tf.variable_scope('', reuse=tf.AUTO_REUSE): args.net.type = 'calc' args.net.build(self.yc) args.net.type = 'eval' self.yh = args.net.build(self.yc) self.ysi = tf.expand_dims(self.ys, 0) style_layers = args.vgg_net.feed_forward(self.ysi, 'style') content_layers = args.vgg_net.feed_forward(self.yc, 'content') self.Fs = args.vgg_net.feed_forward(self.yh, 'mixed') self.Ss = {} for id in self.style_layers: self.Ss[id] = style_layers[id] self.Cs = {} for id in self.content_layers: self.Cs[id] = content_layers[id] L_style, L_content = 0, 0 for id in self.Fs: if id in self.style_layers: F = kld.gram_matrix(self.Fs[id]) S = kld.gram_matrix(self.Ss[id]) b, d1, d2 = kld.get_shape(F) bd1d2 = kld.toFloat(b * d1 * d2) wgt = self.style_layers[id] L_style += wgt * 2 * tf.nn.l2_loss(F - S) / bd1d2 if id in self.content_layers: F = self.Fs[id] C = self.Cs[id] b, h, w, d = kld.get_shape(F) bhwd = kld.toFloat(b * h * w * d) wgt = self.content_layers[id] L_content += wgt * 2 * tf.nn.l2_loss(F - C) / bhwd L_totvar = kld.total_variation_loss(self.yh) self.L_style = args.wgt_style * L_style self.L_content = args.wgt_content * L_content self.L_totvar = args.wgt_totvar * L_totvar self.L_full = self.L_style + self.L_content + self.L_totvar
def conv_tranpose_layer(net, num_filters, filter_size, strides, padding='SAME', relu=True, name=None): weights_init = conv_init_vars(net, num_filters, filter_size, name=name, transpose=True) strides_shape = [1, strides, strides, 1] batch, rows, cols, channels = kld.get_shape(net) new_rows, new_cols = rows * strides, cols * strides new_shape = tf.stack([batch, new_rows, new_cols, num_filters]) net = tf.nn.conv2d_transpose(net, weights_init, new_shape, strides_shape, padding=padding) net = instance_norm(net, name=name) if relu: net = tf.nn.relu(net) return net
def gram_matrix(tensor): b, h, w, c = kld.get_shape(tensor) chw = kld.toFloat(c * h * w) feats = tf.reshape(tensor, (b, h * w, c)) feats_T = tf.transpose(feats, perm=[0, 2, 1]) gram = tf.matmul(feats_T, feats) / chw return gram
def residual_block(net, filter_size=3, name=None): batch, rows, cols, channels = kld.get_shape(net) tmp = conv_layer(net, 128, filter_size, 1, padding='VALID', relu=True, name=name + '_1') return conv_layer( tmp , 128 , filter_size , 1 , padding = 'VALID' , relu = False , name = name + '_2' ) \ + tf.slice( net , [ 0 , 2 , 2 , 0 ] , [ batch , rows - 4 , cols - 4 , channels ] )
def total_variation_loss(img): b, h, w, d = kld.get_shape(img) x_tv_size = kld.toFloat(h * (w - 1) * d) y_tv_size = kld.toFloat((h - 1) * w * d) b = kld.toFloat(b) x_tv = tf.nn.l2_loss(img[:, :, 1:, :] - img[:, :, :w - 1, :]) y_tv = tf.nn.l2_loss(img[:, 1:, :, :] - img[:, :h - 1, :, :]) loss = 2.0 * (x_tv / x_tv_size + y_tv / y_tv_size) / b return loss