Ejemplo n.º 1
0
def get_metrics(config, oracle_loader, gen_loader, oracle_file, gen_file,
                oracle_model, g_pretrain_loss, x_real, sess):
    # set up evaluation metric
    metrics = []
    if config['nll_oracle']:
        nll_oracle = Nll(gen_loader,
                         oracle_model.pretrain_loss,
                         oracle_model.x,
                         sess,
                         name='nll_oracle')
        metrics.append(nll_oracle)
    if config['nll_gen']:
        nll_gen = Nll(oracle_loader,
                      g_pretrain_loss,
                      x_real,
                      sess,
                      name='nll_gen')
        metrics.append(nll_gen)
    if config['doc_embsim']:
        doc_embsim = DocEmbSim(oracle_file,
                               gen_file,
                               config['vocab_size'],
                               name='doc_embsim')
        metrics.append(doc_embsim)

    return metrics
Ejemplo n.º 2
0
    def init_metric(self):
        nll = Nll(data_loader=self.oracle_data_loader, rnn=self.oracle, sess=self.sess)
        self.add_metric(nll)

        inll = Nll(data_loader=self.gen_data_loader, rnn=self.generator, sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)
Ejemplo n.º 3
0
    def init_metric(self):
        nll = Nll(data_loader=self.oracle_data_loader, rnn=self.oracle, sess=self.sess)
        self.add_metric(nll)

        inll = Nll(data_loader=self.gen_data_loader, rnn=self.generator, sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)
Ejemplo n.º 4
0
    def init_real_metric(self):
        from utils.metrics.DocEmbSim import DocEmbSim
        docsim = DocEmbSim(oracle_file=self.oracle_file, generator_file=self.generator_file, num_vocabulary=self.vocab_size)
        self.add_metric(docsim)

        inll = Nll(data_loader=self.gen_data_loader, rnn=self.generator, sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)
Ejemplo n.º 5
0
    def init_metric(self):

        # nll-oracle: 用oracle去评判generator产生的数据
        nll = Nll(data_loader=self.oracle_data_loader,
                  rnn=self.oracle,
                  sess=self.sess)
        self.add_metric(nll)

        # nll-test: 用generator去评判真实数据
        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)
Ejemplo n.º 6
0
def get_metrics(config, oracle_loader, test_file, gen_file, g_pretrain_loss,
                x_real, sess):
    # set up evaluation metric
    metrics = []
    if config['nll_gen']:
        nll_gen = Nll(oracle_loader,
                      g_pretrain_loss,
                      x_real,
                      sess,
                      name='nll_gen')
        metrics.append(nll_gen)
    if config['doc_embsim']:
        doc_embsim = DocEmbSim(test_file,
                               gen_file,
                               config['vocab_size'],
                               name='doc_embsim')
        metrics.append(doc_embsim)
    if config['bleu']:
        for i in range(2, 6):
            bleu = Bleu(test_text=gen_file,
                        real_text=test_file,
                        gram=i,
                        name='bleu' + str(i))
            metrics.append(bleu)
    if config['selfbleu']:
        for i in range(2, 6):
            selfbleu = SelfBleu(test_text=gen_file,
                                gram=i,
                                name='selfbleu' + str(i))
            metrics.append(selfbleu)

    return metrics
Ejemplo n.º 7
0
def get_metrics(config, oracle_loader, test_file, gen_file, g_pretrain_loss,
                x_real, sess, json_file):
    # set up evaluation metric
    metrics = []
    if config['nll_gen']:
        nll_gen = Nll(oracle_loader,
                      g_pretrain_loss,
                      x_real,
                      sess,
                      name='nll_gen')
        metrics.append(nll_gen)
    if config['bleu']:
        for i in range(2, 6):
            bleu = Bleu(test_text=json_file,
                        real_text=test_file,
                        gram=i,
                        name='bleu' + str(i))
            metrics.append(bleu)
    if config['selfbleu']:
        for i in range(2, 6):
            selfbleu = SelfBleu(test_text=json_file,
                                gram=i,
                                name='selfbleu' + str(i))
            metrics.append(selfbleu)
    if config['KL']:
        KL_div = KL_divergence(oracle_loader, json_file, name='KL_divergence')
        metrics.append(KL_div)

    return metrics
Ejemplo n.º 8
0
    def init_real_metric(self):
        from utils.metrics.DocEmbSim import DocEmbSim
        docsim = DocEmbSim(oracle_file=self.oracle_file,
                           generator_file=self.generator_file,
                           num_vocabulary=self.vocab_size)
        self.add_metric(docsim)

        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

        bleu = Bleu(test_text=self.test_file,
                    real_text='data/image_coco.txt',
                    gram=2)
        self.add_metric(bleu)

        sbleu = SelfBleu(test_text=self.test_file, gram=2)
        self.add_metric(sbleu)
Ejemplo n.º 9
0
    def init_metric(self):
        nll = Nll(data_loader=self.oracle_data_loader, rnn=self.oracle, sess=self.sess)
        self.add_metric(nll)

        inll = Nll(data_loader=self.gen_data_loader, rnn=self.generator, sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

        from utils.metrics.DocEmbSim import DocEmbSim
        docsim = DocEmbSim(oracle_file=self.oracle_file, generator_file=self.generator_file, num_vocabulary=self.vocab_size)
        self.add_metric(docsim)
        
        print("Metrics Applied: " + nll.get_name() + ", " + inll.get_name() + ", " + docsim.get_name())
Ejemplo n.º 10
0
    def init_real_metric(self):

        from utils.metrics.Nll import Nll
        from utils.metrics.PPL import PPL
        from utils.metrics.DocEmbSim import DocEmbSim
        from utils.others.Bleu import Bleu
        from utils.metrics.SelfBleu import SelfBleu

        if self.valid_ppl:
            valid_ppl = PPL(self.valid_data_loader, self.generator, self.sess)
            valid_ppl.set_name('valid_ppl')
            self.add_metric(valid_ppl)
        if self.nll_gen:
            nll_gen = Nll(self.gen_data_loader, self.generator, self.sess)
            nll_gen.set_name('nll_gen')
            self.add_metric(nll_gen)
        if self.doc_embsim:
            doc_embsim = DocEmbSim(self.oracle_file, self.generator_file,
                                   self.vocab_size)
            doc_embsim.set_name('doc_embsim')
            self.add_metric(doc_embsim)
        if self.bleu:
            FLAGS = tf.app.flags.FLAGS
            dataset = FLAGS.data
            if dataset == "image_coco":
                real_text = 'data/testdata/test_image_coco.txt'
            elif dataset == "emnlp_news":
                real_text = 'data/testdata/test_emnlp_news.txt'
            else:
                raise ValueError
            for i in range(3, 4):
                bleu = Bleu(test_text=self.text_file,
                            real_text=real_text,
                            gram=i)
                bleu.set_name(f"Bleu{i}")
                self.add_metric(bleu)
        if self.selfbleu:
            for i in range(2, 6):
                selfbleu = SelfBleu(test_text=self.text_file, gram=i)
                selfbleu.set_name(f"Selfbleu{i}")
                self.add_metric(selfbleu)
Ejemplo n.º 11
0
    def init_metric(self):
        # docsim = DocEmbSim(oracle_file=self.truth_file, generator_file=self.generator_file,
        #                    num_vocabulary=self.vocab_size)
        # self.add_metric(docsim)

        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

        bleu1 = Bleu(test_text=self.test_file,
                     real_text=self.trunc_train_file,
                     gram=1)
        bleu1.set_name('BLEU-1')
        self.add_metric(bleu1)

        bleu2 = Bleu(test_text=self.test_file,
                     real_text=self.trunc_train_file,
                     gram=2)
        bleu2.set_name('BLEU-2')
        self.add_metric(bleu2)
Ejemplo n.º 12
0
    def rlm_scores(self, re_generator, fake_data_loader, sess, test_data_loader, valid_data_loader, writer):
        test_rlm = Nll(test_data_loader, re_generator, sess)
        valid_rlm = Nll(valid_data_loader, re_generator, sess)
        print('start train re-generator:')
        valid_rlm_best = 1000
        test_rlm_best = 0
        self.re_gen_num = 80
        for epoch in range(self.re_gen_num):
            start = time()
            loss = re_generator.run_epoch(sess, fake_data_loader, writer, epoch)
            end = time()
            print('epoch:' + str(epoch) + ' loss: ' + str(loss) + ' \t time:' + str(end - start))

            test_rlm_score = test_rlm.get_score()
            valid_rlm_score = valid_rlm.get_score()
            print('valid_rlm_score:' + str(valid_rlm_score) + '   test_rlm_score: ' + str(test_rlm_score))
            if (epoch + 1) % self.ntest == 0:
                if valid_rlm_score < valid_rlm_best:
                    valid_rlm_best = valid_rlm_score
                    test_rlm_best = test_rlm_score
        print('valid_rlm_best:' + str(valid_rlm_best) + 'test_rlm_best: ' + str(test_rlm_best))

        return valid_rlm_best, test_rlm_best
Ejemplo n.º 13
0
    def train(self):
        self.init_real_trainng()
        ###
        #init_metric:NLL
        ###
        gen_valid_nll = Nll(self.valid_data_loader, self.generator, self.sess)

        FLAGS = tf.app.flags.FLAGS
        if FLAGS.data == 'image_coco':
            self.valid_code = self.test_code
        self.valid_data_loader.create_batches_train_list(self.valid_code)
        self.train_data_loader.create_batches_train_list(self.train_code)
        self.test_data_loader.create_batches_train_list(self.test_code)
        self.sess.run(tf.global_variables_initializer())

        # ++ Saver
        # saver_variables = tf.global_variables
        saver_variables = slim.get_variables_to_restore(include=["generator_gpt2"])
        saver = tf.train.Saver(saver_variables, max_to_keep=20)
        # ++ ====================
        
        # summary writer
        self.writer = self.save_summary()

        if self.restore:
            restore_from = tf.train.latest_checkpoint(os.path.join(self.save_path, "gen"))
            saver.restore(self.sess, restore_from)
            print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}")
        else:
            best_nll = 1000
            print('start train generator:')
            for epoch in range(self.train_gen_num):
                start = time()
                loss = self.generator.run_epoch(self.sess, self.train_data_loader, self.writer, epoch)
                end = time()
                print('epoch:' + str(epoch) + ' loss: ' + str(loss) + ' \t time:' + str(end - start))
                if (epoch + 1) % self.ntest == 0:
                    values = gen_valid_nll.get_score()
                    if values < best_nll:
                        best_nll = values
                        # save pre_train
                        saver.save(self.sess, os.path.join(self.save_path, "gen", 'train_best'))
                        print('gen store')
                self.add_epoch()
            restore_from = tf.train.latest_checkpoint(os.path.join(self.save_path, "gen"))
            saver.restore(self.sess, restore_from)
            print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}")


            print('start train discriminator:')
            saver_variables = slim.get_variables_to_restore(include=["discriminator_base"])
            saver = tf.train.Saver(saver_variables, max_to_keep=20)

            self.generate_samples(self.temperature)
            self.get_real_test_file()
            with open(self.generator_file_pkl, 'rb')  as inf:
                self.generator_code = pickle.load(inf)
            self.dis_train_data_loader.load_train_data_list(self.train_code, self.generator_code)
            self.dis_valid_data_loader.load_train_data_list_file(self.valid_code, self.generator_valid_file)
            acc_valid_best = 0
            for epoch in range(self.train_dis_num):
                print("base epoch:" + str(epoch))
                self.train_discriminator(self.discriminator)

                accuracy_valid, loss_valid = self.get_distance(self.generator_valid_file, self.discriminator, \
                "base_valid", self.valid_data_loader, epoch, self.writer)
                if accuracy_valid > acc_valid_best:
                        acc_valid_best = accuracy_valid
                        saver.save(self.sess, os.path.join(self.save_path, f'train_base'))
            print("acc_valid_best base:", acc_valid_best)

            restore_from = os.path.join(self.save_path, "train_base")
            saver.restore(self.sess, restore_from)
            print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}")
            accuracy_test, loss_test = self.get_distance(self.generator_test_file, self.discriminator, \
                                                                  "base_test", self.test_data_loader, epoch,
                                                                  self.writer)
            print("acc_test base:", accuracy_test)

            #####
            # Filter generation
            #####
            for c_give in [0.2, 0.5, 0.8]:
                print("*" * 50)
                print("c", c_give)
                # c_give is the acceptance ratio
                uc_get = self.uc_sampleing(self.generator, self.discriminator, "d1", self.temperature, c=c_give)
                print("uc=", uc_get)

                self.total_num = 0
                while self.keep_num < self.num_generate_train + 20000 + self.batch_size:
                    inp = self.generator.generate(self.sess, self.temperature)
                    self.generate_sample_by_c_uc(inp, self.discriminator, 'd1', c_give, uc_get)
                    self.total_num += 1
                self.num_to_0()
                self.real_text_samples(f'd1_fake_keep_c{int(c_give * 100)}')
                self.real_text_samples(f'd1_fake_filter_c{int(c_give * 100)}')
                print("==" * 50)

            #####
            # Test the accuracy of filtered samples
            #####
            print(f'start train discriminator_d1 ')
            filname = 20
            saver_variables = slim.get_variables_to_restore(include=["discriminator_d1"])
            saver = tf.train.Saver(saver_variables, max_to_keep=20)

            self.d1_fake_codes = self.get_fake_code(
                os.path.join(self.output_path, f"d1_fake_keep_c{filname}_train.pkl"))
            self.dis_train_data_loader.load_train_data_list(self.train_code, self.d1_fake_codes)
            self.dis_valid_data_loader.load_train_data_list_file(self.valid_code, os.path.join(self.output_path,
                                                                                               f"d1_fake_keep_c{filname}_valid.txt"))
            acc_valid_best = 0
            for epoch in range(self.train_dis_num):
                print("d1 epoch:" + str(epoch))
                self.train_discriminator(self.discriminator_d1)

                accuracy_valid, loss_valid = self.get_distance(
                    os.path.join(self.output_path, f"d1_fake_keep_c{filname}_valid.txt"), self.discriminator_d1, \
                    "d1_valid", self.valid_data_loader, epoch, self.writer)
                if accuracy_valid > acc_valid_best:
                    acc_valid_best = accuracy_valid
                    saver.save(self.sess, os.path.join(self.save_path, f'train_d1'))
            print("acc_valid_best d1:", acc_valid_best)
            restore_from = os.path.join(self.save_path, "train_d1")
            saver.restore(self.sess, restore_from)
            print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}")
            accuracy_test, loss_test = self.get_distance(
                os.path.join(self.output_path, f"d1_fake_keep_c{filname}_test.txt"), self.discriminator_d1, \
                "d1_test", self.test_data_loader, epoch, self.writer)
            print(f"c={filname / 100}", "acc_test d1:", accuracy_test)
Ejemplo n.º 14
0
    def init_oracle_metric(self):
        from utils.metrics.Nll import Nll

        nll = Nll(data_loader=self.oracle_data_loader, rnn=self.oracle, sess=self.sess)
        self.add_metric(nll)
Ejemplo n.º 15
0
 def init_metric(self):
     nll = Nll(data_loader=self.oracle_data_loader, rnn=self.oracle, sess=self.sess)
     self.add_metric(nll)
Ejemplo n.º 16
0
 def lm_scores(self, generator, fake_data_loader, sess):
     lm = Nll(fake_data_loader, generator, sess)
     lm_score = lm.get_score()
     print("lm_score:", lm_score)
     return lm_score