def run(epoch, iter): zs_ipt_fixed = [ scipy.stats.truncnorm.rvs(-args.truncation_threshold, args.truncation_threshold, size=[args.n_traversal, z_dim]) for z_dim in args.z_dims ] eps_ipt = scipy.stats.truncnorm.rvs( -args.truncation_threshold, args.truncation_threshold, size=[args.n_traversal, args.eps_dim]) # set the first sample as the "mode" for l in range(len(args.z_dims)): zs_ipt_fixed[l][0, ...] = 0.0 eps_ipt[0, ...] = 0.0 L_opt = sess.run(tl.tensors_filter(G_test.func.variables, 'L')) for l in range(len(args.z_dims)): for j, i in enumerate(np.argsort(np.abs(L_opt[l]))[::-1]): x_f_opts = [] vals = np.linspace(-4.5, 4.5, args.n_left_axis_point * 2 + 1) for v in vals: zs_ipt = copy.deepcopy(zs_ipt_fixed) zs_ipt[l][:, i] = v feed_dict = {z: z_ipt for z, z_ipt in zip(zs, zs_ipt)} feed_dict.update({eps: eps_ipt}) x_f_opt = sess.run(x_f, feed_dict=feed_dict) x_f_opts.append(x_f_opt) sample = im.immerge(np.concatenate(x_f_opts, axis=2), n_rows=args.n_traversal) im.imwrite( sample, '%s/Epoch-%d_Iter-%d_Traversal-%d-%d-%.3f-%d.jpg' % (save_dir, epoch, iter, l, j, np.abs(L_opt[l][i]), i))
def main_loop(trainer, datasets, test_iterations, config, checkpoint, samples_dir): (a_train_dataset, a_test_dataset), (b_train_dataset, b_test_dataset) = datasets optimizer_iterations = config['hyperparameters']['iterations'] c_loss_mean_dict = {} a_dataset_iter = iter(a_train_dataset) b_dataset_iter = iter(b_train_dataset) for iterations in tqdm.tqdm(range(1, optimizer_iterations + 1)): images_a, actions_a = next(a_dataset_iter) images_b, _ = next(b_dataset_iter) # Training ops D_loss_dict, G_images, G_loss_dict, C_loss_dict = trainer.joint_train_step( images_a, actions_a, images_b) for c_loss_label, c_loss in C_loss_dict.items(): if c_loss_label not in c_loss_mean_dict: c_loss_mean_dict[c_loss_label] = tf.keras.metrics.Mean() c_loss_mean_dict[c_loss_label].update_state(c_loss.numpy()) # Logging ops if iterations % config['log_iterations'] == 0: for c_loss_label, c_loss_mean in c_loss_mean_dict.items(): C_loss_dict[c_loss_label] = c_loss_mean.result() c_loss_mean.reset_states() tf2lib.summary(D_loss_dict, step=iterations, name='discriminator') tf2lib.summary(G_loss_dict, step=iterations, name='generator') tf2lib.summary(C_loss_dict, step=iterations, name='controller') # Displaying ops if iterations % config['image_save_iterations'] == 0: img_filename = os.path.join(samples_dir, f'train_{iterations}.jpg') elif iterations % config['image_display_iterations'] == 0: img_filename = os.path.join(samples_dir, f'train.jpg') else: img_filename = None if img_filename: img = imlib.immerge(np.concatenate( [row_a for row_a in images_a] + [row_b for row_b in images_b] + G_images, axis=0), n_rows=8) imlib.imwrite(img, img_filename) # Testing and checkpointing ops if iterations % config[ 'test_every_iterations'] == 0 or iterations == optimizer_iterations: C_loss_dict = test_model(trainer.model, trainer.controller, a_test_dataset, b_test_dataset, test_iterations, samples_dir) tf2lib.summary(C_loss_dict, step=iterations, name='controller') checkpoint.save(iterations)
def snapshot(self, A, B, image_file_name, debug=False): A2B, B2A, A2B2A, B2A2B = self.sample(A, B) img = im.immerge(np.concatenate([A, A2B, A2B2A, B, B2A, B2A2B], axis=0), n_rows=2) im.imwrite( img, py.join(self.sample_dir, image_file_name % self.G_optimizer.iterations.numpy())) buffer = io.BytesIO() if self.color_depth == 1: pyplot.imshow(img.reshape(img.shape[0], img.shape[1]), cmap='gray') else: pyplot.imshow(img) pyplot.savefig(buffer, format='png') if debug: buffer.seek(0) pyplot.imread(buffer) pyplot.show() return buffer
def test_model(model, controller, a_test_dataset, b_test_dataset, max_iterations, samples_dir): training = False mae_metric_a = metrics.MAEMetric() mae_metric_b = metrics.MAEMetric() mse_metric_a = metrics.MSEMetric() mse_metric_b = metrics.MSEMetric() bmae_metric_a = metrics.BMAEMetric() bmae_metric_b = metrics.BMAEMetric() bmse_metric_a = metrics.BMSEMetric() bmse_metric_b = metrics.BMSEMetric() a_test_iter = iter(a_test_dataset) b_test_iter = iter(b_test_dataset) for iterations in tqdm.tqdm(range(1, max_iterations + 1)): try: images_a, actions_a = next(a_test_iter) except StopIteration: images_a, actions_a = None, None try: images_b, actions_b = next(b_test_iter) except StopIteration: images_b, actions_b = None, None # Inference ops G_images = None if images_a is not None and images_b is not None: x_aa, x_ba, x_ab, x_bb, shared = model.encode_ab_decode_aabb( images_a, images_b, training=training) x_bab, _ = model.encode_a_decode_b(x_ba, training=training) x_aba, _ = model.encode_b_decode_a(x_ab, training=training) G_images = [x_aa, x_ba, x_ab, x_bb, x_aba, x_bab] shared_a, shared_b = tf.split( shared, num_or_size_splits=[len(images_a), len(images_b)], axis=0) elif images_a is not None: encoded_a = model.encoder_a(images_a, training=training) shared_a = model.encoder_shared(encoded_a, training=training) shared_b = None elif images_b is not None: encoded_b = model.encoder_b(images_b, training=training) shared_b = model.encoder_shared(encoded_b, training=training) shared_a = None else: raise AssertionError( 'There are no images either from A or B during the test.') # Displaying ops if iterations % (max_iterations // 10) == 0: img_filename = os.path.join(samples_dir, f'test_{iterations}.jpg') images = [] if images_a is not None: images.append(images_a) if images_b is not None: images.append(images_b) if G_images is not None: images.extend(G_images) img = imlib.immerge(np.concatenate(images, axis=0), n_rows=len(images)) imlib.imwrite(img, img_filename) # Control loss accumulation for shared_x, actions_x, metrics_x in [ (shared_a, actions_a, (mae_metric_a, mse_metric_a, bmae_metric_a, bmse_metric_a)), (shared_b, actions_b, (mae_metric_b, mse_metric_b, bmae_metric_b, bmse_metric_b)), ]: if shared_x is None: continue down_x = model.downstream_hidden(shared_x) predictions_x = controller(down_x, training=training) actions_x = actions_x.numpy() predictions_x = predictions_x.numpy() for metric_x in metrics_x: metric_x.update_state(actions_x, predictions_x) C_loss_dict = { 'test_mae_metric_a': mae_metric_a.result(), 'test_mse_metric_a': mse_metric_a.result(), 'test_bmae_metric_a': bmae_metric_a.result(), 'test_bmse_metric_a': bmse_metric_a.result(), 'test_mae_metric_b': mae_metric_b.result(), 'test_mse_metric_b': mse_metric_b.result(), 'test_bmae_metric_b': bmae_metric_b.result(), 'test_bmse_metric_b': bmse_metric_b.result(), } return C_loss_dict
sess = tl.session() # initialization ckpt_dir = './output/%s/checkpoints' % experiment_name try: tl.load_checkpoint(ckpt_dir, sess) except: raise Exception(' [*] No checkpoint!') # train try: z_ipt_sample_ = np.random.normal(size=[10, z_dim]) for i in range(z_dim): z_ipt_sample = np.copy(z_ipt_sample_) img_opt_samples = [] for v in np.linspace(-3, 3, 10): z_ipt_sample[:, i] = v img_opt_samples.append( sess.run(img_sample, feed_dict={ z_sample: z_ipt_sample }).squeeze()) save_dir = './output/%s/sample_traversal' % experiment_name pylib.mkdir(save_dir) im.imwrite(im.immerge(np.concatenate(img_opt_samples, axis=2), 10), '%s/traversal_d%d.jpg' % (save_dir, i)) except: traceback.print_exc() finally: sess.close()
for x_real in tqdm.tqdm(dataset, desc='Inner Epoch Loop', total=len_dataset): # Comment by K.C: # run train_D means to update D once, D_loss can be printed here. D_loss_dict = train_D(x_real) tl.summary(D_loss_dict, step=D_optimizer.iterations, name='D_losses') # Comment by K.C: # Update the Discriminator for every n_d run of the Generator if D_optimizer.iterations.numpy() % args.n_d == 0: G_loss_dict = train_G() tl.summary(G_loss_dict, step=G_optimizer.iterations, name='G_losses') # sample if G_optimizer.iterations.numpy() % 100 == 0: x_fake = sample(z) img = im.immerge(x_fake, n_rows=10) im.imwrite(img, py.join(sample_dir, 'iter-%09d.jpg' % G_optimizer.iterations.numpy())) # Added by K.C: update the mean loss functions every 100 iterations, and plot them out D_loss_summary.append(D_loss_dict.get('d_loss','').numpy()) D_GP_summary.append(D_loss_dict.get('gp', '').numpy()) iteration_summary.append(D_optimizer.iterations.numpy()) G_loss_summary.append(G_loss_dict.get('g_loss', '').numpy()) D_loss_mean = take_mean(D_loss_summary) D_GP_mean = take_mean(D_GP_summary) G_loss_mean = take_mean(G_loss_summary) G_figure = plt.figure() plt.plot(iteration_summary, G_loss_summary) plt.xlabel('iterations')
for A, B in tqdm.tqdm(A_B_dataset, desc='Inner Epoch Loop', total=len_dataset): G_loss_dict, D_loss_dict = train_step(A, B) # # summary tl.summary(G_loss_dict, step=G_optimizer.iterations, name='G_losses') tl.summary(D_loss_dict, step=G_optimizer.iterations, name='D_losses') tl.summary({'learning rate': G_lr_scheduler.current_learning_rate}, step=G_optimizer.iterations, name='learning rate') # sample if G_optimizer.iterations.numpy() % 100 == 0: A, B = next(test_iter) A2B, B2A, A2B2A, B2A2B = sample(A, B) img = im.immerge(np.concatenate([A, A2B, A2B2A, B, B2A, B2A2B], axis=0), n_rows=2) im.imwrite( img, py.join(sample_dir, 'iter-%09d.jpg' % G_optimizer.iterations.numpy())) # save checkpoint checkpoint.save(ep)
if (it + 1) % 1 == 0: print("Epoch: (%3d) (%5d/%5d) Time: %s!" % (epoch, it_in_epoch, it_per_epoch, t)) # save if (it + 1) % 1000 == 0: save_path = saver.save(sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) # sample if (it + 1) % 100 == 0: x_sample_opt_list = [xa_sample_ipt, np.full((n_sample, img_size, img_size // 10, 3), -1.0)] for i, b_sample_ipt in enumerate(b_sample_ipt_list): _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int if i > 0: # i == 0 is for reconstruction _b_sample_ipt[..., i - 1] = _b_sample_ipt[..., i - 1] * test_int / thres_int x_sample_opt_list.append(sess.run(x_sample, feed_dict={xa_sample: xa_sample_ipt, _b_sample: _b_sample_ipt})) sample = np.concatenate(x_sample_opt_list, 2) save_dir = './output/%s/sample_training' % experiment_name pylib.mkdir(save_dir) im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % (save_dir, epoch, it_in_epoch, it_per_epoch)) except Exception: traceback.print_exc() finally: save_path = saver.save(sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) sess.close()
def train(self): it_cnt, update_cnt = tl.counter() # saver saver = tf.train.Saver(max_to_keep=10) # summary writer summary_writer = tf.summary.FileWriter(self.config["projectSummary"], self.sess.graph) # initialization ckpt_dir = self.config["projectCheckpoints"] epoch = self.config["totalEpoch"] n_d = self.config["dStep"] atts = self.config["selectedAttrs"] thres_int = self.config["thresInt"] test_int = self.config["sampleThresInt"] n_sample = self.config["sampleNum"] img_size = self.config["imsize"] sample_freq = self.config["sampleEpoch"] save_freq = self.config["modelSaveEpoch"] lr_base = self.config["gLr"] lrDecayEpoch = self.config["lrDecayEpoch"] try: assert clear == False tl.load_checkpoint(ckpt_dir, self.sess) except: print('NOTE: Initializing all parameters...') self.sess.run(tf.global_variables_initializer()) # train try: # data for sampling xa_sample_ipt, a_sample_ipt = self.val_loader.get_next() b_sample_ipt_list = [a_sample_ipt ] # the first is for reconstruction for i in range(len(atts)): tmp = np.array(a_sample_ipt, copy=True) tmp[:, i] = 1 - tmp[:, i] # inverse attribute tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts) b_sample_ipt_list.append(tmp) it_per_epoch = len(self.data_loader) // (self.config["batchSize"] * (n_d + 1)) max_it = epoch * it_per_epoch for it in range(self.sess.run(it_cnt), max_it): with pylib.Timer(is_output=False) as t: self.sess.run(update_cnt) # which epoch epoch = it // it_per_epoch it_in_epoch = it % it_per_epoch + 1 # learning rate lr_ipt = lr_base / (10**(epoch // lrDecayEpoch)) # train D for i in range(n_d): d_summary_opt, _ = self.sess.run( [self.d_summary, self.d_step], feed_dict={self.lr: lr_ipt}) summary_writer.add_summary(d_summary_opt, it) # train G g_summary_opt, _ = self.sess.run( [self.g_summary, self.g_step], feed_dict={self.lr: lr_ipt}) summary_writer.add_summary(g_summary_opt, it) # display if (it + 1) % 100 == 0: print("Epoch: (%3d) (%5d/%5d) Time: %s!" % (epoch, it_in_epoch, it_per_epoch, t)) # save if (it + 1) % (save_freq if save_freq else it_per_epoch) == 0: save_path = saver.save( self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) # sample if (it + 1) % (sample_freq if sample_freq else it_per_epoch) == 0: x_sample_opt_list = [ xa_sample_ipt, np.full((n_sample, img_size, img_size // 10, 3), -1.0) ] raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 - 1) * thres_int for i, b_sample_ipt in enumerate(b_sample_ipt_list): _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int if i > 0: # i == 0 is for reconstruction _b_sample_ipt[..., i - 1] = _b_sample_ipt[ ..., i - 1] * test_int / thres_int x_sample_opt_list.append( self.sess.run(self.x_sample, feed_dict={ self.xa_sample: xa_sample_ipt, self._b_sample: _b_sample_ipt, self.raw_b_sample: raw_b_sample_ipt })) last_images = x_sample_opt_list[-1] if i > 0: # add a mark (+/-) in the upper-left corner to identify add/remove an attribute for nnn in range(last_images.shape[0]): last_images[nnn, 2:5, 0:7, :] = 1. if _b_sample_ipt[nnn, i - 1] > 0: last_images[nnn, 0:7, 2:5, :] = 1. last_images[nnn, 1:6, 3:4, :] = -1. last_images[nnn, 3:4, 1:6, :] = -1. sample = np.concatenate(x_sample_opt_list, 2) im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \ (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch)) except: traceback.print_exc() finally: save_path = saver.save( self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) self.sess.close()
def train(): # ===================================== Args ===================================== args = parse_args() output_dir = os.path.join('output', args.dataset) os.makedirs(output_dir, exist_ok=True) settings_path = os.path.join(output_dir, 'settings.json') pylib.args_to_json(settings_path, args) # ===================================== Data ===================================== A_img_paths = pylib.glob( os.path.join(args.datasets_dir, args.dataset, 'trainA'), '*.png') B_img_paths = pylib.glob( os.path.join(args.datasets_dir, args.dataset, 'trainB'), '*.png') print(f'len(A_img_paths) = {len(A_img_paths)}') print(f'len(B_img_paths) = {len(B_img_paths)}') load_size = [args.load_size_height, args.load_size_width] crop_size = [args.crop_size_height, args.crop_size_width] A_B_dataset, len_dataset = data.make_zip_dataset(A_img_paths, B_img_paths, args.batch_size, load_size, crop_size, training=True, repeat=False) A2B_pool = data.ItemPool(args.pool_size) B2A_pool = data.ItemPool(args.pool_size) A_img_paths_test = pylib.glob( os.path.join(args.datasets_dir, args.dataset, 'testA'), '*.png') B_img_paths_test = pylib.glob( os.path.join(args.datasets_dir, args.dataset, 'testB'), '*.png') A_B_dataset_test, _ = data.make_zip_dataset(A_img_paths_test, B_img_paths_test, args.batch_size, load_size, crop_size, training=False, repeat=True) # ===================================== Models ===================================== model_input_shape = crop_size + [ 3 ] # [args.crop_size_height, args.crop_size_width, 3] G_A2B = module.ResnetGenerator(input_shape=model_input_shape, n_blocks=6) G_B2A = module.ResnetGenerator(input_shape=model_input_shape, n_blocks=6) D_A = module.ConvDiscriminator(input_shape=model_input_shape) D_B = module.ConvDiscriminator(input_shape=model_input_shape) d_loss_fn, g_loss_fn = tf2gan.get_adversarial_losses_fn( args.adversarial_loss_mode) cycle_loss_fn = tf.losses.MeanAbsoluteError() identity_loss_fn = tf.losses.MeanAbsoluteError() G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset) D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset, args.epoch_decay * len_dataset) G_optimizer = tf.keras.optimizers.Adam(learning_rate=G_lr_scheduler, beta_1=args.beta_1) D_optimizer = tf.keras.optimizers.Adam(learning_rate=D_lr_scheduler, beta_1=args.beta_1) # ===================================== Training steps ===================================== @tf.function def train_generators(A, B): with tf.GradientTape() as t: A2B = G_A2B(A, training=True) B2A = G_B2A(B, training=True) A2B2A = G_B2A(A2B, training=True) B2A2B = G_A2B(B2A, training=True) A2A = G_B2A(A, training=True) B2B = G_A2B(B, training=True) A2B_d_logits = D_B(A2B, training=True) B2A_d_logits = D_A(B2A, training=True) A2B_g_loss = g_loss_fn(A2B_d_logits) B2A_g_loss = g_loss_fn(B2A_d_logits) A2B2A_cycle_loss = cycle_loss_fn(A, A2B2A) B2A2B_cycle_loss = cycle_loss_fn(B, B2A2B) A2A_id_loss = identity_loss_fn(A, A2A) B2B_id_loss = identity_loss_fn(B, B2B) G_loss = (A2B_g_loss + B2A_g_loss) + ( A2B2A_cycle_loss + B2A2B_cycle_loss) * args.cycle_loss_weight + ( A2A_id_loss + B2B_id_loss) * args.identity_loss_weight G_grad = t.gradient( G_loss, G_A2B.trainable_variables + G_B2A.trainable_variables) G_optimizer.apply_gradients( zip(G_grad, G_A2B.trainable_variables + G_B2A.trainable_variables)) return A2B, B2A, { 'A2B_g_loss': A2B_g_loss, 'B2A_g_loss': B2A_g_loss, 'A2B2A_cycle_loss': A2B2A_cycle_loss, 'B2A2B_cycle_loss': B2A2B_cycle_loss, 'A2A_id_loss': A2A_id_loss, 'B2B_id_loss': B2B_id_loss } @tf.function def train_discriminators(A, B, A2B, B2A): with tf.GradientTape() as t: A_d_logits = D_A(A, training=True) B2A_d_logits = D_A(B2A, training=True) B_d_logits = D_B(B, training=True) A2B_d_logits = D_B(A2B, training=True) A_d_loss, B2A_d_loss = d_loss_fn(A_d_logits, B2A_d_logits) B_d_loss, A2B_d_loss = d_loss_fn(B_d_logits, A2B_d_logits) D_A_gp = tf2gan.gradient_penalty(functools.partial(D_A, training=True), A, B2A, mode=args.gradient_penalty_mode) D_B_gp = tf2gan.gradient_penalty(functools.partial(D_B, training=True), B, A2B, mode=args.gradient_penalty_mode) D_loss = (A_d_loss + B2A_d_loss) + (B_d_loss + A2B_d_loss) + ( D_A_gp + D_B_gp) * args.gradient_penalty_weight D_grad = t.gradient(D_loss, D_A.trainable_variables + D_B.trainable_variables) D_optimizer.apply_gradients( zip(D_grad, D_A.trainable_variables + D_B.trainable_variables)) return { 'A_d_loss': A_d_loss + B2A_d_loss, 'B_d_loss': B_d_loss + A2B_d_loss, 'D_A_gp': D_A_gp, 'D_B_gp': D_B_gp } def train_step(A, B): A2B, B2A, G_loss_dict = train_generators(A, B) # cannot autograph `A2B_pool` A2B = A2B_pool( A2B) # or A2B = A2B_pool(A2B.numpy()), but it is much slower B2A = B2A_pool(B2A) # because of the communication between CPU and GPU D_loss_dict = train_discriminators(A, B, A2B, B2A) return G_loss_dict, D_loss_dict @tf.function def sample(A, B): A2B = G_A2B(A, training=False) B2A = G_B2A(B, training=False) A2B2A = G_B2A(A2B, training=False) B2A2B = G_A2B(B2A, training=False) return A2B, B2A, A2B2A, B2A2B # ===================================== Runner code ===================================== # epoch counter ep_cnt = tf.Variable(initial_value=0, trainable=False, dtype=tf.int64) # checkpoint checkpoint = tf2lib.Checkpoint(dict(G_A2B=G_A2B, G_B2A=G_B2A, D_A=D_A, D_B=D_B, G_optimizer=G_optimizer, D_optimizer=D_optimizer, ep_cnt=ep_cnt), os.path.join(output_dir, 'checkpoints'), max_to_keep=5) try: # restore checkpoint including the epoch counter checkpoint.restore().assert_existing_objects_matched() except Exception as e: print(e) # summary train_summary_writer = tf.summary.create_file_writer( os.path.join(output_dir, 'summaries', 'train')) # sample test_iter = iter(A_B_dataset_test) sample_dir = os.path.join(output_dir, 'samples_training') os.makedirs(sample_dir, exist_ok=True) # main loop with train_summary_writer.as_default(): for ep in tqdm.trange(args.epochs, desc='Epoch Loop'): if ep < ep_cnt: continue # update epoch counter ep_cnt.assign_add(1) # train for an epoch for A, B in tqdm.tqdm(A_B_dataset, desc='Inner Epoch Loop', total=len_dataset): G_loss_dict, D_loss_dict = train_step(A, B) # # summary tf2lib.summary(G_loss_dict, step=G_optimizer.iterations, name='G_losses') tf2lib.summary(D_loss_dict, step=G_optimizer.iterations, name='D_losses') tf2lib.summary( {'learning rate': G_lr_scheduler.current_learning_rate}, step=G_optimizer.iterations, name='learning rate') # sample if G_optimizer.iterations.numpy() % 100 == 0: A, B = next(test_iter) A2B, B2A, A2B2A, B2A2B = sample(A, B) img = imlib.immerge(np.concatenate( [A, A2B, A2B2A, B, B2A, B2A2B], axis=0), n_rows=6) imlib.imwrite( img, os.path.join( sample_dir, 'iter-%09d.jpg' % G_optimizer.iterations.numpy())) # save checkpoint checkpoint.save(ep)
1) B2A_diff = B2A - B2A_dn B2A_diff = ( 2.0 * (B2A_diff - np.min(B2A_diff)) / np.ptp(B2A_diff) - 1) img = im.immerge( np.concatenate( [ A, A2B, A2B_diff, A2B_dn, A2B2A, B, B2A, B2A_diff, B2A_dn, B2A2B, ], axis=0, ), n_rows=2, ) else: A2B, B2A, A2B2A, B2A2B = sample(A, B) img = im.immerge( np.concatenate([A, A2B, A2B2A, B, B2A, B2A2B], axis=0), n_rows=2, )
G_loss_dict = train_G(x_real) tl.summary(G_loss_dict, step=G_optimizer.iterations, name='G_losses') # sample if G_optimizer.iterations.numpy() % 100 == 0: ground_truth = get_Mask(x_real) x1_real = get_PET(x_real) x2_real = get_CT(x_real) x1_fake, x2_fake = G(ground_truth, training=True) x1_fake = x1_fake[:,:,:,0] x2_fake = x2_fake[:,:,:,0] img1_real = im.immerge(x1_real.numpy(),n_rows=10) img2_real = im.immerge(x2_real.numpy(),n_rows=10) img3_real = im.immerge(ground_truth.numpy(),n_rows=10) img4_real = tf.concat([img1_real[:,:,0],img2_real[:,:,0],img3_real[:,:,0]],-1) im.imwrite(img4_real.numpy(), py.join(sample_dir, 'img4R-iter-%09d.jpg' % G_optimizer.iterations.numpy())) #print('\n Shape of the generated images:') #print('x1_fake.shape = ', x1_fake.shape) #print('x2_fake.shape = ', x2_fake.shape) img1 = im.immerge(x1_fake.numpy(), n_rows=10) img2 = im.immerge(x2_fake.numpy(), n_rows=10) img3 = im.immerge(ground_truth.numpy(), n_rows=10) img4 = tf.concat([img1,img2,img3[:,:,0]],-1) im.imwrite(img4.numpy(), py.join(sample_dir, 'img4-iter-%09d.jpg' % G_optimizer.iterations.numpy()))
_b_sample: b_sample_ipt }) x_sample_opt_list.append(x_sample_result) mask_sample_opt_list.append(mask_sample_result) last_images = x_sample_opt_list[-1] print('last_images shape: ', last_images.shape) if i > 0: # add a mark (+/-) in the upper-left corner to identify add/remove an attribute for nnn in range(last_images.shape[0]): last_images[nnn, 2:5, 0:7, :] = 1. if b_sample_ipt[nnn] > 0: last_images[nnn, 0:7, 2:5, :] = 1. last_images[nnn, 1:6, 3:4, :] = -1. last_images[nnn, 3:4, 1:6, :] = -1. sample = np.concatenate(x_sample_opt_list, 2) masks = np.concatenate(mask_sample_opt_list, 2) save_dir = './output/%s/sample_training' % experiment_name pylib.mkdir(save_dir) im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \ (save_dir, epoch, it_in_epoch, it_per_epoch)) im.imwrite(im.immerge(masks, n_sample, 1), '%s/Mask_Epoch_(%d)_(%dof%d).jpg' % \ (save_dir, epoch, it_in_epoch, it_per_epoch)) except: traceback.print_exc() finally: save_path = saver.save( sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) sess.close()
# batch data z_ipt = np.random.normal(size=[batch_size, z_dim]) g_summary_opt, _ = sess.run([g_summary, g_step], feed_dict={z: z_ipt}) summary_writer.add_summary(g_summary_opt, it) # display if it % 1 == 0: print("Epoch: (%3d) (%5d/%5d)" % (ep, i + 1, it_per_epoch)) # sample if it % 1000 == 0: f_sample_opt = sess.run(f_sample, feed_dict={ z_sample: z_ipt_sample }).squeeze() save_dir = './output/%s/sample_training' % experiment_name pylib.mkdir(save_dir) im.imwrite( im.immerge(f_sample_opt), '%s/Epoch_(%d)_(%dof%d).jpg' % (save_dir, ep, i + 1, it_per_epoch)) save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep)) print('Model is saved in file: %s' % save_path) except: traceback.print_exc() finally: sess.close()
use_gpu = torch.cuda.is_available() device = torch.device("cuda" if use_gpu else "cpu") py.mkdir(args.out_dir) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) img_paths = 'data/imagenet_small/train' data_loader, shape = data.make_custom_dataset(img_paths, 1, resize=32, pin_memory=use_gpu) num_samples = args.num_samples_per_class out_path = py.join(args.out_dir, args.output_dir) os.makedirs(out_path, exist_ok=True) while (num_samples > 0): for (x_real, labels) in iter(data_loader): if num_samples > 0: x_real = np.transpose(x_real.data.cpu().numpy(), (0, 2, 3, 1)) img = im.immerge(x_real, n_rows=1).squeeze() im.imwrite( img, py.join(out_path, 'img-%d-%d.jpg' % (num_samples, args.num))) num_samples -= 1 print('saving ', num_samples)
def train(self): ckpt_dir = self.config["projectCheckpoints"] epoch = self.config["totalEpoch"] n_d = self.config["dStep"] atts = self.config["selectedAttrs"] thres_int = self.config["thresInt"] test_int = self.config["sampleThresInt"] n_sample = self.config["sampleNum"] img_size = self.config["imsize"] sample_freq = self.config["sampleEpoch"] save_freq = self.config["modelSaveEpoch"] lr_base = self.config["gLr"] lrDecayEpoch = self.config["lrDecayEpoch"] n_att = len(self.config["selectedAttrs"]) if self.config["threads"] >= 0: cpu_config = tf.ConfigProto( intra_op_parallelism_threads=self.config["threads"] // 2, inter_op_parallelism_threads=self.config["threads"] // 2, device_count={'CPU': self.config["threads"]}) cpu_config.gpu_options.allow_growth = True sess = tf.Session(config=cpu_config) else: sess = tl.session() data_loader = Celeba(self.config["dataset_path"], self.config["selectedAttrs"], self.config["imsize"], self.config["batchSize"], part='train', sess=sess, crop=(self.config["imCropSize"] > 0)) val_loader = Celeba(self.config["dataset_path"], self.config["selectedAttrs"], self.config["imsize"], self.config["sampleNum"], part='val', shuffle=False, sess=sess, crop=(self.config["imCropSize"] > 0)) package = __import__("components." + self.config["modelScriptName"], fromlist=True) GencClass = getattr(package, 'Genc') GdecClass = getattr(package, 'Gdec') DClass = getattr(package, 'D') GP = getattr(package, "gradient_penalty") package = __import__("components.STU." + self.config["stuScriptName"], fromlist=True) GstuClass = getattr(package, 'Gstu') Genc = partial(GencClass, dim=self.config["GConvDim"], n_layers=self.config["GLayerNum"], multi_inputs=1) Gdec = partial(GdecClass, dim=self.config["GConvDim"], n_layers=self.config["GLayerNum"], shortcut_layers=self.config["skipNum"], inject_layers=self.config["injectLayers"], one_more_conv=self.config["oneMoreConv"]) Gstu = partial(GstuClass, dim=self.config["stuDim"], n_layers=self.config["skipNum"], inject_layers=self.config["skipNum"], kernel_size=self.config["stuKS"], norm=None, pass_state='stu') D = partial(DClass, n_att=n_att, dim=self.config["DConvDim"], fc_dim=self.config["DFcDim"], n_layers=self.config["DLayerNum"]) # inputs xa = data_loader.batch_op[0] a = data_loader.batch_op[1] b = tf.random_shuffle(a) _a = (tf.to_float(a) * 2 - 1) * self.config["thresInt"] _b = (tf.to_float(b) * 2 - 1) * self.config["thresInt"] xa_sample = tf.placeholder( tf.float32, shape=[None, self.config["imsize"], self.config["imsize"], 3]) _b_sample = tf.placeholder(tf.float32, shape=[None, n_att]) raw_b_sample = tf.placeholder(tf.float32, shape=[None, n_att]) lr = tf.placeholder(tf.float32, shape=[]) # generate z = Genc(xa) zb = Gstu(z, _b - _a) xb_ = Gdec(zb, _b - _a) with tf.control_dependencies([xb_]): za = Gstu(z, _a - _a) xa_ = Gdec(za, _a - _a) # discriminate xa_logit_gan, xa_logit_att = D(xa) xb__logit_gan, xb__logit_att = D(xb_) wd = tf.reduce_mean(xa_logit_gan) - tf.reduce_mean(xb__logit_gan) d_loss_gan = -wd gp = GP(D, xa, xb_) xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att) d_loss = d_loss_gan + gp * 10.0 + xa_loss_att xb__loss_gan = -tf.reduce_mean(xb__logit_gan) xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att) xa__loss_rec = tf.losses.absolute_difference(xa, xa_) g_loss = xb__loss_gan + xb__loss_att * 10.0 + xa__loss_rec * self.config[ "recWeight"] d_var = tl.trainable_variables('D') d_step = tf.train.AdamOptimizer( lr, beta1=self.config["beta1"]).minimize(d_loss, var_list=d_var) g_var = tl.trainable_variables('G') g_step = tf.train.AdamOptimizer( lr, beta1=self.config["beta1"]).minimize(g_loss, var_list=g_var) d_summary = tl.summary( { d_loss_gan: 'd_loss_gan', gp: 'gp', xa_loss_att: 'xa_loss_att', }, scope='D') lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate') g_summary = tl.summary( { xb__loss_gan: 'xb__loss_gan', xb__loss_att: 'xb__loss_att', xa__loss_rec: 'xa__loss_rec', }, scope='G') d_summary = tf.summary.merge([d_summary, lr_summary]) # sample test_label = _b_sample - raw_b_sample x_sample = Gdec(Gstu(Genc(xa_sample, is_training=False), test_label, is_training=False), test_label, is_training=False) it_cnt, update_cnt = tl.counter() # saver saver = tf.train.Saver(max_to_keep=self.config["max2Keep"]) # summary writer summary_writer = tf.summary.FileWriter(self.config["projectSummary"], sess.graph) # initialization if self.config["mode"] == "finetune": print("Continute train the model") tl.load_checkpoint(ckpt_dir, sess) print("Load previous model successfully!") else: print('Initializing all parameters...') sess.run(tf.global_variables_initializer()) # train try: # data for sampling xa_sample_ipt, a_sample_ipt = val_loader.get_next() b_sample_ipt_list = [a_sample_ipt ] # the first is for reconstruction for i in range(len(atts)): tmp = np.array(a_sample_ipt, copy=True) tmp[:, i] = 1 - tmp[:, i] # inverse attribute tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts) b_sample_ipt_list.append(tmp) it_per_epoch = len(data_loader) // (self.config["batchSize"] * (n_d + 1)) max_it = epoch * it_per_epoch print("Start to train the graph!") for it in range(sess.run(it_cnt), max_it): with pylib.Timer(is_output=False) as t: sess.run(update_cnt) # which epoch epoch = it // it_per_epoch it_in_epoch = it % it_per_epoch + 1 # learning rate lr_ipt = lr_base / (10**(epoch // lrDecayEpoch)) # train D for i in range(n_d): d_summary_opt, _ = sess.run([d_summary, d_step], feed_dict={lr: lr_ipt}) summary_writer.add_summary(d_summary_opt, it) # train G g_summary_opt, _ = sess.run([g_summary, g_step], feed_dict={lr: lr_ipt}) summary_writer.add_summary(g_summary_opt, it) # display if (it + 1) % 100 == 0: print("Epoch: (%3d) (%5d/%5d) Time: %s!" % (epoch, it_in_epoch, it_per_epoch, t)) # save if (it + 1) % (save_freq if save_freq else it_per_epoch) == 0: save_path = saver.save( sess, '%s/Epoch_(%d).ckpt' % (ckpt_dir, epoch)) print('Model is saved at %s!' % save_path) # sample if (it + 1) % (sample_freq if sample_freq else it_per_epoch) == 0: x_sample_opt_list = [ xa_sample_ipt, np.full((n_sample, img_size, img_size // 10, 3), -1.0) ] raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 - 1) * thres_int for i, b_sample_ipt in enumerate(b_sample_ipt_list): _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int if i > 0: # i == 0 is for reconstruction _b_sample_ipt[..., i - 1] = _b_sample_ipt[ ..., i - 1] * test_int / thres_int x_sample_opt_list.append( sess.run(x_sample, feed_dict={ xa_sample: xa_sample_ipt, _b_sample: _b_sample_ipt, raw_b_sample: raw_b_sample_ipt })) last_images = x_sample_opt_list[-1] if i > 0: # add a mark (+/-) in the upper-left corner to identify add/remove an attribute for nnn in range(last_images.shape[0]): last_images[nnn, 2:5, 0:7, :] = 1. if _b_sample_ipt[nnn, i - 1] > 0: last_images[nnn, 0:7, 2:5, :] = 1. last_images[nnn, 1:6, 3:4, :] = -1. last_images[nnn, 3:4, 1:6, :] = -1. sample = np.concatenate(x_sample_opt_list, 2) im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \ (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch)) except: traceback.print_exc() finally: save_path = saver.save( sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) sess.close()
def train_CycleGAN(): import logGPU_RAM # summary train_summary_writer = tf.summary.create_file_writer( py.join(output_dir, 'summaries', 'train')) logGPU_RAM.init_gpu_writers(py.join(output_dir, 'summaries', 'GPUs')) # sample test_iter = iter(A_B_dataset_test) sample_dir = py.join(output_dir, 'samples_training') py.mkdir(sample_dir) test_sample = next(test_iter) # timeing import time start_time = time.time() # main loop with train_summary_writer.as_default(): for ep in tqdm.trange(args.epochs, desc='Epoch Loop'): if ep < ep_cnt: continue # update epoch counter ep_cnt.assign_add(1) # train for an epoch for A, B in tqdm.tqdm(A_B_dataset, desc='Inner Epoch Loop', total=len_dataset): G_loss_dict, D_loss_dict = train_step(A, B) iteration = G_optimizer.iterations.numpy() # # summary tl.summary(G_loss_dict, step=iteration, name='G_losses') tl.summary(D_loss_dict, step=iteration, name='D_losses') tl.summary( {'learning rate': G_lr_scheduler.current_learning_rate}, step=iteration, name='learning rate') tl.summary( {'second since start': np.array(time.time() - start_time)}, step=iteration, name='second_Per_Iteration') logGPU_RAM.log_gpu_memory_to_tensorboard() # sample if iteration % 1000 == 0: A, B = next(test_iter) A2B, B2A, A2B2A, B2A2B = sample(A, B) img = im.immerge(np.concatenate( [A, A2B, A2B2A, B, B2A, B2A2B], axis=0), n_rows=2) im.imwrite( img, py.join(sample_dir, 'iter-%09d-sample-test-random.jpg' % iteration)) if iteration % 100 == 0: A, B = test_sample A2B, B2A, A2B2A, B2A2B = sample(A, B) img = im.immerge(np.concatenate( [A, A2B, A2B2A, B, B2A, B2A2B], axis=0), n_rows=2) im.imwrite( img, py.join( sample_dir, 'iter-%09d-sample-test-specific.jpg' % iteration)) # save checkpoint checkpoint.save(ep)
tl.summary(G_loss_dict, step=G_optimizer.iterations, name='G_losses') tl.summary(D_loss_dict, step=G_optimizer.iterations, name='D_losses') tl.summary({'learning rate': G_lr_scheduler.current_learning_rate}, step=G_optimizer.iterations, name='learning rate') # sample if G_optimizer.iterations.numpy() % 100 == 0: A, B = next(test_iter) A2B, B2A, A2B2A, B2A2B = sample(A, B) img_sum = im.immerge(np.concatenate( [A, A2B, A2B2A, B, B2A, B2A2B], axis=0), n_rows=2) print('MSE before GAN: ', MSE(im.immerge(B), im.immerge(A))) print('MSE after GAN: ', MSE(im.immerge(B), im.immerge(A2B))) print('NCC before GAN: ', NCC(im.immerge(B), im.immerge(A))) print('NCC after GAN: ', NCC(im.immerge(B), im.immerge(A2B))) print('SSIM before GAN: ', SSIM(im.immerge(B), im.immerge(A))) print('SSIM after GAN: ', SSIM(im.immerge(B), im.immerge(A2B))) im.imwrite( img_sum, py.join( sample_dir, 'iter-%09d-overview.png' % G_optimizer.iterations.numpy())) im.imwrite( im.immerge(A), py.join(
if (it + 1) % 2000 == 0: save_dir = './output/%s/sample_training' % experiment_name pylib.mkdir(save_dir) img_rec_opt_sample, img_intp_opt_sample = sess.run( [img_rec_sample, img_intp_sample], feed_dict={img: img_ipt_sample}) img_rec_opt_sample, img_intp_opt_sample = img_rec_opt_sample.squeeze( ), img_intp_opt_sample.squeeze() # ipt_rec = np.concatenate((img_ipt_sample, img_rec_opt_sample), axis=2).squeeze() img_opt_sample = sess.run(img_sample).squeeze() # im.imwrite(im.immerge(ipt_rec, padding=img_shape[0] // 8), # '%s/Epoch_(%d)_(%dof%d)_img_rec.png' % (save_dir, ep, it_in_epoch, it_per_epoch)) im.imwrite( im.immerge(img_intp_opt_sample, n_col=1, padding=0), '%s/Epoch_(%d)_(%dof%d)_img_intp.png' % (save_dir, ep, it_in_epoch, it_per_epoch)) im.imwrite( im.immerge(img_opt_sample), '%s/Epoch_(%d)_(%dof%d)_img_sample.png' % (save_dir, ep, it_in_epoch, it_per_epoch)) if fid_stats_path: try: mu_gen, sigma_gen = fid.calculate_activation_statistics( im.im2uint( np.concatenate([ sess.run(fid_sample).squeeze() for _ in range(5) ], 0)),
# sample if (it + 1) % 1000 == 0: save_dir = './output/%s/sample_training' % experiment_name pylib.mkdir(save_dir) img_rec_opt_sample = sess.run(img_rec_sample, feed_dict={img: img_ipt_sample}) ipt_rec = np.concatenate((img_ipt_sample, img_rec_opt_sample), axis=2).squeeze() img_opt_sample = sess.run(img_sample, feed_dict={ z_sample: z_ipt_sample }).squeeze() im.imwrite( im.immerge(ipt_rec, padding=img_shape[0] // 8), '%s/Epoch_(%d)_(%dof%d)_img_rec.jpg' % (save_dir, ep, it_in_epoch, it_per_epoch)) im.imwrite( im.immerge(img_opt_sample), '%s/Epoch_(%d)_(%dof%d)_img_sample.jpg' % (save_dir, ep, it_in_epoch, it_per_epoch)) save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep)) print('Model is saved in file: %s' % save_path) except: traceback.print_exc() finally: sess.close()
sample_dir = py.join(output_dir, 'samples_training') py.mkdir(sample_dir) # main loop z = tf.random.normal((100, 1, 1, args.z_dim)) # a fixed noise for sampling with train_summary_writer.as_default(): for ep in tqdm.trange(args.epochs, desc='Epoch Loop'): if ep < ep_cnt: continue # update epoch counter ep_cnt.assign_add(1) # train for an epoch for x_real in tqdm.tqdm(dataset, desc='Inner Epoch Loop', total=len_dataset): D_loss_dict = train_D(x_real) tl.summary(D_loss_dict, step=D_optimizer.iterations, name='D_losses') if D_optimizer.iterations.numpy() % args.n_d == 0: G_loss_dict = train_G() tl.summary(G_loss_dict, step=G_optimizer.iterations, name='G_losses') # sample if G_optimizer.iterations.numpy() % 100 == 0: x_fake = sample(z) img = im.immerge(x_fake, n_rows=10).squeeze() im.imwrite(img, py.join(sample_dir, 'iter-%09d.jpg' % G_optimizer.iterations.numpy())) # save checkpoint checkpoint.save(ep)
x_sample_opt_list = [ xa_sample_ipt, np.full((n_sample, img_size, img_size // 10, 3), -1.0) ] for i, b_sample_ipt in enumerate(b_sample_ipt_list): _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int if i > 0: # i == 0 is for reconstruction _b_sample_ipt[..., i - 1] = _b_sample_ipt[ ..., i - 1] * test_int / thres_int x_sample_opt_list.append( sess.run(x_sample, feed_dict={ xa_sample: xa_sample_ipt, _b_sample: _b_sample_ipt })) sample = np.concatenate(x_sample_opt_list, 2) save_dir = './' + experiment_dir + '/%s/sample_training' % experiment_name pylib.mkdir(save_dir) im.imwrite( im.immerge(sample, n_sample, 1), '%s/Epoch_%d_%dof%d.jpg' % (save_dir, epoch, it_in_epoch, it_per_epoch)) except: traceback.print_exc() finally: save_path = saver.save( sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) sess.close()
summary_writer.add_summary(g_summary_opt, it) # display if (it + 1) % 1 == 0: print("Epoch: (%3d) (%5d/%5d) Time: %s!" % (epoch, it_in_epoch, it_per_epoch, t)) # save if (it + 1) % 1000 == 0: save_path = saver.save(sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) # sample if (it + 1) % 100 == 0: x_sample_opt_list = [xa_sample_ipt, np.full((n_sample, img_size, img_size // 10, 3), -1.0)] for i, b_sample_ipt in enumerate(b_sample_ipt_list): _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int if i > 0: # i == 0 is for reconstruction _b_sample_ipt[..., i - 1] = _b_sample_ipt[..., i - 1] * test_int / thres_int x_sample_opt_list.append(sess.run(x_sample, feed_dict={xa_sample: xa_sample_ipt, _b_sample: _b_sample_ipt})) sample = np.concatenate(x_sample_opt_list, 2) save_dir = './output/%s/sample_training' % experiment_name pylib.mkdir(save_dir) im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % (save_dir, epoch, it_in_epoch, it_per_epoch)) except: traceback.print_exc() finally: save_path = saver.save(sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) sess.close()
def run(epoch, iter): x_f_opt = sess.run(x_f) sample = im.immerge(x_f_opt, n_rows=int(args.n_samples**0.5)) im.imwrite(sample, '%s/Epoch-%d_Iter-%d.jpg' % (save_dir, epoch, iter))