def build(self, args): self.preBuild() self.ys = kld.plchf([None, None, 3], 'style') self.yc = kld.plchf([None, None, None, 3], 'content') with tf.variable_scope('', reuse=tf.AUTO_REUSE): args.net.type = 'calc' args.net.build(self.yc) args.net.type = 'eval' self.yh = args.net.build(self.yc) self.ysi = tf.expand_dims(self.ys, 0) style_layers = args.vgg_net.feed_forward(self.ysi, 'style') content_layers = args.vgg_net.feed_forward(self.yc, 'content') self.Fs = args.vgg_net.feed_forward(self.yh, 'mixed') self.Ss = {} for id in self.style_layers: self.Ss[id] = style_layers[id] self.Cs = {} for id in self.content_layers: self.Cs[id] = content_layers[id] L_style, L_content = 0, 0 for id in self.Fs: if id in self.style_layers: F = kld.gram_matrix(self.Fs[id]) S = kld.gram_matrix(self.Ss[id]) b, d1, d2 = kld.get_shape(F) bd1d2 = kld.toFloat(b * d1 * d2) wgt = self.style_layers[id] L_style += wgt * 2 * tf.nn.l2_loss(F - S) / bd1d2 if id in self.content_layers: F = self.Fs[id] C = self.Cs[id] b, h, w, d = kld.get_shape(F) bhwd = kld.toFloat(b * h * w * d) wgt = self.content_layers[id] L_content += wgt * 2 * tf.nn.l2_loss(F - C) / bhwd L_totvar = kld.total_variation_loss(self.yh) self.L_style = args.wgt_style * L_style self.L_content = args.wgt_content * L_content self.L_totvar = args.wgt_totvar * L_totvar self.L_full = self.L_style + self.L_content + self.L_totvar
def total_variation_loss(img): b, h, w, d = kld.get_shape(img) x_tv_size = kld.toFloat(h * (w - 1) * d) y_tv_size = kld.toFloat((h - 1) * w * d) b = kld.toFloat(b) x_tv = tf.nn.l2_loss(img[:, :, 1:, :] - img[:, :, :w - 1, :]) y_tv = tf.nn.l2_loss(img[:, 1:, :, :] - img[:, :h - 1, :, :]) loss = 2.0 * (x_tv / x_tv_size + y_tv / y_tv_size) / b return loss
def gram_matrix(tensor): b, h, w, c = kld.get_shape(tensor) chw = kld.toFloat(c * h * w) feats = tf.reshape(tensor, (b, h * w, c)) feats_T = tf.transpose(feats, perm=[0, 2, 1]) gram = tf.matmul(feats_T, feats) / chw return gram