def style_transfer(self, inputs, styles, style_layers=None): """style transfer via recursive feature transforms Args: inputs: input images [batch_size, height, width, channel] styles: input styles [1 or batch_size, height, width, channel] style_layers: a list of enforced style layer ids, default is None that applies all style layers as self.style_loss_layers Returns: outputs: the stylized images [batch_size, height, width, channel] """ # get the style features for the styles styles_image_features = losses.extract_image_features( styles, self.network_name) styles_features = losses.compute_content_features( styles_image_features, self.style_loss_layers) # construct the recurrent modules selected_style_layers = self.style_loss_layers if style_layers: selected_style_layers = [ selected_style_layers[i] for i in style_layers ] else: style_layers = range(len(selected_style_layers)) outputs = tf.identity(inputs) num_modules = len(selected_style_layers) for i in range(num_modules, 0, -1): starting_layer = selected_style_layers[i - 1] # encoding the inputs contents_image_features = losses.extract_image_features( outputs, self.network_name) content_features = losses.compute_content_features( contents_image_features, [starting_layer]) # feature transformation transformed_features = feature_transform( content_features[starting_layer], styles_features[starting_layer]) if self.blending_weight: transformed_features = self.blending_weight * transformed_features + \ (1-self.blending_weight) * content_features[starting_layer] # decoding the contents with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope()): outputs = vgg_decoder.vgg_decoder(transformed_features, self.network_name, starting_layer, reuse=True, scope='decoder_%d' % style_layers[i - 1]) outputs = preprocessing.batch_mean_image_subtraction(outputs) print('Finish the module of [%s]' % starting_layer) # recover the outputs return preprocessing.batch_mean_image_summation(outputs)
def style_transfer(self, inputs, styles, style_layers=(2, )): """style transfer via patch swapping Args: inputs: input images [batch_size, height, width, channel] styles: input styles [1, height, width, channel] style_layers: the list of layers to perform style swapping, default is None that applied all style layers as self.style_loss_layers Returns: outputs: the stylized images [batch_size, height, width, channel]s """ styles_image_features = losses.extract_image_features( styles, self.network_name) styles_features = losses.compute_content_features( styles_image_features, self.style_loss_layers) # construct the recurrent modules selected_style_layers = self.style_loss_layers if style_layers: selected_style_layers = [ selected_style_layers[i] for i in style_layers ] else: style_layers = range(len(selected_style_layers)) # input preprocessing outputs = tf.identity(inputs) # start style transfer num_modules = len(selected_style_layers) for i in range(num_modules, 0, -1): starting_layer = selected_style_layers[i - 1] # encoding the inputs contents_image_features = losses.extract_image_features( outputs, self.network_name) contents_features = losses.compute_content_features( contents_image_features, [starting_layer]) # feature transformation transformed_features = feature_transform( contents_features[starting_layer], styles_features[starting_layer], patch_size=self.patch_size) # decoding the contents with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope()): outputs = vgg_decoder.vgg_decoder(transformed_features, self.network_name, starting_layer, scope='decoder_%d' % style_layers[i - 1]) outputs = preprocessing.batch_mean_image_subtraction(outputs) print('Finish the module of [%s]' % starting_layer) # recover the outputs return preprocessing.batch_mean_image_summation(outputs)
def auto_encoder(self, inputs, content_layer=2, reuse=True): # extract the content features image_features = losses.extract_image_features(inputs, self.network_name) content_features = losses.compute_content_features(image_features, self.content_layers) # used content feature selected_layer = self.content_layers[content_layer] content_feature = content_features[selected_layer] input_content_features = {selected_layer: content_feature} # reconstruct the images with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope(self.weight_decay)): outputs = vgg_decoder.vgg_decoder( content_feature, self.network_name, selected_layer, reuse=reuse, scope='decoder_%d' % content_layer) return outputs, input_content_features