def rlm_scores(self, re_generator, fake_data_loader, sess, test_data_loader, valid_data_loader, writer): test_rlm = Nll(test_data_loader, re_generator, sess) valid_rlm = Nll(valid_data_loader, re_generator, sess) print('start train re-generator:') valid_rlm_best = 1000 test_rlm_best = 0 self.re_gen_num = 80 for epoch in range(self.re_gen_num): start = time() loss = re_generator.run_epoch(sess, fake_data_loader, writer, epoch) end = time() print('epoch:' + str(epoch) + ' loss: ' + str(loss) + ' \t time:' + str(end - start)) test_rlm_score = test_rlm.get_score() valid_rlm_score = valid_rlm.get_score() print('valid_rlm_score:' + str(valid_rlm_score) + ' test_rlm_score: ' + str(test_rlm_score)) if (epoch + 1) % self.ntest == 0: if valid_rlm_score < valid_rlm_best: valid_rlm_best = valid_rlm_score test_rlm_best = test_rlm_score print('valid_rlm_best:' + str(valid_rlm_best) + 'test_rlm_best: ' + str(test_rlm_best)) return valid_rlm_best, test_rlm_best
def train(self): self.init_real_trainng() ### #init_metric:NLL ### gen_valid_nll = Nll(self.valid_data_loader, self.generator, self.sess) FLAGS = tf.app.flags.FLAGS if FLAGS.data == 'image_coco': self.valid_code = self.test_code self.valid_data_loader.create_batches_train_list(self.valid_code) self.train_data_loader.create_batches_train_list(self.train_code) self.test_data_loader.create_batches_train_list(self.test_code) self.sess.run(tf.global_variables_initializer()) # ++ Saver # saver_variables = tf.global_variables saver_variables = slim.get_variables_to_restore(include=["generator_gpt2"]) saver = tf.train.Saver(saver_variables, max_to_keep=20) # ++ ==================== # summary writer self.writer = self.save_summary() if self.restore: restore_from = tf.train.latest_checkpoint(os.path.join(self.save_path, "gen")) saver.restore(self.sess, restore_from) print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}") else: best_nll = 1000 print('start train generator:') for epoch in range(self.train_gen_num): start = time() loss = self.generator.run_epoch(self.sess, self.train_data_loader, self.writer, epoch) end = time() print('epoch:' + str(epoch) + ' loss: ' + str(loss) + ' \t time:' + str(end - start)) if (epoch + 1) % self.ntest == 0: values = gen_valid_nll.get_score() if values < best_nll: best_nll = values # save pre_train saver.save(self.sess, os.path.join(self.save_path, "gen", 'train_best')) print('gen store') self.add_epoch() restore_from = tf.train.latest_checkpoint(os.path.join(self.save_path, "gen")) saver.restore(self.sess, restore_from) print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}") print('start train discriminator:') saver_variables = slim.get_variables_to_restore(include=["discriminator_base"]) saver = tf.train.Saver(saver_variables, max_to_keep=20) self.generate_samples(self.temperature) self.get_real_test_file() with open(self.generator_file_pkl, 'rb') as inf: self.generator_code = pickle.load(inf) self.dis_train_data_loader.load_train_data_list(self.train_code, self.generator_code) self.dis_valid_data_loader.load_train_data_list_file(self.valid_code, self.generator_valid_file) acc_valid_best = 0 for epoch in range(self.train_dis_num): print("base epoch:" + str(epoch)) self.train_discriminator(self.discriminator) accuracy_valid, loss_valid = self.get_distance(self.generator_valid_file, self.discriminator, \ "base_valid", self.valid_data_loader, epoch, self.writer) if accuracy_valid > acc_valid_best: acc_valid_best = accuracy_valid saver.save(self.sess, os.path.join(self.save_path, f'train_base')) print("acc_valid_best base:", acc_valid_best) restore_from = os.path.join(self.save_path, "train_base") saver.restore(self.sess, restore_from) print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}") accuracy_test, loss_test = self.get_distance(self.generator_test_file, self.discriminator, \ "base_test", self.test_data_loader, epoch, self.writer) print("acc_test base:", accuracy_test) ##### # Filter generation ##### for c_give in [0.2, 0.5, 0.8]: print("*" * 50) print("c", c_give) # c_give is the acceptance ratio uc_get = self.uc_sampleing(self.generator, self.discriminator, "d1", self.temperature, c=c_give) print("uc=", uc_get) self.total_num = 0 while self.keep_num < self.num_generate_train + 20000 + self.batch_size: inp = self.generator.generate(self.sess, self.temperature) self.generate_sample_by_c_uc(inp, self.discriminator, 'd1', c_give, uc_get) self.total_num += 1 self.num_to_0() self.real_text_samples(f'd1_fake_keep_c{int(c_give * 100)}') self.real_text_samples(f'd1_fake_filter_c{int(c_give * 100)}') print("==" * 50) ##### # Test the accuracy of filtered samples ##### print(f'start train discriminator_d1 ') filname = 20 saver_variables = slim.get_variables_to_restore(include=["discriminator_d1"]) saver = tf.train.Saver(saver_variables, max_to_keep=20) self.d1_fake_codes = self.get_fake_code( os.path.join(self.output_path, f"d1_fake_keep_c{filname}_train.pkl")) self.dis_train_data_loader.load_train_data_list(self.train_code, self.d1_fake_codes) self.dis_valid_data_loader.load_train_data_list_file(self.valid_code, os.path.join(self.output_path, f"d1_fake_keep_c{filname}_valid.txt")) acc_valid_best = 0 for epoch in range(self.train_dis_num): print("d1 epoch:" + str(epoch)) self.train_discriminator(self.discriminator_d1) accuracy_valid, loss_valid = self.get_distance( os.path.join(self.output_path, f"d1_fake_keep_c{filname}_valid.txt"), self.discriminator_d1, \ "d1_valid", self.valid_data_loader, epoch, self.writer) if accuracy_valid > acc_valid_best: acc_valid_best = accuracy_valid saver.save(self.sess, os.path.join(self.save_path, f'train_d1')) print("acc_valid_best d1:", acc_valid_best) restore_from = os.path.join(self.save_path, "train_d1") saver.restore(self.sess, restore_from) print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}") accuracy_test, loss_test = self.get_distance( os.path.join(self.output_path, f"d1_fake_keep_c{filname}_test.txt"), self.discriminator_d1, \ "d1_test", self.test_data_loader, epoch, self.writer) print(f"c={filname / 100}", "acc_test d1:", accuracy_test)
def lm_scores(self, generator, fake_data_loader, sess): lm = Nll(fake_data_loader, generator, sess) lm_score = lm.get_score() print("lm_score:", lm_score) return lm_score