Exemplo n.º 1
0
    def build_train_graph(self, inputs):
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        for i in range(len(self.content_layers)):
            # skip some networks
            if i < 3:
                continue

            selected_layer = self.content_layers[i]

            outputs, inputs_content_features = self.auto_encoder(
                inputs, content_layer=i, reuse=False)
            outputs = preprocessing.batch_mean_image_subtraction(outputs)

            ########################
            # construct the losses #
            ########################
            # 1) reconstruction loss
            recons_loss = tf.losses.mean_squared_error(
                inputs, outputs, scope='recons_loss/decoder_%d' % i)
            self.recons_loss[selected_layer] = recons_loss
            self.total_loss += self.recons_weight * recons_loss
            summaries.add(
                tf.summary.scalar('recons_loss/decoder_%d' % i, recons_loss))

            # 2) content loss
            outputs_image_features = losses.extract_image_features(
                outputs, self.network_name)
            outputs_content_features = losses.compute_content_features(
                outputs_image_features, [selected_layer])
            content_loss = losses.compute_content_loss(
                outputs_content_features, inputs_content_features,
                [selected_layer])
            self.content_loss[selected_layer] = content_loss
            self.total_loss += self.content_weight * content_loss
            summaries.add(
                tf.summary.scalar('content_loss/decoder_%d' % i, content_loss))

            # 3) total variation loss
            tv_loss = losses.compute_total_variation_loss_l1(outputs)
            self.tv_loss[selected_layer] = tv_loss
            self.total_loss += self.tv_weight * tv_loss
            summaries.add(tf.summary.scalar('tv_loss/decoder_%d' % i, tv_loss))

            image_tiles = tf.concat([inputs, outputs], axis=2)
            image_tiles = preprocessing.batch_mean_image_summation(image_tiles)
            image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0),
                                  tf.uint8)
            summaries.add(
                tf.summary.image('image_comparison/decoder_%d' % i,
                                 image_tiles,
                                 max_outputs=8))

        self.summaries = summaries
        return self.total_loss
Exemplo n.º 2
0
    def build_model(self, inputs):
        for i in range(len(self.content_layers)):
            selected_layer = self.content_layers[i]

            outputs, inputs_content_features = self.auto_encoder(
                inputs, content_layer=i, reuse=False)
            outputs = preprocessing.batch_mean_image_subtraction(outputs)

            ########################
            # construct the losses #
            ########################
            # 1) reconstruction loss
            recons_loss = tf.losses.mean_squared_error(
                inputs, outputs, scope='recons_loss/decoder_%d' % i)
            self.recons_loss[selected_layer] = recons_loss
            self.total_loss += self.recons_weight * recons_loss
            self.summaries.add(tf.summary.scalar(
                'recons_loss/decoder_%d' % i, recons_loss))

            # 2) content loss
            outputs_image_features = losses.extract_image_features(
                outputs, self.network_name)
            outputs_content_features = losses.compute_content_features(
                outputs_image_features, [selected_layer])
            content_loss = losses.compute_content_loss(
                outputs_content_features, inputs_content_features, [selected_layer])
            self.content_loss[selected_layer] = content_loss
            self.total_loss += self.content_weight * content_loss
            self.summaries.add(tf.summary.scalar(
                'content_loss/decoder_%d' % i, content_loss))

            image_tiles = tf.concat([inputs, outputs], axis=2)
            image_tiles = preprocessing.batch_mean_image_summation(image_tiles)
            image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0), tf.uint8)
            self.summaries.add(tf.summary.image(
                'image_comparison/decoder_%d' % i, image_tiles, max_outputs=8))

        return self.total_loss
Exemplo n.º 3
0
    def build_train_graph(self, inputs):
        """build the training graph for the training of the hierarchical autoencoder"""
        outputs = self.hierarchical_autoencoder(inputs, reuse=False)
        outputs = preprocessing.batch_mean_image_subtraction(outputs)

        # summaries
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        ########################
        # construct the losses #
        ########################
        # 1) reconstruction loss
        if self.recons_weight > 0.0:
            recons_loss = tf.losses.mean_squared_error(
                inputs,
                outputs,
                weights=self.recons_weight,
                scope='recons_loss')
            self.recons_loss = recons_loss
            self.total_loss += recons_loss
            summaries.add(tf.summary.scalar('losses/recons_loss', recons_loss))

        # 2) content loss
        if self.content_weight > 0.0:
            outputs_image_features = losses.extract_image_features(
                outputs, self.network_name)
            outputs_content_features = losses.compute_content_features(
                outputs_image_features, self.style_loss_layers)

            inputs_image_features = losses.extract_image_features(
                inputs, self.network_name)
            inputs_content_features = losses.compute_content_features(
                inputs_image_features, self.style_loss_layers)

            content_loss = losses.compute_content_loss(
                outputs_content_features,
                inputs_content_features,
                content_loss_layers=self.style_loss_layers,
                weights=self.content_weight)
            self.content_loss = content_loss
            self.total_loss += content_loss
            summaries.add(
                tf.summary.scalar('losses/content_loss', content_loss))

        # 3) total variation loss
        if self.tv_weight > 0.0:
            tv_loss = losses.compute_total_variation_loss_l1(
                outputs, self.tv_weight)
            self.tv_loss = tv_loss
            self.total_loss += tv_loss
            summaries.add(tf.summary.scalar('losses/tv_loss', tv_loss))

        image_tiles = tf.concat([inputs, outputs], axis=2)
        image_tiles = preprocessing.batch_mean_image_summation(image_tiles)
        image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0),
                              tf.uint8)
        summaries.add(
            tf.summary.image('image_comparison', image_tiles, max_outputs=8))

        self.summaries = summaries
        return self.total_loss
    def build_model(self, inputs, styles, reuse=False):
        """build the graph for the MSG model

        Args:
        inputs: the inputs [batch_size, height, width, channel]
        styles: the styles [1, height, width, channel]
        reuse: whether to reuse the parameters

        Returns:
        total_loss: the total loss for the style transfer
        """
        # extract the content features for the inputs
        inputs_image_features = losses.extract_image_features(
            inputs, self.network_name)
        inputs_content_features = losses.compute_content_features(
            inputs_image_features, self.content_loss_layers)

        # extract styles style features
        styles_image_features = losses.extract_image_features(
            styles, self.network_name)
        styles_style_features = losses.compute_style_features(
            styles_image_features, self.style_loss_layers)

        # transfer the styles from the inputs
        outputs = self.style_transfer(inputs, styles, reuse=reuse)

        # preprocessing the outputs to avoid biases and calculate the features
        outputs = preprocessing.batch_mean_image_subtraction(outputs)
        outputs_content_features, outputs_style_features = \
            losses.compute_content_and_style_features(
                outputs, self.network_name,
                self.content_loss_layers, self.style_loss_layers)

        # gather the summary operations
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # calculate the losses
        # the content loss
        if self.content_weight > 0.0:
            self.content_loss = losses.compute_content_loss(
                inputs_content_features, outputs_content_features,
                self.content_loss_layers)
            self.total_loss += self.content_weight * self.content_loss
            summaries.add(
                tf.summary.scalar('losses/content_loss', self.content_loss))
        # the style loss
        if self.style_weight > 0.0:
            self.style_loss = losses.compute_style_loss(
                styles_style_features, outputs_style_features,
                self.style_loss_layers)
            self.total_loss += self.style_weight * self.style_loss
            summaries.add(
                tf.summary.scalar('losses/style_loss', self.style_loss))
        # the total variation loss
        if self.tv_weight > 0.0:
            self.tv_loss = losses.compute_total_variation_loss(outputs)
            self.total_loss += self.tv_weight * self.tv_loss
            summaries.add(tf.summary.scalar('losses/tv_loss', self.tv_loss))

        summaries.add(tf.summary.scalar('total_loss', self.total_loss))

        # gather the image tiles for style transfer
        image_tiles = tf.concat([inputs, outputs], axis=2)
        image_tiles = preprocessing.batch_mean_image_summation(image_tiles)
        image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0),
                              tf.uint8)
        summaries.add(
            tf.summary.image('style_results', image_tiles, max_outputs=8))

        # gather the styles
        summaries.add(
            tf.summary.image('styles',
                             preprocessing.batch_mean_image_summation(styles),
                             max_outputs=8))
        # gather the summaries
        self.summaries = summaries
        return self.total_loss
Exemplo n.º 5
0
    def build_model(self, inputs, styles):
        # style transfer to the inputs
        outputs, inputs_content_features = self.style_transfer(inputs, styles)

        # calculate the style features for the outputs
        outputs = preprocessing.batch_mean_image_subtraction(outputs)
        # use approximated style loss instead
        # outputs_content_features, outputs_style_features = \
        #     losses.compute_content_and_style_features(
        #         outputs, self.network_name,
        #         self.content_loss_layers, self.style_loss_layers)
        outputs_image_features = losses.extract_image_features(outputs, self.network_name)
        outputs_content_features = losses.compute_content_features(
            outputs_image_features, self.content_loss_layers)
        outputs_style_features = losses.compute_approximate_style_features(
            outputs_image_features, self.style_loss_layers)

        # styles style features
        styles_image_features = losses.extract_image_features(
            styles, self.network_name)

        # use approximated style features instead
        # styles_style_features = losses.compute_style_features(
        #     styles_image_features, self.style_loss_layers)
        styles_style_features = losses.compute_approximate_style_features(
            styles_image_features, self.style_loss_layers)

        # gather the summary operations
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # calculate the losses
        # the content loss
        if self.content_weight > 0.0:
            self.content_loss = losses.compute_content_loss(
                inputs_content_features, outputs_content_features, self.content_loss_layers)
            self.total_loss += self.content_weight * self.content_loss
            summaries.add(tf.summary.scalar('losses/content_loss', self.content_loss))
        # the style loss
        if self.style_weight > 0.0:
            # use approximated style features instead
            # self.style_loss = losses.compute_style_loss(
            #     styles_style_features, outputs_style_features, self.style_loss_layers)
            self.style_loss = losses.compute_approximate_style_loss(
                styles_style_features, outputs_style_features, self.style_loss_layers)
            self.total_loss += self.style_weight * self.style_loss
            summaries.add(tf.summary.scalar('losses/style_loss', self.style_loss))
        # the total weight loss
        if self.tv_weight > 0.0:
            self.tv_loss = losses.compute_total_variation_loss(outputs)
            self.total_loss += self.tv_weight * self.tv_loss
            summaries.add(tf.summary.scalar('losses/tv_loss', self.tv_loss))
        # the weight regularization loss
        if self.weight_decay > 0.0:
            self.weight_loss = tf.add_n(
                tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES),
                name='weight_loss')
            self.total_loss += self.weight_loss
            summaries.add(tf.summary.scalar('losses/weight_loss', self.weight_loss))

        summaries.add(tf.summary.scalar('total_loss', self.total_loss))

        # gather the image tiles for style transfer
        image_tiles = tf.concat([inputs, outputs], axis=2)
        image_tiles = preprocessing.batch_mean_image_summation(image_tiles)
        image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0), tf.uint8)
        summaries.add(tf.summary.image('style_results', image_tiles, max_outputs=8))

        # gather the styles
        summaries.add(tf.summary.image('styles',
                                       preprocessing.batch_mean_image_summation(styles),
                                       max_outputs=8))

        self.summaries = summaries
        return self.total_loss