Ejemplo n.º 1
0
    def build_train_graph(self, inputs):
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        for i in range(len(self.content_layers)):
            # skip some networks
            if i < 3:
                continue

            selected_layer = self.content_layers[i]

            outputs, inputs_content_features = self.auto_encoder(
                inputs, content_layer=i, reuse=False)
            outputs = preprocessing.batch_mean_image_subtraction(outputs)

            ########################
            # construct the losses #
            ########################
            # 1) reconstruction loss
            recons_loss = tf.losses.mean_squared_error(
                inputs, outputs, scope='recons_loss/decoder_%d' % i)
            self.recons_loss[selected_layer] = recons_loss
            self.total_loss += self.recons_weight * recons_loss
            summaries.add(
                tf.summary.scalar('recons_loss/decoder_%d' % i, recons_loss))

            # 2) content loss
            outputs_image_features = losses.extract_image_features(
                outputs, self.network_name)
            outputs_content_features = losses.compute_content_features(
                outputs_image_features, [selected_layer])
            content_loss = losses.compute_content_loss(
                outputs_content_features, inputs_content_features,
                [selected_layer])
            self.content_loss[selected_layer] = content_loss
            self.total_loss += self.content_weight * content_loss
            summaries.add(
                tf.summary.scalar('content_loss/decoder_%d' % i, content_loss))

            # 3) total variation loss
            tv_loss = losses.compute_total_variation_loss_l1(outputs)
            self.tv_loss[selected_layer] = tv_loss
            self.total_loss += self.tv_weight * tv_loss
            summaries.add(tf.summary.scalar('tv_loss/decoder_%d' % i, tv_loss))

            image_tiles = tf.concat([inputs, outputs], axis=2)
            image_tiles = preprocessing.batch_mean_image_summation(image_tiles)
            image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0),
                                  tf.uint8)
            summaries.add(
                tf.summary.image('image_comparison/decoder_%d' % i,
                                 image_tiles,
                                 max_outputs=8))

        self.summaries = summaries
        return self.total_loss
Ejemplo n.º 2
0
    def build_train_graph(self, inputs):
        """build the training graph for the training of the hierarchical autoencoder"""
        outputs = self.hierarchical_autoencoder(inputs, reuse=False)
        outputs = preprocessing.batch_mean_image_subtraction(outputs)

        # summaries
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        ########################
        # construct the losses #
        ########################
        # 1) reconstruction loss
        if self.recons_weight > 0.0:
            recons_loss = tf.losses.mean_squared_error(
                inputs,
                outputs,
                weights=self.recons_weight,
                scope='recons_loss')
            self.recons_loss = recons_loss
            self.total_loss += recons_loss
            summaries.add(tf.summary.scalar('losses/recons_loss', recons_loss))

        # 2) content loss
        if self.content_weight > 0.0:
            outputs_image_features = losses.extract_image_features(
                outputs, self.network_name)
            outputs_content_features = losses.compute_content_features(
                outputs_image_features, self.style_loss_layers)

            inputs_image_features = losses.extract_image_features(
                inputs, self.network_name)
            inputs_content_features = losses.compute_content_features(
                inputs_image_features, self.style_loss_layers)

            content_loss = losses.compute_content_loss(
                outputs_content_features,
                inputs_content_features,
                content_loss_layers=self.style_loss_layers,
                weights=self.content_weight)
            self.content_loss = content_loss
            self.total_loss += content_loss
            summaries.add(
                tf.summary.scalar('losses/content_loss', content_loss))

        # 3) total variation loss
        if self.tv_weight > 0.0:
            tv_loss = losses.compute_total_variation_loss_l1(
                outputs, self.tv_weight)
            self.tv_loss = tv_loss
            self.total_loss += tv_loss
            summaries.add(tf.summary.scalar('losses/tv_loss', tv_loss))

        image_tiles = tf.concat([inputs, outputs], axis=2)
        image_tiles = preprocessing.batch_mean_image_summation(image_tiles)
        image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0),
                              tf.uint8)
        summaries.add(
            tf.summary.image('image_comparison', image_tiles, max_outputs=8))

        self.summaries = summaries
        return self.total_loss