def get_model(input_height, input_width, channels=3, input_count=3): # X, Y, Mask model_input = tf.keras.Input((input_count, input_height, input_width, channels)) input_noise, input_real, input_mask = model_input encoded = mm.encoder(cat(input_noise, input_mask)) decoded = mm.decoder(encoded) image_result = decoded * (1 - input_mask) + input_real * input_mask probabilities = mm.discriminator_red(cat(image_result, input_real)) image_loss, input_loss = probabilities hinge_loss = tf.keras.activations.relu(cat(1-image_loss, 1+input_loss))
def __init__(self, input_height=256, input_width=256, batch_size=1, bar_model_name=None, bar_checkpoint_name=None, mosaic_model_name=None, mosaic_checkpoint_name=None, is_mosaic=False): self.bar_model_name = bar_model_name self.bar_checkpoint_name = bar_checkpoint_name self.mosaic_model_name = mosaic_model_name self.mosaic_checkpoint_name = mosaic_checkpoint_name self.is_mosaic = is_mosaic self.input_height = input_height self.input_width = input_width self.batch_size = batch_size self.check_model_file() self.build_model() self.discriminator = mm.discriminator_red(tf.keras.Input(batch_size, input_height, input_width, 3) ) self.encoder = mm.encoder(tf.keras.Input(batch_size, input_height, input_width, 6))
def build_model(self): # ------- variables tf.compat.v1.disable_eager_execution() self.X = tf.compat.v1.placeholder( tf.float32, [self.batch_size, self.input_height, self.input_width, 3]) self.Y = tf.compat.v1.placeholder( tf.float32, [self.batch_size, self.input_height, self.input_width, 3]) self.MASK = tf.compat.v1.placeholder( tf.float32, [self.batch_size, self.input_height, self.input_width, 3]) IT = tf.compat.v1.placeholder(tf.float32) # ------- structure input = tf.concat([self.X, self.MASK], 3) vec_en = mm.encoder(input, reuse=False, name='G_en') vec_con = mm.contextual_block(vec_en, vec_en, self.MASK, 3, 50.0, 'CB1', stride=1) I_co = mm.decoder(vec_en, self.input_height, self.input_height, reuse=False, name='G_de') I_ge = mm.decoder(vec_con, self.input_height, self.input_height, reuse=True, name='G_de') self.image_result = I_ge * (1 - self.MASK) + self.Y * self.MASK D_real_red = mm.discriminator_red(self.Y, reuse=False, name='disc_red') D_fake_red = mm.discriminator_red(self.image_result, reuse=True, name='disc_red') # ------- Loss Loss_D_red = tf.reduce_mean( input_tensor=tf.nn.relu(1 + D_fake_red)) + tf.reduce_mean( input_tensor=tf.nn.relu(1 - D_real_red)) Loss_D = Loss_D_red Loss_gan_red = -tf.reduce_mean(input_tensor=D_fake_red) Loss_gan = Loss_gan_red Loss_s_re = tf.reduce_mean(input_tensor=tf.abs(I_ge - self.Y)) Loss_hat = tf.reduce_mean(input_tensor=tf.abs(I_co - self.Y)) A = tf.image.rgb_to_yuv((self.image_result + 1) / 2.0) A_Y = tf.cast(A[:, :, :, 0:1] * 255.0, dtype=tf.int32) B = tf.image.rgb_to_yuv((self.Y + 1) / 2.0) B_Y = tf.cast(B[:, :, :, 0:1] * 255.0, dtype=tf.int32) ssim = tf.reduce_mean(input_tensor=tf.image.ssim(A_Y, B_Y, 255.0)) alpha = IT / 1000000 Loss_G = 0.1 * Loss_gan + 10 * Loss_s_re + 5 * (1 - alpha) * Loss_hat # --------------------- variable & optimizer var_D = [ v for v in tf.compat.v1.global_variables() if v.name.startswith('disc_red') ] var_G = [ v for v in tf.compat.v1.global_variables() if v.name.startswith('G_en') or v.name.startswith('G_de') or v.name.startswith('CB1') ] update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimize_D = tf.compat.v1.train.AdamOptimizer(learning_rate=0.0004, beta1=0.5, beta2=0.9).minimize( Loss_D, var_list=var_D) optimize_G = tf.compat.v1.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9).minimize( Loss_G, var_list=var_G) config = tf.compat.v1.ConfigProto() # config.gpu_options.per_process_gpu_memory_fraction = 0.4 # config.gpu_options.allow_growth = False self.sess = tf.compat.v1.Session(config=config) init = tf.compat.v1.global_variables_initializer() self.sess.run(init) saver = tf.compat.v1.train.Saver() if self.is_mosaic: Restore = tf.compat.v1.train.import_meta_graph( self.mosaic_model_name) Restore.restore( self.sess, tf.train.latest_checkpoint(self.mosaic_checkpoint_name)) else: Restore = tf.compat.v1.train.import_meta_graph(self.bar_model_name) Restore.restore( self.sess, tf.train.latest_checkpoint(self.bar_checkpoint_name))
IT = tf.placeholder(tf.float32) # ------- structure input = tf.concat([X, MASK], 3) vec_en = mm.encoder(input, reuse=False) vec_con = mm.contextual_block(vec_en, vec_en, MASK) I_co = mm.decoder(vec_en, reuse=False) I_ge = mm.decoder(vec_con, reuse=True) image_result = I_ge * (1 - MASK) + Y * MASK D_real_red = mm.discriminator_red(Y, reuse=False) D_fake_red = mm.discriminator_red(image_result, reuse=True) # ------- Loss Loss_D_red = tf.reduce_mean(tf.nn.relu(1 + D_fake_red)) + tf.reduce_mean( tf.nn.relu(1 - D_real_red)) Loss_D = Loss_D_red Loss_gan_red = -tf.reduce_mean(D_fake_red) Loss_gan = Loss_gan_red Loss_s_re = tf.reduce_mean(tf.abs(I_ge - Y)) Loss_hat = tf.reduce_mean(tf.abs(I_co - Y))
IT = tf.placeholder(tf.float32) # ------- structure input = tf.concat([X, MASK], 3) vec_en = mm.encoder(input, reuse=False, name='G_en') vec_con = mm.contextual_block(vec_en, vec_en, MASK, 3, 50.0, 'CB1', stride=1) I_co = mm.decoder(vec_en, Height, reuse=False, name='G_de') I_ge = mm.decoder(vec_con, Height, reuse=True, name='G_de') image_result = I_ge * (1 - MASK) + Y * MASK D_real_red = mm.discriminator_red(Y, reuse=False, name='disc_red') D_fake_red = mm.discriminator_red(image_result, reuse=True, name='disc_red') # ------- Loss Loss_D_red = tf.reduce_mean(tf.nn.relu(1 + D_fake_red)) + tf.reduce_mean( tf.nn.relu(1 - D_real_red)) Loss_D = Loss_D_red Loss_gan_red = -tf.reduce_mean(D_fake_red) Loss_gan = Loss_gan_red Loss_s_re = tf.reduce_mean(tf.abs(I_ge - Y)) Loss_hat = tf.reduce_mean(tf.abs(I_co - Y))