コード例 #1
0
def _get_rewards_01(config, data_loader, x_fake_for_rewards, eof_code, sess,
                    first, all_bleu_metrics):
    batch_size = config['batch_size']
    gan_type = config['gan_type']
    seq_len = config['seq_len']
    vocab_size = config['vocab_size']
    rl_bleu_ref_count = data_loader.num_batch * batch_size  # all of training set # 3000
    # rl_n_grams = 4
    rl_mc_samples = 1
    gamma_discount = 0.9

    rewards = np.zeros((batch_size, seq_len), np.float32)

    if first == True:
        bleu_metric_2 = list()
        bleu_metric_3 = list()
        bleu_metric_4 = list()
        bleu_metric_5 = list()

        # train_refs = data_loader.random_some(rl_bleu_ref_count, seq_len+1)
        train_refs = data_loader.get_as_lol_no_padding()
        # np_train_refs = np.array(train_refs)

        for t in range(2, seq_len + 1):
            # train_refs = data_loader.random_some(rl_bleu_ref_count, t)
            # bleu_metric_2.append(Bleu.from_references_indices(2, train_refs))
            bleu_metric_2.append(
                Bleu.from_references_indices(2, [l[:t] for l in train_refs]))

        for t in range(3, seq_len + 1):
            # train_refs = data_loader.random_some(rl_bleu_ref_count, t)
            bleu_metric_3.append(
                Bleu.from_references_indices(3, [l[:t] for l in train_refs]))

        for t in range(4, seq_len + 1):
            # train_refs = data_loader.random_some(rl_bleu_ref_count, t)
            bleu_metric_4.append(
                Bleu.from_references_indices(4, [l[:t] for l in train_refs]))

        for t in range(5, seq_len + 1):
            # train_refs = data_loader.random_some(rl_bleu_ref_count, t)
            bleu_metric_5.append(
                Bleu.from_references_indices(5, [l[:t] for l in train_refs]))

        # put the 5
        all_bleu_metrics = [
            bleu_metric_2, bleu_metric_3, bleu_metric_4, bleu_metric_5
        ]

        first = False

    for _ in range(rl_mc_samples):
        # samples_for_rewards, _ = self.generator.generate_from_noise(self.sess, batch_size, self.current_tau, Config.args.BATCH_SIZE)
        samples_for_rewards = sess.run(x_fake_for_rewards)
        gen_seq_list = samples_no_padding(samples_for_rewards, eof_code)
        for b in range(len(gen_seq_list)):
            rewards[b, :] = rewards[b, :] + _compute_rl_rewards_01(
                gen_seq_list[b], all_bleu_metrics, gamma_discount, seq_len)
    rewards = rewards / (1.0 * rl_mc_samples)
    return samples_for_rewards, rewards, first, all_bleu_metrics
コード例 #2
0
  def getPosBleuScore (self): 
    generated_file = open(self.generated_file_path, 'r')
    generated_pos = open ("/content/generated_pos.txt", 'w')
    reference_pos = open('/content/reference_pos.txt', 'w')
    reference_file = open(self.reference_file_path, 'r')

    for line in generated_file:
      generated_tokenize = nltk.word_tokenize(line)
      generated = nltk.pos_tag(generated_tokenize)      
      for elem in generated : 
          generated_pos.write(elem[1] + " ")
      generated_pos.write("\n")
    generated_file.close()

    
    for line in reference_file: 
      reference_tokenize = nltk.word_tokenize(line)
      reference = nltk.pos_tag(reference_tokenize)
      for elem in reference : 
          reference_pos.write(elem[1] + " ")
      reference_pos.write("\n")
    reference_pos.close()  

    PosBleu = Bleu (self.generated_file_path, self.reference_file_path,self.gram)
    #print("PosBleu score with gram = %d is: %f" %(self.gram, PosBleu.get_score(is_fast=False)))
    return PosBleu.get_score(is_fast=False)
コード例 #3
0
ファイル: real_train.py プロジェクト: lethaiq/RelGAN
def get_metrics(config, oracle_loader, test_file, gen_file, g_pretrain_loss,
                x_real, sess):
    # set up evaluation metric
    metrics = []
    if config['nll_gen']:
        nll_gen = Nll(oracle_loader,
                      g_pretrain_loss,
                      x_real,
                      sess,
                      name='nll_gen')
        metrics.append(nll_gen)
    if config['doc_embsim']:
        doc_embsim = DocEmbSim(test_file,
                               gen_file,
                               config['vocab_size'],
                               name='doc_embsim')
        metrics.append(doc_embsim)
    if config['bleu']:
        for i in range(2, 6):
            bleu = Bleu(test_text=gen_file,
                        real_text=test_file,
                        gram=i,
                        name='bleu' + str(i))
            metrics.append(bleu)
    if config['selfbleu']:
        for i in range(2, 6):
            selfbleu = SelfBleu(test_text=gen_file,
                                gram=i,
                                name='selfbleu' + str(i))
            metrics.append(selfbleu)

    return metrics
コード例 #4
0
def get_metrics(config, oracle_loader, test_file, gen_file, g_pretrain_loss, x_real, x_topic, sess, json_file):
    # set up evaluation metric
    metrics = []
    if config['nll_gen']:
        nll_gen = NllTopic(oracle_loader, g_pretrain_loss, x_real, sess, name='nll_gen', x_topic=x_topic)
        metrics.append(nll_gen)
    if config['doc_embsim']:
        doc_embsim = DocEmbSim(test_file, gen_file, config['vocab_size'], name='doc_embsim')
        metrics.append(doc_embsim)
    if config['bleu']:
        for i in [2, 4]:  # range(2, 6):
            bleu = Bleu(test_text=json_file, real_text=test_file, gram=i, name='bleu' + str(i))
            metrics.append(bleu)
    if config['selfbleu']:
        for i in [4]:
            selfbleu = SelfBleu(test_text=json_file, gram=i, name='selfbleu' + str(i))
            metrics.append(selfbleu)
    if config['KL']:
        KL_div = KL_divergence(oracle_loader, json_file, name='KL_divergence')
        metrics.append(KL_div)
    if config['earth_mover']:
        EM_div = EarthMover(oracle_loader, json_file, name='Earth_Mover_Distance')
        metrics.append(EM_div)

    return metrics
コード例 #5
0
def get_metrics(config, oracle_loader, test_file, gen_file, g_pretrain_loss,
                x_real, sess, json_file):
    # set up evaluation metric
    metrics = []
    if config['nll_gen']:
        nll_gen = Nll(oracle_loader,
                      g_pretrain_loss,
                      x_real,
                      sess,
                      name='nll_gen')
        metrics.append(nll_gen)
    if config['bleu']:
        for i in range(2, 6):
            bleu = Bleu(test_text=json_file,
                        real_text=test_file,
                        gram=i,
                        name='bleu' + str(i))
            metrics.append(bleu)
    if config['selfbleu']:
        for i in range(2, 6):
            selfbleu = SelfBleu(test_text=json_file,
                                gram=i,
                                name='selfbleu' + str(i))
            metrics.append(selfbleu)
    if config['KL']:
        KL_div = KL_divergence(oracle_loader, json_file, name='KL_divergence')
        metrics.append(KL_div)

    return metrics
コード例 #6
0
ファイル: Seqgan.py プロジェクト: skarthika18/Texygen
    def init_real_metric(self):
        from utils.metrics.DocEmbSim import DocEmbSim
        docsim = DocEmbSim(oracle_file=self.oracle_file,
                           generator_file=self.generator_file,
                           num_vocabulary=self.vocab_size)
        self.add_metric(docsim)

        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

        bleu = Bleu(test_text=self.test_file,
                    real_text='data/image_coco.txt',
                    gram=2)
        self.add_metric(bleu)

        sbleu = SelfBleu(test_text=self.test_file, gram=2)
        self.add_metric(sbleu)
コード例 #7
0
ファイル: LeakGan.py プロジェクト: Liugawa/GAN_Poem_Generate
    def init_metric(self):
        # docsim = DocEmbSim(oracle_file=self.truth_file, generator_file=self.generator_file,
        #                    num_vocabulary=self.vocab_size)
        # self.add_metric(docsim)

        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

        bleu1 = Bleu(test_text=self.test_file,
                     real_text=self.trunc_train_file,
                     gram=1)
        bleu1.set_name('BLEU-1')
        self.add_metric(bleu1)

        bleu2 = Bleu(test_text=self.test_file,
                     real_text=self.trunc_train_file,
                     gram=2)
        bleu2.set_name('BLEU-2')
        self.add_metric(bleu2)
コード例 #8
0
 def bleu_vs_sbleu(self, real_text, fake_text):
     bleu = Bleu(test_text=fake_text, real_text=real_text, gram=5)
     sbleu = SelfBleu(test_text=fake_text, gram=5)
     print("Bleu:", bleu.get_score(), "SelfBleu", sbleu.get_score())
コード例 #9
0
#             f1_np = np.asarray(l)
#             f1_filled = np.lib.pad(f1_np, (0, max_length-f1_np.shape[0]),
#                                     'constant', constant_values=0)
#             line = [str(x) for x in f1_filled]
#             line = ' '.join(line) + '\n'
#             fout.write(line)

# pdb.set_trace()

#real_text = './arxiv_result/arxiv_test.txt'  # arxiv_1w

#test_text = './arxiv_result/ot.txt'  # syn_val_words.txt


real_text = '/home/ml/lpagec/tensorflow/FM-GAN/src/text_news/sents/Real_out.txt'

for std in [0.01, 0.1, 0.5, 1, 1.1]:
    test_text = '../src/OUT_SAMPLES_STD/text_' + str(std) + '.txt' #../src/text_news/sents/Lucas_out.txt'

    for i in [5]:#range(2, 6):
        get_Bleu = Bleu(test_text=test_text, real_text=real_text, gram=i)
        score = get_Bleu.get_bleu_parallel()
        print(std, score)
    
#test_text = './arxiv_result/vae.txt'  # syn_val_words.txt

#for i in range(2, 6):
##    get_Bleu = Bleu(test_text=test_text, real_text=real_text, gram=i)
#    score = get_Bleu.get_bleu_parallel()
#    print(score)
コード例 #10
0
def _get_rewards_02(config, data_loader, x_fake_for_rewards, given_num, r_x,
                    r_gen_x, r_gen_x_sample, eof_code, sess, first,
                    all_bleu_metrics):
    # print("Start computing rewards ...")
    batch_size = config['batch_size']
    gan_type = config['gan_type']
    seq_len = config['seq_len']
    vocab_size = config['vocab_size']
    rl_bleu_ref_count = data_loader.num_batch * batch_size  # all of training set
    # rl_n_grams = 4
    rl_mc_samples = config['mc_samples']
    gamma_discount = 0.5

    # rewards = np.zeros((batch_size, seq_len), np.float32)

    if first == True:
        train_refs = data_loader.get_as_lol_no_padding()

        # train_refs = data_loader.random_some(rl_bleu_ref_count, seq_len + 1)
        bleu_metric_2 = Bleu.from_references_indices(2, train_refs)

        # train_refs = data_loader.random_some(rl_bleu_ref_count, seq_len + 1)
        bleu_metric_3 = Bleu.from_references_indices(3, train_refs)

        # train_refs = data_loader.random_some(rl_bleu_ref_count, seq_len + 1)
        bleu_metric_4 = Bleu.from_references_indices(4, train_refs)

        # train_refs = data_loader.random_some(rl_bleu_ref_count, seq_len + 1)
        bleu_metric_5 = Bleu.from_references_indices(5, train_refs)

        all_bleu_metrics = [
            bleu_metric_2, bleu_metric_3, bleu_metric_4, bleu_metric_5
        ]

        first = False

    rewards = list()
    samples_for_rewards = sess.run(x_fake_for_rewards)

    for i in range(rl_mc_samples):
        for given_num_i in range(1, seq_len):
            feed = {r_x: samples_for_rewards, given_num: given_num_i}
            roll_out_samples = sess.run(r_gen_x, feed)
            # feed = {discriminator.input_x: samples}
            # ypred_for_auc = sess.run(discriminator.ypred_for_auc, feed)
            ypred = _compute_rl_rewards_02(roll_out_samples, all_bleu_metrics,
                                           gamma_discount, eof_code)
            # ypred = np.array([item[1] for item in ypred_for_auc])
            if i == 0:
                rewards.append(ypred)
            else:
                rewards[given_num_i - 1] += ypred

        # the last token reward
        # feed = {discriminator.input_x: input_x}
        # ypred_for_auc = sess.run(discriminator.ypred_for_auc, feed)
        # ypred = np.array([item[1] for item in ypred_for_auc])
        ypred = _compute_rl_rewards_02(samples_for_rewards, all_bleu_metrics,
                                       gamma_discount, eof_code)
        if i == 0:
            rewards.append(ypred)
        else:
            rewards[(len(samples_for_rewards[0]) - 1)] += ypred

    # for _ in range(rl_mc_samples):
    #     # samples_for_rewards, _ = self.generator.generate_from_noise(self.sess, batch_size, self.current_tau, Config.args.BATCH_SIZE)
    #     samples_for_rewards = sess.run(x_fake_for_rewards)
    #     for b in range(len(samples_for_rewards)):
    #         rewards[b, :] = rewards[b, :] + _compute_rl_rewards(samples_for_rewards[b], all_bleu_metrics, gamma_discount)
    # rewards = rewards / (1.0 * rl_mc_samples)
    # return samples_for_rewards, rewards

    reward_res = np.transpose(np.array(rewards)) / (
        1.0 * rl_mc_samples)  # batch_size x seq_length
    if config['pg_baseline']:
        reward_res -= config['pg_baseline_val']  # 2.0 for emnlp
    # print("Rewards computed.")
    return samples_for_rewards, reward_res, first, all_bleu_metrics
コード例 #11
0
import os
import sys
import numpy as np

sys.path.append(r'H:\qnyh\Texygen\\')
from utils.metrics.Bleu import Bleu
from utils.metrics.EmbSim import EmbSim

bleu = [0 for x in range(4)]
for i in range(4):
    bleu[i] = Bleu(test_text=r'H:\qnyh\Texygen\save\test_file.txt',
                   real_text=r'H:\qnyh\Texygen\save\qnyh_ref_comb.txt',
                   gram=i + 2)
    result = bleu[i].get_bleu()
    print(result)
コード例 #12
0
ファイル: LeakGan.py プロジェクト: Liugawa/GAN_Poem_Generate
    def __init__(self, wi_dict_path, iw_dict_path, train_data, val_data=None):
        super().__init__()

        self.vocab_size = 20
        self.emb_dim = 64
        self.hidden_dim = 64

        self.input_length = 8
        self.sequence_length = 32
        self.filter_size = [2, 3]
        self.num_filters = [100, 200]
        self.l2_reg_lambda = 0.2
        self.dropout_keep_prob = 0.75
        self.batch_size = 64
        self.generate_num = 256
        self.start_token = 0
        self.dis_embedding_dim = 64
        self.goal_size = 16

        self.save_path = 'save/model/LeakGan/LeakGan'
        self.model_path = 'save/model/LeakGan'
        self.best_path_pre = 'save/model/best-pre-gen/best-pre-gen'
        self.best_path = 'save/model/best-leak-gan/best-leak-gan'
        self.best_model_path = 'save/model/best-leak-gan'

        self.truth_file = 'save/truth.txt'
        self.generator_file = 'save/generator.txt'
        self.test_file = 'save/test_file.txt'

        self.trunc_train_file = 'save/trunc_train.txt'
        self.trunc_val_file = 'save/trunc_val.txt'
        trunc_data(train_data, self.trunc_train_file, self.input_length)
        trunc_data(val_data, self.trunc_val_file, self.input_length)

        if not os.path.isfile(wi_dict_path) or not os.path.isfile(
                iw_dict_path):
            print('Building word/index dictionaries...')
            self.sequence_length, self.vocab_size, word_index_dict, index_word_dict = text_precess(
                train_data, val_data)
            print('Vocab Size: %d' % self.vocab_size)
            print('Saving dictionaries to ' + wi_dict_path + ' ' +
                  iw_dict_path + '...')
            with open(wi_dict_path, 'wb') as f:
                pickle.dump(word_index_dict, f)
            with open(iw_dict_path, 'wb') as f:
                pickle.dump(index_word_dict, f)
        else:
            print('Loading word/index dectionaries...')
            with open(wi_dict_path, 'rb') as f:
                word_index_dict = pickle.load(f)
            with open(iw_dict_path, 'rb') as f:
                index_word_dict = pickle.load(f)
            self.vocab_size = len(word_index_dict) + 1
            print('Vocab Size: %d' % self.vocab_size)

        self.wi_dict = word_index_dict
        self.iw_dict = index_word_dict
        self.train_data = train_data
        self.val_data = val_data

        goal_out_size = sum(self.num_filters)
        self.discriminator = Discriminator(
            sequence_length=self.sequence_length,
            num_classes=2,
            vocab_size=self.vocab_size,
            dis_emb_dim=self.dis_embedding_dim,
            filter_sizes=self.filter_size,
            num_filters=self.num_filters,
            batch_size=self.batch_size,
            hidden_dim=self.hidden_dim,
            start_token=self.start_token,
            goal_out_size=goal_out_size,
            step_size=4,
            l2_reg_lambda=self.l2_reg_lambda)

        self.generator = Generator(num_classes=2,
                                   num_vocabulary=self.vocab_size,
                                   batch_size=self.batch_size,
                                   emb_dim=self.emb_dim,
                                   dis_emb_dim=self.dis_embedding_dim,
                                   goal_size=self.goal_size,
                                   hidden_dim=self.hidden_dim,
                                   sequence_length=self.sequence_length,
                                   input_length=self.input_length,
                                   filter_sizes=self.filter_size,
                                   start_token=self.start_token,
                                   num_filters=self.num_filters,
                                   goal_out_size=goal_out_size,
                                   D_model=self.discriminator,
                                   step_size=4)

        self.saver = tf.train.Saver()
        self.best_pre_saver = tf.train.Saver()
        self.best_saver = tf.train.Saver()

        self.val_bleu1 = Bleu(real_text=self.trunc_val_file, gram=1)
        self.val_bleu2 = Bleu(real_text=self.trunc_val_file, gram=2)
コード例 #13
0
ファイル: LeakGan.py プロジェクト: Liugawa/GAN_Poem_Generate
class LeakGan(Gan):
    def __init__(self, wi_dict_path, iw_dict_path, train_data, val_data=None):
        super().__init__()

        self.vocab_size = 20
        self.emb_dim = 64
        self.hidden_dim = 64

        self.input_length = 8
        self.sequence_length = 32
        self.filter_size = [2, 3]
        self.num_filters = [100, 200]
        self.l2_reg_lambda = 0.2
        self.dropout_keep_prob = 0.75
        self.batch_size = 64
        self.generate_num = 256
        self.start_token = 0
        self.dis_embedding_dim = 64
        self.goal_size = 16

        self.save_path = 'save/model/LeakGan/LeakGan'
        self.model_path = 'save/model/LeakGan'
        self.best_path_pre = 'save/model/best-pre-gen/best-pre-gen'
        self.best_path = 'save/model/best-leak-gan/best-leak-gan'
        self.best_model_path = 'save/model/best-leak-gan'

        self.truth_file = 'save/truth.txt'
        self.generator_file = 'save/generator.txt'
        self.test_file = 'save/test_file.txt'

        self.trunc_train_file = 'save/trunc_train.txt'
        self.trunc_val_file = 'save/trunc_val.txt'
        trunc_data(train_data, self.trunc_train_file, self.input_length)
        trunc_data(val_data, self.trunc_val_file, self.input_length)

        if not os.path.isfile(wi_dict_path) or not os.path.isfile(
                iw_dict_path):
            print('Building word/index dictionaries...')
            self.sequence_length, self.vocab_size, word_index_dict, index_word_dict = text_precess(
                train_data, val_data)
            print('Vocab Size: %d' % self.vocab_size)
            print('Saving dictionaries to ' + wi_dict_path + ' ' +
                  iw_dict_path + '...')
            with open(wi_dict_path, 'wb') as f:
                pickle.dump(word_index_dict, f)
            with open(iw_dict_path, 'wb') as f:
                pickle.dump(index_word_dict, f)
        else:
            print('Loading word/index dectionaries...')
            with open(wi_dict_path, 'rb') as f:
                word_index_dict = pickle.load(f)
            with open(iw_dict_path, 'rb') as f:
                index_word_dict = pickle.load(f)
            self.vocab_size = len(word_index_dict) + 1
            print('Vocab Size: %d' % self.vocab_size)

        self.wi_dict = word_index_dict
        self.iw_dict = index_word_dict
        self.train_data = train_data
        self.val_data = val_data

        goal_out_size = sum(self.num_filters)
        self.discriminator = Discriminator(
            sequence_length=self.sequence_length,
            num_classes=2,
            vocab_size=self.vocab_size,
            dis_emb_dim=self.dis_embedding_dim,
            filter_sizes=self.filter_size,
            num_filters=self.num_filters,
            batch_size=self.batch_size,
            hidden_dim=self.hidden_dim,
            start_token=self.start_token,
            goal_out_size=goal_out_size,
            step_size=4,
            l2_reg_lambda=self.l2_reg_lambda)

        self.generator = Generator(num_classes=2,
                                   num_vocabulary=self.vocab_size,
                                   batch_size=self.batch_size,
                                   emb_dim=self.emb_dim,
                                   dis_emb_dim=self.dis_embedding_dim,
                                   goal_size=self.goal_size,
                                   hidden_dim=self.hidden_dim,
                                   sequence_length=self.sequence_length,
                                   input_length=self.input_length,
                                   filter_sizes=self.filter_size,
                                   start_token=self.start_token,
                                   num_filters=self.num_filters,
                                   goal_out_size=goal_out_size,
                                   D_model=self.discriminator,
                                   step_size=4)

        self.saver = tf.train.Saver()
        self.best_pre_saver = tf.train.Saver()
        self.best_saver = tf.train.Saver()

        self.val_bleu1 = Bleu(real_text=self.trunc_val_file, gram=1)
        self.val_bleu2 = Bleu(real_text=self.trunc_val_file, gram=2)

    def train_discriminator(self):
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.gen_data_loader,
                             self.generator_file)
        self.dis_data_loader.load_train_data(self.truth_file,
                                             self.generator_file)
        for _ in range(3):
            self.dis_data_loader.next_batch()
            x_batch, y_batch = self.dis_data_loader.next_batch()
            feed = {
                self.discriminator.D_input_x: x_batch,
                self.discriminator.D_input_y: y_batch,
            }
            _, _ = self.sess.run(
                [self.discriminator.D_loss, self.discriminator.D_train_op],
                feed)
            self.generator.update_feature_function(self.discriminator)

    def eval(self):
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.gen_data_loader,
                             self.generator_file)
        if self.log is not None:
            if self.epoch == 0 or self.epoch == 1:
                for metric in self.metrics:
                    self.log.write(metric.get_name() + ',')
                self.log.write('\n')
            scores = super().eval()
            for score in scores:
                self.log.write(str(score) + ',')
            self.log.write('\n')
            return scores
        return super().eval()

    def init_metric(self):
        # docsim = DocEmbSim(oracle_file=self.truth_file, generator_file=self.generator_file,
        #                    num_vocabulary=self.vocab_size)
        # self.add_metric(docsim)

        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

        bleu1 = Bleu(test_text=self.test_file,
                     real_text=self.trunc_train_file,
                     gram=1)
        bleu1.set_name('BLEU-1')
        self.add_metric(bleu1)

        bleu2 = Bleu(test_text=self.test_file,
                     real_text=self.trunc_train_file,
                     gram=2)
        bleu2.set_name('BLEU-2')
        self.add_metric(bleu2)

    def train(self, restore=False, model_path=None):
        self.gen_data_loader = DataLoader(batch_size=self.batch_size,
                                          seq_length=self.sequence_length,
                                          input_length=self.input_length)
        self.dis_data_loader = DisDataloader(batch_size=self.batch_size,
                                             seq_length=self.sequence_length)

        tokens = get_tokens(self.train_data)
        with open(self.truth_file, 'w', encoding='utf-8') as outfile:
            outfile.write(
                text_to_code(tokens, self.wi_dict, self.sequence_length))

        wi_dict, iw_dict = self.wi_dict, self.iw_dict

        self.init_metric()

        def get_real_test_file(dict=iw_dict):
            codes = get_tokens(self.generator_file)
            with open(self.test_file, 'w', encoding='utf-8') as outfile:
                outfile.write(
                    code_to_text(codes=codes[self.input_length:], dict=dict))

        if restore:
            self.pre_epoch_num = 0
            if model_path is not None:
                self.model_path = model_path
            savefile = tf.train.latest_checkpoint(self.model_path)
            self.saver.restore(self.sess, savefile)
        else:
            self.sess.run(tf.global_variables_initializer())
            self.pre_epoch_num = 80

        # self.adversarial_epoch_num = 100
        self.log = open('log/experiment-log.txt', 'w', encoding='utf-8')
        self.gen_data_loader.create_batches(self.truth_file)
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.gen_data_loader,
                             self.generator_file)

        self.gen_data_loader.reset_pointer()
        for a in range(1):
            inputs, target = self.gen_data_loader.next_batch()
            g = self.sess.run(self.generator.gen_x,
                              feed_dict={
                                  self.generator.drop_out: 1,
                                  self.generator.train: 1,
                                  self.generator.inputs: inputs
                              })

        print('start pre-train generator:')
        best = 0
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch_gen(self.sess, self.generator,
                                       self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.epoch += 1
            if epoch % 5 == 0:
                generate_samples_gen(self.sess, self.generator,
                                     self.batch_size, self.generate_num,
                                     self.gen_data_loader, self.generator_file)
                get_real_test_file()
                scores = self.eval()
                self.saver.save(self.sess, self.save_path, global_step=epoch)
                if scores[3] > best:
                    print('--- Saving best-pre-gen...')
                    best = scores[3]
                    self.best_pre_saver.save(self.sess,
                                             self.best_path_pre,
                                             global_step=epoch)

        print('start pre-train discriminator:')
        # self.epoch = 0
        for epoch in range(self.pre_epoch_num):
            print('epoch:' + str(epoch))
            self.train_discriminator()
        self.saver.save(self.sess,
                        self.save_path,
                        global_step=self.pre_epoch_num * 2)

        self.epoch = 0
        best = 0
        self.reward = Reward(model=self.generator,
                             dis=self.discriminator,
                             sess=self.sess,
                             rollout_num=4)
        for epoch in range(self.adversarial_epoch_num // 10):
            for epoch_ in range(10):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                start = time()
                for index in range(1):
                    inputs, target = self.gen_data_loader.next_batch()
                    samples = self.generator.generate(self.sess,
                                                      1,
                                                      inputs=inputs)
                    rewards = self.reward.get_reward(samples, inputs)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1,
                        self.generator.inputs: inputs
                    }
                    _, _, g_loss, w_loss = self.sess.run([
                        self.generator.manager_updates,
                        self.generator.worker_updates,
                        self.generator.goal_loss,
                        self.generator.worker_loss,
                    ],
                                                         feed_dict=feed)
                    print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss',
                          w_loss)
                end = time()
                self.epoch += 1
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                if self.epoch % 5 == 0 or self.epoch == self.adversarial_epoch_num - 1:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.gen_data_loader,
                                         self.generator_file)
                    get_real_test_file()
                    scores = self.eval()

                    print('--- Generating poem on val data... ')
                    target_file = 'save/gen_val/val_%d_%f.txt' % (self.epoch,
                                                                  scores[1])
                    self.infer(test_data=self.val_data,
                               target_path=target_file,
                               model_path=self.model_path,
                               restore=False,
                               trunc=True)
                    self.val_bleu1.test_data = target_file
                    self.val_bleu2.test_data = target_file
                    bleu1, bleu2 = self.val_bleu1.get_score(
                    ), self.val_bleu2.get_score()
                    print('--- BLEU on val data: \t bleu1: %f \t bleu2: %f' %
                          (bleu1, bleu2))
                    if bleu2 > best:
                        best = bleu2
                        print('--- Saving best-leak-gan...')
                        self.best_saver.save(self.sess,
                                             self.best_path,
                                             global_step=epoch * 10 + epoch_)

                for _ in range(15):
                    self.train_discriminator()

            self.saver.save(self.sess,
                            self.save_path,
                            global_step=1 + epoch + self.pre_epoch_num * 2)
            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator,
                                           self.gen_data_loader)
                end = time()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                if epoch % 5 == 0:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.gen_data_loader,
                                         self.generator_file)
                    get_real_test_file()
                    self.eval()

            for epoch_ in range(5):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                self.train_discriminator()

    def infer(self,
              test_data,
              target_path,
              model_path,
              restore=True,
              trunc=False):
        if model_path is None:
            model_path = self.model_path

        if restore:
            savefile = tf.train.latest_checkpoint(model_path)
            self.saver.restore(self.sess, savefile)

        tokens = get_tokens(test_data)
        sentence_num = len(tokens)
        temp_file = 'save/infer_temp.txt'
        with open(temp_file, 'w', encoding='utf-8') as outfile:
            outfile.write(text_to_code(tokens, self.wi_dict,
                                       self.input_length))

        test_data_loader = TestDataloader(batch_size=self.batch_size,
                                          input_length=self.input_length)
        test_data_loader.create_batches(temp_file)

        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             test_data_loader.num_batch * self.batch_size,
                             test_data_loader, target_path)

        codes = get_tokens(target_path)[:sentence_num]
        with open(target_path, 'w', encoding='utf-8') as outfile:
            if trunc:
                outfile.write(
                    code_to_text(codes=codes[self.input_length:],
                                 dict=self.iw_dict))
            else:
                outfile.write(code_to_text(codes=codes, dict=self.iw_dict))

        print('Finished generating %d poems to %s' %
              (sentence_num, target_path))
コード例 #14
0
 def getBleuScore(self):
   print(self.reference_file_path)
   BleuScore = Bleu (self.generated_file_path, self.reference_file_path,self.gram)
   #print("Bleu score with gram = %d is: %f" %(self.gram, BleuScore.get_score(is_fast=False)))
   return BleuScore.get_score(is_fast=False)