def train_oracle(self):
        self.init_oracle_trainng()
        self.init_metric()
        self.sess.run(tf.global_variables_initializer())
        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-rankgan.csv', 'w')
        generate_samples(self.sess, self.oracle, self.batch_size,
                         self.generate_num, self.oracle_file)
        generate_samples(self.sess, self.generator, self.batch_size,
                         self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)
        self.oracle_data_loader.create_batches(self.generator_file)

        rollout = Reward(self.generator, .8)
        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch(self.sess, self.generator,
                                   self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        print('adversarial training:')
        self.reward = Reward(self.generator, .8)
        for epoch in range(self.adversarial_epoch_num):
            start = time()
            # print('epoch:' + str(epoch))
            for index in range(1):
                samples = self.generator.generate(self.sess)
                self.dis_data_loader.load_train_data(self.oracle_file,
                                                     self.generator_file)
                rewards = self.reward.get_reward(self.sess, samples, 16,
                                                 self.discriminator,
                                                 self.dis_data_loader)
                feed = {
                    self.generator.x: samples,
                    self.generator.rewards: rewards
                }
                _ = self.sess.run(self.generator.g_updates, feed_dict=feed)
            end = time()
            self.add_epoch()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                generate_samples(self.sess, self.generator, self.batch_size,
                                 self.generate_num, self.generator_file)
                self.evaluate()

            self.reward.update_params()
            for _ in range(15):
                self.train_discriminator()
Exemple #2
0
    def train_oracle(self):
        self.init_oracle_trainng()
        self.init_metric()
        self.sess.run(tf.compat.v1.global_variables_initializer())
        self.log = open(self.log_file, 'w')
        generate_samples(self.sess, self.oracle, self.batch_size,
                         self.generate_num, self.oracle_file)
        generate_samples(self.sess, self.generator, self.batch_size,
                         self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)
        self.oracle_data_loader.create_batches(self.generator_file)

        rollout = Reward(self.generator, .8)
        print('Pre-training  Generator...')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch(self.sess, self.generator,
                                   self.gen_data_loader)
            end = time()
            self.add_epoch()
            if epoch % 5 == 0:
                self.evaluate()

        print('Pre-training   Discriminator...')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num):
            self.train_discriminator()

        print('Adversarial Training...')
        self.reward = Reward(self.generator, .8)
        for epoch in range(self.adversarial_epoch_num):
            start = time()
            for index in range(1):
                samples = self.generator.generate(self.sess)
                self.dis_data_loader.load_train_data(self.oracle_file,
                                                     self.generator_file)
                rewards = self.reward.get_reward(self.sess, samples, 16,
                                                 self.discriminator,
                                                 self.dis_data_loader)
                feed = {
                    self.generator.x: samples,
                    self.generator.rewards: rewards
                }
                _ = self.sess.run(self.generator.g_updates, feed_dict=feed)
            end = time()
            self.add_epoch()
            if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                generate_samples(self.sess, self.generator, self.batch_size,
                                 self.generate_num, self.generator_file)
                self.evaluate()

            self.reward.update_params()
            for _ in range(15):
                self.train_discriminator()
        self.log.close()
    def train_real(self, data_loc=None):
        from utils.text_process import code_to_text
        from utils.text_process import get_tokenlized
        wi_dict, iw_dict = self.init_real_trainng(data_loc)
        self.init_real_metric()

        def get_real_test_file(dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-rankgan-real.csv', 'w')
        generate_samples(self.sess, self.generator, self.batch_size,
                         self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)

        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch(self.sess, self.generator,
                                   self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples(self.sess, self.generator, self.batch_size,
                                 self.generate_num, self.generator_file)
                get_real_test_file()
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        self.reset_epoch()
        print('adversarial training:')
        self.reward = Reward(self.generator, .8)
        for epoch in range(self.adversarial_epoch_num):
            # print('epoch:' + str(epoch))
            start = time()
            for index in range(1):
                samples = self.generator.generate(self.sess)
                rewards = self.reward.get_reward(self.sess, samples, 16,
                                                 self.discriminator,
                                                 self.dis_data_loader)
                feed = {
                    self.generator.x: samples,
                    self.generator.rewards: rewards
                }
                _ = self.sess.run(self.generator.g_updates, feed_dict=feed)
            end = time()
            self.add_epoch()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                generate_samples(self.sess, self.generator, self.batch_size,
                                 self.generate_num, self.generator_file)
                get_real_test_file()
                self.evaluate()

            self.reward.update_params()
            for _ in range(15):
                self.train_discriminator()
    def train_cfg(self):
        import json
        from utils.text_process import get_tokenlized
        from utils.text_process import code_to_text
        cfg_grammar = """
          S -> S PLUS x | S SUB x |  S PROD x | S DIV x | x | '(' S ')'
          PLUS -> '+'
          SUB -> '-'
          PROD -> '*'
          DIV -> '/'
          x -> 'x' | 'y'
        """

        wi_dict_loc, iw_dict_loc = self.init_cfg_training(cfg_grammar)
        with open(iw_dict_loc, 'r') as file:
            iw_dict = json.load(file)

        def get_cfg_test_file(dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        self.init_cfg_metric(grammar=cfg_grammar)
        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-rankgan-cfg.csv', 'w')
        generate_samples(self.sess, self.generator, self.batch_size,
                         self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)
        self.oracle_data_loader.create_batches(self.generator_file)
        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch(self.sess, self.generator,
                                   self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples(self.sess, self.generator, self.batch_size,
                                 self.generate_num, self.generator_file)
                get_cfg_test_file()
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num * 3):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        self.reset_epoch()
        print('adversarial training:')
        self.reward = Reward(self.generator, .8)
        for epoch in range(self.adversarial_epoch_num):
            # print('epoch:' + str(epoch))
            start = time()
            for index in range(1):
                samples = self.generator.generate(self.sess)
                rewards = self.reward.get_reward(self.sess, samples, 16,
                                                 self.discriminator,
                                                 self.dis_data_loader)
                feed = {
                    self.generator.x: samples,
                    self.generator.rewards: rewards
                }
                _ = self.sess.run(self.generator.g_updates, feed_dict=feed)
            end = time()
            self.add_epoch()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                generate_samples(self.sess, self.generator, self.batch_size,
                                 self.generate_num, self.generator_file)
                get_cfg_test_file()
                self.evaluate()

            self.reward.update_params()
            for _ in range(15):
                self.train_discriminator()
        return
Exemple #5
0
    def train_real(self, data_loc=None):
        from utils.text_process import code_to_text
        from utils.text_process import get_tokenlized
        wi_dict, iw_dict = self.init_real_trainng(data_loc)
        self.init_real_metric()

        def get_real_test_file(dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        self.sess.run(tf.compat.v1.global_variables_initializer())

        self.log = open(self.log_file, 'w')
        generate_samples(self.sess, self.generator, self.batch_size,
                         self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)

        print('Pre-training  Generator...')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch(self.sess, self.generator,
                                   self.gen_data_loader)
            end = time()
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples(self.sess, self.generator, self.batch_size,
                                 self.generate_num, self.generator_file)
                get_real_test_file()
                self.evaluate()

        print('Pre-training   Discriminator...')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num):
            self.train_discriminator()

        self.reset_epoch()
        print('Adversarial Training...')
        try:
            self.reward = Reward(self.generator, .8)
            for epoch in range(self.adversarial_epoch_num):
                start = time()
                for index in range(1):
                    samples = self.generator.generate(self.sess)
                    rewards = self.reward.get_reward(self.sess, samples, 16,
                                                     self.discriminator,
                                                     self.dis_data_loader)
                    feed = {
                        self.generator.x: samples,
                        self.generator.rewards: rewards
                    }
                    _ = self.sess.run(self.generator.g_updates, feed_dict=feed)
                end = time()
                self.add_epoch()
                if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                    generate_samples(self.sess, self.generator,
                                     self.batch_size, self.generate_num,
                                     self.generator_file)
                    get_real_test_file()
                    self.evaluate()

                self.reward.update_params()
                for _ in range(15):
                    self.train_discriminator()
        except Exception as e:
            print("evaluation error")
            print(e)
Exemple #6
0
    def train_real(self, size, itr, data_loc=None):
        sizes = [200, 400, 600, 800, 1000]
        # size_files = ['1_200.txt','2_200.txt','3_200.txt']
        from utils.text_process import code_to_text
        from utils.text_process import get_tokenlized

        data_loc = 'train_data/coco/' + str(itr) + '_' + str(size) + '.txt'
        # data_loc = 'train/'+size_files[cv_itr]
        wi_dict, iw_dict = self.init_real_trainng(data_loc)
        self.init_real_metric()

        def get_real_test_file(dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        def get_real_test_file_temp(write_file, dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(write_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        self.sess.run(tf.global_variables_initializer())

        self.log = open('experiment-log-rankgan-real.csv', 'w')
        generate_samples(self.sess, self.generator, self.batch_size,
                         self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)

        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch(self.sess, self.generator,
                                   self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                # write generate.txt in indexes of the gen sentences
                generate_samples(self.sess, self.generator, self.batch_size,
                                 self.generate_num, self.generator_file)
                # write ^ in words
                get_real_test_file()
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        self.reset_epoch()
        print('adversarial training:')
        self.reward = Reward(self.generator, .8)
        for epoch in range(self.adversarial_epoch_num):
            # print('epoch:' + str(epoch))
            start = time()
            for index in range(1):
                samples = self.generator.generate(self.sess)
                rewards = self.reward.get_reward(self.sess, samples, 16,
                                                 self.discriminator,
                                                 self.dis_data_loader)
                feed = {
                    self.generator.x: samples,
                    self.generator.rewards: rewards
                }
                _ = self.sess.run(self.generator.g_updates, feed_dict=feed)
            end = time()
            self.add_epoch()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                generate_samples(self.sess, self.generator, self.batch_size,
                                 self.generate_num, self.generator_file)
                get_real_test_file()
                self.evaluate()

            self.reward.update_params()
            for _ in range(15):
                self.train_discriminator()

        # bleu = Bleu( self.test_file, self.real_file, self.sess)
        # sbleu = SelfBleu(self.test_file, self.sess)
        # scorefile = open(self.result_file, 'a+')

        for alpha in self.temps:
            gen_file = 'gen_data/rankgan/coco/' + str(itr) + '_' + str(
                size) + '_' + str(alpha) + '.txt'
            #                     gen_file= 'gen_data/seqgan/coco/'+'ds.txt'
            print('alpha', alpha, gen_file)
            generate_samples_temp(self.sess, self.generator, self.batch_size,
                                  self.generate_num, alpha,
                                  self.generator_file)
            get_real_test_file_temp(gen_file)