def rlm_scores(self, re_generator, fake_data_loader, sess, test_data_loader, valid_data_loader, writer):
        test_rlm = Nll(test_data_loader, re_generator, sess)
        valid_rlm = Nll(valid_data_loader, re_generator, sess)
        print('start train re-generator:')
        valid_rlm_best = 1000
        test_rlm_best = 0
        self.re_gen_num = 80
        for epoch in range(self.re_gen_num):
            start = time()
            loss = re_generator.run_epoch(sess, fake_data_loader, writer, epoch)
            end = time()
            print('epoch:' + str(epoch) + ' loss: ' + str(loss) + ' \t time:' + str(end - start))

            test_rlm_score = test_rlm.get_score()
            valid_rlm_score = valid_rlm.get_score()
            print('valid_rlm_score:' + str(valid_rlm_score) + '   test_rlm_score: ' + str(test_rlm_score))
            if (epoch + 1) % self.ntest == 0:
                if valid_rlm_score < valid_rlm_best:
                    valid_rlm_best = valid_rlm_score
                    test_rlm_best = test_rlm_score
        print('valid_rlm_best:' + str(valid_rlm_best) + 'test_rlm_best: ' + str(test_rlm_best))

        return valid_rlm_best, test_rlm_best
    def train(self):
        self.init_real_trainng()
        ###
        #init_metric:NLL
        ###
        gen_valid_nll = Nll(self.valid_data_loader, self.generator, self.sess)

        FLAGS = tf.app.flags.FLAGS
        if FLAGS.data == 'image_coco':
            self.valid_code = self.test_code
        self.valid_data_loader.create_batches_train_list(self.valid_code)
        self.train_data_loader.create_batches_train_list(self.train_code)
        self.test_data_loader.create_batches_train_list(self.test_code)
        self.sess.run(tf.global_variables_initializer())

        # ++ Saver
        # saver_variables = tf.global_variables
        saver_variables = slim.get_variables_to_restore(include=["generator_gpt2"])
        saver = tf.train.Saver(saver_variables, max_to_keep=20)
        # ++ ====================
        
        # summary writer
        self.writer = self.save_summary()

        if self.restore:
            restore_from = tf.train.latest_checkpoint(os.path.join(self.save_path, "gen"))
            saver.restore(self.sess, restore_from)
            print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}")
        else:
            best_nll = 1000
            print('start train generator:')
            for epoch in range(self.train_gen_num):
                start = time()
                loss = self.generator.run_epoch(self.sess, self.train_data_loader, self.writer, epoch)
                end = time()
                print('epoch:' + str(epoch) + ' loss: ' + str(loss) + ' \t time:' + str(end - start))
                if (epoch + 1) % self.ntest == 0:
                    values = gen_valid_nll.get_score()
                    if values < best_nll:
                        best_nll = values
                        # save pre_train
                        saver.save(self.sess, os.path.join(self.save_path, "gen", 'train_best'))
                        print('gen store')
                self.add_epoch()
            restore_from = tf.train.latest_checkpoint(os.path.join(self.save_path, "gen"))
            saver.restore(self.sess, restore_from)
            print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}")


            print('start train discriminator:')
            saver_variables = slim.get_variables_to_restore(include=["discriminator_base"])
            saver = tf.train.Saver(saver_variables, max_to_keep=20)

            self.generate_samples(self.temperature)
            self.get_real_test_file()
            with open(self.generator_file_pkl, 'rb')  as inf:
                self.generator_code = pickle.load(inf)
            self.dis_train_data_loader.load_train_data_list(self.train_code, self.generator_code)
            self.dis_valid_data_loader.load_train_data_list_file(self.valid_code, self.generator_valid_file)
            acc_valid_best = 0
            for epoch in range(self.train_dis_num):
                print("base epoch:" + str(epoch))
                self.train_discriminator(self.discriminator)

                accuracy_valid, loss_valid = self.get_distance(self.generator_valid_file, self.discriminator, \
                "base_valid", self.valid_data_loader, epoch, self.writer)
                if accuracy_valid > acc_valid_best:
                        acc_valid_best = accuracy_valid
                        saver.save(self.sess, os.path.join(self.save_path, f'train_base'))
            print("acc_valid_best base:", acc_valid_best)

            restore_from = os.path.join(self.save_path, "train_base")
            saver.restore(self.sess, restore_from)
            print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}")
            accuracy_test, loss_test = self.get_distance(self.generator_test_file, self.discriminator, \
                                                                  "base_test", self.test_data_loader, epoch,
                                                                  self.writer)
            print("acc_test base:", accuracy_test)

            #####
            # Filter generation
            #####
            for c_give in [0.2, 0.5, 0.8]:
                print("*" * 50)
                print("c", c_give)
                # c_give is the acceptance ratio
                uc_get = self.uc_sampleing(self.generator, self.discriminator, "d1", self.temperature, c=c_give)
                print("uc=", uc_get)

                self.total_num = 0
                while self.keep_num < self.num_generate_train + 20000 + self.batch_size:
                    inp = self.generator.generate(self.sess, self.temperature)
                    self.generate_sample_by_c_uc(inp, self.discriminator, 'd1', c_give, uc_get)
                    self.total_num += 1
                self.num_to_0()
                self.real_text_samples(f'd1_fake_keep_c{int(c_give * 100)}')
                self.real_text_samples(f'd1_fake_filter_c{int(c_give * 100)}')
                print("==" * 50)

            #####
            # Test the accuracy of filtered samples
            #####
            print(f'start train discriminator_d1 ')
            filname = 20
            saver_variables = slim.get_variables_to_restore(include=["discriminator_d1"])
            saver = tf.train.Saver(saver_variables, max_to_keep=20)

            self.d1_fake_codes = self.get_fake_code(
                os.path.join(self.output_path, f"d1_fake_keep_c{filname}_train.pkl"))
            self.dis_train_data_loader.load_train_data_list(self.train_code, self.d1_fake_codes)
            self.dis_valid_data_loader.load_train_data_list_file(self.valid_code, os.path.join(self.output_path,
                                                                                               f"d1_fake_keep_c{filname}_valid.txt"))
            acc_valid_best = 0
            for epoch in range(self.train_dis_num):
                print("d1 epoch:" + str(epoch))
                self.train_discriminator(self.discriminator_d1)

                accuracy_valid, loss_valid = self.get_distance(
                    os.path.join(self.output_path, f"d1_fake_keep_c{filname}_valid.txt"), self.discriminator_d1, \
                    "d1_valid", self.valid_data_loader, epoch, self.writer)
                if accuracy_valid > acc_valid_best:
                    acc_valid_best = accuracy_valid
                    saver.save(self.sess, os.path.join(self.save_path, f'train_d1'))
            print("acc_valid_best d1:", acc_valid_best)
            restore_from = os.path.join(self.save_path, "train_d1")
            saver.restore(self.sess, restore_from)
            print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}")
            accuracy_test, loss_test = self.get_distance(
                os.path.join(self.output_path, f"d1_fake_keep_c{filname}_test.txt"), self.discriminator_d1, \
                "d1_test", self.test_data_loader, epoch, self.writer)
            print(f"c={filname / 100}", "acc_test d1:", accuracy_test)
 def lm_scores(self, generator, fake_data_loader, sess):
     lm = Nll(fake_data_loader, generator, sess)
     lm_score = lm.get_score()
     print("lm_score:", lm_score)
     return lm_score