def D_train_graph(): # ====================================== # = graph = # ====================================== # placeholders & inputs lr = tf.placeholder(dtype=tf.float32, shape=[]) xa, a = train_iter.get_next() b = tf.random_shuffle(a) a_ = a * 2 - 1 b_ = b * 2 - 1 # generate xb, _, ms, _ = G(xa, b_ - a_) # discriminate xa_logit_gan, xa_logit_att = D(xa) xb_logit_gan, xb_logit_att = D(xb) # discriminator losses xa_loss_gan, xb_loss_gan = d_loss_fn(xa_logit_gan, xb_logit_gan) gp = tfprob.gradient_penalty(lambda x: D(x)[0], xa, xb, args.gradient_penalty_mode, args.gradient_penalty_sample_mode) xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att) reg_loss = tf.reduce_sum(D.func.reg_losses) loss = (xa_loss_gan + xb_loss_gan + gp * args.d_gradient_penalty_weight + xa_loss_att * args.d_attribute_loss_weight + reg_loss) # optim step_cnt, _ = tl.counter() step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize( loss, global_step=step_cnt, var_list=D.func.trainable_variables) # summary with tf.contrib.summary.create_file_writer('./output/%s/summaries/D' % args.experiment_name).as_default(),\ tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt): summary = [ tl.summary_v2( { 'loss_gan': xa_loss_gan + xb_loss_gan, 'gp': gp, 'xa_loss_att': xa_loss_att, 'reg_loss': reg_loss }, step=step_cnt, name='D'), tl.summary_v2({'lr': lr}, step=step_cnt, name='learning_rate') ] # ====================================== # = run function = # ====================================== def run(**pl_ipts): sess.run([step, summary], feed_dict={lr: pl_ipts['lr']}) return run
xb__loss_att: 'xb__loss_att', xa__loss_rec: 'xa__loss_rec', }, scope='G') d_summary = tf.summary.merge([d_summary, lr_summary]) # sample x_sample = Gdec(Genc(xa_sample, is_training=False), _b_sample, is_training=False) # ============================================================================== # = train = # ============================================================================== # iteration counter it_cnt, update_cnt = tl.counter() # saver saver = tf.train.Saver(max_to_keep=1) # summary writer summary_writer = tf.summary.FileWriter('./output/%s/summaries' % experiment_name, sess.graph) # initialization ckpt_dir = './output/%s/checkpoints' % experiment_name pylib.mkdir(ckpt_dir) try: tl.load_checkpoint(ckpt_dir, sess) except: sess.run(tf.global_variables_initializer())
def G_train_graph(): # ====================================== # = graph = # ====================================== # placeholders & inputs lr = tf.placeholder(dtype=tf.float32, shape=[]) xa, a = train_iter.get_next() b = tf.random_shuffle(a) a_ = a * 2 - 1 b_ = b * 2 - 1 # generate xb, _, ms, ms_multi = G(xa, b_ - a_) # discriminate xb_logit_gan, xb_logit_att = D(xb) # generator losses xb_loss_gan = g_loss_fn(xb_logit_gan) xb_loss_att = tf.losses.sigmoid_cross_entropy(b, xb_logit_att) spasity_loss = tf.reduce_sum([ tf.reduce_mean(m) * w for m, w in zip(ms, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) ]) full_overlap_mask_pair_loss, non_overlap_mask_pair_loss = module.overlap_loss_fn( ms_multi, args.att_names) reg_loss = tf.reduce_sum(G.func.reg_losses) loss = ( xb_loss_gan + xb_loss_att * args.g_attribute_loss_weight + spasity_loss * args.g_spasity_loss_weight + full_overlap_mask_pair_loss * args.g_full_overlap_mask_pair_loss_weight + non_overlap_mask_pair_loss * args.g_non_overlap_mask_pair_loss_weight + reg_loss) # optim step_cnt, _ = tl.counter() step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize( loss, global_step=step_cnt, var_list=G.func.trainable_variables) # summary with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\ tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt): summary = tl.summary_v2( { 'xb_loss_gan': xb_loss_gan, 'xb_loss_att': xb_loss_att, 'spasity_loss': spasity_loss, 'full_overlap_mask_pair_loss': full_overlap_mask_pair_loss, 'non_overlap_mask_pair_loss': non_overlap_mask_pair_loss, 'reg_loss': reg_loss }, step=step_cnt, name='G') # ====================================== # = generator size = # ====================================== n_params, n_bytes = tl.count_parameters(G.func.variables) print('Generator Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024)) # ====================================== # = run function = # ====================================== def run(**pl_ipts): sess.run([step, summary], feed_dict={lr: pl_ipts['lr']}) return run
def G_train_graph(): # ====================================== # = graph = # ====================================== # placeholders & inputs lr = tf.placeholder(dtype=tf.float32, shape=[]) xa, a = train_iter.get_next() b = tf.random_shuffle(a) a_ = a * 2 - 1 b_ = b * 2 - 1 # generate z = Genc(xa) xa_ = Gdec(z, a_) xb_ = Gdec(z, b_) # discriminate xb__logit_gan, xb__logit_att = D(xb_) # generator losses xb__loss_gan = g_loss_fn(xb__logit_gan) xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att) xa__loss_rec = tf.losses.absolute_difference(xa, xa_) reg_loss = tf.reduce_sum(Genc.func.reg_losses + Gdec.func.reg_losses) loss = (xb__loss_gan + xb__loss_att * args.g_attribute_loss_weight + xa__loss_rec * args.g_reconstruction_loss_weight + reg_loss) # optim step_cnt, _ = tl.counter() step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(loss, global_step=step_cnt, var_list=Genc.func.trainable_variables + Gdec.func.trainable_variables) # summary with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\ tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt): summary = tl.summary_v2({ 'xb__loss_gan': xb__loss_gan, 'xb__loss_att': xb__loss_att, 'xa__loss_rec': xa__loss_rec, 'reg_loss': reg_loss }, step=step_cnt, name='G') # ====================================== # = generator size = # ====================================== n_params, n_bytes = tl.count_parameters(Genc.func.variables + Gdec.func.variables) print('Generator Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024)) # ====================================== # = run function = # ====================================== def run(**pl_ipts): sess.run([step, summary], feed_dict={lr: pl_ipts['lr']}) return run
def train(self): it_cnt, update_cnt = tl.counter() # saver saver = tf.train.Saver(max_to_keep=10) # summary writer summary_writer = tf.summary.FileWriter(self.config["projectSummary"], self.sess.graph) # initialization ckpt_dir = self.config["projectCheckpoints"] epoch = self.config["totalEpoch"] n_d = self.config["dStep"] atts = self.config["selectedAttrs"] thres_int = self.config["thresInt"] test_int = self.config["sampleThresInt"] n_sample = self.config["sampleNum"] img_size = self.config["imsize"] sample_freq = self.config["sampleEpoch"] save_freq = self.config["modelSaveEpoch"] lr_base = self.config["gLr"] lrDecayEpoch = self.config["lrDecayEpoch"] try: assert clear == False tl.load_checkpoint(ckpt_dir, self.sess) except: print('NOTE: Initializing all parameters...') self.sess.run(tf.global_variables_initializer()) # train try: # data for sampling xa_sample_ipt, a_sample_ipt = self.val_loader.get_next() b_sample_ipt_list = [a_sample_ipt ] # the first is for reconstruction for i in range(len(atts)): tmp = np.array(a_sample_ipt, copy=True) tmp[:, i] = 1 - tmp[:, i] # inverse attribute tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts) b_sample_ipt_list.append(tmp) it_per_epoch = len(self.data_loader) // (self.config["batchSize"] * (n_d + 1)) max_it = epoch * it_per_epoch for it in range(self.sess.run(it_cnt), max_it): with pylib.Timer(is_output=False) as t: self.sess.run(update_cnt) # which epoch epoch = it // it_per_epoch it_in_epoch = it % it_per_epoch + 1 # learning rate lr_ipt = lr_base / (10**(epoch // lrDecayEpoch)) # train D for i in range(n_d): d_summary_opt, _ = self.sess.run( [self.d_summary, self.d_step], feed_dict={self.lr: lr_ipt}) summary_writer.add_summary(d_summary_opt, it) # train G g_summary_opt, _ = self.sess.run( [self.g_summary, self.g_step], feed_dict={self.lr: lr_ipt}) summary_writer.add_summary(g_summary_opt, it) # display if (it + 1) % 100 == 0: print("Epoch: (%3d) (%5d/%5d) Time: %s!" % (epoch, it_in_epoch, it_per_epoch, t)) # save if (it + 1) % (save_freq if save_freq else it_per_epoch) == 0: save_path = saver.save( self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) # sample if (it + 1) % (sample_freq if sample_freq else it_per_epoch) == 0: x_sample_opt_list = [ xa_sample_ipt, np.full((n_sample, img_size, img_size // 10, 3), -1.0) ] raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 - 1) * thres_int for i, b_sample_ipt in enumerate(b_sample_ipt_list): _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int if i > 0: # i == 0 is for reconstruction _b_sample_ipt[..., i - 1] = _b_sample_ipt[ ..., i - 1] * test_int / thres_int x_sample_opt_list.append( self.sess.run(self.x_sample, feed_dict={ self.xa_sample: xa_sample_ipt, self._b_sample: _b_sample_ipt, self.raw_b_sample: raw_b_sample_ipt })) last_images = x_sample_opt_list[-1] if i > 0: # add a mark (+/-) in the upper-left corner to identify add/remove an attribute for nnn in range(last_images.shape[0]): last_images[nnn, 2:5, 0:7, :] = 1. if _b_sample_ipt[nnn, i - 1] > 0: last_images[nnn, 0:7, 2:5, :] = 1. last_images[nnn, 1:6, 3:4, :] = -1. last_images[nnn, 3:4, 1:6, :] = -1. sample = np.concatenate(x_sample_opt_list, 2) im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \ (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch)) except: traceback.print_exc() finally: save_path = saver.save( self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) self.sess.close()
def train(self): ckpt_dir = self.config["projectCheckpoints"] epoch = self.config["totalEpoch"] n_d = self.config["dStep"] atts = self.config["selectedAttrs"] thres_int = self.config["thresInt"] test_int = self.config["sampleThresInt"] n_sample = self.config["sampleNum"] img_size = self.config["imsize"] sample_freq = self.config["sampleEpoch"] save_freq = self.config["modelSaveEpoch"] lr_base = self.config["gLr"] lrDecayEpoch = self.config["lrDecayEpoch"] n_att = len(self.config["selectedAttrs"]) if self.config["threads"] >= 0: cpu_config = tf.ConfigProto( intra_op_parallelism_threads=self.config["threads"] // 2, inter_op_parallelism_threads=self.config["threads"] // 2, device_count={'CPU': self.config["threads"]}) cpu_config.gpu_options.allow_growth = True sess = tf.Session(config=cpu_config) else: sess = tl.session() data_loader = Celeba(self.config["dataset_path"], self.config["selectedAttrs"], self.config["imsize"], self.config["batchSize"], part='train', sess=sess, crop=(self.config["imCropSize"] > 0)) val_loader = Celeba(self.config["dataset_path"], self.config["selectedAttrs"], self.config["imsize"], self.config["sampleNum"], part='val', shuffle=False, sess=sess, crop=(self.config["imCropSize"] > 0)) package = __import__("components." + self.config["modelScriptName"], fromlist=True) GencClass = getattr(package, 'Genc') GdecClass = getattr(package, 'Gdec') DClass = getattr(package, 'D') GP = getattr(package, "gradient_penalty") package = __import__("components.STU." + self.config["stuScriptName"], fromlist=True) GstuClass = getattr(package, 'Gstu') Genc = partial(GencClass, dim=self.config["GConvDim"], n_layers=self.config["GLayerNum"], multi_inputs=1) Gdec = partial(GdecClass, dim=self.config["GConvDim"], n_layers=self.config["GLayerNum"], shortcut_layers=self.config["skipNum"], inject_layers=self.config["injectLayers"], one_more_conv=self.config["oneMoreConv"]) Gstu = partial(GstuClass, dim=self.config["stuDim"], n_layers=self.config["skipNum"], inject_layers=self.config["skipNum"], kernel_size=self.config["stuKS"], norm=None, pass_state='stu') D = partial(DClass, n_att=n_att, dim=self.config["DConvDim"], fc_dim=self.config["DFcDim"], n_layers=self.config["DLayerNum"]) # inputs xa = data_loader.batch_op[0] a = data_loader.batch_op[1] b = tf.random_shuffle(a) _a = (tf.to_float(a) * 2 - 1) * self.config["thresInt"] _b = (tf.to_float(b) * 2 - 1) * self.config["thresInt"] xa_sample = tf.placeholder( tf.float32, shape=[None, self.config["imsize"], self.config["imsize"], 3]) _b_sample = tf.placeholder(tf.float32, shape=[None, n_att]) raw_b_sample = tf.placeholder(tf.float32, shape=[None, n_att]) lr = tf.placeholder(tf.float32, shape=[]) # generate z = Genc(xa) zb = Gstu(z, _b - _a) xb_ = Gdec(zb, _b - _a) with tf.control_dependencies([xb_]): za = Gstu(z, _a - _a) xa_ = Gdec(za, _a - _a) # discriminate xa_logit_gan, xa_logit_att = D(xa) xb__logit_gan, xb__logit_att = D(xb_) wd = tf.reduce_mean(xa_logit_gan) - tf.reduce_mean(xb__logit_gan) d_loss_gan = -wd gp = GP(D, xa, xb_) xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att) d_loss = d_loss_gan + gp * 10.0 + xa_loss_att xb__loss_gan = -tf.reduce_mean(xb__logit_gan) xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att) xa__loss_rec = tf.losses.absolute_difference(xa, xa_) g_loss = xb__loss_gan + xb__loss_att * 10.0 + xa__loss_rec * self.config[ "recWeight"] d_var = tl.trainable_variables('D') d_step = tf.train.AdamOptimizer( lr, beta1=self.config["beta1"]).minimize(d_loss, var_list=d_var) g_var = tl.trainable_variables('G') g_step = tf.train.AdamOptimizer( lr, beta1=self.config["beta1"]).minimize(g_loss, var_list=g_var) d_summary = tl.summary( { d_loss_gan: 'd_loss_gan', gp: 'gp', xa_loss_att: 'xa_loss_att', }, scope='D') lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate') g_summary = tl.summary( { xb__loss_gan: 'xb__loss_gan', xb__loss_att: 'xb__loss_att', xa__loss_rec: 'xa__loss_rec', }, scope='G') d_summary = tf.summary.merge([d_summary, lr_summary]) # sample test_label = _b_sample - raw_b_sample x_sample = Gdec(Gstu(Genc(xa_sample, is_training=False), test_label, is_training=False), test_label, is_training=False) it_cnt, update_cnt = tl.counter() # saver saver = tf.train.Saver(max_to_keep=self.config["max2Keep"]) # summary writer summary_writer = tf.summary.FileWriter(self.config["projectSummary"], sess.graph) # initialization if self.config["mode"] == "finetune": print("Continute train the model") tl.load_checkpoint(ckpt_dir, sess) print("Load previous model successfully!") else: print('Initializing all parameters...') sess.run(tf.global_variables_initializer()) # train try: # data for sampling xa_sample_ipt, a_sample_ipt = val_loader.get_next() b_sample_ipt_list = [a_sample_ipt ] # the first is for reconstruction for i in range(len(atts)): tmp = np.array(a_sample_ipt, copy=True) tmp[:, i] = 1 - tmp[:, i] # inverse attribute tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts) b_sample_ipt_list.append(tmp) it_per_epoch = len(data_loader) // (self.config["batchSize"] * (n_d + 1)) max_it = epoch * it_per_epoch print("Start to train the graph!") for it in range(sess.run(it_cnt), max_it): with pylib.Timer(is_output=False) as t: sess.run(update_cnt) # which epoch epoch = it // it_per_epoch it_in_epoch = it % it_per_epoch + 1 # learning rate lr_ipt = lr_base / (10**(epoch // lrDecayEpoch)) # train D for i in range(n_d): d_summary_opt, _ = sess.run([d_summary, d_step], feed_dict={lr: lr_ipt}) summary_writer.add_summary(d_summary_opt, it) # train G g_summary_opt, _ = sess.run([g_summary, g_step], feed_dict={lr: lr_ipt}) summary_writer.add_summary(g_summary_opt, it) # display if (it + 1) % 100 == 0: print("Epoch: (%3d) (%5d/%5d) Time: %s!" % (epoch, it_in_epoch, it_per_epoch, t)) # save if (it + 1) % (save_freq if save_freq else it_per_epoch) == 0: save_path = saver.save( sess, '%s/Epoch_(%d).ckpt' % (ckpt_dir, epoch)) print('Model is saved at %s!' % save_path) # sample if (it + 1) % (sample_freq if sample_freq else it_per_epoch) == 0: x_sample_opt_list = [ xa_sample_ipt, np.full((n_sample, img_size, img_size // 10, 3), -1.0) ] raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 - 1) * thres_int for i, b_sample_ipt in enumerate(b_sample_ipt_list): _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int if i > 0: # i == 0 is for reconstruction _b_sample_ipt[..., i - 1] = _b_sample_ipt[ ..., i - 1] * test_int / thres_int x_sample_opt_list.append( sess.run(x_sample, feed_dict={ xa_sample: xa_sample_ipt, _b_sample: _b_sample_ipt, raw_b_sample: raw_b_sample_ipt })) last_images = x_sample_opt_list[-1] if i > 0: # add a mark (+/-) in the upper-left corner to identify add/remove an attribute for nnn in range(last_images.shape[0]): last_images[nnn, 2:5, 0:7, :] = 1. if _b_sample_ipt[nnn, i - 1] > 0: last_images[nnn, 0:7, 2:5, :] = 1. last_images[nnn, 1:6, 3:4, :] = -1. last_images[nnn, 3:4, 1:6, :] = -1. sample = np.concatenate(x_sample_opt_list, 2) im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \ (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch)) except: traceback.print_exc() finally: save_path = saver.save( sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) sess.close()
def train_on_fake(dataset, c_dim, result_dir, gpu_id, use_real=0, epoch_=200): """ param """ batch_size = 64 batch_size_fake = batch_size lr = 0.0002 ''' data ''' if use_real == 1: print('======Using real data======') batch_size_real = batch_size // 2 batch_size_fake = batch_size - batch_size_real train_tfrecord_path_real = './tfrecords/celeba_tfrecord_train' train_data_pool_real = tl.TfrecordData(train_tfrecord_path_real, batch_size_real, shuffle=True) train_tfrecord_path_fake = os.path.join(result_dir, 'synthetic_tfrecord') train_data_pool_fake = tl.TfrecordData(train_tfrecord_path_fake, batch_size_fake, shuffle=True) if dataset == 'CelebA': test_tfrecord_path = './tfrecords/celeba_tfrecord_test' elif dataset == 'RaFD': test_tfrecord_path = './tfrecords/rafd_test' test_data_pool = tl.TfrecordData(test_tfrecord_path, 120) att_dim = c_dim """ graphs """ with tf.device('/gpu:{}'.format(gpu_id)): ''' models ''' classifier = models.classifier ''' graph ''' # inputs x_255 = tf.placeholder(tf.float32, shape=[None, 128, 128, 3]) x = x_255 / 127.5 - 1 if dataset == 'CelebA': att = tf.placeholder(tf.int64, shape=[None, att_dim]) elif dataset == 'RaFD': att = tf.placeholder(tf.float32, shape=[None, att_dim]) # classify logits = classifier(x, att_dim=att_dim, reuse=False) # loss reg_loss = tf.losses.get_regularization_loss() if dataset == 'CelebA': loss = tf.losses.sigmoid_cross_entropy(att, logits) + reg_loss acc = mean_accuracy_multi_binary_label_with_logits(att, logits) elif dataset == 'RaFD': loss = tf.losses.softmax_cross_entropy(att, logits) + reg_loss acc = mean_accuracy_one_hot_label_with_logits(att, logits) lr_ = tf.placeholder(tf.float32, shape=[]) # optim #with tf.variable_scope('Adam', reuse=tf.AUTO_REUSE): step = tf.train.AdamOptimizer(lr_, beta1=0.9).minimize(loss) # test test_logits = classifier(x, att_dim=att_dim, training=False) if dataset == 'CelebA': test_acc = mean_accuracy_multi_binary_label_with_logits(att, test_logits) elif dataset == 'RaFD': test_acc = mean_accuracy_one_hot_label_with_logits(att, test_logits) mean_acc = tf.placeholder(tf.float32, shape=()) # summary summary = tl.summary({loss: 'loss', acc: 'acc'}) test_summary = tl.summary({mean_acc: 'test_acc'}) """ train """ ''' init ''' # session sess = tf.Session() # iteration counter it_cnt, update_cnt = tl.counter() # saver saver = tf.train.Saver(max_to_keep=None) # summary writer sum_dir = os.path.join(result_dir, 'summaries_train_on_fake') if use_real == 1: sum_dir += '_real' summary_writer = tf.summary.FileWriter(sum_dir, sess.graph) ''' initialization ''' ckpt_dir = os.path.join(result_dir, 'checkpoints_train_on_fake') if use_real == 1: ckpt_dir += '_real' if not os.path.exists(ckpt_dir): os.mkdir(ckpt_dir + '/') if not tl.load_checkpoint(ckpt_dir, sess): sess.run(tf.global_variables_initializer()) ''' train ''' try: batch_epoch = len(train_data_pool_fake) // batch_size max_it = epoch_ * batch_epoch for it in range(sess.run(it_cnt), max_it): bth = it//batch_epoch - 8 lr__ = lr*(1-max(bth, 0)/epoch_)**0.75 if it % batch_epoch == 0: print('======learning rate:', lr__, '======') sess.run(update_cnt) # which epoch epoch = it // batch_epoch it_epoch = it % batch_epoch + 1 x_255_ipt, att_ipt = train_data_pool_fake.batch(['img', 'attr']) if dataset == 'RaFD': att_ipt = ToOnehot(att_ipt, att_dim) if use_real == 1: x_255_ipt_real, att_ipt_real = train_data_pool_real.batch(['img', 'class']) x_255_ipt = np.concatenate([x_255_ipt, x_255_ipt_real]) att_ipt = np.concatenate([att_ipt, att_ipt_real]) summary_opt, _ = sess.run([summary, step], feed_dict={x_255: x_255_ipt, att: att_ipt, lr_:lr__}) summary_writer.add_summary(summary_opt, it) # display if (it + 1) % batch_epoch == 0: print("Epoch: (%3d) (%5d/%5d)" % (epoch, it_epoch, batch_epoch)) # save if (it + 1) % (batch_epoch * 50) == 0: save_path = saver.save(sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_epoch, batch_epoch)) print('Model saved in file: % s' % save_path) # sample if it % 100 == 0: test_it = 100 if dataset == 'CelebA' else 7 test_acc_opt_list = [] for i in range(test_it): key = 'class' if dataset == 'CelebA' else 'attr' x_255_ipt, att_ipt = test_data_pool.batch(['img', key]) if dataset == 'RaFD': att_ipt = ToOnehot(att_ipt, att_dim) test_acc_opt = sess.run(test_acc, feed_dict={x_255: x_255_ipt, att: att_ipt}) test_acc_opt_list.append(test_acc_opt) test_summary_opt = sess.run(test_summary, feed_dict={mean_acc: np.mean(test_acc_opt_list)}) summary_writer.add_summary(test_summary_opt, it) except Exception: traceback.print_exc() finally: print(" [*] Close main session!") sess.close()
{l: 'g_f_tree_loss_%d' % i for i, l in enumerate(g_f_tree_losses)}, scope='G_Tree') g_summary = tf.summary.merge([g_summary, g_tree_summary]) # sample z_sample = tf.placeholder(tf.float32, [None, z_dim]) c_sample = tf.placeholder(tf.float32, [None, c_dim]) f_sample = G(z_sample, c_sample, is_training=False) # ============================================================================== # = train = # ============================================================================== # epoch counter ep_cnt, update_cnt = tl.counter(start=1) # session sess = tl.session() # saver saver = tf.train.Saver(max_to_keep=1) # summary writer summary_writer = tf.summary.FileWriter( './output/%s/summaries' % experiment_name, sess.graph) # initialization ckpt_dir = './output/%s/checkpoints' % experiment_name pylib.mkdir(ckpt_dir) try:
scope='D') lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate') generator_summary = tl.summary( { discriminator_label_loss: 'discriminator_label_loss', label_decoder_loss: 'label_decoder_loss', image_decoder_loss: 'image_decode_loss', }, scope='G') discriminator_summary = tf.summary.merge([discriminator_summary, lr_summary]) # Iteration counter iteration_counter, update_counter = tl.counter() # Saver for model saver = tf.train.Saver(max_to_keep=1) # Logging information summary_writer = tf.summary.FileWriter( './output/%s/summaries' % experiment_name, sess.graph) # Check if training was already begun checkpoint_dir = './output/%s/checkpoints' % experiment_name pylib.mkdir(checkpoint_dir) try: tl.load_checkpoint(checkpoint_dir, sess) except: sess.run(tf.global_variables_initializer())
def G_train_graph(): # ====================================== # = graph = # ====================================== # placeholders & inputs lr = tf.placeholder(dtype=tf.float32, shape=[]) zs = [tf.random.normal([args.batch_size, z_dim]) for z_dim in args.z_dims] eps = tf.random.normal([args.batch_size, args.eps_dim]) # counter step_cnt, _ = tl.counter() # optimizer optimizer = tf.train.AdamOptimizer(lr, beta1=args.beta_1) def graph_per_gpu(zs, eps): # generate x_f = G(zs, eps) # discriminate x_f_logit = D(x_f) # loss x_f_loss = g_loss_fn(x_f_logit) orth_loss = tf.reduce_sum( tl.tensors_filter(G.func.reg_losses, 'orthogonal_regularizer')) reg_loss = tf.reduce_sum( tl.tensors_filter(G.func.reg_losses, 'l2_regularizer')) loss = (x_f_loss * args.g_loss_weight_x_gan + orth_loss * args.g_loss_weight_orth_loss + reg_loss * args.weight_decay) # optim grads = optimizer.compute_gradients( loss, var_list=G.func.trainable_variables) return grads, x_f_loss, orth_loss, reg_loss split_grads, split_x_f_loss, split_orth_loss, split_reg_loss = zip( *tl.parellel_run(tl.gpus(), graph_per_gpu, tl.split_nest((zs, eps), len(tl.gpus())))) # split_grads, split_x_f_loss, split_orth_loss, split_reg_loss = zip(*tl.parellel_run(['cpu:0'], graph_per_gpu, tl.split_nest((zs, eps), 1))) grads = tl.average_gradients(split_grads) x_f_loss, orth_loss, reg_loss = [ tf.reduce_mean(t) for t in [split_x_f_loss, split_orth_loss, split_reg_loss] ] step = optimizer.apply_gradients(grads, global_step=step_cnt) # moving average with tf.control_dependencies([step]): step = G_ema.apply(G.func.trainable_variables) # summary summary_dict = { 'x_f_loss': x_f_loss, 'orth_loss': orth_loss, 'reg_loss': reg_loss } summary_dict.update({ 'L_%d' % i: t for i, t in enumerate(tl.tensors_filter(G.func.variables, 'L')) }) summary_loss = tl.create_summary_statistic_v2(summary_dict, './output/%s/summaries/G' % args.experiment_name, step=step_cnt, n_steps_per_record=10, name='G_loss') summary_image = tl.create_summary_image_v2( { 'orth_U_%d' % i: t[None, :, :, None] for i, t in enumerate(tf.get_collection('orth', G.func.scope + '/')) }, './output/%s/summaries/G' % args.experiment_name, step=step_cnt, n_steps_per_record=10, name='G_image') # ====================================== # = model size = # ====================================== n_params, n_bytes = tl.count_parameters(G.func.variables) print('Model Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024)) # ====================================== # = run function = # ====================================== def run(**pl_ipts): sess.run([step, summary_loss, summary_image], feed_dict={lr: pl_ipts['lr']}) return run
def D_train_graph(): # ====================================== # = graph = # ====================================== # placeholders & inputs lr = tf.placeholder(dtype=tf.float32, shape=[]) x_r = train_iter.get_next() zs = [tf.random.normal([args.batch_size, z_dim]) for z_dim in args.z_dims] eps = tf.random.normal([args.batch_size, args.eps_dim]) # counter step_cnt, _ = tl.counter() # optimizer optimizer = tf.train.AdamOptimizer(lr, beta1=args.beta_1) def graph_per_gpu(x_r, zs, eps): # generate x_f = G(zs, eps) # discriminate x_r_logit = D(x_r) x_f_logit = D(x_f) # loss x_r_loss, x_f_loss = d_loss_fn(x_r_logit, x_f_logit) x_gp = tf.cond( tf.equal(step_cnt % args.d_lazy_reg_period, 0), lambda: tfprob.gradient_penalty( D, x_r, x_f, args.gradient_penalty_mode, args. gradient_penalty_sample_mode) * args.d_lazy_reg_period, lambda: tf.constant(0.0)) if args.d_loss_weight_x_gp == 0: x_gp = tf.constant(0.0) reg_loss = tf.reduce_sum(D.func.reg_losses) loss = ((x_r_loss + x_f_loss) * args.d_loss_weight_x_gan + x_gp * args.d_loss_weight_x_gp + reg_loss * args.weight_decay) # optim grads = optimizer.compute_gradients( loss, var_list=D.func.trainable_variables) return grads, x_r_loss, x_f_loss, x_gp, reg_loss split_grads, split_x_r_loss, split_x_f_loss, split_x_gp, split_reg_loss = zip( *tl.parellel_run(tl.gpus(), graph_per_gpu, tl.split_nest((x_r, zs, eps), len(tl.gpus())))) # split_grads, split_x_r_loss, split_x_f_loss, split_x_gp, split_reg_loss = zip(*tl.parellel_run(['cpu:0'], graph_per_gpu, tl.split_nest((x_r, zs, eps), 1))) grads = tl.average_gradients(split_grads) x_r_loss, x_f_loss, x_gp, reg_loss = [ tf.reduce_mean(t) for t in [split_x_r_loss, split_x_f_loss, split_x_gp, split_reg_loss] ] step = optimizer.apply_gradients(grads, global_step=step_cnt) # summary summary = tl.create_summary_statistic_v2( { 'x_gan_loss': x_r_loss + x_f_loss, 'x_gp': x_gp, 'reg_loss': reg_loss, 'lr': lr }, './output/%s/summaries/D' % args.experiment_name, step=step_cnt, n_steps_per_record=10, name='D') # ====================================== # = run function = # ====================================== def run(**pl_ipts): for _ in range(args.n_d): sess.run([step, summary], feed_dict={lr: pl_ipts['lr']}) return run