def style_transfer(self, inputs, styles, style_layers=None): """style transfer via recursive feature transforms Args: inputs: input images [batch_size, height, width, channel] styles: input styles [1 or batch_size, height, width, channel] style_layers: a list of enforced style layer ids, default is None that applies all style layers as self.style_loss_layers Returns: outputs: the stylized images [batch_size, height, width, channel] """ # get the style features for the styles styles_image_features = losses.extract_image_features( styles, self.network_name) styles_features = losses.compute_content_features( styles_image_features, self.style_loss_layers) # construct the recurrent modules selected_style_layers = self.style_loss_layers if style_layers: selected_style_layers = [ selected_style_layers[i] for i in style_layers ] else: style_layers = range(len(selected_style_layers)) outputs = tf.identity(inputs) num_modules = len(selected_style_layers) for i in range(num_modules, 0, -1): starting_layer = selected_style_layers[i - 1] # encoding the inputs contents_image_features = losses.extract_image_features( outputs, self.network_name) content_features = losses.compute_content_features( contents_image_features, [starting_layer]) # feature transformation transformed_features = feature_transform( content_features[starting_layer], styles_features[starting_layer]) if self.blending_weight: transformed_features = self.blending_weight * transformed_features + \ (1-self.blending_weight) * content_features[starting_layer] # decoding the contents with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope()): outputs = vgg_decoder.vgg_decoder(transformed_features, self.network_name, starting_layer, reuse=True, scope='decoder_%d' % style_layers[i - 1]) outputs = preprocessing.batch_mean_image_subtraction(outputs) print('Finish the module of [%s]' % starting_layer) # recover the outputs return preprocessing.batch_mean_image_summation(outputs)
def style_transfer(self, inputs, styles, style_layers=(2, )): """style transfer via patch swapping Args: inputs: input images [batch_size, height, width, channel] styles: input styles [1, height, width, channel] style_layers: the list of layers to perform style swapping, default is None that applied all style layers as self.style_loss_layers Returns: outputs: the stylized images [batch_size, height, width, channel]s """ styles_image_features = losses.extract_image_features( styles, self.network_name) styles_features = losses.compute_content_features( styles_image_features, self.style_loss_layers) # construct the recurrent modules selected_style_layers = self.style_loss_layers if style_layers: selected_style_layers = [ selected_style_layers[i] for i in style_layers ] else: style_layers = range(len(selected_style_layers)) # input preprocessing outputs = tf.identity(inputs) # start style transfer num_modules = len(selected_style_layers) for i in range(num_modules, 0, -1): starting_layer = selected_style_layers[i - 1] # encoding the inputs contents_image_features = losses.extract_image_features( outputs, self.network_name) contents_features = losses.compute_content_features( contents_image_features, [starting_layer]) # feature transformation transformed_features = feature_transform( contents_features[starting_layer], styles_features[starting_layer], patch_size=self.patch_size) # decoding the contents with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope()): outputs = vgg_decoder.vgg_decoder(transformed_features, self.network_name, starting_layer, scope='decoder_%d' % style_layers[i - 1]) outputs = preprocessing.batch_mean_image_subtraction(outputs) print('Finish the module of [%s]' % starting_layer) # recover the outputs return preprocessing.batch_mean_image_summation(outputs)
def build_train_graph(self, inputs): summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) for i in range(len(self.content_layers)): # skip some networks if i < 3: continue selected_layer = self.content_layers[i] outputs, inputs_content_features = self.auto_encoder( inputs, content_layer=i, reuse=False) outputs = preprocessing.batch_mean_image_subtraction(outputs) ######################## # construct the losses # ######################## # 1) reconstruction loss recons_loss = tf.losses.mean_squared_error( inputs, outputs, scope='recons_loss/decoder_%d' % i) self.recons_loss[selected_layer] = recons_loss self.total_loss += self.recons_weight * recons_loss summaries.add( tf.summary.scalar('recons_loss/decoder_%d' % i, recons_loss)) # 2) content loss outputs_image_features = losses.extract_image_features( outputs, self.network_name) outputs_content_features = losses.compute_content_features( outputs_image_features, [selected_layer]) content_loss = losses.compute_content_loss( outputs_content_features, inputs_content_features, [selected_layer]) self.content_loss[selected_layer] = content_loss self.total_loss += self.content_weight * content_loss summaries.add( tf.summary.scalar('content_loss/decoder_%d' % i, content_loss)) # 3) total variation loss tv_loss = losses.compute_total_variation_loss_l1(outputs) self.tv_loss[selected_layer] = tv_loss self.total_loss += self.tv_weight * tv_loss summaries.add(tf.summary.scalar('tv_loss/decoder_%d' % i, tv_loss)) image_tiles = tf.concat([inputs, outputs], axis=2) image_tiles = preprocessing.batch_mean_image_summation(image_tiles) image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0), tf.uint8) summaries.add( tf.summary.image('image_comparison/decoder_%d' % i, image_tiles, max_outputs=8)) self.summaries = summaries return self.total_loss
def auto_encoder(self, inputs, content_layer=2, reuse=True): # extract the content features image_features = losses.extract_image_features(inputs, self.network_name) content_features = losses.compute_content_features(image_features, self.content_layers) # used content feature selected_layer = self.content_layers[content_layer] content_feature = content_features[selected_layer] input_content_features = {selected_layer: content_feature} # reconstruct the images with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope(self.weight_decay)): outputs = vgg_decoder.vgg_decoder( content_feature, self.network_name, selected_layer, reuse=reuse, scope='decoder_%d' % content_layer) return outputs, input_content_features
def build_model(self, inputs): for i in range(len(self.content_layers)): selected_layer = self.content_layers[i] outputs, inputs_content_features = self.auto_encoder( inputs, content_layer=i, reuse=False) outputs = preprocessing.batch_mean_image_subtraction(outputs) ######################## # construct the losses # ######################## # 1) reconstruction loss recons_loss = tf.losses.mean_squared_error( inputs, outputs, scope='recons_loss/decoder_%d' % i) self.recons_loss[selected_layer] = recons_loss self.total_loss += self.recons_weight * recons_loss self.summaries.add(tf.summary.scalar( 'recons_loss/decoder_%d' % i, recons_loss)) # 2) content loss outputs_image_features = losses.extract_image_features( outputs, self.network_name) outputs_content_features = losses.compute_content_features( outputs_image_features, [selected_layer]) content_loss = losses.compute_content_loss( outputs_content_features, inputs_content_features, [selected_layer]) self.content_loss[selected_layer] = content_loss self.total_loss += self.content_weight * content_loss self.summaries.add(tf.summary.scalar( 'content_loss/decoder_%d' % i, content_loss)) image_tiles = tf.concat([inputs, outputs], axis=2) image_tiles = preprocessing.batch_mean_image_summation(image_tiles) image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0), tf.uint8) self.summaries.add(tf.summary.image( 'image_comparison/decoder_%d' % i, image_tiles, max_outputs=8)) return self.total_loss
def hierarchical_autoencoder(self, inputs, reuse=True): """hierarchical autoencoder for content reconstruction""" # extract the content features image_features = losses.extract_image_features(inputs, self.network_name) content_features = losses.compute_content_features( image_features, self.style_loss_layers) # the applied content feature for the decode network selected_layer = self.style_loss_layers[-1] hidden_feature = content_features[selected_layer] # decode the hidden feature to the output image with slim.arg_scope( vgg_decoder.vgg_decoder_arg_scope(self.weight_decay)): outputs = vgg_decoder.vgg_combined_decoder( hidden_feature, content_features, fusion_fn=network_ops.adaptive_instance_normalization, network_name=self.network_name, starting_layer=selected_layer, reuse=reuse) return outputs
def transfer_styles(self, inputs, styles, inter_weight=1.0, intra_weights=(1, )): """transfer the content image by style images Args: inputs: input images [batch_size, height, width, channel] styles: a list of input styles, in default the size is 1 inter_weight: the blending weight between the content and style intra_weights: a list of blending weights among the styles, in default it is (1,) Returns: outputs: the stylized images [batch_size, height, width, channel] """ if not isinstance(styles, (list, tuple)): styles = [styles] if not isinstance(intra_weights, (list, tuple)): intra_weights = [intra_weights] # 1) extract the style features styles_features = [] for style in styles: style_image_features = losses.extract_image_features( style, self.network_name) style_features = losses.compute_content_features( style_image_features, self.style_loss_layers) styles_features.append(style_features) # 2) content features inputs_image_features = losses.extract_image_features( inputs, self.network_name) inputs_features = losses.compute_content_features( inputs_image_features, self.style_loss_layers) # 3) style decorator # the applied content feature from the content input selected_layer = self.style_loss_layers[-1] hidden_feature = inputs_features[selected_layer] # applying the style decorator blended_feature = 0.0 n = 0 for style_features in styles_features: swapped_feature = style_decorator(hidden_feature, style_features[selected_layer], style_coding=self.style_coding, style_interp=self.style_interp, ratio_interp=inter_weight, patch_size=self.patch_size) blended_feature += intra_weights[n] * swapped_feature n += 1 # 4) decode the hidden feature to the output image with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope()): outputs = vgg_decoder.vgg_multiple_combined_decoder( blended_feature, styles_features, intra_weights, fusion_fn=network_ops.adaptive_instance_normalization, network_name=self.network_name, starting_layer=selected_layer) return outputs
def build_train_graph(self, inputs): """build the training graph for the training of the hierarchical autoencoder""" outputs = self.hierarchical_autoencoder(inputs, reuse=False) outputs = preprocessing.batch_mean_image_subtraction(outputs) # summaries summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) ######################## # construct the losses # ######################## # 1) reconstruction loss if self.recons_weight > 0.0: recons_loss = tf.losses.mean_squared_error( inputs, outputs, weights=self.recons_weight, scope='recons_loss') self.recons_loss = recons_loss self.total_loss += recons_loss summaries.add(tf.summary.scalar('losses/recons_loss', recons_loss)) # 2) content loss if self.content_weight > 0.0: outputs_image_features = losses.extract_image_features( outputs, self.network_name) outputs_content_features = losses.compute_content_features( outputs_image_features, self.style_loss_layers) inputs_image_features = losses.extract_image_features( inputs, self.network_name) inputs_content_features = losses.compute_content_features( inputs_image_features, self.style_loss_layers) content_loss = losses.compute_content_loss( outputs_content_features, inputs_content_features, content_loss_layers=self.style_loss_layers, weights=self.content_weight) self.content_loss = content_loss self.total_loss += content_loss summaries.add( tf.summary.scalar('losses/content_loss', content_loss)) # 3) total variation loss if self.tv_weight > 0.0: tv_loss = losses.compute_total_variation_loss_l1( outputs, self.tv_weight) self.tv_loss = tv_loss self.total_loss += tv_loss summaries.add(tf.summary.scalar('losses/tv_loss', tv_loss)) image_tiles = tf.concat([inputs, outputs], axis=2) image_tiles = preprocessing.batch_mean_image_summation(image_tiles) image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0), tf.uint8) summaries.add( tf.summary.image('image_comparison', image_tiles, max_outputs=8)) self.summaries = summaries return self.total_loss
def build_model(self, inputs, styles, reuse=False): """build the graph for the MSG model Args: inputs: the inputs [batch_size, height, width, channel] styles: the styles [1, height, width, channel] reuse: whether to reuse the parameters Returns: total_loss: the total loss for the style transfer """ # extract the content features for the inputs inputs_image_features = losses.extract_image_features( inputs, self.network_name) inputs_content_features = losses.compute_content_features( inputs_image_features, self.content_loss_layers) # extract styles style features styles_image_features = losses.extract_image_features( styles, self.network_name) styles_style_features = losses.compute_style_features( styles_image_features, self.style_loss_layers) # transfer the styles from the inputs outputs = self.style_transfer(inputs, styles, reuse=reuse) # preprocessing the outputs to avoid biases and calculate the features outputs = preprocessing.batch_mean_image_subtraction(outputs) outputs_content_features, outputs_style_features = \ losses.compute_content_and_style_features( outputs, self.network_name, self.content_loss_layers, self.style_loss_layers) # gather the summary operations summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # calculate the losses # the content loss if self.content_weight > 0.0: self.content_loss = losses.compute_content_loss( inputs_content_features, outputs_content_features, self.content_loss_layers) self.total_loss += self.content_weight * self.content_loss summaries.add( tf.summary.scalar('losses/content_loss', self.content_loss)) # the style loss if self.style_weight > 0.0: self.style_loss = losses.compute_style_loss( styles_style_features, outputs_style_features, self.style_loss_layers) self.total_loss += self.style_weight * self.style_loss summaries.add( tf.summary.scalar('losses/style_loss', self.style_loss)) # the total variation loss if self.tv_weight > 0.0: self.tv_loss = losses.compute_total_variation_loss(outputs) self.total_loss += self.tv_weight * self.tv_loss summaries.add(tf.summary.scalar('losses/tv_loss', self.tv_loss)) summaries.add(tf.summary.scalar('total_loss', self.total_loss)) # gather the image tiles for style transfer image_tiles = tf.concat([inputs, outputs], axis=2) image_tiles = preprocessing.batch_mean_image_summation(image_tiles) image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0), tf.uint8) summaries.add( tf.summary.image('style_results', image_tiles, max_outputs=8)) # gather the styles summaries.add( tf.summary.image('styles', preprocessing.batch_mean_image_summation(styles), max_outputs=8)) # gather the summaries self.summaries = summaries return self.total_loss
def build_model(self, inputs, styles): # style transfer to the inputs outputs, inputs_content_features = self.style_transfer(inputs, styles) # calculate the style features for the outputs outputs = preprocessing.batch_mean_image_subtraction(outputs) # use approximated style loss instead # outputs_content_features, outputs_style_features = \ # losses.compute_content_and_style_features( # outputs, self.network_name, # self.content_loss_layers, self.style_loss_layers) outputs_image_features = losses.extract_image_features(outputs, self.network_name) outputs_content_features = losses.compute_content_features( outputs_image_features, self.content_loss_layers) outputs_style_features = losses.compute_approximate_style_features( outputs_image_features, self.style_loss_layers) # styles style features styles_image_features = losses.extract_image_features( styles, self.network_name) # use approximated style features instead # styles_style_features = losses.compute_style_features( # styles_image_features, self.style_loss_layers) styles_style_features = losses.compute_approximate_style_features( styles_image_features, self.style_loss_layers) # gather the summary operations summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # calculate the losses # the content loss if self.content_weight > 0.0: self.content_loss = losses.compute_content_loss( inputs_content_features, outputs_content_features, self.content_loss_layers) self.total_loss += self.content_weight * self.content_loss summaries.add(tf.summary.scalar('losses/content_loss', self.content_loss)) # the style loss if self.style_weight > 0.0: # use approximated style features instead # self.style_loss = losses.compute_style_loss( # styles_style_features, outputs_style_features, self.style_loss_layers) self.style_loss = losses.compute_approximate_style_loss( styles_style_features, outputs_style_features, self.style_loss_layers) self.total_loss += self.style_weight * self.style_loss summaries.add(tf.summary.scalar('losses/style_loss', self.style_loss)) # the total weight loss if self.tv_weight > 0.0: self.tv_loss = losses.compute_total_variation_loss(outputs) self.total_loss += self.tv_weight * self.tv_loss summaries.add(tf.summary.scalar('losses/tv_loss', self.tv_loss)) # the weight regularization loss if self.weight_decay > 0.0: self.weight_loss = tf.add_n( tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), name='weight_loss') self.total_loss += self.weight_loss summaries.add(tf.summary.scalar('losses/weight_loss', self.weight_loss)) summaries.add(tf.summary.scalar('total_loss', self.total_loss)) # gather the image tiles for style transfer image_tiles = tf.concat([inputs, outputs], axis=2) image_tiles = preprocessing.batch_mean_image_summation(image_tiles) image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0), tf.uint8) summaries.add(tf.summary.image('style_results', image_tiles, max_outputs=8)) # gather the styles summaries.add(tf.summary.image('styles', preprocessing.batch_mean_image_summation(styles), max_outputs=8)) self.summaries = summaries return self.total_loss