def __init__(self, sess: tf.Session, model: ConditionalGan, dataset: TextDataset, cfg): self.sess = sess self.model = model self.Retrieval = Retrieval(cfg) self.dataset = dataset self.test_data_loader = DatasetLoader(cfg.RETRIEVAL.DATA_PATH, mode='test') self.cfg = cfg self.bs = self.cfg.EVAL.SAMPLE_SIZE
def __init__(self, sess: tf.Session, model: ConditionalGan, dataset: TextDataset, cfg, cfg_stage_i): self.sess = sess self.model = model self.dataset = dataset self.R_loader = DatasetLoader(cfg.RETRIEVAL.DATA_PATH, mode='train') self.test_data_loader = DatasetLoader(cfg.RETRIEVAL.DATA_PATH, mode='test') self.Retrieval = Retrieval(cfg) self.cfg = cfg self.cfg_stage_i = cfg_stage_i self.lr = self.cfg.TRAIN.D_LR
def __init__(self, cfg, batch_size, steps, check_dir_write, check_dir_read, dataset, sample_path, log_dir, stage, trans, build_model=True): self.cfg = cfg self.R_loader = DatasetLoader(cfg.RETRIEVAL.DATA_PATH, mode='train') self.test_data_loader = DatasetLoader(cfg.RETRIEVAL.DATA_PATH, mode='test') self.Retrieval = Retrieval(cfg) self.batch_size = batch_size self.steps = steps self.check_dir_write = check_dir_write self.check_dir_read = check_dir_read self.dataset = dataset self.sample_path = sample_path self.log_dir = log_dir self.stage = stage self.trans = trans self.z_dim = 128 # self.embed_dim = 1024 self.embed_dim = 512 self.out_size = 4 * pow(2, stage - 1) self.channel = 3 self.sample_num = 64 self.compr_embed_dim = 128 self.lr = 0.00005 self.lr_inp = self.lr self.output_size = 4 * pow(2, stage - 1) self.dt = tf.Variable(0.0, trainable=False) self.alpha_tra = tf.Variable(initial_value=0.0, trainable=False, name='alpha_tra') if build_model: # self.build_model() self.define_losses() self.define_summaries()
class StageIIEval(object): def __init__(self, sess: tf.Session, model: ConditionalGan, dataset: TextDataset, cfg): self.sess = sess self.model = model self.Retrieval = Retrieval(cfg) self.dataset = dataset self.test_data_loader = DatasetLoader(cfg.RETRIEVAL.DATA_PATH, mode='test') self.cfg = cfg self.bs = self.cfg.EVAL.SAMPLE_SIZE def evaluate_fid(self): incep_batch_size = self.cfg.EVAL.INCEP_BATCH_SIZE _, layers = load_inception_inference(self.sess, 20, incep_batch_size, self.cfg.EVAL.INCEP_CHECKPOINT_DIR) pool3 = layers['PreLogits'] act_op = tf.reshape(pool3, shape=[incep_batch_size, -1]) if not os.path.exists(self.cfg.EVAL.ACT_STAT_PATH): print('Computing activation statistics for real x') fid.compute_and_save_activation_statistics(self.cfg.EVAL.R_IMG_PATH, self.sess, incep_batch_size, act_op, self.cfg.EVAL.ACT_STAT_PATH, verbose=True) print('Loading activation statistics for the real x') stats = np.load(self.cfg.EVAL.ACT_STAT_PATH) mu_real = stats['mu'] sigma_real = stats['sigma'] z = tf.placeholder(tf.float32, [self.bs, self.model.z_dim], name='real_images') cond = tf.placeholder(tf.float32, [self.bs] + [self.model.embed_dim], name='cond') eval_gen, _, _ = self.model.generator(z, cond, reuse=False) saver = tf.train.Saver(tf.global_variables('g_net')) could_load, _ = load(saver, self.sess, self.cfg.CHECKPOINT_DIR) if could_load: print(" [*] Load SUCCESS") else: print(" [!] Load failed...") raise RuntimeError('Could not load the checkpoints of the generator') print('Generating batches...') fid_size = self.cfg.EVAL.SIZE n_batches = fid_size // self.bs w, h, c = self.model.image_dims[0], self.model.image_dims[1], self.model.image_dims[2] # Evaluate each bach on inception dynamically to avoid getting out of memory for i in range(n_batches): sample_z = np.random.normal(0, 1, size=(self.bs, self.model.z_dim)) images, _, embed, _, _ = self.dataset.test.next_batch(self.bs, 4, embeddings=True) samples = denormalize_images(self.sess.run(eval_gen, feed_dict={z: sample_z, cond: embed})) print('Computing activation statistics for generated x...') mu_gen, sigma_gen = fid.calculate_activation_statistics(samples, self.sess, incep_batch_size, act_op, verbose=True) print("calculate FID:", end=" ", flush=True) try: FID = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real) except Exception as e: print(e) FID = 500 print(FID) def evaluate_inception(self): incep_batch_size = self.cfg.EVAL.INCEP_BATCH_SIZE logits, _ = load_inception_inference(self.sess, self.cfg.EVAL.NUM_CLASSES, incep_batch_size, self.cfg.EVAL.INCEP_CHECKPOINT_DIR) pred_op = tf.nn.softmax(logits) z = tf.placeholder(tf.float32, [self.bs, self.model.stagei.z_dim], name='z') cond = tf.placeholder(tf.float32, [self.bs] + [self.model.stagei.embed_dim], name='cond') stagei_gen, _, _ = self.model.stagei.generator(z, cond, reuse=False, is_training=False) eval_gen, _, _ = self.model.generator(stagei_gen, cond, reuse=False, is_training=False) self.Retrieval.eval(self.bs) saver = tf.train.Saver(tf.global_variables('g_net')+tf.global_variables('vf_')+tf.global_variables('sf_')+ tf.global_variables('att')) could_load, _ = load(saver, self.sess, self.model.stagei.cfg.CHECKPOINT_DIR) if could_load: print(" [*] Load SUCCESS") else: print(" [!] Load failed...") raise RuntimeError('Could not load the checkpoints of stage I') saver = tf.train.Saver(tf.global_variables('stageII_g_net')) could_load, _ = load(saver, self.sess, self.cfg.CHECKPOINT_DIR) if could_load: print(" [*] Load SUCCESS") else: print(" [!] Load failed...") raise RuntimeError('Could not load the checkpoints of stage II') print('Generating batches...') size = self.cfg.EVAL.SIZE n_batches = size // self.bs all_preds = [] for i in range(n_batches): print("\rGenerating batch %d/%d" % (i + 1, n_batches), end="", flush=True) sample_z = np.random.normal(0, 1, size=(self.bs, self.model.z_dim)) # _, _, embed, _, _ = self.dataset.test.next_batch(self.bs, 4, embeddings=True) _, _, embed, _, _ = self.dataset.test.next_batch(self.bs, 1, embeddings=True) im_feats, sent_feats, labels = self.test_data_loader.get_batch(i, self.bs, phase = 'incep') # Generate a batch and scale it up for inception sent_emb = self.sess.run(self.Retrieval.sent_embed_tensor, feed_dict={ self.Retrieval.image_placeholder_test: im_feats, self.Retrieval.sent_placeholder_test: sent_feats, }) gen_batch = self.sess.run(eval_gen, feed_dict={z: sample_z, cond: sent_emb}) samples = denormalize_images(gen_batch) incep_samples = np.empty((self.bs, 299, 299, 3)) for sample_idx in range(self.bs): incep_samples[sample_idx] = prep_incep_img(samples[sample_idx]) # Run prediction for current batch pred = self.sess.run(pred_op, feed_dict={'inputs:0': incep_samples}) all_preds.append(pred) # Get rid of the first dimension all_preds = np.concatenate(all_preds, 0) print('\nComputing inception score...') mean, std = inception_score.get_inception_from_predictions(all_preds, 10) print('Inception Score | mean:', "%.2f" % mean, 'std:', "%.2f" % std)
class StageIEval(object): def __init__(self, sess: tf.Session, model: ConditionalGan, dataset: TextDataset, cfg): self.sess = sess self.model = model self.Retrieval = Retrieval(cfg) self.dataset = dataset self.cfg = cfg self.test_data_loader = DatasetLoader(cfg.RETRIEVAL.DATA_PATH, mode='test') self.classes = self.cfg.EVAL.NUM_CLASSES self.bs = self.cfg.EVAL.SAMPLE_SIZE def evaluate_fid(self): incep_batch_size = self.cfg.EVAL.INCEP_BATCH_SIZE _, layers = load_inception_inference( self.sess, self.classes, incep_batch_size, self.cfg.EVAL.INCEP_CHECKPOINT_DIR) pool3 = layers['PreLogits'] act_op = tf.reshape(pool3, shape=[incep_batch_size, -1]) if not os.path.exists(self.cfg.EVAL.ACT_STAT_PATH): print('Computing activation statistics for real x') fid.compute_and_save_activation_statistics( self.cfg.EVAL.R_IMG_PATH, self.sess, incep_batch_size, act_op, self.cfg.EVAL.ACT_STAT_PATH, verbose=True) print('Loading activation statistics for the real x') stats = np.load(self.cfg.EVAL.ACT_STAT_PATH) mu_real = stats['mu'] sigma_real = stats['sigma'] z = tf.placeholder(tf.float32, [self.bs, self.model.z_dim], name='real_images') cond = tf.placeholder(tf.float32, [self.bs] + [self.model.embed_dim], name='cond') eval_gen, _, _ = self.model.generator(z, cond, reuse=False) saver = tf.train.Saver(tf.global_variables('g_net')) could_load, _ = load(saver, self.sess, self.cfg.CHECKPOINT_DIR) if could_load: print(" [*] Load SUCCESS") else: print(" [!] Load failed...") raise RuntimeError( 'Could not load the checkpoints of the generator') print('Generating x...') fid_size = self.cfg.EVAL.SIZE n_batches = fid_size // self.bs w, h, c = self.model.image_dims[0], self.model.image_dims[ 1], self.model.image_dims[2] samples = np.zeros((n_batches * self.bs, w, h, c)) for i in range(n_batches): start = i * self.bs end = start + self.bs im_feats, sent_feats, labels = self.test_data_loader.get_batch(i, self.cfg.RETRIEVAL.BATCH_SIZE\ , phase = 'test') im_feats_rep = np.repeat(im_feats, test_num_samples, 0) labels_rep = np.repeat(labels, test_num_samples, 0) ######### image_embed, sent_embed = self.sess.run( (self.Retrieval.image_embed_tensor, self.Retrieval.sent_embed_tensor), feed_dict={ self.Retrieval.image_placeholder_test: im_feats_rep, self.Retrieval.sent_placeholder_test: sent_feats, }) ###### sample_z = np.random.normal(0, 1, size=(self.bs, self.model.z_dim)) images, _, embed, _, _ = self.dataset.test.next_batch( self.bs, 4, embeddings=True) samples[start:end] = denormalize_images( self.sess.run(eval_gen, feed_dict={ z: sample_z, cond: embed })) print('Computing activation statistics for generated x...') mu_gen, sigma_gen = fid.calculate_activation_statistics( samples, self.sess, incep_batch_size, act_op, verbose=True) print("calculate FID:", end=" ", flush=True) try: FID = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real) except Exception as e: print(e) FID = 500 print(FID) def evaluate_inception(self): incep_batch_size = self.cfg.EVAL.INCEP_BATCH_SIZE logits, _ = load_inception_inference( self.sess, self.classes, incep_batch_size, self.cfg.EVAL.INCEP_CHECKPOINT_DIR) pred_op = tf.nn.softmax(logits) z = tf.placeholder(tf.float32, [self.bs, self.model.z_dim], name='z') cond = tf.placeholder(tf.float32, [self.bs] + [self.model.embed_dim], name='cond') eval_gen, _, _ = self.model.generator(z, cond, reuse=False, is_training=False) self.Retrieval.eval(self.bs) import pdb # pdb.set_trace() saver = tf.train.Saver( tf.global_variables('g_net') + tf.global_variables('vf_') + tf.global_variables('sf_') + tf.global_variables('att')) # saver = tf.train.Saver(tf.global_variables()) could_load, _ = load(saver, self.sess, self.cfg.CHECKPOINT_DIR) if could_load: print(" [*] Load SUCCESS") else: print(" [!] Load failed...") raise RuntimeError( 'Could not load the checkpoints of the generator') print('Generating x...') size = self.cfg.EVAL.SIZE n_batches = size // self.bs w, h, c = self.model.image_dims[0], self.model.image_dims[ 1], self.model.image_dims[2] samples = np.zeros((n_batches * self.bs, w, h, c)) for i in range(n_batches): print("\rGenerating batch %d/%d" % (i + 1, n_batches), end="", flush=True) sample_z = np.random.normal(0, 1, size=(self.bs, self.model.z_dim)) # _, _, embed, _, _ = self.dataset.test.next_batch(self.bs, 4, embeddings=True) _, _, embed, _, _ = self.dataset.test.next_batch(self.bs, 1, embeddings=True) im_feats, sent_feats, labels = self.test_data_loader.get_batch( i, self.bs, phase='incep') start = i * self.bs end = start + self.bs sent_emb = self.sess.run(self.Retrieval.sent_embed_tensor, feed_dict={ self.Retrieval.image_placeholder_test: im_feats, self.Retrieval.sent_placeholder_test: sent_feats, }) gen_batch = self.sess.run(eval_gen, feed_dict={ z: sample_z, cond: sent_emb }) samples[start:end] = denormalize_images(gen_batch) print('\nComputing inception score...') mean, std = inception_score.get_inception_score(samples, self.sess, incep_batch_size, 10, pred_op, verbose=True) print('Inception Score | mean:', "%.2f" % mean, 'std:', "%.2f" % std)
class ConditionalGanTrainer(object): def __init__(self, sess: tf.Session, model: ConditionalGan, dataset: TextDataset, cfg, cfg_stage_i): self.sess = sess self.model = model self.dataset = dataset self.R_loader = DatasetLoader(cfg.RETRIEVAL.DATA_PATH, mode='train') self.test_data_loader = DatasetLoader(cfg.RETRIEVAL.DATA_PATH, mode='test') self.Retrieval = Retrieval(cfg) self.cfg = cfg self.cfg_stage_i = cfg_stage_i self.lr = self.cfg.TRAIN.D_LR def define_losses(self): self.img_emb, self.txt_emb, self.R_loss = self.Retrieval.build_model() self.model.build_model(self.txt_emb) self.learning_rate = tf.placeholder(dtype=tf.float32, shape=None, name='lr') self.D_synthetic_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.model.D_synthetic_logits, labels=tf.zeros_like(self.model.D_synthetic))) self.D_real_match_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.model.D_real_match_logits, labels=tf.fill(self.model.D_real_match.get_shape(), 0.95))) self.D_real_mismatch_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.model.D_real_mismatch_logits, labels=tf.zeros_like(self.model.D_real_mismatch))) self.G_kl_loss = self.kl_loss(self.model.embed_mean, self.model.embed_log_sigma) self.G_gan_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=self.model.D_synthetic_logits, labels=tf.ones_like(self.model.D_synthetic))) # Define the final losses alpha_coeff = self.cfg.TRAIN.COEFF.ALPHA_MISMATCH_LOSS kl_coeff = self.cfg.TRAIN.COEFF.KL self.D_loss = self.D_real_match_loss + alpha_coeff * self.D_real_mismatch_loss \ + (1.0 - alpha_coeff) * self.D_synthetic_loss self.G_loss = self.G_gan_loss + kl_coeff * self.G_kl_loss self.R_loss = self.R_loss + 0.5*self.G_loss + 0.5*self.D_loss self.G_loss_summ = tf.summary.scalar("g_loss", self.G_loss) self.D_loss_summ = tf.summary.scalar("d_loss", self.D_loss) self.R_loss_summ = tf.summary.scalar("R_loss", self.R_loss) stagei_vars = tf.global_variables('g_net') stageii_vars = tf.global_variables('stageII_g_net') + tf.global_variables('stageII_d_net') self.stagei_g_saver = tf.train.Saver(stagei_vars) self.stageii_saver = tf.train.Saver(var_list=stageii_vars, max_to_keep=self.cfg.TRAIN.CHECKPOINTS_TO_KEEP) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): self.D_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.cfg.TRAIN.D_BETA_DECAY) \ .minimize(self.D_loss, var_list=self.model.d_vars) self.G_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.cfg.TRAIN.G_BETA_DECAY) \ .minimize(self.G_loss, var_list=self.model.g_vars) self.R_optim = tf.train.AdamOptimizer(learning_rate=self.cfg.RETRIEVAL.R_LR)\ .minimize(self.R_loss, var_list=self.Retrieval.var_list) def kl_loss(self, mean, log_sigma): loss = -log_sigma + .5 * (-1 + tf.exp(2. * log_sigma) + tf.square(mean)) loss = tf.reduce_mean(loss) return loss def define_summaries(self): self.D_synthetic_summ = tf.summary.histogram('d_synthetic_sum', self.model.D_synthetic) self.D_real_match_summ = tf.summary.histogram('d_real_match_sum', self.model.D_real_match) self.D_real_mismatch_summ = tf.summary.histogram('d_real_mismatch_sum', self.model.D_real_mismatch) self.G_img_summ = tf.summary.image("g_sum", self.model.G) self.z_sum = tf.summary.histogram("z", self.model.z) self.D_synthetic_loss_summ = tf.summary.scalar('d_synthetic_sum_loss', self.D_synthetic_loss) self.D_real_match_loss_summ = tf.summary.scalar('d_real_match_sum_loss', self.D_real_match_loss) self.D_real_mismatch_loss_summ = tf.summary.scalar('d_real_mismatch_sum_loss', self.D_real_mismatch_loss) self.D_loss_summ = tf.summary.scalar("d_loss", self.D_loss) self.G_gan_loss_summ = tf.summary.scalar("g_gan_loss", self.G_gan_loss) self.G_kl_loss_summ = tf.summary.scalar("g_kl_loss", self.G_kl_loss) self.G_loss_summ = tf.summary.scalar("g_loss", self.G_loss) self.G_merged_summ = tf.summary.merge([self.G_img_summ, self.G_loss_summ, self.G_gan_loss_summ, self.G_kl_loss_summ]) self.D_merged_summ = tf.summary.merge([self.D_real_mismatch_summ, self.D_real_match_summ, self.D_synthetic_summ, self.D_synthetic_loss_summ, self.D_real_mismatch_loss_summ, self.D_real_match_loss_summ, self.D_loss_summ]) self.writer = tf.summary.FileWriter(self.cfg.LOGS_DIR, self.sess.graph) def train(self): self.define_losses() self.define_summaries() sample_z = np.random.normal(0, 1, (self.model.sample_num, self.model.z_dim)) _, sample_embed, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 0, 1) im_feats_test, sent_feats_test, labels_test = self.test_data_loader.get_batch(0,self.cfg.RETRIEVAL.SAMPLE_NUM,\ image_aug = self.cfg.RETRIEVAL.IMAGE_AUG, phase='test') sample_embed = np.squeeze(sample_embed, axis=0) print(sample_embed.shape) save_captions(self.cfg.SAMPLE_DIR, captions) counter = 1 start_time = time.time() could_load, checkpoint_counter = load(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR) if could_load: counter = checkpoint_counter print(" [*] Load SUCCESS: Stage II networks are loaded.") else: print(" [!] Load failed for stage II networks...") could_load, checkpoint_counter = load(self.stagei_g_saver, self.sess, self.cfg_stage_i.CHECKPOINT_DIR) if could_load: print(" [*] Load SUCCESS: Stage I generator is loaded") else: print(" [!] WARNING!!! Failed to load the parameters for stage I generator...") initialize_uninitialized(self.sess) # Updates per epoch are given by the training data size / batch size updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size epoch_start = counter // updates_per_epoch for epoch in range(epoch_start, self.cfg.TRAIN.EPOCH): cen_epoch = epoch // 100 for idx in range(0, updates_per_epoch): images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.model.batch_size, 1, embeddings=True, wrong_img=True) batch_z = np.random.normal(0, 1, (self.model.batch_size, self.model.z_dim)) # Retrieval data loader if idx % updates_per_epoch == 0: self.R_loader.shuffle_inds() im_feats, sent_feats, labels = self.R_loader.get_batch(idx % updates_per_epoch,\ self.cfg.RETRIEVAL.BATCH_SIZE, image_aug = self.cfg.RETRIEVAL.IMAGE_AUG) feed_dict = { self.learning_rate: self.lr * (0.5**cen_epoch), self.model.inputs: images, self.model.wrong_inputs: wrong_images, # self.model.embed_inputs: embed, # self.model.embed_inputs: self.txt_emb, self.model.z: batch_z, self.Retrieval.image_placeholder : im_feats, self.Retrieval.sent_placeholder : sent_feats, self.Retrieval.label_placeholder : labels } # Update D network _, err_d, summary_str = self.sess.run([self.D_optim, self.D_loss, self.D_merged_summ], feed_dict=feed_dict) self.writer.add_summary(summary_str, counter) # Update G network _, err_g, summary_str = self.sess.run([self.G_optim, self.G_loss, self.G_merged_summ], feed_dict=feed_dict) self.writer.add_summary(summary_str, counter) # Update R network _, err_r, summary_str = self.sess.run([self.R_optim, self.R_loss, self.R_loss_summ], feed_dict=feed_dict) self.writer.add_summary(summary_str, counter) counter += 1 print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, r_loss: %.8f" % (epoch, idx, updates_per_epoch, time.time() - start_time, err_d, err_g, err_r)) if np.mod(counter, 1000) == 0: try: # pdb.set_trace() self.Retrieval.eval() sent_emb = self.sess.run(self.Retrieval.sent_embed_tensor, feed_dict={ self.Retrieval.image_placeholder_test: im_feats_test, self.Retrieval.sent_placeholder_test: sent_feats_test, }) self.model.eval(sent_emb) samples = self.sess.run(self.model.sampler, feed_dict={ self.model.z_sample: sample_z, # self.model.embed_sample: sample_embed, self.model.embed_sample: sent_emb, }) save_images(samples, get_balanced_factorization(samples.shape[0]), '{}train_{:02d}_{:04d}.png'.format(self.cfg.SAMPLE_DIR, epoch, idx)) except Exception as e: print("Failed to generate sample image") print(type(e)) print(e.args) print(e) if np.mod(counter, 500) == 2: save(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR, counter) if np.mod(epoch, 50) == 0 and epoch!=0: self.ret_eval(epoch) def ret_eval(self, epoch): test_num_samples = self.test_data_loader.no_samples all_labels = [] sim_mat = np.zeros((test_num_samples,test_num_samples)) self.Retrieval.eval(sample_size=test_num_samples) for i in range(test_num_samples): # im_feats, sent_feats, labels = test_data_loader.get_batch(i, params.batchSize, phase = 'test') im_feats, sent_feats, labels = self.test_data_loader.get_batch(i, 1, phase = 'eval') im_feats_rep = np.repeat(im_feats,test_num_samples,0) labels_rep = np.repeat(labels,test_num_samples,0) image_embed, sent_embed = self.sess.run((self.Retrieval.image_embed_tensor, self.Retrieval.sent_embed_tensor), feed_dict={ self.Retrieval.image_placeholder_test: im_feats_rep, self.Retrieval.sent_placeholder_test: sent_feats, }) sim_mat[i,:] = pairwise_distances(image_embed[0,:].reshape(1,-1),sent_embed)#,'cosine') all_labels.extend(labels) if i % 100 == 0: print('Done: '+str(i)+' of '+str(test_num_samples)) i2s_mapk = compMapScore(sim_mat, all_labels) s2i_mapk = compMapScore(sim_mat.T, all_labels) print('Image to Sentence mAP@50: ',i2s_mapk,'\n', 'Sentence to Image mAP@50: ',s2i_mapk,'\n',) save_scores(self.cfg.SCORE_DIR, i2s_mapk, s2i_mapk, epoch)
class PGGAN(object): # build model def __init__(self, cfg, batch_size, steps, check_dir_write, check_dir_read, dataset, sample_path, log_dir, stage, trans, build_model=True): self.cfg = cfg self.R_loader = DatasetLoader(cfg.RETRIEVAL.DATA_PATH, mode='train') self.test_data_loader = DatasetLoader(cfg.RETRIEVAL.DATA_PATH, mode='test') self.Retrieval = Retrieval(cfg) self.batch_size = batch_size self.steps = steps self.check_dir_write = check_dir_write self.check_dir_read = check_dir_read self.dataset = dataset self.sample_path = sample_path self.log_dir = log_dir self.stage = stage self.trans = trans self.z_dim = 128 # self.embed_dim = 1024 self.embed_dim = 512 self.out_size = 4 * pow(2, stage - 1) self.channel = 3 self.sample_num = 64 self.compr_embed_dim = 128 self.lr = 0.00005 self.lr_inp = self.lr self.output_size = 4 * pow(2, stage - 1) self.dt = tf.Variable(0.0, trainable=False) self.alpha_tra = tf.Variable(initial_value=0.0, trainable=False, name='alpha_tra') if build_model: # self.build_model() self.define_losses() self.define_summaries() def build_model(self, emb): # Define the input tensor by appending the batch size dimension to the image dimension self.iter = tf.placeholder(tf.int32, shape=None) self.learning_rate = tf.placeholder(tf.float32, shape=None) self.x = tf.placeholder(tf.float32, [ self.batch_size, self.output_size, self.output_size, self.channel ], name='x') self.x_mismatch = tf.placeholder(tf.float32, [ self.batch_size, self.output_size, self.output_size, self.channel ], name='x_mismatch') # self.cond = tf.placeholder(tf.float32, [self.batch_size, self.embed_dim], name='cond') self.cond = emb self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z') self.epsilon = tf.placeholder(tf.float32, [self.batch_size, 1, 1, 1], name='eps') self.z_sample = tf.placeholder(tf.float32, [self.sample_num] + [self.z_dim], name='z_sample') self.cond_sample = tf.placeholder(tf.float32, [self.sample_num] + [self.embed_dim], name='cond_sample') self.G, self.mean, self.log_sigma = self.generator(self.z, self.cond, stages=self.stage, t=self.trans) self.Dg_logit = self.discriminator(self.G, self.cond, reuse=False, stages=self.stage, t=self.trans) self.Dx_logit = self.discriminator(self.x, self.cond, reuse=True, stages=self.stage, t=self.trans) self.Dxmi_logit = self.discriminator(self.x_mismatch, self.cond, reuse=True, stages=self.stage, t=self.trans) self.epsilon = tf.random_uniform([self.batch_size, 1, 1, 1], 0., 1.) self.x_hat = self.epsilon * self.G + (1. - self.epsilon) * self.x self.cond_inp = self.cond + 0.0 self.Dx_hat_logit = self.discriminator(self.x_hat, self.cond_inp, reuse=True, stages=self.stage, t=self.trans) # self.sampler, _, _ = self.generator(self.z_sample, self.cond_sample, reuse=True, stages=self.stage, # t=self.trans) self.alpha_assign = tf.assign( self.alpha_tra, (tf.cast(tf.cast(self.iter, tf.float32) / self.steps, tf.float32))) self.d_vars = tf.trainable_variables('d_net') self.g_vars = tf.trainable_variables('g_net') show_all_variables() def eval(self, emb): self.sampler, _, _ = self.generator(self.z_sample, emb, reuse=True, stages=self.stage, t=self.trans) def get_gradient_penalty(self, x, y): grad_y = tf.gradients(y, [x])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(grad_y), reduction_indices=[1, 2, 3])) return tf.reduce_mean(tf.maximum(0.0, slopes - 1.)**2) def get_gradient_penalty2(self, x, y): grad_y = tf.gradients(y, [x])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(grad_y), reduction_indices=[1])) return tf.reduce_mean(tf.maximum(0.0, slopes - 1.)**2) def define_losses(self): self.img_emb, self.txt_emb, self.R_loss = self.Retrieval.build_model() self.build_model(self.txt_emb) self.D_loss_real = tf.reduce_mean(self.Dx_logit) self.D_loss_fake = tf.reduce_mean(self.Dg_logit) self.D_loss_mismatch = tf.reduce_mean(self.Dxmi_logit) self.wdist = self.D_loss_real - self.D_loss_fake self.wdist2 = self.D_loss_real - self.D_loss_mismatch self.reg_loss = tf.reduce_mean(tf.square(self.Dxmi_logit)) self.G_kl_loss = self.kl_std_normal_loss(self.mean, self.log_sigma) self.real_gp = self.get_gradient_penalty(self.x_hat, self.Dx_hat_logit) self.real_gp2 = self.get_gradient_penalty2(self.cond_inp, self.Dx_hat_logit) self.D_loss = -self.wdist - self.wdist2 + 200.0 * (self.real_gp + self.real_gp2) # self.D_loss = -self.wdist - self.wdist2 + 50000000.0 * (self.real_gp + self.real_gp2) self.G_loss = -self.D_loss_fake + 5.0 * self.G_kl_loss # self.G_loss = -self.D_loss_fake + 5.0 * self.G_kl_loss self.R_loss = self.R_loss + 0.5 * self.G_loss + 0.5 * self.D_loss self.D_optimizer = tf.train.AdamOptimizer(0.000002, beta1=0.0, beta2=0.99) self.G_optimizer = tf.train.AdamOptimizer(0.000002, beta1=0.0, beta2=0.99) self.R_optimizer = tf.train.AdamOptimizer(0.000002, beta1=0.0, beta2=0.99) with tf.control_dependencies([self.alpha_assign]): self.D_optim = self.D_optimizer.minimize(self.D_loss, var_list=self.d_vars) self.G_optim = self.G_optimizer.minimize(self.G_loss, var_list=self.g_vars) self.R_optim = self.R_optimizer.minimize( self.R_loss, var_list=self.Retrieval.var_list) # variables to save vars_to_save = self.get_variables_up_to_stage(self.stage) print('Length of the vars to save: %d' % len(vars_to_save)) print('\n\nVariables to save:') print_vars(vars_to_save) self.saver = tf.train.Saver(vars_to_save, max_to_keep=2) # variables to restore self.restore = None if self.stage > 1 and self.trans: vars_to_restore = self.get_variables_up_to_stage(self.stage - 1) print('Length of the vars to restore: %d' % len(vars_to_restore)) print('\n\nVariables to restore:') print_vars(vars_to_restore) self.restore = tf.train.Saver(vars_to_restore) def define_summaries(self): summaries = [ tf.summary.image('x', self.x), tf.summary.image('G_img', self.G), tf.summary.histogram('z', self.z), tf.summary.histogram('z_sample', self.z_sample), tf.summary.scalar('G_loss_wass', -self.D_loss_fake), tf.summary.scalar('kl_loss', self.G_kl_loss), tf.summary.scalar('G_loss', self.G_loss), # tf.summary.scalar('d_lr', self.d_lr), # tf.summary.scalar('g_lr', self.g_lr), tf.summary.scalar('D_loss_real', self.D_loss_real), tf.summary.scalar('D_loss_fake', self.D_loss_fake), tf.summary.scalar('real_gp', self.real_gp), tf.summary.scalar('D_loss', self.D_loss), tf.summary.scalar('reg_loss', self.reg_loss), tf.summary.scalar('wdist', self.wdist), tf.summary.scalar('wdist2', self.wdist2), tf.summary.scalar('d_loss_mismatch', self.D_loss_mismatch), tf.summary.scalar('real_gp2', self.real_gp2), ] self.summary_op = tf.summary.merge(summaries) # do train def train(self): config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as self.sess: summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) start_point = 0 if self.stage != 1: if self.trans: could_load, _ = load(self.restore, self.sess, self.check_dir_read) if not could_load: raise RuntimeError( 'Could not load previous stage during transition') else: could_load, _ = load(self.saver, self.sess, self.check_dir_read) if not could_load: raise RuntimeError('Could not load current stage') # variables to init vars_to_init = initialize_uninitialized(self.sess) self.sess.run(tf.variables_initializer(vars_to_init)) sample_z = np.random.normal(0, 1, (self.sample_num, self.z_dim)) _, sample_cond, _, captions = self.dataset.test.next_batch_test( self.sample_num, 0, 1) sample_cond = np.squeeze(sample_cond, axis=0) print('Conditionals sampler shape: {}'.format(sample_cond.shape)) save_captions(self.sample_path, captions) start_time = time.time() for idx in range(start_point + 1, self.steps): if self.trans: # Reduce the learning rate during the transition period and slowly increase it p = idx / self.steps self.lr_inp = self.lr # * np.exp(-2 * np.square(1 - p)) epoch_size = self.dataset.train.num_examples // self.batch_size epoch = idx // epoch_size images, wrong_images, embed, _, _ = self.dataset.train.next_batch( self.batch_size, 1, wrong_img=True, embeddings=True) im_feats_test, sent_feats_test, labels_test = self.test_data_loader.get_batch(0,self.cfg.RETRIEVAL.SAMPLE_NUM,\ image_aug = self.cfg.RETRIEVAL.IMAGE_AUG, phase='test') batch_z = np.random.normal(0, 1, (self.batch_size, self.z_dim)) eps = np.random.uniform(0., 1., size=(self.batch_size, 1, 1, 1)) import pdb # pdb.set_trace() # Retrieval data loader # epoch_size = 120 # if idx % epoch_size == 0: # self.R_loader.shuffle_inds() im_feats, sent_feats, labels = self.R_loader.get_batch(idx % epoch_size,\ self.cfg.RETRIEVAL.BATCH_SIZE, image_aug = self.cfg.RETRIEVAL.IMAGE_AUG) feed_dict = { self.x: images, self.learning_rate: self.lr_inp, self.x_mismatch: wrong_images, # self.cond: embed, self.z: batch_z, self.epsilon: eps, self.z_sample: sample_z, # self.cond_sample: sample_cond, self.iter: idx, self.Retrieval.image_placeholder: im_feats, self.Retrieval.sent_placeholder: sent_feats, self.Retrieval.label_placeholder: labels } _, err_d = self.sess.run([self.D_optim, self.D_loss], feed_dict=feed_dict) _, err_g = self.sess.run([self.G_optim, self.G_loss], feed_dict=feed_dict) _, err_r = self.sess.run([self.R_optim, self.R_loss], feed_dict=feed_dict) if np.mod(idx, 20) == 0: summary_str = self.sess.run(self.summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary_str, idx) print( "Epoch: [%2d] [%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, r_loss: %.8f" % (epoch, idx, time.time() - start_time, err_d, err_g, err_r)) if np.mod(idx, 2000) == 0: try: self.Retrieval.eval() sent_emb = self.sess.run( self.Retrieval.sent_embed_tensor, feed_dict={ self.Retrieval.image_placeholder_test: im_feats_test, self.Retrieval.sent_placeholder_test: sent_feats_test, }) self.eval(sent_emb) samples = self.sess.run( self.sampler, feed_dict={ self.z_sample: sample_z, # self.model.embed_sample: sample_embed, self.cond_sample: sent_emb, }) # samples = sess.run(self.sampler, feed_dict={ # self.z_sample: sample_z, # self.cond_sample: sample_cond}) samples = np.clip(samples, -1., 1.) if self.out_size > 256: samples = samples[:4] save_images( samples, get_balanced_factorization(samples.shape[0]), '{}train_{:02d}_{:04d}.png'.format( self.sample_path, epoch, idx)) except Exception as e: print("Failed to generate sample image") print(type(e)) print(e.args) print(e) if np.mod(idx, 2000) == 0 or idx == self.steps - 1: save(self.saver, self.sess, self.check_dir_write, idx) sys.stdout.flush() if np.mod(idx, 20000) == 0: print('yes') self.ret_eval(idx) tf.reset_default_graph() def discriminator(self, inp, cond, stages, t, reuse=False): alpha_trans = self.alpha_tra with tf.variable_scope("d_net", reuse=reuse): x_iden = None if t: x_iden = pool(inp, 2) x_iden = self.from_rgb(x_iden, stages - 2) x = self.from_rgb(inp, stages - 1) for i in range(stages - 1, 0, -1): with tf.variable_scope(self.get_conv_scope_name(i), reuse=reuse): x = conv2d(x, f=self.get_dnf(i), ks=(3, 3), s=(1, 1), act=lrelu_act()) x = conv2d(x, f=self.get_dnf(i - 1), ks=(3, 3), s=(1, 1), act=lrelu_act()) x = pool(x, 2) if i == stages - 1 and t: x = tf.multiply(alpha_trans, x) + tf.multiply( tf.subtract(1., alpha_trans), x_iden) with tf.variable_scope(self.get_conv_scope_name(0), reuse=reuse): # Real/False branch cond_compress = fc(cond, units=128, act=lrelu_act()) concat = self.concat_cond4(x, cond_compress) x_b1 = conv2d(concat, f=self.get_dnf(0), ks=(3, 3), s=(1, 1), act=lrelu_act()) x_b1 = conv2d(x_b1, f=self.get_dnf(0), ks=(4, 4), s=(1, 1), padding='VALID', act=lrelu_act()) output_b1 = fc(x_b1, units=1) return output_b1 def generator(self, z_var, cond_inp, stages, t, reuse=False, cond_noise=True): alpha_trans = self.alpha_tra with tf.variable_scope('g_net', reuse=reuse): with tf.variable_scope(self.get_conv_scope_name(0), reuse=reuse): mean_lr, log_sigma_lr = self.generate_conditionals(cond_inp) cond = self.sample_normal_conditional(mean_lr, log_sigma_lr, cond_noise) # import pdb # pdb.set_trace() x = tf.concat([z_var, cond], axis=1) x = fc(x, units=4 * 4 * self.get_nf(0)) x = layer_norm(x) x = tf.reshape(x, [-1, 4, 4, self.get_nf(0)]) x = conv2d(x, f=self.get_nf(0), ks=(3, 3), s=(1, 1)) x = layer_norm(x, act=tf.nn.relu) x = conv2d(x, f=self.get_nf(0), ks=(3, 3), s=(1, 1)) x = layer_norm(x, act=tf.nn.relu) x_iden = None for i in range(1, stages): if (i == stages - 1) and t: x_iden = self.to_rgb(x, stages - 2) x_iden = upscale(x_iden, 2) with tf.variable_scope(self.get_conv_scope_name(i), reuse=reuse): x = upscale(x, 2) x = conv2d(x, f=self.get_nf(i), ks=(3, 3), s=(1, 1)) x = layer_norm(x, act=tf.nn.relu) x = conv2d(x, f=self.get_nf(i), ks=(3, 3), s=(1, 1)) x = layer_norm(x, act=tf.nn.relu) x = self.to_rgb(x, stages - 1) if t: x = tf.multiply(tf.subtract(1., alpha_trans), x_iden) + tf.multiply(alpha_trans, x) return x, mean_lr, log_sigma_lr def concat_cond4(self, x, cond): cond_compress = tf.expand_dims(tf.expand_dims(cond, 1), 1) cond_compress = tf.tile(cond_compress, [1, 4, 4, 1]) x = tf.concat([x, cond_compress], axis=3) return x def concat_cond128(self, x, cond_inp, cond_noise=True): mean, log_sigma = self.generate_conditionals(cond_inp, units=256) cond = self.sample_normal_conditional(mean, log_sigma, cond_noise) cond_compress = tf.reshape(cond, [-1, 16, 16, 1]) cond_compress = tf.tile(cond_compress, [1, 8, 8, 8]) x = tf.concat([x, cond_compress], axis=3) return x, mean, log_sigma def get_rgb_name(self, stage): return 'rgb_stage_%d' % stage def get_conv_scope_name(self, stage): return 'conv_stage_%d' % stage def get_dnf(self, stage): return min(1024 // (2**stage) * 2, 512) def get_nf(self, stage): return min(1024 // (2**stage) * 4, 512) def from_rgb(self, x, stage): with tf.variable_scope(self.get_rgb_name(stage)): return conv2d(x, f=self.get_dnf(stage), ks=(1, 1), s=(1, 1), act=lrelu_act()) def generate_conditionals(self, embeddings, units=128): """Takes the embeddings, compresses them and builds the statistics for a multivariate normal distribution""" mean = fc(embeddings, units, act=lrelu_act()) log_sigma = fc(embeddings, units, act=lrelu_act()) return mean, log_sigma def sample_normal_conditional(self, mean, log_sigma, cond_noise=True): if cond_noise: epsilon = tf.truncated_normal(tf.shape(mean)) stddev = tf.exp(log_sigma) return mean + stddev * epsilon return mean def kl_std_normal_loss(self, mean, log_sigma): loss = -log_sigma + .5 * (-1 + tf.exp(2. * log_sigma) + tf.square(mean)) loss = tf.reduce_mean(loss) return loss def to_rgb(self, x, stage): with tf.variable_scope(self.get_rgb_name(stage)): x = conv2d(x, f=9, ks=(2, 2), s=(1, 1), act=tf.nn.relu) x = conv2d(x, f=3, ks=(1, 1), s=(1, 1)) return x def get_adam_vars(self, opt, vars_to_train): opt_vars = [ opt.get_slot(var, name) for name in opt.get_slot_names() for var in vars_to_train if opt.get_slot(var, name) is not None ] opt_vars.extend(list(opt._get_beta_accumulators())) return opt_vars def get_variables_up_to_stage(self, stages): d_vars_to_save = tf.global_variables('d_net/%s' % self.get_rgb_name(stages - 1)) g_vars_to_save = tf.global_variables('g_net/%s' % self.get_rgb_name(stages - 1)) for stage in range(stages): d_vars_to_save += tf.global_variables( 'd_net/%s' % self.get_conv_scope_name(stage)) g_vars_to_save += tf.global_variables( 'g_net/%s' % self.get_conv_scope_name(stage)) return d_vars_to_save + g_vars_to_save def ret_eval(self, epoch): test_num_samples = self.test_data_loader.no_samples all_labels = [] sim_mat = np.zeros((test_num_samples, test_num_samples)) self.Retrieval.eval(sample_size=test_num_samples) for i in range(test_num_samples): # im_feats, sent_feats, labels = test_data_loader.get_batch(i, params.batchSize, phase = 'test') im_feats, sent_feats, labels = self.test_data_loader.get_batch( i, 1, phase='eval') im_feats_rep = np.repeat(im_feats, test_num_samples, 0) labels_rep = np.repeat(labels, test_num_samples, 0) image_embed, sent_embed = self.sess.run( (self.Retrieval.image_embed_tensor, self.Retrieval.sent_embed_tensor), feed_dict={ self.Retrieval.image_placeholder_test: im_feats_rep, self.Retrieval.sent_placeholder_test: sent_feats, }) sim_mat[i, :] = pairwise_distances(image_embed[0, :].reshape( 1, -1), sent_embed) #,'cosine') all_labels.extend(labels) if i % 100 == 0: print('Done: ' + str(i) + ' of ' + str(test_num_samples)) i2s_mapk = compMapScore(sim_mat, all_labels) s2i_mapk = compMapScore(sim_mat.T, all_labels) print( 'Image to Sentence mAP@50: ', i2s_mapk, '\n', 'Sentence to Image mAP@50: ', s2i_mapk, '\n', ) save_scores(self.cfg.SCORE_DIR, i2s_mapk, s2i_mapk, epoch)