Ejemplo n.º 1
0
    def style_transfer(self, inputs, styles, style_layers=None):
        """style transfer via recursive feature transforms

        Args:
            inputs: input images [batch_size, height, width, channel]
            styles: input styles [1 or batch_size, height, width, channel]
            style_layers: a list of enforced style layer ids, default is None that
                applies all style layers as self.style_loss_layers

        Returns:
            outputs: the stylized images [batch_size, height, width, channel]
        """
        # get the style features for the styles
        styles_image_features = losses.extract_image_features(
            styles, self.network_name)
        styles_features = losses.compute_content_features(
            styles_image_features, self.style_loss_layers)

        # construct the recurrent modules
        selected_style_layers = self.style_loss_layers
        if style_layers:
            selected_style_layers = [
                selected_style_layers[i] for i in style_layers
            ]
        else:
            style_layers = range(len(selected_style_layers))

        outputs = tf.identity(inputs)
        num_modules = len(selected_style_layers)
        for i in range(num_modules, 0, -1):
            starting_layer = selected_style_layers[i - 1]
            # encoding the inputs
            contents_image_features = losses.extract_image_features(
                outputs, self.network_name)
            content_features = losses.compute_content_features(
                contents_image_features, [starting_layer])

            # feature transformation
            transformed_features = feature_transform(
                content_features[starting_layer],
                styles_features[starting_layer])
            if self.blending_weight:
                transformed_features = self.blending_weight * transformed_features + \
                                       (1-self.blending_weight) * content_features[starting_layer]

            # decoding the contents
            with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope()):
                outputs = vgg_decoder.vgg_decoder(transformed_features,
                                                  self.network_name,
                                                  starting_layer,
                                                  reuse=True,
                                                  scope='decoder_%d' %
                                                  style_layers[i - 1])
            outputs = preprocessing.batch_mean_image_subtraction(outputs)
            print('Finish the module of [%s]' % starting_layer)

        # recover the outputs
        return preprocessing.batch_mean_image_summation(outputs)
    def style_transfer(self, inputs, styles, style_layers=(2, )):
        """style transfer via patch swapping

        Args:
            inputs: input images [batch_size, height, width, channel]
            styles: input styles [1, height, width, channel]
            style_layers: the list of layers to perform style swapping, default is None
                that applied all style layers as self.style_loss_layers

        Returns:
            outputs: the stylized images [batch_size, height, width, channel]s
        """
        styles_image_features = losses.extract_image_features(
            styles, self.network_name)
        styles_features = losses.compute_content_features(
            styles_image_features, self.style_loss_layers)

        # construct the recurrent modules
        selected_style_layers = self.style_loss_layers
        if style_layers:
            selected_style_layers = [
                selected_style_layers[i] for i in style_layers
            ]
        else:
            style_layers = range(len(selected_style_layers))

        # input preprocessing
        outputs = tf.identity(inputs)
        # start style transfer
        num_modules = len(selected_style_layers)
        for i in range(num_modules, 0, -1):
            starting_layer = selected_style_layers[i - 1]
            # encoding the inputs
            contents_image_features = losses.extract_image_features(
                outputs, self.network_name)
            contents_features = losses.compute_content_features(
                contents_image_features, [starting_layer])
            # feature transformation
            transformed_features = feature_transform(
                contents_features[starting_layer],
                styles_features[starting_layer],
                patch_size=self.patch_size)
            # decoding the contents
            with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope()):
                outputs = vgg_decoder.vgg_decoder(transformed_features,
                                                  self.network_name,
                                                  starting_layer,
                                                  scope='decoder_%d' %
                                                  style_layers[i - 1])
            outputs = preprocessing.batch_mean_image_subtraction(outputs)
            print('Finish the module of [%s]' % starting_layer)

        # recover the outputs
        return preprocessing.batch_mean_image_summation(outputs)
Ejemplo n.º 3
0
    def auto_encoder(self, inputs, content_layer=2, reuse=True):
        # extract the content features
        image_features = losses.extract_image_features(inputs, self.network_name)
        content_features = losses.compute_content_features(image_features, self.content_layers)

        # used content feature
        selected_layer = self.content_layers[content_layer]
        content_feature = content_features[selected_layer]
        input_content_features = {selected_layer: content_feature}

        # reconstruct the images
        with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope(self.weight_decay)):
            outputs = vgg_decoder.vgg_decoder(
                content_feature,
                self.network_name,
                selected_layer,
                reuse=reuse,
                scope='decoder_%d' % content_layer)
        return outputs, input_content_features
Ejemplo n.º 4
0
    def hierarchical_autoencoder(self, inputs, reuse=True):
        """hierarchical autoencoder for content reconstruction"""
        # extract the content features
        image_features = losses.extract_image_features(inputs,
                                                       self.network_name)
        content_features = losses.compute_content_features(
            image_features, self.style_loss_layers)

        # the applied content feature for the decode network
        selected_layer = self.style_loss_layers[-1]
        hidden_feature = content_features[selected_layer]

        # decode the hidden feature to the output image
        with slim.arg_scope(
                vgg_decoder.vgg_decoder_arg_scope(self.weight_decay)):
            outputs = vgg_decoder.vgg_combined_decoder(
                hidden_feature,
                content_features,
                fusion_fn=network_ops.adaptive_instance_normalization,
                network_name=self.network_name,
                starting_layer=selected_layer,
                reuse=reuse)
        return outputs
Ejemplo n.º 5
0
    def transfer_styles(self,
                        inputs,
                        styles,
                        inter_weight=1.0,
                        intra_weights=(1, )):
        """transfer the content image by style images

        Args:
            inputs: input images [batch_size, height, width, channel]
            styles: a list of input styles, in default the size is 1
            inter_weight: the blending weight between the content and style
            intra_weights: a list of blending weights among the styles,
                in default it is (1,)

        Returns:
            outputs: the stylized images [batch_size, height, width, channel]
        """
        if not isinstance(styles, (list, tuple)):
            styles = [styles]

        if not isinstance(intra_weights, (list, tuple)):
            intra_weights = [intra_weights]

        # 1) extract the style features
        styles_features = []
        for style in styles:
            style_image_features = losses.extract_image_features(
                style, self.network_name)
            style_features = losses.compute_content_features(
                style_image_features, self.style_loss_layers)
            styles_features.append(style_features)

        # 2) content features
        inputs_image_features = losses.extract_image_features(
            inputs, self.network_name)
        inputs_features = losses.compute_content_features(
            inputs_image_features, self.style_loss_layers)

        # 3) style decorator
        # the applied content feature from the content input
        selected_layer = self.style_loss_layers[-1]
        hidden_feature = inputs_features[selected_layer]

        # applying the style decorator
        blended_feature = 0.0
        n = 0
        for style_features in styles_features:
            swapped_feature = style_decorator(hidden_feature,
                                              style_features[selected_layer],
                                              style_coding=self.style_coding,
                                              style_interp=self.style_interp,
                                              ratio_interp=inter_weight,
                                              patch_size=self.patch_size)
            blended_feature += intra_weights[n] * swapped_feature
            n += 1

        # 4) decode the hidden feature to the output image
        with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope()):
            outputs = vgg_decoder.vgg_multiple_combined_decoder(
                blended_feature,
                styles_features,
                intra_weights,
                fusion_fn=network_ops.adaptive_instance_normalization,
                network_name=self.network_name,
                starting_layer=selected_layer)
        return outputs