def run_model(mode, learning_rate=2e-4, beta1=0.5, l1_lambda=100, max_epochs=200, summary_freq=200, display_freq=50, save_freq=400, checkpoint_dir="summary/conGAN.ckpt"): if mode == "train": xs_train, ys_train = get_input("train") xs_val, ys_val = get_input("val") print("load train data successfully") print("input x shape is {}".format(xs_train.shape)) print("input y shape is {}".format(ys_train.shape)) else: xs_test, ys_test = get_input("test") print("load test data successfully") print("input x shape is {}".format(xs_test.shape)) print("input y shape is {}".format(ys_test.shape)) # build model # ----------- with tf.name_scope("input"): x = tf.placeholder(tf.float32, [None, 256, 256, 3], name="x-input") y_ = tf.placeholder(tf.float32, [None, 256, 256, 3], name="y-input") G_sample = models.generator(x) logits_fake = models.con_discriminator(x, G_sample) logits_real = models.con_discriminator(x, y_) # get loss D_loss, G_loss_gan = loss.gan_loss(logits_fake=logits_fake, logits_real=logits_real) l1_loss = loss.l1_loss(y_, G_sample) with tf.variable_scope("G_loss"): G_loss = G_loss_gan + l1_lambda * l1_loss tf.summary.scalar("D_loss", D_loss) tf.summary.scalar("G_loss", G_loss) merged = tf.summary.merge_all() # get weights list D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator") G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator") # get solver D_solver, G_solver = get_solver(learning_rate=learning_rate, beta1=beta1) # get training steps D_train_step = D_solver.minimize(D_loss, var_list=D_vars) G_train_step = G_solver.minimize(G_loss, var_list=G_vars) # ----------- # get session config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=config) # get saver saver = tf.train.Saver() # training phase if mode == "train": train_writer = tf.summary.FileWriter("summary/", sess.graph) # init sess.run(tf.global_variables_initializer()) # iterations for step in range(max_epochs * NUM_TRAIN_IMAGES): if step % NUM_TRAIN_IMAGES == 0: print("Epoch: {}".format(step / NUM_TRAIN_IMAGES)) mask = np.random.choice(NUM_TRAIN_IMAGES, 1) _, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={ x: xs_train[mask], y_: ys_train[mask] }) _, G_loss_curr = sess.run([G_train_step, G_loss], feed_dict={ x: xs_train[mask], y_: ys_train[mask] }) _, G_loss_curr = sess.run([G_train_step, G_loss], feed_dict={ x: xs_train[mask], y_: ys_train[mask] }) if step % display_freq == 0: print("step {}: D_loss: {}, G_loss: {}".format( step, D_loss_curr, G_loss_curr)) # save summary and checkpoint if step % summary_freq == 0: mask = np.random.choice(NUM_TRAIN_IMAGES, 30) summary = sess.run(merged, feed_dict={ x: xs_train[mask], y_: ys_train[mask] }) train_writer.add_summary(summary) saver.save(sess, checkpoint_dir) # save 5 sample images if step % save_freq == 0: samples_train = sess.run(G_sample, feed_dict={ x: xs_train[0:5], y_: ys_train[0:5] }) save_sample_img(samples_train, step=step, mode="train") samples_val = sess.run(G_sample, feed_dict={ x: xs_val[0:5], y_: ys_val[0:5] }) save_sample_img(samples_val, step=step, mode="val") # testing phase if mode == "test": saver.restore(sess, checkpoint_dir) for i in range(20): samples_test = sess.run(G_sample, feed_dict={ x: xs_test[5 * i:5 * (i + 1)], y_: ys_test[5 * i:5 * (i + 1)] }) save_sample_img(samples_test, step=i, mode="test") # close sess sess.close() return 0
def train(args): if args.c_dim != len(args.selected_attrs): print("c_dim must be the same as the num of selected attributes. Modified c_dim.") args.c_dim = len(args.selected_attrs) # Dump the config information. config = dict() print("Used config:") for k in args.__dir__(): if not k.startswith("_"): config[k] = getattr(args, k) print("'{}' : {}".format(k, getattr(args, k))) # Prepare Generator and Discriminator based on user config. generator = functools.partial( model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num) discriminator = functools.partial(model.discriminator, image_size=args.image_size, conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num) x_real = nn.Variable( [args.batch_size, 3, args.image_size, args.image_size]) label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1]) label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1]) with nn.parameter_scope("dis"): dis_real_img, dis_real_cls = discriminator(x_real) with nn.parameter_scope("gen"): x_fake = generator(x_real, label_trg) x_fake.persistent = True # to retain its value during computation. # get an unlinked_variable of x_fake x_fake_unlinked = x_fake.get_unlinked_variable() with nn.parameter_scope("dis"): dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked) # ---------------- Define Loss for Discriminator ----------------- d_loss_real = (-1) * loss.gan_loss(dis_real_img) d_loss_fake = loss.gan_loss(dis_fake_img) d_loss_cls = loss.classification_loss(dis_real_cls, label_org) d_loss_cls.persistent = True # Gradient Penalty. alpha = F.rand(shape=(args.batch_size, 1, 1, 1)) x_hat = F.mul2(alpha, x_real) + \ F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked) with nn.parameter_scope("dis"): dis_for_gp, _ = discriminator(x_hat) grads = nn.grad([dis_for_gp], [x_hat]) l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5 d_loss_gp = F.mean((l2norm - 1.0) ** 2.0) # total discriminator loss. d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \ d_loss_cls + args.lambda_gp * d_loss_gp # ---------------- Define Loss for Generator ----------------- g_loss_fake = (-1) * loss.gan_loss(dis_fake_img) g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg) g_loss_cls.persistent = True # Reconstruct Images. with nn.parameter_scope("gen"): x_recon = generator(x_fake_unlinked, label_org) x_recon.persistent = True g_loss_rec = loss.recon_loss(x_real, x_recon) g_loss_rec.persistent = True # total generator loss. g_loss = g_loss_fake + args.lambda_rec * \ g_loss_rec + args.lambda_cls * g_loss_cls # -------------------- Solver Setup --------------------- d_lr = args.d_lr # initial learning rate for Discriminator g_lr = args.g_lr # initial learning rate for Generator solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2) solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2) # register parameters to each solver. with nn.parameter_scope("dis"): solver_dis.set_parameters(nn.get_parameters()) with nn.parameter_scope("gen"): solver_gen.set_parameters(nn.get_parameters()) # -------------------- Create Monitors -------------------- monitor = Monitor(args.monitor_path) monitor_d_cls_loss = MonitorSeries( 'real_classification_loss', monitor, args.log_step) monitor_g_cls_loss = MonitorSeries( 'fake_classification_loss', monitor, args.log_step) monitor_loss_dis = MonitorSeries( 'discriminator_loss', monitor, args.log_step) monitor_recon_loss = MonitorSeries( 'reconstruction_loss', monitor, args.log_step) monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step) monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step) # -------------------- Prepare / Split Dataset -------------------- using_attr = args.selected_attrs dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr) random.seed(313) # use fixed seed. random.shuffle(dataset) # shuffle dataset. test_dataset = dataset[-2000:] # extract 2000 images for test if args.num_data: # Use training data partially. training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)] else: training_dataset = dataset[:-2000] print("Use {} images for training.".format(len(training_dataset))) # create data iterators. load_func = functools.partial(stargan_load_func, dataset=training_dataset, image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size) data_iterator = data_iterator_simple(load_func, len( training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False) load_func_test = functools.partial(stargan_load_func, dataset=test_dataset, image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size) test_data_iterator = data_iterator_simple(load_func_test, len( test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False) # Keep fixed test images for intermediate translation visualization. test_real_ndarray, test_label_ndarray = test_data_iterator.next() test_label_ndarray = test_label_ndarray.reshape( test_label_ndarray.shape + (1, 1)) # -------------------- Training Loop -------------------- one_epoch = data_iterator.size // args.batch_size num_max_iter = args.max_epoch * one_epoch for i in range(num_max_iter): # Get real images and labels. real_ndarray, label_ndarray = data_iterator.next() label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1)) label_ndarray = label_ndarray.astype(float) x_real.d, label_org.d = real_ndarray, label_ndarray # Generate target domain labels randomly. rand_idx = np.random.permutation(label_org.shape[0]) label_trg.d = label_ndarray[rand_idx] # ---------------- Train Discriminator ----------------- # generate fake image. x_fake.forward(clear_no_need_grad=True) d_loss.forward(clear_no_need_grad=True) solver_dis.zero_grad() d_loss.backward(clear_buffer=True) solver_dis.update() monitor_loss_dis.add(i, d_loss.d.item()) monitor_d_cls_loss.add(i, d_loss_cls.d.item()) monitor_time.add(i) # -------------- Train Generator -------------- if (i + 1) % args.n_critic == 0: g_loss.forward(clear_no_need_grad=True) solver_dis.zero_grad() solver_gen.zero_grad() x_fake_unlinked.grad.zero() g_loss.backward(clear_buffer=True) x_fake.backward(grad=None) solver_gen.update() monitor_loss_gen.add(i, g_loss.d.item()) monitor_g_cls_loss.add(i, g_loss_cls.d.item()) monitor_recon_loss.add(i, g_loss_rec.d.item()) monitor_time.add(i) if (i + 1) % args.sample_step == 0: # save image. save_results(i, args, x_real, x_fake, label_org, label_trg, x_recon) if args.test_during_training: # translate images from test dataset. x_real.d, label_org.d = test_real_ndarray, test_label_ndarray label_trg.d = test_label_ndarray[rand_idx] x_fake.forward(clear_no_need_grad=True) save_results(i, args, x_real, x_fake, label_org, label_trg, None, is_training=False) # Learning rates get decayed if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0: g_lr = max(0, g_lr - (args.lr_update_step * args.g_lr / float(0.5 * num_max_iter))) d_lr = max(0, d_lr - (args.lr_update_step * args.d_lr / float(0.5 * num_max_iter))) solver_gen.set_learning_rate(g_lr) solver_dis.set_learning_rate(d_lr) print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) # Save parameters and training config. param_name = 'trained_params_{}.h5'.format( datetime.datetime.today().strftime("%m%d%H%M")) param_path = os.path.join(args.model_save_path, param_name) nn.save_parameters(param_path) config["pretrained_params"] = param_name with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f: json.dump(config, f) # -------------------- Translation on test dataset -------------------- for i in range(args.num_test): real_ndarray, label_ndarray = test_data_iterator.next() label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1)) label_ndarray = label_ndarray.astype(float) x_real.d, label_org.d = real_ndarray, label_ndarray rand_idx = np.random.permutation(label_org.shape[0]) label_trg.d = label_ndarray[rand_idx] x_fake.forward(clear_no_need_grad=True) save_results(i, args, x_real, x_fake, label_org, label_trg, None, is_training=False)
fake_target = netG_A(real_src, training=True) rec_source = netG_B(fake_target, training=True) fake_src = netG_B(real_target, training=True) rec_target = netG_A(fake_src, training=True) fake_tar_seg = netG_A_map(fake_target, training=True) rec_src_seg = netG_B_map(rec_source, training=True) disc_target_fake = netD_A(fake_target, training=True) disc_source_fake = netD_B(fake_src, training=True) disc_target_real = netD_A(real_target, training=True) disc_source_real = netD_B(fake_src, training=True) loss_G_A = gan_loss(disc_target_fake, tf.ones_like(disc_target_fake)) loss_G_B = gan_loss(disc_source_fake, tf.ones_like(disc_source_fake)) loss_seg_A = categorical_loss(real_src_label, fake_tar_seg, 5) loss_seg_B = categorical_loss(real_src_label, rec_src_seg, 5) loss_cycle_A = cycle_loss(real_src, rec_source, lambda_para) loss_cycle_B = cycle_loss(real_target, rec_target, lambda_para) loss_G_A = loss_G_A + loss_cycle_A + loss_cycle_B + loss_seg_A loss_G_B = loss_G_B + loss_cycle_A + loss_cycle_B + loss_seg_B loss_D_A_real = gan_loss(disc_target_real, tf.ones_like(disc_target_real)) loss_D_A_fake = gan_loss(disc_target_fake,