Exemplo n.º 1
0
 def begin(self):
     if self._summary is None:
         self._summary = summary.merge_all()
Exemplo n.º 2
0
    def build(self):
        print('build')
        # x = [[5 for i in range(512)] for j in range(512)]
        if self.options.is_training:
            # ==================== Define placeholders. ===================== #
            with tf.name_scope('placeholder'):
                self.input_painting = placeholder(
                    dtype=tf.float32,
                    shape=[self.batch_size, None, None, 3],
                    name='painting')
                self.input_photo = placeholder(
                    dtype=tf.float32,
                    shape=[self.batch_size, None, None, 3],
                    name='photo')
                self.lr = placeholder(dtype=tf.float32,
                                      shape=(),
                                      name='learning_rate')

            # ===================== Wire the graph. ========================= #
            # Encode input images.
            self.input_photo_features = encoder(image=self.input_photo,
                                                options=self.options,
                                                reuse=False)

            # Decode obtained features
            self.output_photo = decoder(features=self.input_photo_features,
                                        options=self.options,
                                        reuse=False)

            # Get features of output images. Need them to compute feature loss.
            self.output_photo_features = encoder(image=self.output_photo,
                                                 options=self.options,
                                                 reuse=True)

            # Add discriminators.
            # Note that each of the predictions contain multiple predictions at different scale.
            self.input_painting_discr_predictions = discriminator(
                image=self.input_painting, options=self.options, reuse=False)
            self.input_photo_discr_predictions = discriminator(
                image=self.input_photo, options=self.options, reuse=True)
            self.output_photo_discr_predictions = discriminator(
                image=self.output_photo, options=self.options, reuse=True)

            # ===================== Final losses that we optimize. ===================== #

            # Discriminator.
            # Have to predict ones only for original paintings, otherwise predict zero.
            scale_weight = {
                "scale_0": 1.,
                "scale_1": 1.,
                "scale_3": 1.,
                "scale_5": 1.,
                "scale_6": 1.
            }
            self.input_painting_discr_loss = {
                key: self.loss(pred, tf.ones_like(pred)) * scale_weight[key]
                for key, pred in zip(
                    self.input_painting_discr_predictions.keys(),
                    self.input_painting_discr_predictions.values())
            }
            self.input_photo_discr_loss = {
                key: self.loss(pred, tf.zeros_like(pred)) * scale_weight[key]
                for key, pred in zip(
                    self.input_photo_discr_predictions.keys(),
                    self.input_photo_discr_predictions.values())
            }
            self.output_photo_discr_loss = {
                key: self.loss(pred, tf.zeros_like(pred)) * scale_weight[key]
                for key, pred in zip(
                    self.output_photo_discr_predictions.keys(),
                    self.output_photo_discr_predictions.values())
            }

            self.discr_loss = tf.add_n(list(self.input_painting_discr_loss.values())) + \
                              tf.add_n(list(self.input_photo_discr_loss.values())) + \
                              tf.add_n(list(self.output_photo_discr_loss.values()))

            # Compute discriminator accuracies.
            self.input_painting_discr_acc = {
                key: tf.reduce_mean(
                    tf.cast(x=(pred > tf.zeros_like(pred)), dtype=tf.float32))
                * scale_weight[key]
                for key, pred in zip(
                    self.input_painting_discr_predictions.keys(),
                    self.input_painting_discr_predictions.values())
            }
            self.input_photo_discr_acc = {
                key: tf.reduce_mean(
                    tf.cast(x=(pred < tf.zeros_like(pred)), dtype=tf.float32))
                * scale_weight[key]
                for key, pred in zip(
                    self.input_photo_discr_predictions.keys(),
                    self.input_photo_discr_predictions.values())
            }
            self.output_photo_discr_acc = {
                key: tf.reduce_mean(
                    tf.cast(x=(pred < tf.zeros_like(pred)), dtype=tf.float32))
                * scale_weight[key]
                for key, pred in zip(
                    self.output_photo_discr_predictions.keys(),
                    self.output_photo_discr_predictions.values())
            }
            self.discr_acc = (tf.add_n(list(self.input_painting_discr_acc.values())) + \
                              tf.add_n(list(self.input_photo_discr_acc.values())) + \
                              tf.add_n(list(self.output_photo_discr_acc.values()))) / float(len(scale_weight.keys())*3)

            # Generator.
            # Predicts ones for both output images.
            self.output_photo_gener_loss = {
                key: self.loss(pred, tf.ones_like(pred)) * scale_weight[key]
                for key, pred in zip(
                    self.output_photo_discr_predictions.keys(),
                    self.output_photo_discr_predictions.values())
            }

            self.gener_loss = tf.add_n(
                list(self.output_photo_gener_loss.values()))

            # Compute generator accuracies.
            self.output_photo_gener_acc = {
                key: tf.reduce_mean(
                    tf.cast(x=(pred > tf.zeros_like(pred)), dtype=tf.float32))
                * scale_weight[key]
                for key, pred in zip(
                    self.output_photo_discr_predictions.keys(),
                    self.output_photo_discr_predictions.values())
            }

            self.gener_acc = tf.add_n(
                list(self.output_photo_gener_acc.values())) / float(
                    len(scale_weight.keys()))

            # Image loss.
            self.img_loss_photo = mse_criterion(
                transformer_block(self.output_photo),
                transformer_block(self.input_photo))
            self.img_loss = self.img_loss_photo

            # Features loss.
            self.feature_loss_photo = abs_criterion(self.output_photo_features,
                                                    self.input_photo_features)
            self.feature_loss = self.feature_loss_photo

            # ================== Define optimization steps. =============== #
            t_vars = tf.compat.v1.trainable_variables()
            self.discr_vars = [
                var for var in t_vars if 'discriminator' in var.name
            ]
            self.encoder_vars = [
                var for var in t_vars if 'encoder' in var.name
            ]
            self.decoder_vars = [
                var for var in t_vars if 'decoder' in var.name
            ]

            # Discriminator and generator steps.
            update_ops = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.UPDATE_OPS)

            with tf.control_dependencies(update_ops):
                self.d_optim_step = tf.compat.v1.train.AdamOptimizer(
                    self.lr).minimize(loss=self.options.discr_loss_weight *
                                      self.discr_loss,
                                      var_list=[self.discr_vars])
                self.g_optim_step = tf.compat.v1.train.AdamOptimizer(
                    self.lr).minimize(
                        loss=self.options.discr_loss_weight * self.gener_loss +
                        self.options.transformer_loss_weight * self.img_loss +
                        self.options.feature_loss_weight * self.feature_loss,
                        var_list=[self.encoder_vars + self.decoder_vars])

            # ============= Write statistics to tensorboard. ================ #

            # Discriminator loss summary.
            s_d1 = [
                summary.scalar(
                    "discriminator/input_painting_discr_loss/" + key, val)
                for key, val in zip(self.input_painting_discr_loss.keys(),
                                    self.input_painting_discr_loss.values())
            ]
            s_d2 = [
                summary.scalar("discriminator/input_photo_discr_loss/" + key,
                               val)
                for key, val in zip(self.input_photo_discr_loss.keys(),
                                    self.input_photo_discr_loss.values())
            ]
            s_d3 = [
                summary.scalar("discriminator/output_photo_discr_loss/" + key,
                               val)
                for key, val in zip(self.output_photo_discr_loss.keys(),
                                    self.output_photo_discr_loss.values())
            ]
            s_d = summary.scalar("discriminator/discr_loss", self.discr_loss)
            self.summary_discriminator_loss = summary.merge(s_d1 + s_d2 +
                                                            s_d3 + [s_d])

            # Discriminator acc summary.
            s_d1_acc = [
                summary.scalar("discriminator/input_painting_discr_acc/" + key,
                               val)
                for key, val in zip(self.input_painting_discr_acc.keys(),
                                    self.input_painting_discr_acc.values())
            ]
            s_d2_acc = [
                summary.scalar("discriminator/input_photo_discr_acc/" + key,
                               val)
                for key, val in zip(self.input_photo_discr_acc.keys(),
                                    self.input_photo_discr_acc.values())
            ]
            s_d3_acc = [
                summary.scalar("discriminator/output_photo_discr_acc/" + key,
                               val)
                for key, val in zip(self.output_photo_discr_acc.keys(),
                                    self.output_photo_discr_acc.values())
            ]
            s_d_acc = summary.scalar("discriminator/discr_acc", self.discr_acc)
            s_d_acc_g = summary.scalar("discriminator/discr_acc",
                                       self.gener_acc)
            self.summary_discriminator_acc = summary.merge(s_d1_acc +
                                                           s_d2_acc +
                                                           s_d3_acc +
                                                           [s_d_acc])

            # Image loss summary.
            s_i1 = summary.scalar("image_loss/photo", self.img_loss_photo)
            s_i = summary.scalar("image_loss/loss", self.img_loss)
            self.summary_image_loss = summary.merge([s_i1 + s_i])

            # Feature loss summary.
            s_f1 = summary.scalar("feature_loss/photo",
                                  self.feature_loss_photo)
            s_f = summary.scalar("feature_loss/loss", self.feature_loss)
            self.summary_feature_loss = summary.merge([s_f1 + s_f])

            self.summary_merged_all = summary.merge_all()
            self.writer = summary.FileWriter(self.logs_dir, self.sess.graph)
        else:
            # ==================== Define placeholders. ===================== #
            with tf.name_scope('placeholder'):
                self.input_photo = placeholder(
                    dtype=tf.float32,
                    shape=[self.batch_size, None, None, 3],
                    name='photo')

            # ===================== Wire the graph. ========================= #
            # Encode input images.
            self.input_photo_features = encoder(image=self.input_photo,
                                                options=self.options,
                                                reuse=False)

            # Decode obtained features.
            self.output_photo = decoder(features=self.input_photo_features,
                                        options=self.options,
                                        reuse=False)