def build(self): if self.is_built: return self.is_built = True gen_factory = self.create_generator() dis_factory = self.create_discriminator() smoothing = 0.9 if self.options.label_smoothing else 1 seed = seed = self.options.seed kernel = self.options.kernel_size self.input_rgb = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='input_rgb') self.input_rgb_prev = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='input_rgb_prev') self.input_gray = tf.image.rgb_to_grayscale(self.input_rgb) self.input_color = preprocess(self.input_rgb, colorspace_in=COLORSPACE_RGB, colorspace_out=self.options.color_space) self.input_color_prev = preprocess( self.input_rgb_prev, colorspace_in=COLORSPACE_RGB, colorspace_out=self.options.color_space) gen = gen_factory.create( tf.concat([self.input_gray, self.input_color_prev], 3), kernel, seed) dis_real = dis_factory.create( tf.concat([self.input_color, self.input_color_prev], 3), kernel, seed) dis_fake = dis_factory.create(tf.concat([gen, self.input_color_prev], 3), kernel, seed, reuse_variables=True) gen_ce = tf.nn.sigmoid_cross_entropy_with_logits( logits=dis_fake, labels=tf.ones_like(dis_fake)) dis_real_ce = tf.nn.sigmoid_cross_entropy_with_logits( logits=dis_real, labels=tf.ones_like(dis_real) * smoothing) dis_fake_ce = tf.nn.sigmoid_cross_entropy_with_logits( logits=dis_fake, labels=tf.zeros_like(dis_fake)) self.dis_loss_real = tf.reduce_mean(dis_real_ce) self.dis_loss_fake = tf.reduce_mean(dis_fake_ce) self.dis_loss = tf.reduce_mean(dis_real_ce + dis_fake_ce) self.gen_loss_gan = tf.reduce_mean(gen_ce) self.gen_loss_l1 = tf.reduce_mean( tf.abs(self.input_color - gen)) * self.options.l1_weight #self.gen_loss_l1 = tf.reduce_mean(tf.abs(self.input_gray - tf.image.rgb_to_grayscale(gen))) * self.options.l1_weight self.gen_loss = self.gen_loss_l1 #self.gen_loss_gan + self.gen_loss_l1 self.sampler = gen_factory.create(tf.concat( [self.input_gray, self.input_color_prev], 3), kernel, seed, reuse_variables=True) self.accuracy = pixelwise_accuracy(self.input_color, gen, self.options.color_space, self.options.acc_thresh) self.learning_rate = tf.constant(self.options.lr) # learning rate decay if self.options.lr_decay_rate > 0: self.learning_rate = tf.maximum( 1e-8, tf.train.exponential_decay( learning_rate=self.options.lr, global_step=self.global_step, decay_steps=self.options.lr_decay_steps, decay_rate=self.options.lr_decay_rate)) # generator optimizaer self.gen_train = tf.train.AdamOptimizer( learning_rate=self.learning_rate, beta1=self.options.beta1).minimize(self.gen_loss, var_list=gen_factory.var_list) # discriminator optimizaer self.dis_train = tf.train.AdamOptimizer( learning_rate=self.learning_rate, beta1=self.options.beta1).minimize(self.dis_loss, var_list=dis_factory.var_list, global_step=self.global_step) self.saver = tf.train.Saver()
def build(self): if self.is_built: return self.is_built = True gen_factory = self.create_generator() dis_factory = self.create_discriminator() smoothing = 0.9 if self.options.label_smoothing else 1 seed = self.options.seed kernel = 4 # model input placeholder: RGB imaege self.input_rgb = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='input_rgb') # model input after preprocessing: LAB image self.input_color = preprocess(self.input_rgb, colorspace_in=COLORSPACE_RGB, colorspace_out=self.options.color_space) # test mode: model input is a graycale placeholder if self.options.mode == 1: self.input_gray = tf.placeholder(tf.float32, shape=(None, None, None, 1), name='input_gray') # train/turing-test we extract grayscale image from color image else: self.input_gray = tf.placeholder(tf.float32, shape=(None, None, None, 1), name='input_gray') gen = gen_factory.create(self.input_gray, kernel, seed) dis_real = dis_factory.create( tf.concat([self.input_gray, self.input_color], 3), kernel, seed) dis_fake = dis_factory.create(tf.concat([self.input_gray, gen], 3), kernel, seed, reuse_variables=True) gen_ce = tf.nn.sigmoid_cross_entropy_with_logits( logits=dis_fake, labels=tf.ones_like(dis_fake)) dis_real_ce = tf.nn.sigmoid_cross_entropy_with_logits( logits=dis_real, labels=tf.ones_like(dis_real) * smoothing) dis_fake_ce = tf.nn.sigmoid_cross_entropy_with_logits( logits=dis_fake, labels=tf.zeros_like(dis_fake)) self.dis_loss_real = tf.reduce_mean(dis_real_ce) self.dis_loss_fake = tf.reduce_mean(dis_fake_ce) self.dis_loss = tf.reduce_mean(dis_real_ce + dis_fake_ce) self.gen_loss_gan = tf.reduce_mean(gen_ce) self.gen_loss_l1 = tf.reduce_mean( tf.abs(self.input_color - gen)) * self.options.l1_weight self.gen_loss = self.gen_loss_gan + self.gen_loss_l1 self.sampler = tf.identity(gen_factory.create(self.input_gray, kernel, seed, reuse_variables=True), name='output') self.accuracy = pixelwise_accuracy(self.input_color, gen, self.options.color_space, self.options.acc_thresh) self.learning_rate = tf.constant(self.options.lr) # learning rate decay if self.options.lr_decay and self.options.lr_decay_rate > 0: self.learning_rate = tf.maximum( 1e-6, tf.train.exponential_decay( learning_rate=self.options.lr, global_step=self.global_step, decay_steps=self.options.lr_decay_steps, decay_rate=self.options.lr_decay_rate)) # generator optimizaer # 学习者的优化器 self.gen_train = tf.train.AdamOptimizer( learning_rate=self.learning_rate, beta1=self.options.beta1).minimize(self.gen_loss, var_list=gen_factory.var_list) # discriminator optimizaer # 打分者的优化器 self.dis_train = tf.train.AdamOptimizer( learning_rate=self.learning_rate / 10, beta1=self.options.beta1).minimize(self.dis_loss, var_list=dis_factory.var_list, global_step=self.global_step) self.saver = tf.train.Saver()