def build(self):
        if self.is_built:
            return

        self.is_built = True

        gen_factory = self.create_generator()
        dis_factory = self.create_discriminator()
        smoothing = 0.9 if self.options.label_smoothing else 1
        seed = seed = self.options.seed
        kernel = self.options.kernel_size

        self.input_rgb = tf.placeholder(tf.float32,
                                        shape=(None, None, None, 3),
                                        name='input_rgb')
        self.input_rgb_prev = tf.placeholder(tf.float32,
                                             shape=(None, None, None, 3),
                                             name='input_rgb_prev')

        self.input_gray = tf.image.rgb_to_grayscale(self.input_rgb)
        self.input_color = preprocess(self.input_rgb,
                                      colorspace_in=COLORSPACE_RGB,
                                      colorspace_out=self.options.color_space)
        self.input_color_prev = preprocess(
            self.input_rgb_prev,
            colorspace_in=COLORSPACE_RGB,
            colorspace_out=self.options.color_space)

        gen = gen_factory.create(
            tf.concat([self.input_gray, self.input_color_prev], 3), kernel,
            seed)
        dis_real = dis_factory.create(
            tf.concat([self.input_color, self.input_color_prev], 3), kernel,
            seed)
        dis_fake = dis_factory.create(tf.concat([gen, self.input_color_prev],
                                                3),
                                      kernel,
                                      seed,
                                      reuse_variables=True)

        gen_ce = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=dis_fake, labels=tf.ones_like(dis_fake))
        dis_real_ce = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=dis_real, labels=tf.ones_like(dis_real) * smoothing)
        dis_fake_ce = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=dis_fake, labels=tf.zeros_like(dis_fake))

        self.dis_loss_real = tf.reduce_mean(dis_real_ce)
        self.dis_loss_fake = tf.reduce_mean(dis_fake_ce)
        self.dis_loss = tf.reduce_mean(dis_real_ce + dis_fake_ce)

        self.gen_loss_gan = tf.reduce_mean(gen_ce)
        self.gen_loss_l1 = tf.reduce_mean(
            tf.abs(self.input_color - gen)) * self.options.l1_weight
        #self.gen_loss_l1 = tf.reduce_mean(tf.abs(self.input_gray - tf.image.rgb_to_grayscale(gen))) * self.options.l1_weight

        self.gen_loss = self.gen_loss_l1  #self.gen_loss_gan + self.gen_loss_l1

        self.sampler = gen_factory.create(tf.concat(
            [self.input_gray, self.input_color_prev], 3),
                                          kernel,
                                          seed,
                                          reuse_variables=True)
        self.accuracy = pixelwise_accuracy(self.input_color, gen,
                                           self.options.color_space,
                                           self.options.acc_thresh)
        self.learning_rate = tf.constant(self.options.lr)

        # learning rate decay
        if self.options.lr_decay_rate > 0:
            self.learning_rate = tf.maximum(
                1e-8,
                tf.train.exponential_decay(
                    learning_rate=self.options.lr,
                    global_step=self.global_step,
                    decay_steps=self.options.lr_decay_steps,
                    decay_rate=self.options.lr_decay_rate))

        # generator optimizaer
        self.gen_train = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate,
            beta1=self.options.beta1).minimize(self.gen_loss,
                                               var_list=gen_factory.var_list)

        # discriminator optimizaer
        self.dis_train = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate,
            beta1=self.options.beta1).minimize(self.dis_loss,
                                               var_list=dis_factory.var_list,
                                               global_step=self.global_step)

        self.saver = tf.train.Saver()
コード例 #2
0
ファイル: models.py プロジェクト: z1z9b89/StyleTransfer
    def build(self):
        if self.is_built:
            return

        self.is_built = True

        gen_factory = self.create_generator()
        dis_factory = self.create_discriminator()
        smoothing = 0.9 if self.options.label_smoothing else 1
        seed = self.options.seed
        kernel = 4

        # model input placeholder: RGB imaege
        self.input_rgb = tf.placeholder(tf.float32,
                                        shape=(None, None, None, 3),
                                        name='input_rgb')

        # model input after preprocessing: LAB image
        self.input_color = preprocess(self.input_rgb,
                                      colorspace_in=COLORSPACE_RGB,
                                      colorspace_out=self.options.color_space)

        # test mode: model input is a graycale placeholder
        if self.options.mode == 1:
            self.input_gray = tf.placeholder(tf.float32,
                                             shape=(None, None, None, 1),
                                             name='input_gray')

        # train/turing-test we extract grayscale image from color image
        else:
            self.input_gray = tf.placeholder(tf.float32,
                                             shape=(None, None, None, 1),
                                             name='input_gray')

        gen = gen_factory.create(self.input_gray, kernel, seed)
        dis_real = dis_factory.create(
            tf.concat([self.input_gray, self.input_color], 3), kernel, seed)
        dis_fake = dis_factory.create(tf.concat([self.input_gray, gen], 3),
                                      kernel,
                                      seed,
                                      reuse_variables=True)

        gen_ce = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=dis_fake, labels=tf.ones_like(dis_fake))
        dis_real_ce = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=dis_real, labels=tf.ones_like(dis_real) * smoothing)
        dis_fake_ce = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=dis_fake, labels=tf.zeros_like(dis_fake))

        self.dis_loss_real = tf.reduce_mean(dis_real_ce)
        self.dis_loss_fake = tf.reduce_mean(dis_fake_ce)
        self.dis_loss = tf.reduce_mean(dis_real_ce + dis_fake_ce)

        self.gen_loss_gan = tf.reduce_mean(gen_ce)
        self.gen_loss_l1 = tf.reduce_mean(
            tf.abs(self.input_color - gen)) * self.options.l1_weight
        self.gen_loss = self.gen_loss_gan + self.gen_loss_l1

        self.sampler = tf.identity(gen_factory.create(self.input_gray,
                                                      kernel,
                                                      seed,
                                                      reuse_variables=True),
                                   name='output')
        self.accuracy = pixelwise_accuracy(self.input_color, gen,
                                           self.options.color_space,
                                           self.options.acc_thresh)
        self.learning_rate = tf.constant(self.options.lr)

        # learning rate decay
        if self.options.lr_decay and self.options.lr_decay_rate > 0:
            self.learning_rate = tf.maximum(
                1e-6,
                tf.train.exponential_decay(
                    learning_rate=self.options.lr,
                    global_step=self.global_step,
                    decay_steps=self.options.lr_decay_steps,
                    decay_rate=self.options.lr_decay_rate))

        # generator optimizaer
        # 学习者的优化器
        self.gen_train = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate,
            beta1=self.options.beta1).minimize(self.gen_loss,
                                               var_list=gen_factory.var_list)

        # discriminator optimizaer
        # 打分者的优化器
        self.dis_train = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate / 10,
            beta1=self.options.beta1).minimize(self.dis_loss,
                                               var_list=dis_factory.var_list,
                                               global_step=self.global_step)

        self.saver = tf.train.Saver()