class BigGAN_256(object): ################################################################################## # Generator ################################################################################## def generator(self, z, is_training=True, reuse=False): with tf.variable_scope("generator", reuse=reuse): # 7 if self.z_dim == 128: split_dim = 18 split_dim_remainder = self.z_dim - (split_dim * 6) z_split = tf.split(z, num_or_size_splits=[split_dim] * 6 + [split_dim_remainder], axis=-1) else: split_dim = self.z_dim // 7 split_dim_remainder = self.z_dim - (split_dim * 7) if split_dim_remainder == 0: z_split = tf.split(z, num_or_size_splits=[split_dim] * 7, axis=-1) else: z_split = tf.split(z, num_or_size_splits=[split_dim] * 6 + [split_dim_remainder], axis=-1) ch = 16 * self.ch x = fully_conneted(z_split[0], units=4 * 4 * ch, sn=self.sn, scope='dense') x = tf.reshape(x, shape=[-1, 4, 4, ch]) x = resblock_up_condition(x, z_split[1], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_16') ch = ch // 2 x = resblock_up_condition(x, z_split[2], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_0') x = resblock_up_condition(x, z_split[3], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_1') ch = ch // 2 x = resblock_up_condition(x, z_split[4], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_4') ch = ch // 2 x = resblock_up_condition(x, z_split[5], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_2') # Non-Local Block x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention') ch = ch // 2 x = resblock_up_condition(x, z_split[6], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_1') x = batch_norm(x, is_training) x = relu(x) x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='G_logit') x = tanh(x) return x ################################################################################## # Discriminator ################################################################################## def discriminator(self, x, is_training=True, reuse=False): with tf.variable_scope("discriminator", reuse=reuse): ch = self.ch x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_1') ch = ch * 2 x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_2') # Non-Local Block x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention') ch = ch * 2 x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_4') ch = ch * 2 x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_0') x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_1') ch = ch * 2 x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_16') x = resblock(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock') x = relu(x) x = global_sum_pooling(x) x = fully_conneted(x, units=1, sn=self.sn, scope='D_logit') return x def gradient_penalty(self, real, fake): if self.gan_type.__contains__('dragan'): eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.) _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) x_std = tf.sqrt( x_var) # magnitude of noise decides the size of local region fake = real + 0.5 * x_std * eps alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) interpolated = real + alpha * (fake - real) logit = self.discriminator(interpolated, reuse=True) grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated) grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm GP = 0 # WGAN - LP if self.gan_type == 'wgan-lp': GP = self.ld * tf.reduce_mean( tf.square(tf.maximum(0.0, grad_norm - 1.))) elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)) return GP ################################################################################## # Model ################################################################################## def build_model(self): """ Graph Input """ # images Image_Data_Class = ImageData(self.img_size, self.c_dim, self.custom_dataset) inputs = tf.data.Dataset.from_tensor_slices(self.data) gpu_device = '/gpu:0' inputs = inputs.\ apply(shuffle_and_repeat(self.dataset_num)).\ apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).\ apply(prefetch_to_device(gpu_device, self.batch_size)) inputs_iterator = inputs.make_one_shot_iterator() self.inputs = inputs_iterator.get_next() # noises self.z = tf.truncated_normal(shape=[self.batch_size, 1, 1, self.z_dim], name='random_z') """ Loss Function """ # output of D for real images real_logits = self.discriminator(self.inputs) # output of D for fake images fake_images = self.generator(self.z) fake_logits = self.discriminator(fake_images, reuse=True) if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_penalty(real=self.inputs, fake=fake_images) else: GP = 0 # get loss for discriminator self.d_loss = discriminator_loss( self.gan_type, real=real_logits, fake=fake_logits) + GP # get loss for generator self.g_loss = generator_loss(self.gan_type, fake=fake_logits) """ Training """ # divide trainable variables into a group for D and a group for G t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'discriminator' in var.name] g_vars = [var for var in t_vars if 'generator' in var.name] # optimizers with tf.control_dependencies(tf.get_collection( tf.GraphKeys.UPDATE_OPS)): self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize( self.d_loss, var_list=d_vars) self.opt = MovingAverageOptimizer(tf.train.AdamOptimizer( self.g_learning_rate, beta1=self.beta1, beta2=self.beta2), average_decay=self.moving_decay) self.g_optim = self.opt.minimize(self.g_loss, var_list=g_vars) """" Testing """ # for test self.fake_images = self.generator(self.z, is_training=False, reuse=True) """ Summary """ self.d_sum = tf.summary.scalar("d_loss", self.d_loss) self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
class BigGAN_256(object): def __init__(self, sess, args): self.model_name = "BigGAN" # name for checkpoint self.sess = sess self.dataset_name = args.dataset self.checkpoint_dir = args.checkpoint_dir self.sample_dir = args.sample_dir self.result_dir = args.result_dir self.log_dir = args.log_dir self.epoch = args.epoch self.iteration = args.iteration self.batch_size = args.batch_size self.print_freq = args.print_freq self.save_freq = args.save_freq self.img_size = args.img_size """ Generator """ self.ch = args.ch self.z_dim = args.z_dim # dimension of noise-vector self.gan_type = args.gan_type """ Discriminator """ self.n_critic = args.n_critic self.sn = args.sn self.ld = args.ld self.sample_num = args.sample_num # number of generated images to be saved self.test_num = args.test_num # train self.g_learning_rate = args.g_lr self.d_learning_rate = args.d_lr self.beta1 = args.beta1 self.beta2 = args.beta2 self.moving_decay = args.moving_decay self.custom_dataset = False if self.dataset_name == 'mnist': self.c_dim = 1 self.data = load_mnist() elif self.dataset_name == 'cifar10': self.c_dim = 3 self.data = load_cifar10() else: self.c_dim = 3 self.data = load_data(dataset_name=self.dataset_name) self.custom_dataset = True self.dataset_num = len(self.data) self.sample_dir = os.path.join(self.sample_dir, self.model_dir) check_folder(self.sample_dir) print() print("##### Information #####") print("# BigGAN 256") print("# gan type : ", self.gan_type) print("# dataset : ", self.dataset_name) print("# dataset number : ", self.dataset_num) print("# batch_size : ", self.batch_size) print("# epoch : ", self.epoch) print("# iteration per epoch : ", self.iteration) print() print("##### Generator #####") print("# spectral normalization : ", self.sn) print("# learning rate : ", self.g_learning_rate) print() print("##### Discriminator #####") print("# the number of critic : ", self.n_critic) print("# spectral normalization : ", self.sn) print("# learning rate : ", self.d_learning_rate) ################################################################################## # Generator ################################################################################## def generator(self, z, is_training=True, reuse=False): with tf.variable_scope("generator", reuse=reuse): # 7 if self.z_dim == 128: split_dim = 18 split_dim_remainder = self.z_dim - (split_dim * 6) z_split = tf.split(z, num_or_size_splits=[split_dim] * 6 + [split_dim_remainder], axis=-1) else: split_dim = self.z_dim // 7 split_dim_remainder = self.z_dim - (split_dim * 7) if split_dim_remainder == 0: z_split = tf.split(z, num_or_size_splits=[split_dim] * 7, axis=-1) else: z_split = tf.split(z, num_or_size_splits=[split_dim] * 6 + [split_dim_remainder], axis=-1) ch = 16 * self.ch x = fully_conneted(z_split[0], units=4 * 4 * ch, sn=self.sn, scope='dense') x = tf.reshape(x, shape=[-1, 4, 4, ch]) x = resblock_up_condition(x, z_split[1], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_16') ch = ch // 2 x = resblock_up_condition(x, z_split[2], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_0') x = resblock_up_condition(x, z_split[3], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_8_1') ch = ch // 2 x = resblock_up_condition(x, z_split[4], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_4') ch = ch // 2 x = resblock_up_condition(x, z_split[5], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_2') # Non-Local Block x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention') ch = ch // 2 x = resblock_up_condition(x, z_split[6], channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_up_1') x = batch_norm(x, is_training) x = relu(x) x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='G_logit') x = tanh(x) return x ################################################################################## # Discriminator ################################################################################## def discriminator(self, x, is_training=True, reuse=False): with tf.variable_scope("discriminator", reuse=reuse): ch = self.ch x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_1') ch = ch * 2 x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_2') # Non-Local Block x = self_attention_2(x, channels=ch, sn=self.sn, scope='self_attention') ch = ch * 2 x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_4') ch = ch * 2 x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_0') x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_8_1') ch = ch * 2 x = resblock_down(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock_down_16') x = resblock(x, channels=ch, use_bias=False, is_training=is_training, sn=self.sn, scope='resblock') x = relu(x) x = global_sum_pooling(x) x = fully_conneted(x, units=1, sn=self.sn, scope='D_logit') return x def gradient_penalty(self, real, fake): if self.gan_type.__contains__('dragan'): eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.) _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) x_std = tf.sqrt( x_var) # magnitude of noise decides the size of local region fake = real + 0.5 * x_std * eps alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) interpolated = real + alpha * (fake - real) logit = self.discriminator(interpolated, reuse=True) grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated) grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm GP = 0 # WGAN - LP if self.gan_type == 'wgan-lp': GP = self.ld * tf.reduce_mean( tf.square(tf.maximum(0.0, grad_norm - 1.))) elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)) return GP ################################################################################## # Model ################################################################################## def build_model(self): """ Graph Input """ # images Image_Data_Class = ImageData(self.img_size, self.c_dim, self.custom_dataset) inputs = tf.data.Dataset.from_tensor_slices(self.data) gpu_device = '/gpu:0' inputs = inputs.\ apply(shuffle_and_repeat(self.dataset_num)).\ apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).\ apply(prefetch_to_device(gpu_device, self.batch_size)) inputs_iterator = inputs.make_one_shot_iterator() self.inputs = inputs_iterator.get_next() # noises self.z = tf.truncated_normal(shape=[self.batch_size, 1, 1, self.z_dim], name='random_z') """ Loss Function """ # output of D for real images real_logits = self.discriminator(self.inputs) # output of D for fake images fake_images = self.generator(self.z) fake_logits = self.discriminator(fake_images, reuse=True) if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': GP = self.gradient_penalty(real=self.inputs, fake=fake_images) else: GP = 0 # get loss for discriminator self.d_loss = discriminator_loss( self.gan_type, real=real_logits, fake=fake_logits) + GP # get loss for generator self.g_loss = generator_loss(self.gan_type, fake=fake_logits) """ Training """ # divide trainable variables into a group for D and a group for G t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'discriminator' in var.name] g_vars = [var for var in t_vars if 'generator' in var.name] # optimizers with tf.control_dependencies(tf.get_collection( tf.GraphKeys.UPDATE_OPS)): self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize( self.d_loss, var_list=d_vars) self.opt = MovingAverageOptimizer(tf.train.AdamOptimizer( self.g_learning_rate, beta1=self.beta1, beta2=self.beta2), average_decay=self.moving_decay) self.g_optim = self.opt.minimize(self.g_loss, var_list=g_vars) """" Testing """ # for test self.fake_images = self.generator(self.z, is_training=False, reuse=True) """ Summary """ self.d_sum = tf.summary.scalar("d_loss", self.d_loss) self.g_sum = tf.summary.scalar("g_loss", self.g_loss) ################################################################################## # Train ################################################################################## def train(self): # initialize all variables tf.global_variables_initializer().run() # saver to save model self.saver = self.opt.swapping_saver() # summary writer self.writer = tf.summary.FileWriter( self.log_dir + '/' + self.model_dir, self.sess.graph) # restore check-point if it exits could_load, checkpoint_counter = self.load(self.checkpoint_dir) if could_load: start_epoch = (int)(checkpoint_counter / self.iteration) start_batch_id = checkpoint_counter - start_epoch * self.iteration counter = checkpoint_counter print(" [*] Load SUCCESS") else: start_epoch = 0 start_batch_id = 0 counter = 1 print(" [!] Load failed...") # loop for epoch start_time = time.time() past_g_loss = -1. for epoch in range(start_epoch, self.epoch): # get batch data for idx in range(start_batch_id, self.iteration): # update D network _, summary_str, d_loss = self.sess.run( [self.d_optim, self.d_sum, self.d_loss]) self.writer.add_summary(summary_str, counter) # update G network g_loss = None if (counter - 1) % self.n_critic == 0: _, summary_str, g_loss = self.sess.run( [self.g_optim, self.g_sum, self.g_loss]) self.writer.add_summary(summary_str, counter) past_g_loss = g_loss # display training status counter += 1 if g_loss == None: g_loss = past_g_loss print("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) # save training results for every 300 steps if np.mod(idx + 1, self.print_freq) == 0: samples = self.sess.run(self.fake_images) tot_num_samples = min(self.sample_num, self.batch_size) manifold_h = int(np.floor(np.sqrt(tot_num_samples))) manifold_w = int(np.floor(np.sqrt(tot_num_samples))) save_images( samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], './' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format(epoch, idx + 1)) if np.mod(idx + 1, self.save_freq) == 0: self.save(self.checkpoint_dir, counter) # After an epoch, start_batch_id is set to zero # non-zero value is only for the first epoch after loading pre-trained model start_batch_id = 0 # save model self.save(self.checkpoint_dir, counter) # show temporal results # self.visualize_results(epoch) # save model for final step self.save(self.checkpoint_dir, counter) @property def model_dir(self): if self.sn: sn = '_sn' else: sn = '' return "{}_{}_{}_{}_{}{}".format(self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, sn) def save(self, checkpoint_dir, step): checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) def load(self, checkpoint_dir): print(" [*] Reading checkpoints...") checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) counter = int(ckpt_name.split('-')[-1]) print(" [*] Success to read {}".format(ckpt_name)) return True, counter else: print(" [*] Failed to find a checkpoint") return False, 0 def visualize_results(self, epoch): tot_num_samples = min(self.sample_num, self.batch_size) image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) """ random condition, random noise """ samples = self.sess.run(self.fake_images) save_images( samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], self.sample_dir + '/' + self.model_name + '_epoch%02d' % epoch + '_visualize.png') def test(self): tf.global_variables_initializer().run() self.saver = tf.train.Saver() could_load, checkpoint_counter = self.load(self.checkpoint_dir) result_dir = os.path.join(self.result_dir, self.model_dir) check_folder(result_dir) if could_load: print(" [*] Load SUCCESS") else: print(" [!] Load failed...") tot_num_samples = min(self.sample_num, self.batch_size) image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) """ random condition, random noise """ for i in range(self.test_num): samples = self.sess.run(self.fake_images) save_images( samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], result_dir + '/' + self.model_name + '_test_{}.png'.format(i))
def tpu_model_fn(self, features, labels, mode, params): params = EasyDict(**params) d_loss, d_vars, g_loss, g_vars, fake_images, fake_logits, z = self.base_model_fn( features, labels, mode, params) # -------------------------------------------------------------------------- # Predict # -------------------------------------------------------------------------- if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "z": z, "fake_image": fake_images, "fake_logits": fake_logits, } return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions) # -------------------------------------------------------------------------- # Train or Eval # -------------------------------------------------------------------------- loss = g_loss for i in range(params.n_critic): loss += d_loss if mode == tf.estimator.ModeKeys.EVAL: # Hack to allow it out of a fixed batch size TPU d_loss_batched = tf.tile(tf.expand_dims(d_loss, 0), [params.batch_size]) g_loss_batched = tf.tile(tf.expand_dims(g_loss, 0), [params.batch_size]) return tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=(lambda d_loss, g_loss, fake_logits: self. tpu_metric_fn(d_loss, g_loss, fake_logits), [d_loss_batched, g_loss_batched, fake_logits])) if mode == tf.estimator.ModeKeys.TRAIN: # Create training ops for both D and G d_optimizer = tf.train.AdamOptimizer(params.d_lr, beta1=params.beta1, beta2=params.beta2) if params.use_tpu: d_optimizer = tf.contrib.tpu.CrossShardOptimizer(d_optimizer) d_train_op = d_optimizer.minimize( d_loss, var_list=d_vars, global_step=tf.train.get_global_step()) g_optimizer = MovingAverageOptimizer( tf.train.AdamOptimizer(params.g_lr, beta1=params.beta1, beta2=params.beta2), average_decay=params.moving_decay) if params.use_tpu: g_optimizer = tf.contrib.tpu.CrossShardOptimizer(g_optimizer) g_train_op = g_optimizer.minimize( g_loss, var_list=g_vars, global_step=tf.train.get_global_step()) # For each training op of G, do n_critic training ops of D train_ops = [g_train_op] for i in range(params.n_critic): train_ops.append(d_train_op) train_op = tf.group(*train_ops) return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)