示例#1
0
文件: Maligan.py 项目: axueya/Texygen
    def init_real_trainng(self, data_loc=None):
        from utils.text_process import text_precess, text_to_code
        from utils.text_process import get_tokenlized, get_word_list, get_dict
        if data_loc is None:
            data_loc = 'data/image_coco.txt'
        self.sequence_length, self.vocab_size = text_precess(data_loc)

        generator = Generator(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim,
                              hidden_dim=self.hidden_dim, sequence_length=self.sequence_length,
                              start_token=self.start_token)
        self.set_generator(generator)

        discriminator = Discriminator(sequence_length=self.sequence_length, num_classes=2, vocab_size=self.vocab_size,
                                      emd_dim=self.emb_dim, filter_sizes=self.filter_size, num_filters=self.num_filters,
                                      l2_reg_lambda=self.l2_reg_lambda)
        self.set_discriminator(discriminator)

        gen_dataloader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length)
        oracle_dataloader = None
        dis_dataloader = DisDataloader(batch_size=self.batch_size, seq_length=self.sequence_length)

        self.set_data_loader(gen_loader=gen_dataloader, dis_loader=dis_dataloader, oracle_loader=oracle_dataloader)
        tokens = get_tokenlized(data_loc)
        word_set = get_word_list(tokens)
        [word_index_dict, index_word_dict] = get_dict(word_set)
        with open(self.oracle_file, 'w') as outfile:
            outfile.write(text_to_code(tokens, word_index_dict, self.sequence_length))
        return word_index_dict, index_word_dict
示例#2
0
文件: Textgan.py 项目: axueya/Texygen
        def get_real_code():
            text = get_tokenlized(self.oracle_file)

            def toint_list(x):
                return list(map(int, x))

            codes = list(map(toint_list, text))
            return codes
示例#3
0
文件: Maligan.py 项目: axueya/Texygen
 def get_real_test_file(dict=iw_dict):
     with open(self.generator_file, 'r') as file:
         codes = get_tokenlized(self.generator_file)
     with open(self.test_file, 'w') as outfile:
         outfile.write(code_to_text(codes=codes, dictionary=dict))
示例#4
0
    def train_real(self, data_loc=None):
        from utils.text_process import code_to_text
        from utils.text_process import get_tokenlized
        wi_dict, iw_dict = self.init_real_trainng(data_loc)
        self.init_real_metric()

        def get_real_test_file(dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 200
        self.log = open('experiment-log-pepgan-real.csv', 'w')
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)

        for a in range(1):
            g = self.sess.run(self.generator.gen_x,
                              feed_dict={
                                  self.generator.drop_out: 1,
                                  self.generator.train: 1
                              })

        print(list(wi_dict))
        self.reset_epoch()
        print('adversarial training:')
        self.reward = Reward(model=self.generator,
                             dis=self.discriminator,
                             sess=self.sess,
                             rollout_num=8)
        for epoch in range(self.adversarial_epoch_num):
            #for epoch_ in range(30):
            print('epoch:' + str(epoch))
            start = time()
            for index in range(1):
                samples = self.generator.generate(self.sess, 1)
                rewards = self.reward.get_reward(samples)
                feed = {
                    self.generator.x: samples,
                    self.generator.reward: rewards,
                    self.generator.drop_out: 1
                }
                _, _, g_loss, w_loss = self.sess.run([
                    self.generator.manager_updates,
                    self.generator.worker_updates,
                    self.generator.goal_loss,
                    self.generator.worker_loss,
                ],
                                                     feed_dict=feed)
                print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss', w_loss)
            end = time()
            self.add_epoch()
            print('epoch:' + str(epoch) + '\t time:' + str(end - start))
            with open("save/test_file_scores_ave.txt", 'a') as outfile:
                outfile.write(str(np.mean(np.concatenate(rewards))) + "\n")

            generate_samples_gen(self.sess, self.generator, self.batch_size,
                                 self.generate_num, self.generator_file)
            get_real_test_file()
            self.evaluate()

            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open("save/test_file" + str(epoch) + ".txt", 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=iw_dict))

            for _ in range(5):
                self.train_discriminator()
示例#5
0
    def init_real_trainng(self, data_loc=None):
        from utils.text_process import text_precess, text_to_code
        from utils.text_process import get_tokenlized, get_word_list, get_dict
        if data_loc is None:
            data_loc = 'data/image_coco.txt'
        self.sequence_length, self.vocab_size = text_precess(data_loc)

        goal_out_size = sum(self.num_filters)
        discriminator = Discriminator(sequence_length=self.sequence_length,
                                      num_classes=2,
                                      vocab_size=self.vocab_size,
                                      dis_emb_dim=self.dis_embedding_dim,
                                      filter_sizes=self.filter_size,
                                      num_filters=self.num_filters,
                                      batch_size=self.batch_size,
                                      hidden_dim=self.hidden_dim,
                                      start_token=self.start_token,
                                      goal_out_size=goal_out_size,
                                      step_size=4,
                                      l2_reg_lambda=self.l2_reg_lambda)
        self.set_discriminator(discriminator)

        generator = Generator(num_classes=2,
                              num_vocabulary=self.vocab_size,
                              batch_size=self.batch_size,
                              emb_dim=self.emb_dim,
                              dis_emb_dim=self.dis_embedding_dim,
                              goal_size=self.goal_size,
                              hidden_dim=self.hidden_dim,
                              sequence_length=self.sequence_length,
                              filter_sizes=self.filter_size,
                              start_token=self.start_token,
                              num_filters=self.num_filters,
                              goal_out_size=goal_out_size,
                              D_model=discriminator,
                              step_size=4)
        self.set_generator(generator)
        gen_dataloader = DataLoader(batch_size=self.batch_size,
                                    seq_length=self.sequence_length)
        oracle_dataloader = None
        dis_dataloader = DisDataloader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length)

        self.set_data_loader(gen_loader=gen_dataloader,
                             dis_loader=dis_dataloader,
                             oracle_loader=oracle_dataloader)
        tokens = get_tokenlized(data_loc)
        word_set = get_word_list(tokens)
        #[word_index_dict, index_word_dict] = get_dict(word_set)#Original

        [word_index_dict, index_word_dict] = [{
            'b': '0',
            'a': '1',
            'r': '2',
            'n': '3',
            'd': '4',
            'c': '5',
            'q': '6',
            'e': '7',
            'g': '8',
            'h': '9',
            'i': '10',
            'l': '11',
            'k': '12',
            'm': '13',
            'f': '14',
            'p': '15',
            's': '16',
            't': '17',
            'w': '18',
            'y': '19',
            'v': '20'
        }, {
            '0': 'b',
            '1': 'a',
            '2': 'r',
            '3': 'n',
            '4': 'd',
            '5': 'c',
            '6': 'q',
            '7': 'e',
            '8': 'g',
            '9': 'h',
            '10': 'i',
            '11': 'l',
            '12': 'k',
            '13': 'm',
            '14': 'f',
            '15': 'p',
            '16': 's',
            '17': 't',
            '18': 'w',
            '19': 'y',
            '20': 'v'
        }]

        with open(self.oracle_file, 'w') as outfile:
            outfile.write(
                text_to_code(tokens, word_index_dict, self.sequence_length))
        return word_index_dict, index_word_dict
示例#6
0
def main():
    print('program start')
    from utils.text_process import text_precess, text_to_code  # TODO: move?
    from utils.text_process import get_tokenlized, get_word_list, get_dict

    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    SEQ_LENGTH, vocab_size = text_precess(true_file, val_file)
    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    val_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)

    # Create training file and dicts
    tokens = get_tokenlized(true_file)
    val_tokens = get_tokenlized(val_file)
    word_set = get_word_list(tokens + val_tokens)
    [word_index_dict, index_word_dict] = get_dict(word_set)
    with open(oracle_file, 'w') as outfile:
        outfile.write(text_to_code(tokens, word_index_dict, SEQ_LENGTH))
    with open(val_oracle_file, 'w') as outfile:
        outfile.write(text_to_code(val_tokens, word_index_dict, SEQ_LENGTH))

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    #target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
    #target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, 32, 32, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model
    # replace target lstm with true data

    mediator = Generator(vocab_size,
                         BATCH_SIZE * 2,
                         EMB_DIM * 2,
                         HIDDEN_DIM * 2,
                         SEQ_LENGTH,
                         START_TOKEN,
                         name="mediator",
                         dropout_rate=M_DROPOUT_RATE)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    gen_data_loader.create_batches(oracle_file)
    val_data_loader.create_batches(val_oracle_file)

    log = open('save/experiment-log.txt', 'w')
    log_nll = open('save/experiment-log-nll.txt', 'w')

    #  pre-train generator (default 0 epochs)(not recommended)
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = mle_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             generator_file)
            #get_real_test_file(index_word_dict, generator_file, test_file) # only needed in debugging
            test_loss = target_loss(sess, generator, val_data_loader)
            print('pre-train epoch ', epoch, 'nll_test ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll_test:\t' + str(
                test_loss) + '\n'
            log_nll.write(buffer)

    print(
        '#########################################################################'
    )
    toc = time.time()
    print('Start Cooperative Training...')
    for iter_idx in range(TOTAL_BATCH):
        print('iteration: ' + str(iter_idx) + '\ntime: ' +
              str(time.time() - toc))
        toc = time.time()
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = mediator.get_reward(
                sess, np.concatenate([samples, samples], axis=0))
            feed = {
                generator.x: samples,
                generator.rewards: rewards[0:BATCH_SIZE]
            }
            loss, _ = sess.run([generator.g_loss, generator.g_updates],
                               feed_dict=feed)
        # Test, removed oracle test
        if iter_idx % gen_data_loader.num_batch == 0:  # epochs instead of batches
            test_loss = target_loss(sess, generator, val_data_loader)
            print('epoch:\t', iter_idx // gen_data_loader.num_batch,
                  'nll_test ', test_loss)
            buffer = 'epoch:\t' + str(
                iter_idx // gen_data_loader.num_batch) + '\tnll_test:\t' + str(
                    test_loss) + '\n'
            log_nll.write(buffer)
        if iter_idx == TOTAL_BATCH - 1:
            print('generating samples')
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             generator_file)
            get_real_test_file(index_word_dict, generator_file, test_file)
        # Train the mediator
        for _ in range(1):
            print('training mediator...')
            bnll_ = []
            collected_x = []
            ratio = 2
            for it in range(ratio):
                if it % 2 == 0:
                    x_batch = gen_data_loader.next_batch()
                else:
                    x_batch = generator.generate(sess)
                collected_x.append(x_batch)
            collected_x = np.reshape(collected_x, [-1, SEQ_LENGTH])
            np.random.shuffle(collected_x)
            collected_x = np.reshape(collected_x,
                                     [-1, BATCH_SIZE * 2, SEQ_LENGTH])
            for it in range(1):
                feed = {
                    mediator.x: collected_x[it],
                }
                print('running bnll sess')
                bnll = sess.run(mediator.likelihood_loss, feed)
                bnll_.append(bnll)
                print('running mediator and updating')
                sess.run(mediator.dropout_on)
                _ = sess.run(mediator.likelihood_updates, feed)
                sess.run(mediator.dropout_off)
            if iter_idx % 50 == 0:
                bnll = np.mean(bnll_)
                print("mediator cooptrain iter#%d, balanced_nll %f" %
                      (iter_idx, bnll))
                log.write("%d\t%f\n" % (iter_idx, bnll))

    log.close()
    log_nll.close()