def __init__(self, generator: Generator, discriminator: Discriminator, optimizer: Optimizer, random_dimension, model_path, loss_dir, images_directory): self.random_dim = random_dimension self.generator = generator.build(optimizer) self.discriminator = discriminator.build(optimizer) self.discriminator_losses = [] self.generator_losses = [] self.optimizer = optimizer self.model_path = model_path self.loss_dir = loss_dir self.img_dir = images_directory
class WGANGP(Model): def __init__(self, scope_name, channel_min, img_size, generator_size=100, channel_rate=2): super(WGANGP, self).__init__(scope_name) self.scope_name = scope_name self.channel_min = channel_min self.channel_rate = channel_rate self.img_size = img_size self.input_img = tf.placeholder(tf.float32, shape=[None] + img_size, name="input_img") self.input_z = tf.placeholder(tf.float32, shape=[None, generator_size], name="input_z") def buind(self, train_fn=tf_tools.adam_fn, real_lr=1e-5, fake_lr=1e-5): with tf.variable_scope(self.scope_name) as scope: self.fake_img, self.real_output, self.fake_output = self.buind_network( ) self.var_list = tf.trainable_variables(scope=self.scope_name) self.real_train_fn = train_fn(real_lr, 0, 0.9) self.fake_train_fn = train_fn(fake_lr, 0, 0.9) self.build_optimization() def buind_network(self, fake_normal=True): self.real_network = Discriminator(self.channel_min, 1, name="discriminator") self.fake_network = Generator(self.channel_min, self.img_size, name="generator") fake_img = self.fake_network.build(self.input_z, times=3) if fake_normal: fake_img = tf.nn.sigmoid(fake_img) real_output = self.real_network.build(self.input_img, times=3, normal=False) fake_output = self.real_network.build(fake_img, times=3, reuse=tf.AUTO_REUSE, normal=False) return fake_img, real_output, fake_output def build_optimization(self): epsilon = tf.random_uniform([], 0.0, 1.0) x_hat = self.input_img * epsilon + (1 - epsilon) * self.fake_img d_hat = self.real_network.build(x_hat, times=3, reuse=tf.AUTO_REUSE, normal=False) gradients = tf.gradients(d_hat, x_hat)[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) gradient_penalty = 10 * tf.reduce_mean((slopes - 1.0)**2) self.real_loss_op = tf.reduce_mean(self.fake_output) - tf.reduce_mean( self.real_output) + gradient_penalty self.fake_loss_op = -tf.reduce_mean(self.fake_output) self.real_index = tf.Variable(0) self.real_train_op = self.real_train_fn.minimize( self.real_loss_op, var_list=self.real_network.var_list, global_step=self.real_index) self.fake_index = tf.Variable(0) self.fake_train_op = self.fake_train_fn.minimize( self.fake_loss_op, var_list=self.fake_network.var_list, global_step=self.fake_index) def predict(self, z): session = tf.get_default_session() feed_dict = {self.input_z: z} output = session.run(self.fake_img, feed_dict=feed_dict) return output def train(self, img, z, mode='D'): session = tf.get_default_session() feed_dict = {self.input_img: img, self.input_z: z} if mode == 'D': session.run(self.real_train_op, feed_dict=feed_dict) else: session.run(self.fake_train_op, feed_dict=feed_dict) return self.loss(img, z) def loss(self, img, z): session = tf.get_default_session() feed_dict = {self.input_img: img, self.input_z: z} loss = session.run([self.real_loss_op, self.fake_loss_op], feed_dict=feed_dict) return loss
class WGAN(Model): def __init__(self, scope_name, channel_min, img_size, generator_size=100, channel_rate=2): super(WGAN, self).__init__(scope_name) self.scope_name = scope_name self.channel_min = channel_min self.channel_rate = channel_rate self.img_size = img_size self.input_img = tf.placeholder(tf.float32, shape=[None] + img_size, name="input_img") self.input_z = tf.placeholder(tf.float32, shape=[None, generator_size], name="input_z") def buind(self, train_fn=tf_tools.adam_fn, real_lr=1e-5, fake_lr=1e-5): with tf.variable_scope(self.scope_name) as scope: self.fake_img, self.real_output, self.fake_output = self.buind_network() self.var_list = tf.trainable_variables(scope=self.scope_name) self.real_train_fn = train_fn(real_lr, 0, 0.9) self.fake_train_fn = train_fn(fake_lr, 0, 0.9) self.build_optimization() def buind_network(self, fake_normal=True): self.real_network = Discriminator(self.channel_min, 1, name="discriminator") self.fake_network = Generator(self.channel_min, self.img_size, name="generator") fake_img = self.fake_network.build(self.input_z, times=3) if fake_normal: fake_img = tf.nn.sigmoid(fake_img) real_output = self.real_network.build(self.input_img, times=3, normal=False) fake_output = self.real_network.build(fake_img, times=3, reuse=tf.AUTO_REUSE, normal=False) return fake_img, real_output, fake_output def build_optimization(self): self.real_loss_op = tf.reduce_mean(self.fake_output) - tf.reduce_mean(self.real_output) self.fake_loss_op = -tf.reduce_mean(self.fake_output) self.real_index = tf.Variable(0) self.real_train_op = self.real_train_fn.minimize(self.real_loss_op, var_list=self.real_network.var_list, global_step=self.real_index) self.clip_op = [param.assign(tf.clip_by_value(param, -0.01, 0.01)) for param in self.real_network.var_list] self.fake_index = tf.Variable(0) self.fake_train_op = self.fake_train_fn.minimize(self.fake_loss_op, var_list=self.fake_network.var_list, global_step=self.fake_index) def predict(self, z): session = tf.get_default_session() feed_dict = {self.input_z: z} output = session.run(self.fake_img, feed_dict=feed_dict) return output def train(self, img, z, mode='D'): session = tf.get_default_session() feed_dict = {self.input_img: img, self.input_z: z} if mode == 'D': session.run(self.clip_op, feed_dict=feed_dict) session.run(self.real_train_op, feed_dict=feed_dict) else: session.run(self.fake_train_op, feed_dict=feed_dict) return self.loss(img, z) def loss(self, img, z): session = tf.get_default_session() feed_dict = {self.input_img: img, self.input_z: z} loss = session.run([self.real_loss_op, self.fake_loss_op], feed_dict=feed_dict) return loss