def train(n_channels=3, resolution=32, z_dim=128, n_labels=0, lr=1e-3, e_drift=1e-3, wgp_target=750, initial_resolution=4, total_kimg=25000, training_kimg=500, transition_kimg=500, iters_per_checkpoint=500, n_checkpoint_images=16, glob_str='cifar10', out_dir='cifar10'): # instantiate logger logger = SummaryWriter(out_dir) # load data batch_size = MINIBATCH_OVERWRITES[0] train_iterator = iterate_minibatches(glob_str, batch_size, resolution) # build models G = Generator(n_channels, resolution, z_dim, n_labels) D = Discriminator(n_channels, resolution, n_labels) G_train, D_train = GAN(G, D, z_dim, n_labels, resolution, n_channels) D_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8) G_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8) # define loss functions D_loss = [loss_mean, loss_gradient_penalty, 'mse'] G_loss = [loss_wasserstein] # compile graphs used during training G.compile(G_opt, loss=loss_wasserstein) D.trainable = False G_train.compile(G_opt, loss=G_loss) D.trainable = True D_train.compile(D_opt, loss=D_loss, loss_weights=[1, GP_WEIGHT, e_drift]) # for computing the loss ones = np.ones((batch_size, 1), dtype=np.float32) zeros = ones * 0.0 # fix a z vector for training evaluation z_fixed = np.random.normal(0, 1, size=(n_checkpoint_images, z_dim)) # vars resolution_log2 = int(np.log2(resolution)) starting_block = resolution_log2 starting_block -= np.floor(np.log2(initial_resolution)) cur_block = starting_block cur_nimg = 0 # compute duration of each phase and use proxy to update minibatch size phase_kdur = training_kimg + transition_kimg phase_idx_prev = 0 # offset variable for transitioning between blocks offset = 0 i = 0 while cur_nimg < total_kimg * 1000: # block processing kimg = cur_nimg / 1000.0 phase_idx = int(np.floor((kimg + transition_kimg) / phase_kdur)) phase_idx = max(phase_idx, 0.0) phase_kimg = phase_idx * phase_kdur # update batch size and ones vector if we switched phases if phase_idx_prev < phase_idx: batch_size = MINIBATCH_OVERWRITES[phase_idx] train_iterator = iterate_minibatches(glob_str, batch_size) ones = np.ones((batch_size, 1), dtype=np.float32) zeros = ones * 0.0 phase_idx_prev = phase_idx # possibly gradually update current level of detail if transition_kimg > 0 and phase_idx > 0: offset = (kimg + transition_kimg - phase_kimg) / transition_kimg offset = min(offset, 1.0) offset = offset + phase_idx - 1 cur_block = max(starting_block - offset, 0.0) # update level of detail K.set_value(G_train.cur_block, np.float32(cur_block)) K.set_value(D_train.cur_block, np.float32(cur_block)) # train D for j in range(N_CRITIC_ITERS): z = np.random.normal(0, 1, size=(batch_size, z_dim)) real_batch = next(train_iterator) fake_batch = G.predict_on_batch([z]) interpolated_batch = get_interpolated_images( real_batch, fake_batch) losses_d = D_train.train_on_batch( [real_batch, fake_batch, interpolated_batch], [ones, ones * wgp_target, zeros]) cur_nimg += batch_size # train G z = np.random.normal(0, 1, size=(batch_size, z_dim)) loss_g = G_train.train_on_batch(z, -1 * ones) logger.add_scalar("cur_block", cur_block, i) logger.add_scalar("learning_rate", lr, i) logger.add_scalar("batch_size", z.shape[0], i) print("iter", i, "cur_block", cur_block, "lr", lr, "kimg", kimg, "losses_d", losses_d, "loss_g", loss_g) if (i % iters_per_checkpoint) == 0: G.trainable = False fake_images = G.predict(z_fixed) # log fake images log_images(fake_images, 'fake', i, logger, fake_images.shape[1], fake_images.shape[2], int(np.sqrt(n_checkpoint_images))) # plot real images for reference log_images(real_batch[:n_checkpoint_images], 'real', i, logger, real_batch.shape[1], real_batch.shape[2], int(np.sqrt(n_checkpoint_images))) # save the model to eventually resume training or do inference save_model(G, out_dir + "/model.json", out_dir + "/model.h5") log_losses(losses_d, loss_g, i, logger) i += 1
def train(): parser = argparse.ArgumentParser(description="keras pix2pix") parser.add_argument('--batchsize', '-b', type=int, default=1) parser.add_argument('--patchsize', '-p', type=int, default=64) parser.add_argument('--epoch', '-e', type=int, default=500) parser.add_argument('--out', '-o', default='result') parser.add_argument('--lmd', '-l', type=int, default=100) parser.add_argument('--dark', '-d', type=float, default=0.01) parser.add_argument('--gpu', '-g', type=int, default=2) args = parser.parse_args() args = parser.parse_args() PATCH_SIZE = args.patchsize BATCH_SIZE = args.batchsize epoch = args.epoch lmd = args.lmd # set gpu environment os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) K.set_session(sess) # make directory to save results if not os.path.exists("./result"): os.mkdir("./result") resultDir = "./result/" + args.out modelDir = resultDir + "/model/" if not os.path.exists(resultDir): os.mkdir(resultDir) if not os.path.exists(modelDir): os.mkdir(modelDir) # make a logfile and add colnames o = open(resultDir + "/log.txt", "w") o.write("batch:" + str(BATCH_SIZE) + " lambda:" + str(lmd) + "\n") o.write( "epoch,dis_loss,gan_mae,gan_entropy,vdis_loss,vgan_mae,vgan_entropy" + "\n") o.close() # load data ds1_first, ds1_last, num_ds1 = 1, 1145, 1145 ds2_first, ds2_last, num_ds2 = 2000, 6749, 4750 # ds1_first, ds1_last, num_ds1 = 1, 100, 100 # ds2_first, ds2_last, num_ds2 = 101, 200, 100 train_data_i = np.concatenate([ np.arange(ds1_first, ds1_last + 1)[:int(num_ds1 * 0.7)], np.arange(ds2_first, ds2_last + 1)[:int(num_ds2 * 0.7)] ]) test_data_i = np.concatenate([ np.arange(ds1_first, ds1_last + 1)[int(num_ds1 * 0.7):], np.arange(ds2_first, ds2_last + 1)[int(num_ds2 * 0.7):] ]) train_gt, _, train_night = load_dataset(data_range=train_data_i, dark=args.dark) test_gt, _, test_night = load_dataset(data_range=test_data_i, dark=args.dark) # Create optimizers opt_Gan = Adam(lr=1E-3) opt_Discriminator = Adam(lr=1E-3) opt_Generator = Adam(lr=1E-3) # set the loss of gan def dis_entropy(y_true, y_pred): return -K.log(K.abs((y_pred - y_true)) + 1e-07) gan_loss = ['mae', dis_entropy] gan_loss_weights = [lmd, 1] # make models Generator = generator() Generator.compile(loss='mae', optimizer=opt_Generator) Discriminator = discriminator() Discriminator.trainable = False Gan = GAN(Generator, Discriminator) Gan.compile(loss=gan_loss, loss_weights=gan_loss_weights, optimizer=opt_Gan) Discriminator.trainable = True Discriminator.compile(loss=dis_entropy, optimizer=opt_Discriminator) # start training n_train = train_gt.shape[0] n_test = test_gt.shape[0] print(n_train, n_test) p = ProgressBar() for epoch in p(range(epoch)): p.update(epoch + 1) out_file = open(resultDir + "/log.txt", "a") train_ind = np.random.permutation(n_train) test_ind = np.random.permutation(n_test) dis_losses = [] gan_losses = [] test_dis_losses = [] test_gan_losses = [] y_real = np.array([1] * BATCH_SIZE) y_fake = np.array([0] * BATCH_SIZE) y_gan = np.array([1] * BATCH_SIZE) # training for batch_i in range(int(n_train / BATCH_SIZE)): gt_batch = train_gt[train_ind[(batch_i * BATCH_SIZE):((batch_i + 1) * BATCH_SIZE)], :, :, :] night_batch = train_night[train_ind[( batch_i * BATCH_SIZE):((batch_i + 1) * BATCH_SIZE)], :, :, :] generated_batch = Generator.predict(night_batch) # train Discriminator dis_real_loss = np.array( Discriminator.train_on_batch([night_batch, gt_batch], y_real)) dis_fake_loss = np.array( Discriminator.train_on_batch([night_batch, generated_batch], y_fake)) dis_loss_batch = (dis_real_loss + dis_fake_loss) / 2 dis_losses.append(dis_loss_batch) gan_loss_batch = np.array( Gan.train_on_batch(night_batch, [gt_batch, y_gan])) gan_losses.append(gan_loss_batch) dis_loss = np.mean(np.array(dis_losses)) gan_loss = np.mean(np.array(gan_losses), axis=0) # validation for batch_i in range(int(n_test / BATCH_SIZE)): gt_batch = test_gt[test_ind[(batch_i * BATCH_SIZE):((batch_i + 1) * BATCH_SIZE)], :, :, :] night_batch = test_night[test_ind[( batch_i * BATCH_SIZE):((batch_i + 1) * BATCH_SIZE)], :, :, :] generated_batch = Generator.predict(night_batch) # train Discriminator dis_real_loss = np.array( Discriminator.test_on_batch([night_batch, gt_batch], y_real)) dis_fake_loss = np.array( Discriminator.test_on_batch([night_batch, generated_batch], y_fake)) test_dis_loss_batch = (dis_real_loss + dis_fake_loss) / 2 test_dis_losses.append(test_dis_loss_batch) test_gan_loss_batch = np.array( Gan.test_on_batch(night_batch, [gt_batch, y_gan])) test_gan_losses.append(test_gan_loss_batch) test_dis_loss = np.mean(np.array(test_dis_losses)) test_gan_loss = np.mean(np.array(gan_losses), axis=0) # write log of leaning out_file.write( str(epoch) + "," + str(dis_loss) + "," + str(gan_loss[1]) + "," + str(gan_loss[2]) + "," + str(test_dis_loss) + "," + str(test_gan_loss[1]) + "," + str(test_gan_loss[2]) + "\n") # visualize if epoch % 50 == 0: # for training data gt_batch = train_gt[train_ind[0:9], :, :, :] night_batch = train_night[train_ind[0:9], :, :, :] generated_batch = Generator.predict(night_batch) save_images(night_batch, resultDir + "/label_" + str(epoch) + "epoch.png") save_images(gt_batch, resultDir + "/gt_" + str(epoch) + "epoch.png") save_images(generated_batch, resultDir + "/generated_" + str(epoch) + "epoch.png") # for validation data gt_batch = test_gt[test_ind[0:9], :, :, :] night_batch = test_night[test_ind[0:9], :, :, :] generated_batch = Generator.predict(night_batch) save_images(night_batch, resultDir + "/vlabel_" + str(epoch) + "epoch.png") save_images(gt_batch, resultDir + "/vgt_" + str(epoch) + "epoch.png") save_images(generated_batch, resultDir + "/vgenerated_" + str(epoch) + "epoch.png") Gan.save_weights(modelDir + 'gan_weights' + "_lambda" + str(lmd) + "_epoch" + str(epoch) + '.h5') out_file.close() out_file.close()