xa_sample = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3]) _b_sample = tf.placeholder(tf.float32, shape=[None, n_att]) # sample x_sample = Gdec(Genc(xa_sample, is_training=False), _b_sample, is_training=False) # ============================================================================== # = test = # ============================================================================== # initialization ckpt_dir = './output/%s/checkpoints' % experiment_name try: tl.load_checkpoint(ckpt_dir, sess) except: raise Exception(' [*] No checkpoint!') # sample try: # print(te_data) # for idx, batch in enumerate(te_data): # print(idx) # print(batch) for idx, batch in enumerate(te_data): xa_sample_ipt = batch[0] a_sample_ipt = batch[1] b_sample_ipt_list = [a_sample_ipt] # the first is for reconstruction for i in range(len(atts)): tmp = np.array(a_sample_ipt, copy=True)
# with tf.device('/gpu:%d' % gpu_id): ''' models ''' classifier = models.classifier ''' graph ''' # inputs x = tf.placeholder(tf.float32, shape=[None, 128, 128, 3]) # classify logits = classifier(x, reuse=False, training=False) pred = tf.cast(tf.round(tf.nn.sigmoid(logits)), tf.int64) """ train """ ''' init ''' # session sess = tl.session() ''' initialization ''' tl.load_checkpoint(ckpt_file, sess) ''' train ''' try: img_paths = glob(os.path.join(img_dir, '*.jpg')) img_paths.sort() cnt = np.zeros([len(att_id)]) err_cnt = np.zeros([len(att_id)]) err_each_cnt = np.zeros([len(att_id), len(att_id)]) for img_path in img_paths: imgs = im.imread(img_path) # imgs = im.resize(imgs, (128, 128)) print(imgs.shape) # imgs = np.concatenate([imgs[:, :img_size, :], imgs[:, img_size+img_size//10:, :]], axis=1) imgs = np.expand_dims(imgs, axis=0) # imgs = np.concatenate(np.split(imgs, 15, axis=2))
test_summary = tl.summary({mean_acc: 'test_acc'}) """ train """ ''' init ''' # session sess = tf.Session() # iteration counter it_cnt, update_cnt = tl.counter() # saver saver = tf.train.Saver(max_to_keep=None) # summary writer summary_writer = tf.summary.FileWriter('./summaries', sess.graph) ''' initialization ''' ckpt_dir = './checkpoints' if not os.path.exists(ckpt_dir): os.mkdir(ckpt_dir + '/') if not tl.load_checkpoint(ckpt_dir, sess): sess.run(tf.global_variables_initializer()) ''' train ''' try: batch_epoch = len(train_data_pool) // batch_size max_it = epoch_ * batch_epoch for it in range(sess.run(it_cnt), max_it): bth = it // batch_epoch - 8 lr__ = lr * (1 - max(bth, 0) / epoch_)**0.75 if it % batch_epoch == 0: print('======learning rate:', lr__, '======') sess.run(update_cnt) # which epoch epoch = it // batch_epoch it_epoch = it % batch_epoch + 1
def train(self): it_cnt, update_cnt = tl.counter() # saver saver = tf.train.Saver(max_to_keep=10) # summary writer summary_writer = tf.summary.FileWriter(self.config["projectSummary"], self.sess.graph) # initialization ckpt_dir = self.config["projectCheckpoints"] epoch = self.config["totalEpoch"] n_d = self.config["dStep"] atts = self.config["selectedAttrs"] thres_int = self.config["thresInt"] test_int = self.config["sampleThresInt"] n_sample = self.config["sampleNum"] img_size = self.config["imsize"] sample_freq = self.config["sampleEpoch"] save_freq = self.config["modelSaveEpoch"] lr_base = self.config["gLr"] lrDecayEpoch = self.config["lrDecayEpoch"] try: assert clear == False tl.load_checkpoint(ckpt_dir, self.sess) except: print('NOTE: Initializing all parameters...') self.sess.run(tf.global_variables_initializer()) # train try: # data for sampling xa_sample_ipt, a_sample_ipt = self.val_loader.get_next() b_sample_ipt_list = [a_sample_ipt ] # the first is for reconstruction for i in range(len(atts)): tmp = np.array(a_sample_ipt, copy=True) tmp[:, i] = 1 - tmp[:, i] # inverse attribute tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts) b_sample_ipt_list.append(tmp) it_per_epoch = len(self.data_loader) // (self.config["batchSize"] * (n_d + 1)) max_it = epoch * it_per_epoch for it in range(self.sess.run(it_cnt), max_it): with pylib.Timer(is_output=False) as t: self.sess.run(update_cnt) # which epoch epoch = it // it_per_epoch it_in_epoch = it % it_per_epoch + 1 # learning rate lr_ipt = lr_base / (10**(epoch // lrDecayEpoch)) # train D for i in range(n_d): d_summary_opt, _ = self.sess.run( [self.d_summary, self.d_step], feed_dict={self.lr: lr_ipt}) summary_writer.add_summary(d_summary_opt, it) # train G g_summary_opt, _ = self.sess.run( [self.g_summary, self.g_step], feed_dict={self.lr: lr_ipt}) summary_writer.add_summary(g_summary_opt, it) # display if (it + 1) % 100 == 0: print("Epoch: (%3d) (%5d/%5d) Time: %s!" % (epoch, it_in_epoch, it_per_epoch, t)) # save if (it + 1) % (save_freq if save_freq else it_per_epoch) == 0: save_path = saver.save( self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) # sample if (it + 1) % (sample_freq if sample_freq else it_per_epoch) == 0: x_sample_opt_list = [ xa_sample_ipt, np.full((n_sample, img_size, img_size // 10, 3), -1.0) ] raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 - 1) * thres_int for i, b_sample_ipt in enumerate(b_sample_ipt_list): _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int if i > 0: # i == 0 is for reconstruction _b_sample_ipt[..., i - 1] = _b_sample_ipt[ ..., i - 1] * test_int / thres_int x_sample_opt_list.append( self.sess.run(self.x_sample, feed_dict={ self.xa_sample: xa_sample_ipt, self._b_sample: _b_sample_ipt, self.raw_b_sample: raw_b_sample_ipt })) last_images = x_sample_opt_list[-1] if i > 0: # add a mark (+/-) in the upper-left corner to identify add/remove an attribute for nnn in range(last_images.shape[0]): last_images[nnn, 2:5, 0:7, :] = 1. if _b_sample_ipt[nnn, i - 1] > 0: last_images[nnn, 0:7, 2:5, :] = 1. last_images[nnn, 1:6, 3:4, :] = -1. last_images[nnn, 3:4, 1:6, :] = -1. sample = np.concatenate(x_sample_opt_list, 2) im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \ (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch)) except: traceback.print_exc() finally: save_path = saver.save( self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) self.sess.close()
def train(self): ckpt_dir = self.config["projectCheckpoints"] epoch = self.config["totalEpoch"] n_d = self.config["dStep"] atts = self.config["selectedAttrs"] thres_int = self.config["thresInt"] test_int = self.config["sampleThresInt"] n_sample = self.config["sampleNum"] img_size = self.config["imsize"] sample_freq = self.config["sampleEpoch"] save_freq = self.config["modelSaveEpoch"] lr_base = self.config["gLr"] lrDecayEpoch = self.config["lrDecayEpoch"] n_att = len(self.config["selectedAttrs"]) if self.config["threads"] >= 0: cpu_config = tf.ConfigProto( intra_op_parallelism_threads=self.config["threads"] // 2, inter_op_parallelism_threads=self.config["threads"] // 2, device_count={'CPU': self.config["threads"]}) cpu_config.gpu_options.allow_growth = True sess = tf.Session(config=cpu_config) else: sess = tl.session() data_loader = Celeba(self.config["dataset_path"], self.config["selectedAttrs"], self.config["imsize"], self.config["batchSize"], part='train', sess=sess, crop=(self.config["imCropSize"] > 0)) val_loader = Celeba(self.config["dataset_path"], self.config["selectedAttrs"], self.config["imsize"], self.config["sampleNum"], part='val', shuffle=False, sess=sess, crop=(self.config["imCropSize"] > 0)) package = __import__("components." + self.config["modelScriptName"], fromlist=True) GencClass = getattr(package, 'Genc') GdecClass = getattr(package, 'Gdec') DClass = getattr(package, 'D') GP = getattr(package, "gradient_penalty") package = __import__("components.STU." + self.config["stuScriptName"], fromlist=True) GstuClass = getattr(package, 'Gstu') Genc = partial(GencClass, dim=self.config["GConvDim"], n_layers=self.config["GLayerNum"], multi_inputs=1) Gdec = partial(GdecClass, dim=self.config["GConvDim"], n_layers=self.config["GLayerNum"], shortcut_layers=self.config["skipNum"], inject_layers=self.config["injectLayers"], one_more_conv=self.config["oneMoreConv"]) Gstu = partial(GstuClass, dim=self.config["stuDim"], n_layers=self.config["skipNum"], inject_layers=self.config["skipNum"], kernel_size=self.config["stuKS"], norm=None, pass_state='stu') D = partial(DClass, n_att=n_att, dim=self.config["DConvDim"], fc_dim=self.config["DFcDim"], n_layers=self.config["DLayerNum"]) # inputs xa = data_loader.batch_op[0] a = data_loader.batch_op[1] b = tf.random_shuffle(a) _a = (tf.to_float(a) * 2 - 1) * self.config["thresInt"] _b = (tf.to_float(b) * 2 - 1) * self.config["thresInt"] xa_sample = tf.placeholder( tf.float32, shape=[None, self.config["imsize"], self.config["imsize"], 3]) _b_sample = tf.placeholder(tf.float32, shape=[None, n_att]) raw_b_sample = tf.placeholder(tf.float32, shape=[None, n_att]) lr = tf.placeholder(tf.float32, shape=[]) # generate z = Genc(xa) zb = Gstu(z, _b - _a) xb_ = Gdec(zb, _b - _a) with tf.control_dependencies([xb_]): za = Gstu(z, _a - _a) xa_ = Gdec(za, _a - _a) # discriminate xa_logit_gan, xa_logit_att = D(xa) xb__logit_gan, xb__logit_att = D(xb_) wd = tf.reduce_mean(xa_logit_gan) - tf.reduce_mean(xb__logit_gan) d_loss_gan = -wd gp = GP(D, xa, xb_) xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att) d_loss = d_loss_gan + gp * 10.0 + xa_loss_att xb__loss_gan = -tf.reduce_mean(xb__logit_gan) xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att) xa__loss_rec = tf.losses.absolute_difference(xa, xa_) g_loss = xb__loss_gan + xb__loss_att * 10.0 + xa__loss_rec * self.config[ "recWeight"] d_var = tl.trainable_variables('D') d_step = tf.train.AdamOptimizer( lr, beta1=self.config["beta1"]).minimize(d_loss, var_list=d_var) g_var = tl.trainable_variables('G') g_step = tf.train.AdamOptimizer( lr, beta1=self.config["beta1"]).minimize(g_loss, var_list=g_var) d_summary = tl.summary( { d_loss_gan: 'd_loss_gan', gp: 'gp', xa_loss_att: 'xa_loss_att', }, scope='D') lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate') g_summary = tl.summary( { xb__loss_gan: 'xb__loss_gan', xb__loss_att: 'xb__loss_att', xa__loss_rec: 'xa__loss_rec', }, scope='G') d_summary = tf.summary.merge([d_summary, lr_summary]) # sample test_label = _b_sample - raw_b_sample x_sample = Gdec(Gstu(Genc(xa_sample, is_training=False), test_label, is_training=False), test_label, is_training=False) it_cnt, update_cnt = tl.counter() # saver saver = tf.train.Saver(max_to_keep=self.config["max2Keep"]) # summary writer summary_writer = tf.summary.FileWriter(self.config["projectSummary"], sess.graph) # initialization if self.config["mode"] == "finetune": print("Continute train the model") tl.load_checkpoint(ckpt_dir, sess) print("Load previous model successfully!") else: print('Initializing all parameters...') sess.run(tf.global_variables_initializer()) # train try: # data for sampling xa_sample_ipt, a_sample_ipt = val_loader.get_next() b_sample_ipt_list = [a_sample_ipt ] # the first is for reconstruction for i in range(len(atts)): tmp = np.array(a_sample_ipt, copy=True) tmp[:, i] = 1 - tmp[:, i] # inverse attribute tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts) b_sample_ipt_list.append(tmp) it_per_epoch = len(data_loader) // (self.config["batchSize"] * (n_d + 1)) max_it = epoch * it_per_epoch print("Start to train the graph!") for it in range(sess.run(it_cnt), max_it): with pylib.Timer(is_output=False) as t: sess.run(update_cnt) # which epoch epoch = it // it_per_epoch it_in_epoch = it % it_per_epoch + 1 # learning rate lr_ipt = lr_base / (10**(epoch // lrDecayEpoch)) # train D for i in range(n_d): d_summary_opt, _ = sess.run([d_summary, d_step], feed_dict={lr: lr_ipt}) summary_writer.add_summary(d_summary_opt, it) # train G g_summary_opt, _ = sess.run([g_summary, g_step], feed_dict={lr: lr_ipt}) summary_writer.add_summary(g_summary_opt, it) # display if (it + 1) % 100 == 0: print("Epoch: (%3d) (%5d/%5d) Time: %s!" % (epoch, it_in_epoch, it_per_epoch, t)) # save if (it + 1) % (save_freq if save_freq else it_per_epoch) == 0: save_path = saver.save( sess, '%s/Epoch_(%d).ckpt' % (ckpt_dir, epoch)) print('Model is saved at %s!' % save_path) # sample if (it + 1) % (sample_freq if sample_freq else it_per_epoch) == 0: x_sample_opt_list = [ xa_sample_ipt, np.full((n_sample, img_size, img_size // 10, 3), -1.0) ] raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 - 1) * thres_int for i, b_sample_ipt in enumerate(b_sample_ipt_list): _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int if i > 0: # i == 0 is for reconstruction _b_sample_ipt[..., i - 1] = _b_sample_ipt[ ..., i - 1] * test_int / thres_int x_sample_opt_list.append( sess.run(x_sample, feed_dict={ xa_sample: xa_sample_ipt, _b_sample: _b_sample_ipt, raw_b_sample: raw_b_sample_ipt })) last_images = x_sample_opt_list[-1] if i > 0: # add a mark (+/-) in the upper-left corner to identify add/remove an attribute for nnn in range(last_images.shape[0]): last_images[nnn, 2:5, 0:7, :] = 1. if _b_sample_ipt[nnn, i - 1] > 0: last_images[nnn, 0:7, 2:5, :] = 1. last_images[nnn, 1:6, 3:4, :] = -1. last_images[nnn, 3:4, 1:6, :] = -1. sample = np.concatenate(x_sample_opt_list, 2) im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \ (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch)) except: traceback.print_exc() finally: save_path = saver.save( sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_in_epoch, it_per_epoch)) print('Model is saved at %s!' % save_path) sess.close()
def model_initialization(): """Model creation and weight load. Load of several parameters found in the pretrained STGAN model: https://drive.google.com/open?id=1329IbLE6877DcDUut1reKxckijBJye7N. Returns: sess (TF Session): Current session for inference. x_sample (tfTensor): Tensor of shape (n_img, 128, 128, 3). xa_sample (tfTensor): Input tensor of shape (n_img, 128, 128, 3). _b_sample (tfTensor): Label tensor of shape (n_img, 13). raw_b_sample (tfTensor): Label tensor of shape (n_img, 13). """ with open('./model/setting.txt') as f: args = json.load(f) atts = args['atts'] n_atts = len(atts) img_size = args['img_size'] shortcut_layers = args['shortcut_layers'] inject_layers = args['inject_layers'] enc_dim = args['enc_dim'] dec_dim = args['dec_dim'] dis_dim = args['dis_dim'] dis_fc_dim = args['dis_fc_dim'] enc_layers = args['enc_layers'] dec_layers = args['dec_layers'] dis_layers = args['dis_layers'] label = args['label'] use_stu = args['use_stu'] stu_dim = args['stu_dim'] stu_layers = args['stu_layers'] stu_inject_layers = args['stu_inject_layers'] stu_kernel_size = args['stu_kernel_size'] stu_norm = args['stu_norm'] stu_state = args['stu_state'] multi_inputs = args['multi_inputs'] rec_loss_weight = args['rec_loss_weight'] one_more_conv = args['one_more_conv'] sess = tl.session() # Models Genc = partial(models.Genc, dim=enc_dim, n_layers=enc_layers, multi_inputs=multi_inputs) Gdec = partial(models.Gdec, dim=dec_dim, n_layers=dec_layers, shortcut_layers=shortcut_layers, inject_layers=inject_layers, one_more_conv=one_more_conv) Gstu = partial(models.Gstu, dim=stu_dim, n_layers=stu_layers, inject_layers=stu_inject_layers, kernel_size=stu_kernel_size, norm=stu_norm, pass_state=stu_state) # Inputs xa_sample = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3]) _b_sample = tf.placeholder(tf.float32, shape=[None, n_atts]) raw_b_sample = tf.placeholder(tf.float32, shape=[None, n_atts]) # Sample test_label = _b_sample - raw_b_sample if label == 'diff' else _b_sample if use_stu: x_sample = Gdec(Gstu(Genc(xa_sample, is_training=False), test_label, is_training=False), test_label, is_training=False) else: x_sample = Gdec(Genc(xa_sample, is_training=False), test_label, is_training=False) # Initialization ckpt_dir = './model/checkpoints' tl.load_checkpoint(ckpt_dir, sess) return sess, x_sample, xa_sample, _b_sample, raw_b_sample
def train_on_fake(dataset, c_dim, result_dir, gpu_id, use_real=0, epoch_=200): """ param """ batch_size = 64 batch_size_fake = batch_size lr = 0.0002 ''' data ''' if use_real == 1: print('======Using real data======') batch_size_real = batch_size // 2 batch_size_fake = batch_size - batch_size_real train_tfrecord_path_real = './tfrecords/celeba_tfrecord_train' train_data_pool_real = tl.TfrecordData(train_tfrecord_path_real, batch_size_real, shuffle=True) train_tfrecord_path_fake = os.path.join(result_dir, 'synthetic_tfrecord') train_data_pool_fake = tl.TfrecordData(train_tfrecord_path_fake, batch_size_fake, shuffle=True) if dataset == 'CelebA': test_tfrecord_path = './tfrecords/celeba_tfrecord_test' elif dataset == 'RaFD': test_tfrecord_path = './tfrecords/rafd_test' test_data_pool = tl.TfrecordData(test_tfrecord_path, 120) att_dim = c_dim """ graphs """ with tf.device('/gpu:{}'.format(gpu_id)): ''' models ''' classifier = models.classifier ''' graph ''' # inputs x_255 = tf.placeholder(tf.float32, shape=[None, 128, 128, 3]) x = x_255 / 127.5 - 1 if dataset == 'CelebA': att = tf.placeholder(tf.int64, shape=[None, att_dim]) elif dataset == 'RaFD': att = tf.placeholder(tf.float32, shape=[None, att_dim]) # classify logits = classifier(x, att_dim=att_dim, reuse=False) # loss reg_loss = tf.losses.get_regularization_loss() if dataset == 'CelebA': loss = tf.losses.sigmoid_cross_entropy(att, logits) + reg_loss acc = mean_accuracy_multi_binary_label_with_logits(att, logits) elif dataset == 'RaFD': loss = tf.losses.softmax_cross_entropy(att, logits) + reg_loss acc = mean_accuracy_one_hot_label_with_logits(att, logits) lr_ = tf.placeholder(tf.float32, shape=[]) # optim #with tf.variable_scope('Adam', reuse=tf.AUTO_REUSE): step = tf.train.AdamOptimizer(lr_, beta1=0.9).minimize(loss) # test test_logits = classifier(x, att_dim=att_dim, training=False) if dataset == 'CelebA': test_acc = mean_accuracy_multi_binary_label_with_logits(att, test_logits) elif dataset == 'RaFD': test_acc = mean_accuracy_one_hot_label_with_logits(att, test_logits) mean_acc = tf.placeholder(tf.float32, shape=()) # summary summary = tl.summary({loss: 'loss', acc: 'acc'}) test_summary = tl.summary({mean_acc: 'test_acc'}) """ train """ ''' init ''' # session sess = tf.Session() # iteration counter it_cnt, update_cnt = tl.counter() # saver saver = tf.train.Saver(max_to_keep=None) # summary writer sum_dir = os.path.join(result_dir, 'summaries_train_on_fake') if use_real == 1: sum_dir += '_real' summary_writer = tf.summary.FileWriter(sum_dir, sess.graph) ''' initialization ''' ckpt_dir = os.path.join(result_dir, 'checkpoints_train_on_fake') if use_real == 1: ckpt_dir += '_real' if not os.path.exists(ckpt_dir): os.mkdir(ckpt_dir + '/') if not tl.load_checkpoint(ckpt_dir, sess): sess.run(tf.global_variables_initializer()) ''' train ''' try: batch_epoch = len(train_data_pool_fake) // batch_size max_it = epoch_ * batch_epoch for it in range(sess.run(it_cnt), max_it): bth = it//batch_epoch - 8 lr__ = lr*(1-max(bth, 0)/epoch_)**0.75 if it % batch_epoch == 0: print('======learning rate:', lr__, '======') sess.run(update_cnt) # which epoch epoch = it // batch_epoch it_epoch = it % batch_epoch + 1 x_255_ipt, att_ipt = train_data_pool_fake.batch(['img', 'attr']) if dataset == 'RaFD': att_ipt = ToOnehot(att_ipt, att_dim) if use_real == 1: x_255_ipt_real, att_ipt_real = train_data_pool_real.batch(['img', 'class']) x_255_ipt = np.concatenate([x_255_ipt, x_255_ipt_real]) att_ipt = np.concatenate([att_ipt, att_ipt_real]) summary_opt, _ = sess.run([summary, step], feed_dict={x_255: x_255_ipt, att: att_ipt, lr_:lr__}) summary_writer.add_summary(summary_opt, it) # display if (it + 1) % batch_epoch == 0: print("Epoch: (%3d) (%5d/%5d)" % (epoch, it_epoch, batch_epoch)) # save if (it + 1) % (batch_epoch * 50) == 0: save_path = saver.save(sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_epoch, batch_epoch)) print('Model saved in file: % s' % save_path) # sample if it % 100 == 0: test_it = 100 if dataset == 'CelebA' else 7 test_acc_opt_list = [] for i in range(test_it): key = 'class' if dataset == 'CelebA' else 'attr' x_255_ipt, att_ipt = test_data_pool.batch(['img', key]) if dataset == 'RaFD': att_ipt = ToOnehot(att_ipt, att_dim) test_acc_opt = sess.run(test_acc, feed_dict={x_255: x_255_ipt, att: att_ipt}) test_acc_opt_list.append(test_acc_opt) test_summary_opt = sess.run(test_summary, feed_dict={mean_acc: np.mean(test_acc_opt_list)}) summary_writer.add_summary(test_summary_opt, it) except Exception: traceback.print_exc() finally: print(" [*] Close main session!") sess.close()
def attr_cls(self): """Computes the GAN-test attribute classification accuracy.""" # Load the trained generator. self.restore_model(self.test_iters) # Set data loader. data_loader = self.labelled_loader attr_list = [] if self.dataset == 'CelebA': ckpt_file = './checkpoints_train_on_real/CelebA/Epoch_(127)_(2543of2543).ckpt' attr_list = self.selected_attrs n_print = 2000 elif self.dataset == 'RaFD': ckpt_file = './checkpoints_train_on_real/RaFD/Epoch_(199)_(112of112).ckpt' attr_list = self.selected_emots n_print = 200 classifier = models.classifier # Classifier graph x = tf.placeholder(tf.float32, shape=[None, 128, 128, 3]) logits = classifier(x, att_dim=len(attr_list), reuse=False, training=False) if self.dataset == 'CelebA': pred_s = tf.cast(tf.nn.sigmoid(logits), tf.float64) elif self.dataset == 'RaFD': pred_s = tf.cast(tf.nn.softmax(logits), tf.float64) cnt_pos = np.zeros([self.c_dim]).astype(np.int64) cnt_neg = np.zeros([self.c_dim]).astype(np.int64) cnt_rec = np.zeros([self.c_dim]).astype(np.int64) c_pos = np.zeros([self.c_dim]) c_neg = np.zeros([self.c_dim]) c_rec = np.zeros([self.c_dim]) ca_req = np.zeros([self.c_dim]).astype(np.int64) cr_req = np.zeros([self.c_dim]).astype(np.int64) co_req = np.zeros([self.c_dim]).astype(np.int64) with torch.no_grad(): with tl.session() as sess: tl.load_checkpoint(ckpt_file, sess) attr_list = ['Reconstruction'] + attr_list total_count = 0 for i, (x_real, c_org) in enumerate(data_loader): if self.dataset == 'RaFD': c_org = self.label2onehot(c_org, self.c_dim) # Prepare input images and target domain labels. x_real = x_real.to(self.device) c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) c_trg_batch = torch.cat( [c.unsqueeze(1) for c in c_trg_list], dim=1).cpu().numpy() c_trg_list = [None] + [c_org.to(self.device)] + c_trg_list att_gt_batch = c_org.numpy() # Classify translate images. pred_score_list = [] preds_list = [] for j, c_trg in enumerate(c_trg_list): if j == 0: feed = np.transpose(x_real.cpu().numpy(), [0, 2, 3, 1]) else: x_fake = self.G(x_real, c_trg) feed = np.transpose(x_fake.cpu().numpy(), [0, 2, 3, 1]) pred_score = sess.run(pred_s, feed_dict={x: feed}) pred_score_list.append( np.expand_dims(pred_score, axis=1)) if self.dataset == 'CelebA': preds = np.round(pred_score).astype(int) elif self.dataset == 'RaFD': max_id = np.argmax(pred_score, axis=1) preds = np.zeros_like(pred_score).astype(int) preds[np.arange(pred_score.shape[0]), max_id] = 1 preds_list.append(np.expand_dims(preds, axis=1)) pred_score_batch = np.concatenate(pred_score_list, axis=1) preds_opt_batch = np.concatenate(preds_list, axis=1) # Calculate accuracy. for pred_score, preds_opt, att_gt, c_trg in zip( pred_score_batch, preds_opt_batch, att_gt_batch, c_trg_batch): for k in range(2, len(preds_opt)): if c_trg[k - 2, k - 2] == 1 - att_gt[k - 2]: if att_gt[k - 2] == 0: ca_req[k - 2] += 1 elif att_gt[k - 2] == 1: cr_req[k - 2] += 1 if preds_opt[k, k - 2] == 1 - att_gt[k - 2]: if preds_opt[k, k - 2] == 1: cnt_pos[k - 2] += 1 c_pos[k - 2] += pred_score[k, k - 2] elif preds_opt[k, k - 2] == 0: cnt_neg[k - 2] += 1 c_neg[k - 2] += 1 - pred_score[k, k - 2] else: co_req[k - 2] += 1 if preds_opt[k, k - 2] == att_gt[k - 2]: cnt_rec[k - 2] += 1 if preds_opt[k, k - 2] == 1: c_rec[k - 2] += pred_score[k, k - 2] elif preds_opt[k, k - 2] == 0: c_rec[k - 2] += 1 - pred_score[k, k - 2] total_count += x_real.shape[0] if total_count % n_print == 0: print('{} images classified.'.format(total_count)) print('\tAcc. Addition') print('\t', cnt_pos / ca_req) print('\t', np.mean(cnt_pos / ca_req)) attr_cls_path = os.path.join(self.result_dir, 'GAN-test.txt') with open(attr_cls_path, 'w') as f: f.write('Overall accuracy,{},average,{}\n'.format( arr_2_str((cnt_pos + cnt_neg + cnt_rec) / (ca_req + cr_req + co_req)), arr_2_str( np.mean((cnt_pos + cnt_neg + cnt_rec) / (ca_req + cr_req + co_req))))) print('GAN-test accuracy: {}'.format( arr_2_str( np.mean((cnt_pos + cnt_neg + cnt_rec) / (ca_req + cr_req + co_req)))))
def test_train_on_fake(dataset, c_dim, result_dir, gpu_id, epoch_=200): img_size = 128 ''' data ''' if dataset == 'CelebA': ckpt_file = 'checkpoints_train_on_fake/Epoch_({})_(2513of2513).ckpt'.format( epoch_ - 1) test_tfrecord_path = './tfrecords_test/celeba_tfrecord_test' test_data_pool = tl.TfrecordData(test_tfrecord_path, 18, shuffle=False) elif dataset == 'RaFD': ckpt_file = 'checkpoints_train_on_fake/Epoch_({})_(112of112).ckpt'.format( epoch_ - 1) test_tfrecord_path = './tfrecords_test/rafd_test' test_data_pool = tl.TfrecordData(test_tfrecord_path, 120, shuffle=False) ckpt_file = os.path.join(result_dir, ckpt_file) """ graphs """ with tf.device('/gpu:{}'.format(gpu_id)): ''' models ''' classifier = models.classifier ''' graph ''' # inputs x_255 = tf.placeholder(tf.float32, shape=[None, 128, 128, 3]) x = x_255 / 127.5 - 1 if dataset == 'CelebA': label = tf.placeholder(tf.int64, shape=[None, c_dim]) elif dataset == 'RaFD': label = tf.placeholder(tf.float32, shape=[None, c_dim]) # classify logits = classifier(x, att_dim=c_dim, reuse=False, training=False) if dataset == 'CelebA': accuracy = mean_accuracy_multi_binary_label_with_logits( label, logits) elif dataset == 'RaFD': accuracy = mean_accuracy_one_hot_label_with_logits(label, logits) """ train """ ''' init ''' # session sess = tl.session() ''' initialization ''' tl.load_checkpoint(ckpt_file, sess) ''' train ''' try: all_accuracies = [] denom = 18 if dataset == 'CelebA' else 120 key = 'class' if dataset == 'CelebA' else 'attr' test_iter = len(test_data_pool) // denom for iter in range(test_iter): img, label_gt = test_data_pool.batch(['img', key]) if dataset == 'RaFD': label_gt = ToOnehot(label_gt, c_dim) print('Test batch {}'.format(iter), end='\r') batch_accuracy = sess.run(accuracy, feed_dict={ x_255: img, label: label_gt }) all_accuracies.append(batch_accuracy) if dataset == 'CelebA': mean_accuracies = np.mean(np.concatenate(all_accuracies), axis=0) mean_accuracy = np.mean(mean_accuracies) print('\nIndividual accuracies: {} Average: {:.4f}'.format( mean_accuracies, mean_accuracy)) with open(os.path.join(result_dir, 'GAN_train.txt'), 'w') as f: for attr, acc in zip([ 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young' ], mean_accuracies): f.write('{}: {}\n'.format(attr, acc)) f.write('Average: {}'.format(mean_accuracy)) elif dataset == 'RaFD': mean_accuracy = np.mean(all_accuracies) print('\nAverage accuracies: {:.4f}'.format(mean_accuracy)) with open(os.path.join(result_dir, 'GAN_train.txt'), 'w') as f: f.write('Average accuracy: {}'.format(mean_accuracy)) except Exception: traceback.print_exc() finally: print(" [*] Close main session!") sess.close()
def runModel(image_url, file_name, test_att, n_slide, image_labels, model_type): # ============================================================================== # = param = # ============================================================================== parser = argparse.ArgumentParser() parser.add_argument('--experiment_name', dest='experiment_name', default="384_shortcut1_inject1_none_hd", help='experiment_name') parser.add_argument('--test_att', dest='test_att', help='test_att') parser.add_argument('--test_int_min', dest='test_int_min', type=float, default=-1.0, help='test_int_min') parser.add_argument('--test_int_max', dest='test_int_max', type=float, default=1.0, help='test_int_max') args_ = parser.parse_args() if model_type == 0: experiment_name = args_.experiment_name else: experiment_name = "128_custom" print("EXPERIMENT NAME WORKING:" + experiment_name) with open('./output/%s/setting.txt' % experiment_name) as f: args = json.load(f) # model atts = args['atts'] n_att = len(atts) img_size = args['img_size'] shortcut_layers = args['shortcut_layers'] inject_layers = args['inject_layers'] enc_dim = args['enc_dim'] dec_dim = args['dec_dim'] dis_dim = args['dis_dim'] dis_fc_dim = args['dis_fc_dim'] enc_layers = args['enc_layers'] dec_layers = args['dec_layers'] dis_layers = args['dis_layers'] # testing thres_int = args['thres_int'] test_int_min = args_.test_int_min test_int_max = args_.test_int_max # others use_cropped_img = args['use_cropped_img'] n_slide = int(n_slide) assert test_att is not None, 'test_att should be chosen in %s' % ( str(atts)) # ============================================================================== # = graphs = # ============================================================================== # data sess = tl.session() # get image print(image_url) if experiment_name == "128_custom": os.system( "wget -P /home/tug44606/AttGAN-Tensorflow-master/data/img_align_celeba " + image_url) else: os.system( "wget -P /home/tug44606/AttGAN-Tensorflow-master/data/img_crop_celeba " + image_url) print("Working") # pass image with labels to dataset te_data = data.Celeba('./data', atts, img_size, 1, part='val', sess=sess, crop=not use_cropped_img, image_labels=image_labels, file_name=file_name) sample = None # models Genc = partial(models.Genc, dim=enc_dim, n_layers=enc_layers) Gdec = partial(models.Gdec, dim=dec_dim, n_layers=dec_layers, shortcut_layers=shortcut_layers, inject_layers=inject_layers) # inputs xa_sample = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3]) _b_sample = tf.placeholder(tf.float32, shape=[None, n_att]) # sample x_sample = Gdec(Genc(xa_sample, is_training=False), _b_sample, is_training=False) # ============================================================================== # = test = # ============================================================================== # initialization ckpt_dir = './output/%s/checkpoints' % experiment_name print("CHECKPOINT DIR: " + ckpt_dir) try: tl.load_checkpoint(ckpt_dir, sess) except: raise Exception(' [*] No checkpoint!') save_location = "" # sample try: for idx, batch in enumerate(te_data): xa_sample_ipt = batch[0] b_sample_ipt = batch[1] x_sample_opt_list = [] for i in range(n_slide - 1, n_slide): test_int = (test_int_max - test_int_min) / (n_slide - 1) * i + test_int_min _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int _b_sample_ipt[..., atts.index(test_att)] = test_int x_sample_opt_list.append( sess.run(x_sample, feed_dict={ xa_sample: xa_sample_ipt, _b_sample: _b_sample_ipt })) sample = np.concatenate(x_sample_opt_list, 2) save_location = '/output/%s/sample_testing_slide_%s/' % ( experiment_name, test_att) save_dir = './output/%s/sample_testing_slide_%s' % ( experiment_name, test_att) pylib.mkdir(save_dir) im.imwrite(sample.squeeze(0), '%s/%s' % (save_dir, file_name)) print('%d.png done!' % (idx + 0)) if (idx + 1 == te_data._img_num): break except: traceback.print_exc() finally: sess.close() if experiment_name == "128_custom": os.system("rm ./data/img_align_celeba/" + file_name) else: os.system("rm ./data/img_crop_celeba/" + file_name) return "http://129.32.22.10:7001" + save_location + file_name