def train(train_datasetA, train_datasetB, epochs, lsgan=True, cyc_lambda=10): for epoch in range(epochs): start = time.time() with tf.GradientTape() as genA2B_tape, tf.GradientTape() as genB2A_tape, \ tf.GradientTape() as discA_tape, tf.GradientTape() as discB_tape: try: # Next training minibatches, default size 1 trainA = next(train_datasetA) trainB = next(train_datasetB) except tf.errors.OutOfRangeError: print("Error, run out of data") break genA2B_output = genA2B(trainA, training=True) genB2A_output = genB2A(trainB, training=True) discA_real_output = discA(trainA, training=True) discB_real_output = discB(trainB, training=True) discA_fake_output = discA(genB2A_output, training=True) discB_fake_output = discB(genA2B_output, training=True) reconstructedA = genB2A(genA2B_output, training=True) reconstructedB = genA2B(genB2A_output, training=True) # Use history buffer of 50 for disc loss discA_loss = discriminator_loss(discA_real_output, discA_fake_output, lsgan=lsgan) discB_loss = discriminator_loss(discB_real_output, discB_fake_output, lsgan=lsgan) genA2B_loss = generator_loss(discB_fake_output, lsgan=lsgan) + \ cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB, cyc_lambda=cyc_lambda) genB2A_loss = generator_loss(discA_fake_output, lsgan=lsgan) + \ cycle_consistency_loss(trainA, trainB, reconstructedA, reconstructedB, cyc_lambda=cyc_lambda) genA2B_gradients = genA2B_tape.gradient(genA2B_loss, genA2B.trainable_variables) genB2A_gradients = genB2A_tape.gradient(genB2A_loss, genB2A.trainable_variables) discA_gradients = discA_tape.gradient(discA_loss, discA.trainable_variables) discB_gradients = discB_tape.gradient(discB_loss, discB.trainable_variables) genA2B_optimizer.apply_gradients(zip(genA2B_gradients, genA2B.trainable_variables)) genB2A_optimizer.apply_gradients(zip(genB2A_gradients, genB2A.trainable_variables)) discA_optimizer.apply_gradients(zip(discA_gradients, discA.trainable_variables)) discB_optimizer.apply_gradients(zip(discB_gradients, discB.trainable_variables)) if epoch % 40 == 0: generate_images(trainA, trainB, genB2A_output, genA2B_output, epoch) print('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
def train_step(images): noise = tf.random.normal([args.batsize, noise_dim]) # D and G learns separately with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True) real_output = discriminator(images, training=True) fake_output = discriminator(generated_images, training=True) gen_loss = generator_loss(fake_output) disc_loss = discriminator_loss(real_output, fake_output, args.alpha) gradients_of_generator = gen_tape.gradient( gen_loss, generator.trainable_variables) gradients_of_discriminator = disc_tape.gradient( disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients( zip(gradients_of_generator, generator.trainable_variables)) discriminator_optimizer.apply_gradients( zip(gradients_of_discriminator, discriminator.trainable_variables)) return gen_loss.numpy(), disc_loss.numpy()
def train_step(input_image, target, epoch): with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: gen_output = generator(input_image, training=True) disc_real_output = discriminator([input_image, target], training=True) disc_generated_output = discriminator([input_image, gen_output], training=True) gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss( disc_generated_output, gen_output, target) disc_loss = discriminator_loss(disc_real_output, disc_generated_output) generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables) discriminator_gradients = disc_tape.gradient( disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients( zip(generator_gradients, generator.trainable_variables)) discriminator_optimizer.apply_gradients( zip(discriminator_gradients, discriminator.trainable_variables)) with summary_writer.as_default(): tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch) tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch) tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch) tf.summary.scalar('disc_loss', disc_loss, step=epoch)
def train_step(self, input_image, target): # def train_step(self, input_image, target, meta): with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: gen_output = self.generator(input_image, training=True) # gen_output = self.generator([input_image, meta], training = True) # disc_real_output = self.discriminator([input_image, meta, target], training = True) # disc_gen_output = self.discriminator([input_image, meta, gen_output], training = True) disc_real_output = self.discriminator([input_image, target], training=True) disc_gen_output = self.discriminator([input_image, gen_output], training=True) gen_total_loss, _, _ = generator_loss(disc_gen_output, gen_output, target) disc_loss = discriminator_loss(disc_real_output, disc_gen_output) generator_gradients = gen_tape.gradient( gen_total_loss, self.generator.trainable_variables) discriminator_gradients = disc_tape.gradient( disc_loss, self.discriminator.trainable_variables) tf.print( 'XX:XX:XX INFO trainer > Generator Loss: ', gen_total_loss) tf.print( 'XX:XX:XX INFO trainer > Discriminator Loss: ', disc_loss) self.generator_optimizer.apply_gradients( zip(generator_gradients, self.generator.trainable_variables)) self.discriminator_optimizer.apply_gradients( zip(discriminator_gradients, self.discriminator.trainable_variables))
def train_step(images): noise = tf.random.normal([BATCH_SIZE, noise_dim]) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True) real_output = discriminator(images, training=True) fake_output = discriminator(generated_images, training=True) gen_loss = generator_loss(fake_output) disc_loss = discriminator_loss(real_output, fake_output) gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train_step(input_image, target): ''' Perform one training step Args: input_image : Input image target : Output image (ground thruth) Returns: gen_loss : Generator loss disc_loss : Dicriminator loss ''' with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: # Compute the Generator output gen_output = generator(input_image, training=True) # Compute the Discriminator output for real and generated inputs disc_real_output = discriminator([input_image, target], training=True) disc_generated_output = discriminator([input_image, gen_output], training=True) # Computes the Generator and Discriminator losses gen_loss = generator_loss(disc_generated_output, gen_output, target) disc_loss = discriminator_loss(disc_real_output, disc_generated_output) # Apply Gradient Descent generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables) discriminator_gradients = disc_tape.gradient( disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients( zip(generator_gradients, generator.trainable_variables)) discriminator_optimizer.apply_gradients( zip(discriminator_gradients, discriminator.trainable_variables)) return gen_loss, disc_loss, gen_output
def train_step(input_image, target): with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: gen_output = generator(input_image, training=True) disc_real_output = discriminator([input_image, target], training=True) disc_generated_output = discriminator([input_image, gen_output], training=True) gen_loss = generator_loss(disc_generated_output, gen_output, target) disc_loss = discriminator_loss(disc_real_output, disc_generated_output) generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables) discriminator_gradients = disc_tape.gradient( disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients( zip(generator_gradients, generator.trainable_variables)) discriminator_optimizer.apply_gradients( zip(discriminator_gradients, discriminator.trainable_variables))
def train_step(input_data, target): with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape: gen_output = gen(input_data) dis_real_output = dis([input_data, target]) dis_gene_output = dis([input_data, gen_output]) tot_gen_loss, gen_loss, gen_l1_loss = model.generator_loss(dis_gene_output, gen_output, target) tot_dis_loss = model.discriminator_loss(dis_real_output, dis_gene_output) gen_gradients = gen_tape.gradient(tot_gen_loss, gen.trainable_variables) dis_gradients = dis_tape.gradient(tot_dis_loss, dis.trainable_variables) generator_optimizer.apply_gradients( zip(gen_gradients, gen.trainable_variables) ) discriminator_optimizer.apply_gradients( zip(dis_gradients, dis.trainable_variables) ) return tot_gen_loss, gen_loss, gen_l1_loss, tot_dis_loss
def train_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, noise_size=96, num_epochs=10): iter_count = 0 for epoch in range(num_epochs): for x, _ in train_data: bs = x.shape[0] # 判别网络 real_data = Variable(x) # 真实数据 logits_real = D_net(real_data) # 判别网络得分 sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布 g_fake_seed = Variable(sample_noise) fake_images = G_net(g_fake_seed) # 生成的假的数据 logits_fake = D_net(fake_images) # 判别网络得分 d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss D_optimizer.zero_grad() d_total_error.backward() D_optimizer.step() # 优化判别网络 # 生成网络 g_fake_seed = Variable(sample_noise).cuda() fake_images = G_net(g_fake_seed) # 生成的假的数据 gen_logits_fake = D_net(fake_images) g_error = generator_loss(gen_logits_fake) # 生成网络的 loss G_optimizer.zero_grad() g_error.backward() G_optimizer.step() # 优化生成网络
def main(): parser = argparse.ArgumentParser(description='Train Blending GAN') parser.add_argument('--nef', type=int, default=64, help='number of base filters in encoder') parser.add_argument('--ngf', type=int, default=64, help='number of base filters in decoder') parser.add_argument('--nc', type=int, default=3, help='number of output channels in decoder') parser.add_argument('--nBottleneck', type=int, default=4000, help='number of output channels in encoder') parser.add_argument('--ndf', type=int, default=64, help='number of base filters in D') parser.add_argument('--lr_d', type=float, default=0.0002, help='Learning rate for Critic, default=0.0002') parser.add_argument('--lr_g', type=float, default=0.002, help='Learning rate for Generator, default=0.002') parser.add_argument('--beta1', type=float, default=0.5, help='Beta for Adam, default=0.5') parser.add_argument('--l2_weight', type=float, default=0.99, help='Weight for l2 loss, default=0.999') parser.add_argument('--train_steps', default=float("58000"), help='Max amount of training cycles') parser.add_argument('--batch_size', type=int, default=64, help='Input batch size') parser.add_argument('--data_root', default='DataBase/TransientAttributes/cropped_images', help='Path to dataset') parser.add_argument('--train_data_root', default='DataBase/TransientAttributes/train.tfrecords', help='Path to train dataset') parser.add_argument('--val_data_root', default='DataBase/TransientAttributes/val.tfrecords', help='Path to val dataset') parser.add_argument( '--image_size', type=int, default=64, help='The height / width of the network\'s input image') parser.add_argument( '--d_iters', type=int, default=5, help='# of discriminator iters per each generator iter') parser.add_argument('--clamp_lower', type=float, default=-0.01, help='Lower bound for weight clipping') parser.add_argument('--clamp_upper', type=float, default=0.01, help='Upper bound for weight clipping') parser.add_argument('--experiment', default='blending_gan', help='Where to store samples and models') parser.add_argument('--save_folder', default='GP-GAN_training', help='location to save') parser.add_argument('--tboard_save_dir', default='tensorboard', help='location to save tboard records') parser.add_argument('--val_freq', type=int, default=500, help='frequency of validation') parser.add_argument('--snapshot_interval', type=int, default=500, help='Interval of snapshot (steps)') parser.add_argument('--weights_path', type=str, default=None, help='path to checkpoint') args = parser.parse_args() print('Input arguments:') for key, value in vars(args).items(): print('\t{}: {}'.format(key, value)) print('') # Set up generator & discriminator print('Create & Init models ...') print('\tInit Generator network ...') generator = EncoderDecoder(encoder_filters=args.nef, encoded_dims=args.nBottleneck, output_channels=args.nc, decoder_filters=args.ngf, is_training=True, image_size=args.image_size, skip=False, scope_name='generator') #, conv_init=init_conv, generator_val = EncoderDecoder(encoder_filters=args.nef, encoded_dims=args.nBottleneck, output_channels=args.nc, decoder_filters=args.ngf, is_training=False, image_size=args.image_size, skip=False, scope_name='generator') print('\tInit Discriminator network ...') discriminator = DCGAN_D(image_size=args.image_size, encoded_dims=1, filters=args.ndf, is_training=True, scope_name='discriminator' ) #, conv_init=init_conv, bn_init=init_bn) # D discriminator_val = DCGAN_D(image_size=args.image_size, encoded_dims=1, filters=args.ndf, is_training=False, scope_name='discriminator') # Set up training graph with tf.device('/gpu:0'): train_dataset = DataFeeder(tfrecords_path=args.train_data_root, dataset_flag='train') composed_image, real_image = train_dataset.inputs( batch_size=args.batch_size, name='train_dataset') shape = composed_image.get_shape().as_list() composed_image.set_shape( [shape[0], args.image_size, args.image_size, shape[3]]) real_image.set_shape( [shape[0], args.image_size, args.image_size, shape[3]]) validation_dataset = DataFeeder(tfrecords_path=args.val_data_root, dataset_flag='val') composed_image_val, real_image_val = validation_dataset.inputs( batch_size=args.batch_size, name='val_dataset') composed_image_val.set_shape( [shape[0], args.image_size, args.image_size, shape[3]]) real_image_val.set_shape( [shape[0], args.image_size, args.image_size, shape[3]]) # Compute losses: # Train tensors fake = generator(composed_image) prob_disc_real = discriminator.encode(real_image) prob_disc_fake = discriminator.encode(fake) # Validation tensors fake_val = generator_val(composed_image) prob_disc_real_val = discriminator_val.encode(real_image) prob_disc_fake_val = discriminator_val.encode(fake) # Calculate losses gen_loss, l2_comp, disc_comp, fake_image_train = l2_generator_loss( fake=fake, target=real_image, prob_disc_fake=prob_disc_fake, l2_weight=args.l2_weight) disc_loss = discriminator_loss(prob_disc_real=prob_disc_real, prob_disc_fake=prob_disc_fake) gen_loss_val, _, _, fake_image_val = l2_generator_loss( fake=fake_val, target=real_image, prob_disc_fake=prob_disc_fake_val, l2_weight=args.l2_weight) disc_loss_val = discriminator_loss(prob_disc_real=prob_disc_real_val, prob_disc_fake=prob_disc_fake_val) # Set optimizers global_step = tf.Variable(0, name='global_step', trainable=False) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): discriminator_variables = [ v for v in tf.trainable_variables() if v.name.startswith("discriminator") ] generator_variables = [ v for v in tf.trainable_variables() if v.name.startswith("generator") ] optimizer_gen = tf.train.AdamOptimizer( learning_rate=args.lr_g, beta1=args.beta1).minimize(loss=gen_loss, global_step=global_step, var_list=generator_variables) optimizer_disc = tf.train.AdamOptimizer( learning_rate=args.lr_d, beta1=args.beta1).minimize(loss=disc_loss, global_step=global_step, var_list=discriminator_variables) with tf.name_scope("clip_weights"): clip_discriminator_var_op = [ var.assign( tf.clip_by_value(var, args.clamp_lower, args.clamp_upper)) for var in discriminator_variables ] # Set summaries for Tensorboard model_save_dir = os.path.join(args.save_folder, args.experiment) tboard_save_dir = os.path.join(model_save_dir, args.tboard_save_dir) os.makedirs(tboard_save_dir, exist_ok=True) sum_gen_train = tf.summary.scalar(name='train_gen_loss', tensor=gen_loss) sum_gen_disc_comp = tf.summary.scalar(name='train_gen_disc_component', tensor=disc_comp) sum_gen_l2_comp = tf.summary.scalar(name='train_gen_l2_component', tensor=l2_comp) sum_gen_val = tf.summary.scalar(name='val_gen_loss', tensor=gen_loss_val, collections='') sum_disc_train = tf.summary.scalar(name='train_disc_loss', tensor=disc_loss) sum_disc_val = tf.summary.scalar(name='val_disc_loss', tensor=disc_loss_val) sum_fake_image_train = tf.summary.image(name='train_image_generated', tensor=fake_image_train) sum_fake_image_val = tf.summary.image(name='val_image_generated', tensor=fake_image_val) sum_disc_real = tf.summary.scalar(name='train_disc_value_real', tensor=tf.reduce_mean(prob_disc_real)) sum_disc_fake = tf.summary.scalar(name='train_disc_value_fake', tensor=tf.reduce_mean(prob_disc_fake)) sum_composed = tf.summary.image(name='composed', tensor=composed_image) sum_real = tf.summary.image(name='real', tensor=real_image) train_merge = tf.summary.merge([ sum_gen_train, sum_fake_image_train, sum_disc_train, sum_composed, sum_real, sum_gen_disc_comp, sum_gen_l2_comp, sum_disc_real, sum_disc_fake ]) # Set saver configuration loader = tf.train.Saver() saver = tf.train.Saver() os.makedirs(model_save_dir, exist_ok=True) train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) model_name = 'GP-GAN_{:s}.ckpt'.format(str(train_start_time)) model_save_path = os.path.join(model_save_dir, model_name) # Set sess configuration sess_config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=sess_config) # Write graph to tensorboard summary_writer = tf.summary.FileWriter(tboard_save_dir) summary_writer.add_graph(sess.graph) # Set the training parameters with sess.as_default(): step = 0 cycle = 0 if args.weights_path is None: print('Training from scratch') init = tf.global_variables_initializer() sess.run(init) else: print('Restore model from {:s}'.format(args.weights_path)) loader.restore(sess=sess, save_path=args.weights_path) step_cycle = args.weights_path.split('ckpt-')[-1] step, cycle = decode_step_cycle(step_cycle) gen_train_loss = '?' while cycle <= args.train_steps: # (1) Update discriminator network # train the discriminator Diters times if cycle < 25 or cycle % 500 == 0: Diters = 100 else: Diters = args.d_iters for _ in range(Diters): # enforce Lipschitz constraint sess.run(clip_discriminator_var_op) _, disc_train_loss = sess.run([optimizer_disc, disc_loss]) print('Step: ' + str(step) + ' Cycle: ' + str(cycle) + ' Train discriminator loss: ' + str(disc_train_loss) + ' Train generator loss: ' + str(gen_train_loss)) step += 1 # (2) Update generator network _, gen_train_loss, train_merge_value = sess.run( [optimizer_gen, gen_loss, train_merge]) summary_writer.add_summary(summary=train_merge_value, global_step=cycle) if cycle != 0 and cycle % args.val_freq == 0: _, disc_val_loss, gen_val_value, fake_image_val_value = sess.run( [ optimizer_disc, gen_loss_val, sum_gen_val, sum_fake_image_val ]) _, gen_val_loss, disc_val_value = sess.run( [optimizer_gen, disc_loss_val, sum_disc_val]) print('Step: ' + str(step) + ' Cycle: ' + str(cycle) + ' Val discriminator loss: ' + str(disc_val_loss) + ' Val generator loss: ' + str(gen_val_loss)) summary_writer.add_summary(summary=gen_val_value, global_step=cycle) summary_writer.add_summary(summary=disc_val_value, global_step=cycle) summary_writer.add_summary(summary=fake_image_val_value, global_step=cycle) if cycle != 0 and cycle % args.snapshot_interval == 0: saver.save(sess=sess, save_path=model_save_path, global_step=encode_step_cycle(step, cycle)) cycle += 1
def train_step(real_x, real_y, G_YtoX, G_XtoY, D_X, D_Y, G_YtoX_optimizer, G_XtoY_optimizer, D_X_optimizer, D_Y_optimizer, opt): # persistent is set to True because the tape is used more than once to calculate the gradients. with tf.GradientTape(persistent=True) as tape: # Generator G_XtoY translates X -> Y # Generator G_YtoX translates Y -> X. fake_y = G_XtoY(real_x, training=True) cycled_x = G_YtoX(fake_y, training=True) fake_x = G_YtoX(real_y, training=True) cycled_y = G_XtoY(fake_x, training=True) # same_x and same_y are used for identity loss. same_x = G_XtoY(real_x, training=True) same_y = G_XtoY(real_y, training=True) disc_real_x = D_X(real_x, training=True) disc_real_y = D_Y(real_y, training=True) disc_fake_x = D_X(fake_x, training=True) disc_fake_y = D_Y(fake_y, training=True) # calculate the loss G_XtoY_loss = model.generator_loss(disc_fake_y) G_YtoX_loss = model.generator_loss(disc_fake_x) if opt["use_cycle_consistency_loss"]: total_cycle_loss = model.calc_cycle_loss( real_x, cycled_x) + model.calc_cycle_loss(real_y, cycled_y) else: total_cycle_loss = 0 # Total generator loss = adversarial loss + cycle loss total_G_XtoY_loss = G_XtoY_loss + total_cycle_loss + model.identity_loss( real_y, same_y) total_G_YtoX_loss = G_YtoX_loss + total_cycle_loss + model.identity_loss( real_x, same_x) disc_x_loss, update_D_X = model.discriminator_loss( disc_real_x, disc_fake_x) disc_y_loss, update_D_Y = model.discriminator_loss( disc_real_y, disc_fake_y) # total loss to be shown total_disc_loss = (disc_x_loss + disc_y_loss) / 2 total_gen_loss = (total_G_XtoY_loss + total_G_YtoX_loss) / 2 # Calculate the gradients for generator and discriminator G_XtoY_gradients = tape.gradient(total_G_XtoY_loss, G_XtoY.trainable_variables) G_YtoX_gradients = tape.gradient(total_G_YtoX_loss, G_YtoX.trainable_variables) if update_D_X: D_X_gradients = tape.gradient(disc_x_loss, D_X.trainable_variables) if update_D_Y: D_Y_gradients = tape.gradient(disc_y_loss, D_Y.trainable_variables) # Apply the gradients to the optimizer G_XtoY_optimizer.apply_gradients( zip(G_XtoY_gradients, G_XtoY.trainable_variables)) G_YtoX_optimizer.apply_gradients( zip(G_YtoX_gradients, G_YtoX.trainable_variables)) if update_D_X: D_X_optimizer.apply_gradients( zip(D_X_gradients, D_X.trainable_variables)) if update_D_Y: D_Y_optimizer.apply_gradients( zip(D_Y_gradients, D_Y.trainable_variables)) return total_disc_loss, total_gen_loss