Esempio n. 1
0
class Leakgan(Gan):
    def __init__(self, oracle=None):
        super().__init__()
        # you can change parameters, generator here

    def init_oracle_trainng(self, oracle=None):
        goal_out_size = sum(self.num_filters)

        if oracle is None:
            oracle = OracleLstm(num_vocabulary=self.vocab_size,
                                batch_size=self.batch_size,
                                emb_dim=self.emb_dim,
                                hidden_dim=self.hidden_dim,
                                sequence_length=self.sequence_length,
                                start_token=self.start_token)
        self.set_oracle(oracle)

        discriminator = Discriminator(sequence_length=self.sequence_length,
                                      num_classes=2,
                                      vocab_size=self.vocab_size,
                                      dis_emb_dim=self.dis_embedding_dim,
                                      filter_sizes=self.filter_size,
                                      num_filters=self.num_filters,
                                      batch_size=self.batch_size,
                                      hidden_dim=self.hidden_dim,
                                      start_token=self.start_token,
                                      goal_out_size=goal_out_size,
                                      step_size=4,
                                      l2_reg_lambda=self.l2_reg_lambda)
        self.set_discriminator(discriminator)

        generator = Generator(num_classes=2,
                              num_vocabulary=self.vocab_size,
                              batch_size=self.batch_size,
                              emb_dim=self.emb_dim,
                              dis_emb_dim=self.dis_embedding_dim,
                              goal_size=self.goal_size,
                              hidden_dim=self.hidden_dim,
                              sequence_length=self.sequence_length,
                              filter_sizes=self.filter_size,
                              start_token=self.start_token,
                              num_filters=self.num_filters,
                              goal_out_size=goal_out_size,
                              D_model=discriminator,
                              step_size=4)
        self.set_generator(generator)

        gen_dataloader = DataLoader(batch_size=self.batch_size,
                                    seq_length=self.sequence_length)
        oracle_dataloader = DataLoader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length)
        dis_dataloader = DisDataloader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.5
        self.sess = tf.Session(config=config)
        self.set_data_loader(gen_loader=gen_dataloader,
                             dis_loader=dis_dataloader,
                             oracle_loader=oracle_dataloader)

    def init_metric(self):
        nll = Nll(data_loader=self.oracle_data_loader,
                  rnn=self.oracle,
                  sess=self.sess)
        self.add_metric(nll)

        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

        from utils.metrics.DocEmbSim import DocEmbSim
        docsim = DocEmbSim(oracle_file=self.oracle_file,
                           generator_file=self.generator_file,
                           num_vocabulary=self.vocab_size)
        self.add_metric(docsim)

    def train_discriminator(self):
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        self.dis_data_loader.load_train_data(self.oracle_file,
                                             self.generator_file)
        for _ in range(3):
            self.dis_data_loader.next_batch()
            x_batch, y_batch = self.dis_data_loader.next_batch()
            feed = {
                self.discriminator.D_input_x: x_batch,
                self.discriminator.D_input_y: y_batch,
            }
            _, _ = self.sess.run(
                [self.discriminator.D_loss, self.discriminator.D_train_op],
                feed)
            self.generator.update_feature_function(self.discriminator)

    def train_oracle(self):
        self.init_oracle_trainng()
        self.init_metric()
        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-leakgan.csv', 'w')
        generate_samples(self.sess, self.oracle, self.batch_size,
                         self.generate_num, self.oracle_file)
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)
        self.oracle_data_loader.create_batches(self.generator_file)

        for a in range(1):
            g = self.sess.run(self.generator.gen_x,
                              feed_dict={
                                  self.generator.drop_out: 1,
                                  self.generator.train: 1
                              })

        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch_gen(self.sess, self.generator,
                                       self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples_gen(self.sess, self.generator,
                                     self.batch_size, self.generate_num,
                                     self.generator_file)
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        print('adversarial training:')

        self.reward = Reward(model=self.generator,
                             dis=self.discriminator,
                             sess=self.sess,
                             rollout_num=4)
        for epoch in range(self.adversarial_epoch_num // 10):
            for epoch_ in range(10):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                for index in range(1):
                    start = time()
                    samples = self.generator.generate(self.sess, 1)
                    rewards = self.reward.get_reward(samples)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1
                    }
                    _, _, g_loss, w_loss = self.sess.run([
                        self.generator.manager_updates,
                        self.generator.worker_updates,
                        self.generator.goal_loss,
                        self.generator.worker_loss,
                    ],
                                                         feed_dict=feed)
                    print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss',
                          w_loss)
                    end = time()
                    print('epoch:' + str(epoch) + '--' + str(epoch_) +
                          '\t time:' + str(end - start))
                    if self.epoch % 5 == 0 or self.epoch == self.adversarial_epoch_num - 1:
                        generate_samples_gen(self.sess, self.generator,
                                             self.batch_size,
                                             self.generate_num,
                                             self.generator_file)
                        self.evaluate()
                    self.add_epoch()

                for _ in range(15):
                    self.train_discriminator()
            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator,
                                           self.gen_data_loader)
                end = time()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                self.add_epoch()
                if epoch % 5 == 0:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    # self.evaluate()
            for epoch_ in range(5):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                self.train_discriminator()

    def init_cfg_training(self, grammar=None):
        from utils.oracle.OracleCfg import OracleCfg
        oracle = OracleCfg(sequence_length=self.sequence_length,
                           cfg_grammar=grammar)
        self.set_oracle(oracle)
        self.oracle.generate_oracle()
        self.vocab_size = self.oracle.vocab_size + 1
        goal_out_size = sum(self.num_filters)

        discriminator = Discriminator(sequence_length=self.sequence_length,
                                      num_classes=2,
                                      vocab_size=self.vocab_size,
                                      dis_emb_dim=self.dis_embedding_dim,
                                      filter_sizes=self.filter_size,
                                      num_filters=self.num_filters,
                                      batch_size=self.batch_size,
                                      hidden_dim=self.hidden_dim,
                                      start_token=self.start_token,
                                      goal_out_size=goal_out_size,
                                      step_size=4,
                                      l2_reg_lambda=self.l2_reg_lambda)
        self.set_discriminator(discriminator)

        generator = Generator(num_classes=2,
                              num_vocabulary=self.vocab_size,
                              batch_size=self.batch_size,
                              emb_dim=self.emb_dim,
                              dis_emb_dim=self.dis_embedding_dim,
                              goal_size=self.goal_size,
                              hidden_dim=self.hidden_dim,
                              sequence_length=self.sequence_length,
                              filter_sizes=self.filter_size,
                              start_token=self.start_token,
                              num_filters=self.num_filters,
                              goal_out_size=goal_out_size,
                              D_model=discriminator,
                              step_size=4)
        self.set_generator(generator)

        gen_dataloader = DataLoader(batch_size=self.batch_size,
                                    seq_length=self.sequence_length)
        oracle_dataloader = DataLoader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length)
        dis_dataloader = DisDataloader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length)

        self.set_data_loader(gen_loader=gen_dataloader,
                             dis_loader=dis_dataloader,
                             oracle_loader=oracle_dataloader)
        return oracle.wi_dict, oracle.iw_dict

    def init_cfg_metric(self, grammar=None):
        from utils.metrics.Cfg import Cfg
        cfg = Cfg(test_file=self.test_file, cfg_grammar=grammar)
        self.add_metric(cfg)

    def train_cfg(self):
        import json
        from utils.text_process import get_tokenlized
        from utils.text_process import code_to_text
        cfg_grammar = """
          S -> S PLUS x | S SUB x |  S PROD x | S DIV x | x | '(' S ')'
          PLUS -> '+'
          SUB -> '-'
          PROD -> '*'
          DIV -> '/'
          x -> 'x' | 'y'
        """

        wi_dict_loc, iw_dict_loc = self.init_cfg_training(cfg_grammar)
        with open(iw_dict_loc, 'r') as file:
            iw_dict = json.load(file)

        def get_cfg_test_file(dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        self.init_cfg_metric(grammar=cfg_grammar)
        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-leakganbasic-cfg.csv', 'w')
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)
        self.oracle_data_loader.create_batches(self.generator_file)
        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch_gen(self.sess, self.generator,
                                       self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples_gen(self.sess, self.generator,
                                     self.batch_size, self.generate_num,
                                     self.generator_file)
                get_cfg_test_file()
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num * 3):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        self.reset_epoch()
        print('adversarial training:')
        self.reward = Reward(model=self.generator,
                             dis=self.discriminator,
                             sess=self.sess,
                             rollout_num=4)
        for epoch in range(self.adversarial_epoch_num // 10):
            for epoch_ in range(10):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                start = time()
                for index in range(1):
                    samples = self.generator.generate(self.sess, 1)
                    rewards = self.reward.get_reward(samples)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1
                    }
                    _, _, g_loss, w_loss = self.sess.run([
                        self.generator.manager_updates,
                        self.generator.worker_updates,
                        self.generator.goal_loss,
                        self.generator.worker_loss,
                    ],
                                                         feed_dict=feed)
                    print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss',
                          w_loss)
                end = time()
                self.add_epoch()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    get_cfg_test_file()
                    self.evaluate()

                for _ in range(15):
                    self.train_discriminator()
            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator,
                                           self.gen_data_loader)
                end = time()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                self.add_epoch()
                if epoch % 5 == 0:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    get_cfg_test_file()
                    self.evaluate()
            for epoch_ in range(5):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                self.train_discriminator()

    def init_real_trainng(self, data_loc=None):
        from utils.text_process import text_precess, text_to_code
        from utils.text_process import get_tokenlized, get_word_list, get_dict
        if data_loc is None:
            data_loc = 'data/image_coco.txt'
        self.sequence_length, self.vocab_size = text_precess(data_loc)

        goal_out_size = sum(self.num_filters)
        discriminator = Discriminator(sequence_length=self.sequence_length,
                                      num_classes=2,
                                      vocab_size=self.vocab_size,
                                      dis_emb_dim=self.dis_embedding_dim,
                                      filter_sizes=self.filter_size,
                                      num_filters=self.num_filters,
                                      batch_size=self.batch_size,
                                      hidden_dim=self.hidden_dim,
                                      start_token=self.start_token,
                                      goal_out_size=goal_out_size,
                                      step_size=4,
                                      l2_reg_lambda=self.l2_reg_lambda,
                                      splited_steps=self.splited_steps)
        self.set_discriminator(discriminator)

        generator = Generator(num_classes=2,
                              num_vocabulary=self.vocab_size,
                              batch_size=self.batch_size,
                              emb_dim=self.emb_dim,
                              dis_emb_dim=self.dis_embedding_dim,
                              goal_size=self.goal_size,
                              hidden_dim=self.hidden_dim,
                              sequence_length=self.sequence_length,
                              filter_sizes=self.filter_size,
                              start_token=self.start_token,
                              num_filters=self.num_filters,
                              goal_out_size=goal_out_size,
                              D_model=discriminator,
                              step_size=4)
        self.set_generator(generator)
        gen_dataloader = DataLoader(batch_size=self.batch_size,
                                    seq_length=self.sequence_length)
        oracle_dataloader = None
        dis_dataloader = DisDataloader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length)

        self.set_data_loader(gen_loader=gen_dataloader,
                             dis_loader=dis_dataloader,
                             oracle_loader=oracle_dataloader)
        tokens = get_tokenlized(data_loc)
        with open(self.oracle_file, 'w') as outfile:
            outfile.write(
                text_to_code(tokens, self.wi_dict, self.sequence_length))

    def train_real(self, data_loc=None):
        self.init_real_trainng(data_loc)
        self.init_real_metric()

        self.sess.run(tf.global_variables_initializer())

        #++ Saver
        saver_variables = tf.global_variables()
        saver = tf.train.Saver(saver_variables)
        #++ ====================

        # summary writer
        self.save_summary()

        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)

        # for a in range(1):
        #     g = self.sess.run(self.generator.gen_x, feed_dict={self.generator.drop_out: 1, self.generator.train: 1})

        if self.restore:
            restore_from = tf.train.latest_checkpoint(self.save_path)
            saver.restore(self.sess, restore_from)
            print(f"{Fore.BLUE}Restore from : {restore_from}{Fore.RESET}")
            self.epoch = self.pre_epoch_num
        else:

            print('start pre-train generator:')
            for epoch in range(self.pre_epoch_num):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator,
                                           self.gen_data_loader)
                end = time()
                print(
                    f"pre-G(global epoch:{self.epoch}): epoch:{epoch} \t time: {end - start:.1f}s"
                )
                self.add_epoch()
                if epoch % self.ntest_pre == 0:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    self.get_real_test_file()
                    self.evaluate()

            print('start pre-train discriminator:')
            for epoch in range(self.pre_epoch_num):
                start = time()
                self.train_discriminator()
                end = time()
                print(f"pre-D: epoch:{epoch} \t time: {end - start:.1f}s")

            # save pre_train
            saver.save(self.sess, os.path.join(self.save_path, 'pre_train'))

        # stop after pretrain
        if self.pretrain:
            generate_samples_gen(self.sess, self.generator, self.batch_size,
                                 self.generate_num, self.generator_file)
            self.get_real_test_file()
            self.evaluate()
            exit()

        print('start adversarial:')
        self.reward = Reward(model=self.generator,
                             dis=self.discriminator,
                             sess=self.sess,
                             rollout_num=4)
        for epoch in range(self.adversarial_epoch_num // 10):
            for epoch_ in range(10):
                start = time()
                for index in range(1):
                    samples = self.generator.generate(self.sess, 1)
                    rewards = self.reward.get_reward(samples)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1
                    }
                    _, _, g_loss, w_loss = self.sess.run([
                        self.generator.manager_updates,
                        self.generator.worker_updates,
                        self.generator.goal_loss,
                        self.generator.worker_loss,
                    ],
                                                         feed_dict=feed)
                    print(
                        f"epoch {epoch} \t g_loss: {g_loss} w_loss: {w_loss}")
                end = time()
                self.add_epoch()
                print(
                    f"adv-G(global epoch:{self.epoch}): epoch:{epoch}--{epoch_} \t time: {end - start:.1f}s"
                )
                if self.epoch % self.ntest == 0 or epoch == self.adversarial_epoch_num - 1:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    self.get_real_test_file()
                    self.evaluate()

                start = time()
                for epoch__ in range(15):
                    print(f"adv-D: epoch:{epoch}--{epoch_}: " + '>' * epoch__ +
                          f"({epoch__}/15)",
                          end='\r')
                    self.train_discriminator()
                end = time()
                print(f"adv-D: epoch:{epoch}--{epoch_}: " + '>' * 15 +
                      f"(15/15) \t time: {end - start:.1f}s")

            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator,
                                           self.gen_data_loader)
                end = time()
                self.add_epoch()
                print(
                    f"mle-G(global epoch:{self.epoch}): epoch:{epoch}--{epoch_} \t time: {end - start:.1f}s"
                )
                if self.epoch % self.ntest == 0:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    self.get_real_test_file()
                    self.evaluate()
            for epoch_ in range(5):
                start = time()
                self.train_discriminator()
                end = time()
                print(
                    f"mle-D(global epoch:{self.epoch}): epoch:{epoch}--{epoch_} \t time: {end - start:.1f}s"
                )
Esempio n. 2
0
class Leakgan(Gan):
    def __init__(self, oracle=None):
        super().__init__()
        # you can change parameters, generator here
        self.vocab_size = 20
        self.emb_dim = 32
        self.hidden_dim = 32
        flags = tf.app.flags
        FLAGS = flags.FLAGS
        flags.DEFINE_boolean('restore', False, 'Training or testing a model')
        flags.DEFINE_boolean('resD', False, 'Training or testing a D model')
        flags.DEFINE_integer('length', 20, 'The length of toy data')
        flags.DEFINE_string('model', "", 'Model NAME')
        self.sequence_length = FLAGS.length
        self.filter_size = [2, 3]
        self.num_filters = [100, 200]
        self.l2_reg_lambda = 0.2
        self.dropout_keep_prob = 0.75
        self.batch_size = 64
        self.generate_num = 256
        self.start_token = 0
        self.dis_embedding_dim = 64
        self.goal_size = 16

        self.oracle_file = 'save/oracle.txt'
        self.generator_file = 'save/generator.txt'
        self.test_file = 'save/test_file.txt'

    def init_oracle_trainng(self, oracle=None):
        goal_out_size = sum(self.num_filters)

        if oracle is None:
            oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim,
                                hidden_dim=self.hidden_dim, sequence_length=self.sequence_length,
                                start_token=self.start_token)
        self.set_oracle(oracle)

        discriminator = Discriminator(sequence_length=self.sequence_length, num_classes=2, vocab_size=self.vocab_size,
                                      dis_emb_dim=self.dis_embedding_dim, filter_sizes=self.filter_size,
                                      num_filters=self.num_filters,
                                      batch_size=self.batch_size, hidden_dim=self.hidden_dim,
                                      start_token=self.start_token,
                                      goal_out_size=goal_out_size, step_size=4,
                                      l2_reg_lambda=self.l2_reg_lambda)
        self.set_discriminator(discriminator)

        generator = Generator(num_classes=2, num_vocabulary=self.vocab_size, batch_size=self.batch_size,
                              emb_dim=self.emb_dim, dis_emb_dim=self.dis_embedding_dim, goal_size=self.goal_size,
                              hidden_dim=self.hidden_dim, sequence_length=self.sequence_length,
                              filter_sizes=self.filter_size, start_token=self.start_token,
                              num_filters=self.num_filters, goal_out_size=goal_out_size, D_model=discriminator,
                              step_size=4)
        self.set_generator(generator)

        gen_dataloader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length)
        oracle_dataloader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length)
        dis_dataloader = DisDataloader(batch_size=self.batch_size, seq_length=self.sequence_length)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.5
        self.sess = tf.Session(config=config)
        self.set_data_loader(gen_loader=gen_dataloader, dis_loader=dis_dataloader, oracle_loader=oracle_dataloader)

    def init_metric(self):
        nll = Nll(data_loader=self.oracle_data_loader, rnn=self.oracle, sess=self.sess)
        self.add_metric(nll)

        inll = Nll(data_loader=self.gen_data_loader, rnn=self.generator, sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

        from utils.metrics.DocEmbSim import DocEmbSim
        docsim = DocEmbSim(oracle_file=self.oracle_file, generator_file=self.generator_file, num_vocabulary=self.vocab_size)
        self.add_metric(docsim)

    def train_discriminator(self):
        generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.generator_file)
        self.dis_data_loader.load_train_data(self.oracle_file, self.generator_file)
        for _ in range(3):
            self.dis_data_loader.next_batch()
            x_batch, y_batch = self.dis_data_loader.next_batch()
            feed = {
                self.discriminator.D_input_x: x_batch,
                self.discriminator.D_input_y: y_batch,
            }
            _, _ = self.sess.run([self.discriminator.D_loss, self.discriminator.D_train_op], feed)
            self.generator.update_feature_function(self.discriminator)

    def evaluate(self):
        generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.generator_file)
        if self.oracle_data_loader is not None:
            self.oracle_data_loader.create_batches(self.generator_file)
        if self.log is not None:
            if self.epoch == 0 or self.epoch == 1:
                for metric in self.metrics:
                    self.log.write(metric.get_name() + ',')
                self.log.write('\n')
            scores = super().evaluate()
            for score in scores:
                self.log.write(str(score) + ',')
            self.log.write('\n')
            return scores
        return super().evaluate()

    def train_oracle(self):
        self.init_oracle_trainng()
        self.init_metric()
        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-leakgan.csv', 'w')
        generate_samples(self.sess, self.oracle, self.batch_size, self.generate_num, self.oracle_file)
        generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)
        self.oracle_data_loader.create_batches(self.generator_file)

        for a in range(1):
            g = self.sess.run(self.generator.gen_x, feed_dict={self.generator.drop_out: 1, self.generator.train: 1})

        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch_gen(self.sess, self.generator, self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.generator_file)
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        print('adversarial training:')

        self.reward = Reward(model=self.generator, dis=self.discriminator, sess=self.sess, rollout_num=4)
        for epoch in range(self.adversarial_epoch_num//10):
            for epoch_ in range(10):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                for index in range(1):
                    start = time()
                    samples = self.generator.generate(self.sess, 1)
                    rewards = self.reward.get_reward(samples)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1
                    }
                    _, _, g_loss, w_loss = self.sess.run(
                        [self.generator.manager_updates, self.generator.worker_updates, self.generator.goal_loss,
                         self.generator.worker_loss, ], feed_dict=feed)
                    print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss', w_loss)
                    end = time()
                    print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' + str(end - start))
                    if self.epoch % 5 == 0 or self.epoch == self.adversarial_epoch_num - 1:
                        generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num,
                                             self.generator_file)
                        self.evaluate()
                    self.add_epoch()


                for _ in range(15):
                    self.train_discriminator()
            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator, self.gen_data_loader)
                end = time()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' + str(end - start))
                self.add_epoch()
                if epoch % 5 == 0:
                    generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num,
                                         self.generator_file)
                    # self.evaluate()
            for epoch_ in range(5):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                self.train_discriminator()

    def init_cfg_training(self, grammar=None):
        from utils.oracle.OracleCfg import OracleCfg
        oracle = OracleCfg(sequence_length=self.sequence_length, cfg_grammar=grammar)
        self.set_oracle(oracle)
        self.oracle.generate_oracle()
        self.vocab_size = self.oracle.vocab_size + 1
        goal_out_size = sum(self.num_filters)

        discriminator = Discriminator(sequence_length=self.sequence_length, num_classes=2, vocab_size=self.vocab_size,
                                      dis_emb_dim=self.dis_embedding_dim, filter_sizes=self.filter_size,
                                      num_filters=self.num_filters,
                                      batch_size=self.batch_size, hidden_dim=self.hidden_dim,
                                      start_token=self.start_token,
                                      goal_out_size=goal_out_size, step_size=4,
                                      l2_reg_lambda=self.l2_reg_lambda)
        self.set_discriminator(discriminator)

        generator = Generator(num_classes=2, num_vocabulary=self.vocab_size, batch_size=self.batch_size,
                              emb_dim=self.emb_dim, dis_emb_dim=self.dis_embedding_dim, goal_size=self.goal_size,
                              hidden_dim=self.hidden_dim, sequence_length=self.sequence_length,
                              filter_sizes=self.filter_size, start_token=self.start_token,
                              num_filters=self.num_filters, goal_out_size=goal_out_size, D_model=discriminator,
                              step_size=4)
        self.set_generator(generator)

        gen_dataloader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length)
        oracle_dataloader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length)
        dis_dataloader = DisDataloader(batch_size=self.batch_size, seq_length=self.sequence_length)

        self.set_data_loader(gen_loader=gen_dataloader, dis_loader=dis_dataloader, oracle_loader=oracle_dataloader)
        return oracle.wi_dict, oracle.iw_dict

    def init_cfg_metric(self, grammar=None):
        from utils.metrics.Cfg import Cfg
        cfg = Cfg(test_file=self.test_file, cfg_grammar=grammar)
        self.add_metric(cfg)

    def train_cfg(self):
        import json
        from utils.text_process import get_tokenlized
        from utils.text_process import code_to_text
        cfg_grammar = """
          S -> S PLUS x | S SUB x |  S PROD x | S DIV x | x | '(' S ')'
          PLUS -> '+'
          SUB -> '-'
          PROD -> '*'
          DIV -> '/'
          x -> 'x' | 'y'
        """

        wi_dict_loc, iw_dict_loc = self.init_cfg_training(cfg_grammar)
        with open(iw_dict_loc, 'r') as file:
            iw_dict = json.load(file)

        def get_cfg_test_file(dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        self.init_cfg_metric(grammar=cfg_grammar)
        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-leakganbasic-cfg.csv', 'w')
        generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)
        self.oracle_data_loader.create_batches(self.generator_file)
        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch_gen(self.sess, self.generator, self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.generator_file)
                get_cfg_test_file()
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num * 3):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        self.reset_epoch()
        print('adversarial training:')
        self.reward = Reward(model=self.generator, dis=self.discriminator, sess=self.sess, rollout_num=4)
        for epoch in range(self.adversarial_epoch_num//10):
            for epoch_ in range(10):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                start = time()
                for index in range(1):
                    samples = self.generator.generate(self.sess, 1)
                    rewards = self.reward.get_reward(samples)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1
                    }
                    _, _, g_loss, w_loss = self.sess.run(
                        [self.generator.manager_updates, self.generator.worker_updates, self.generator.goal_loss,
                         self.generator.worker_loss, ], feed_dict=feed)
                    print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss', w_loss)
                end = time()
                self.add_epoch()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' + str(end - start))
                if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                    generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.generator_file)
                    get_cfg_test_file()
                    self.evaluate()

                for _ in range(15):
                    self.train_discriminator()
            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator, self.gen_data_loader)
                end = time()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' + str(end - start))
                self.add_epoch()
                if epoch % 5 == 0:
                    generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num,
                                         self.generator_file)
                    get_cfg_test_file()
                    self.evaluate()
            for epoch_ in range(5):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                self.train_discriminator()

    def init_real_trainng(self, data_loc=None):
        from utils.text_process import text_precess, text_to_code
        from utils.text_process import get_tokenlized, get_word_list, get_dict
        if data_loc is None:
            data_loc = 'data/image_coco.txt'
        self.sequence_length, self.vocab_size = text_precess(data_loc)

        goal_out_size = sum(self.num_filters)
        discriminator = Discriminator(sequence_length=self.sequence_length, num_classes=2, vocab_size=self.vocab_size,
                                      dis_emb_dim=self.dis_embedding_dim, filter_sizes=self.filter_size,
                                      num_filters=self.num_filters,
                                      batch_size=self.batch_size, hidden_dim=self.hidden_dim,
                                      start_token=self.start_token,
                                      goal_out_size=goal_out_size, step_size=4,
                                      l2_reg_lambda=self.l2_reg_lambda)
        self.set_discriminator(discriminator)

        generator = Generator(num_classes=2, num_vocabulary=self.vocab_size, batch_size=self.batch_size,
                              emb_dim=self.emb_dim, dis_emb_dim=self.dis_embedding_dim, goal_size=self.goal_size,
                              hidden_dim=self.hidden_dim, sequence_length=self.sequence_length,
                              filter_sizes=self.filter_size, start_token=self.start_token,
                              num_filters=self.num_filters, goal_out_size=goal_out_size, D_model=discriminator,
                              step_size=4)
        self.set_generator(generator)
        gen_dataloader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length)
        oracle_dataloader = None
        dis_dataloader = DisDataloader(batch_size=self.batch_size, seq_length=self.sequence_length)

        self.set_data_loader(gen_loader=gen_dataloader, dis_loader=dis_dataloader, oracle_loader=oracle_dataloader)
        tokens = get_tokenlized(data_loc)
        word_set = get_word_list(tokens)
        [word_index_dict, index_word_dict] = get_dict(word_set)
        with open(self.oracle_file, 'w') as outfile:
            outfile.write(text_to_code(tokens, word_index_dict, self.sequence_length))
        return word_index_dict, index_word_dict

    def init_real_metric(self):
        from utils.metrics.DocEmbSim import DocEmbSim
        docsim = DocEmbSim(oracle_file=self.oracle_file, generator_file=self.generator_file, num_vocabulary=self.vocab_size)
        self.add_metric(docsim)

        inll = Nll(data_loader=self.gen_data_loader, rnn=self.generator, sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

    def train_real(self, data_loc=None):
        from utils.text_process import code_to_text
        from utils.text_process import get_tokenlized
        wi_dict, iw_dict = self.init_real_trainng(data_loc)
        self.init_real_metric()

        def get_real_test_file(dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-leakgan-real.csv', 'w')
        generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)

        for a in range(1):
            g = self.sess.run(self.generator.gen_x, feed_dict={self.generator.drop_out: 1, self.generator.train: 1})

        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch_gen(self.sess, self.generator, self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.generator_file)
                get_real_test_file()
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num):
            print('epoch:' + str(epoch))
            self.train_discriminator()


        self.reset_epoch()
        self.reward = Reward(model=self.generator, dis=self.discriminator, sess=self.sess, rollout_num=4)
        for epoch in range(self.adversarial_epoch_num//10):
            for epoch_ in range(10):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                start = time()
                for index in range(1):
                    samples = self.generator.generate(self.sess, 1)
                    rewards = self.reward.get_reward(samples)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1
                    }
                    _, _, g_loss, w_loss = self.sess.run(
                        [self.generator.manager_updates, self.generator.worker_updates, self.generator.goal_loss,
                         self.generator.worker_loss, ], feed_dict=feed)
                    print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss', w_loss)
                end = time()
                self.add_epoch()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' + str(end - start))
                if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                    generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.generator_file)
                    get_real_test_file()
                    self.evaluate()

                for _ in range(15):
                    self.train_discriminator()
            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator, self.gen_data_loader)
                end = time()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' + str(end - start))
                if epoch % 5 == 0:
                    generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num,
                                         self.generator_file)
                    get_real_test_file()
                    # self.evaluate()
            for epoch_ in range(5):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                self.train_discriminator()
Esempio n. 3
0
class Leakgan(Gan):
    def __init__(self, oracle=None):
        super().__init__()
        # you can change parameters, generator here
        self.vocab_size = 20
        self.emb_dim = 32
        self.hidden_dim = 32
        flags = tf.app.flags
        FLAGS = flags.FLAGS
        flags.DEFINE_boolean('restore', False, 'Training or testing a model')
        flags.DEFINE_boolean('resD', False, 'Training or testing a D model')
        flags.DEFINE_integer('length', 20, 'The length of toy data')
        flags.DEFINE_string('model', "", 'Model NAME')
        self.sequence_length = FLAGS.length
        self.filter_size = [2, 3]
        self.num_filters = [100, 200]
        self.l2_reg_lambda = 0.2
        self.dropout_keep_prob = 0.75
        self.batch_size = 64
        self.generate_num = 256
        self.start_token = 0
        self.dis_embedding_dim = 64
        self.goal_size = 16

        self.oracle_file = 'save/oracle.txt'
        self.generator_file = 'save/generator.txt'
        self.test_file = 'save/test_file.txt'

    def init_metric(self):
        nll = Nll(data_loader=self.oracle_data_loader,
                  rnn=self.oracle,
                  sess=self.sess)
        self.add_metric(nll)

        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

        from utils.metrics.DocEmbSim import DocEmbSim
        docsim = DocEmbSim(oracle_file=self.oracle_file,
                           generator_file=self.generator_file,
                           num_vocabulary=self.vocab_size)
        self.add_metric(docsim)

    def train_discriminator(self):
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        self.dis_data_loader.load_train_data(self.oracle_file,
                                             self.generator_file)
        for _ in range(3):
            self.dis_data_loader.next_batch()
            x_batch, y_batch = self.dis_data_loader.next_batch()
            feed = {
                self.discriminator.D_input_x: x_batch,
                self.discriminator.D_input_y: y_batch,
            }
            _, _ = self.sess.run(
                [self.discriminator.D_loss, self.discriminator.D_train_op],
                feed)
            self.generator.update_feature_function(self.discriminator)

    def evaluate(self):
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        if self.oracle_data_loader is not None:
            self.oracle_data_loader.create_batches(self.generator_file)
        if self.log is not None:
            if self.epoch == 0 or self.epoch == 1:
                for metric in self.metrics:
                    self.log.write(metric.get_name() + ',')
                self.log.write('\n')
            scores = super().evaluate()
            for score in scores:
                self.log.write(str(score) + ',')
            self.log.write('\n')
            return scores
        return super().evaluate()

    def init_real_trainng(self, data_loc=None):
        from utils.text_process import text_precess, text_to_code
        from utils.text_process import get_tokenlized, get_word_list, get_dict
        if data_loc is None:
            data_loc = 'data/image_coco.txt'
        self.sequence_length, self.vocab_size = text_precess(data_loc)

        goal_out_size = sum(self.num_filters)
        discriminator = Discriminator(sequence_length=self.sequence_length,
                                      num_classes=2,
                                      vocab_size=self.vocab_size,
                                      dis_emb_dim=self.dis_embedding_dim,
                                      filter_sizes=self.filter_size,
                                      num_filters=self.num_filters,
                                      batch_size=self.batch_size,
                                      hidden_dim=self.hidden_dim,
                                      start_token=self.start_token,
                                      goal_out_size=goal_out_size,
                                      step_size=4,
                                      l2_reg_lambda=self.l2_reg_lambda)
        self.set_discriminator(discriminator)

        generator = Generator(num_classes=2,
                              num_vocabulary=self.vocab_size,
                              batch_size=self.batch_size,
                              emb_dim=self.emb_dim,
                              dis_emb_dim=self.dis_embedding_dim,
                              goal_size=self.goal_size,
                              hidden_dim=self.hidden_dim,
                              sequence_length=self.sequence_length,
                              filter_sizes=self.filter_size,
                              start_token=self.start_token,
                              num_filters=self.num_filters,
                              goal_out_size=goal_out_size,
                              D_model=discriminator,
                              step_size=4)
        self.set_generator(generator)
        gen_dataloader = DataLoader(batch_size=self.batch_size,
                                    seq_length=self.sequence_length)
        oracle_dataloader = None
        dis_dataloader = DisDataloader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length)

        self.set_data_loader(gen_loader=gen_dataloader,
                             dis_loader=dis_dataloader,
                             oracle_loader=oracle_dataloader)
        tokens = get_tokenlized(data_loc)
        word_set = get_word_list(tokens)
        [word_index_dict, index_word_dict] = get_dict(word_set)
        with open(self.oracle_file, 'w') as outfile:
            outfile.write(
                text_to_code(tokens, word_index_dict, self.sequence_length))
        return word_index_dict, index_word_dict

    def init_real_metric(self):
        from utils.metrics.DocEmbSim import DocEmbSim
        docsim = DocEmbSim(oracle_file=self.oracle_file,
                           generator_file=self.generator_file,
                           num_vocabulary=self.vocab_size)
        self.add_metric(docsim)

        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

    def train_real(self, data_loc=None):
        from utils.text_process import code_to_text
        from utils.text_process import get_tokenlized
        wi_dict, iw_dict = self.init_real_trainng(data_loc)
        self.init_real_metric()

        def get_real_test_file(dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-leakgan-real.csv', 'w')
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)

        for a in range(1):
            g = self.sess.run(self.generator.gen_x,
                              feed_dict={
                                  self.generator.drop_out: 1,
                                  self.generator.train: 1
                              })

        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch_gen(self.sess, self.generator,
                                       self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples_gen(self.sess, self.generator,
                                     self.batch_size, self.generate_num,
                                     self.generator_file)
                get_real_test_file()
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        self.reset_epoch()
        self.reward = Reward(model=self.generator,
                             dis=self.discriminator,
                             sess=self.sess,
                             rollout_num=4)
        for epoch in range(self.adversarial_epoch_num // 10):
            for epoch_ in range(10):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                start = time()
                for index in range(1):
                    samples = self.generator.generate(self.sess, 1)
                    rewards = self.reward.get_reward(samples)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1
                    }
                    _, _, g_loss, w_loss = self.sess.run([
                        self.generator.manager_updates,
                        self.generator.worker_updates,
                        self.generator.goal_loss,
                        self.generator.worker_loss,
                    ],
                                                         feed_dict=feed)
                    print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss',
                          w_loss)
                end = time()
                self.add_epoch()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    get_real_test_file()
                    self.evaluate()

                for _ in range(15):
                    self.train_discriminator()
            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator,
                                           self.gen_data_loader)
                end = time()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                if epoch % 5 == 0:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    get_real_test_file()
                    # self.evaluate()
            for epoch_ in range(5):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                self.train_discriminator()
Esempio n. 4
0
class Leakgan(Gan):
    def __init__(self, config, oracle=None):
        super().__init__(config)
        # you can change parameters, generator here
        self.vocab_size = 20
        self.emb_dim = 32
        self.hidden_dim = 32
        flags = tf.app.flags
        FLAGS = flags.FLAGS
        flags.DEFINE_boolean('restore', False, 'Training or testing a model')
        flags.DEFINE_boolean('resD', False, 'Training or testing a D model')
        flags.DEFINE_integer('length', 20, 'The length of toy data')
        flags.DEFINE_string('model', "", 'Model NAME')
        self.sequence_length = 20  # FLAGS.length
        self.filter_size = [2, 3]
        self.num_filters = [100, 200]
        self.l2_reg_lambda = 0.2
        self.dropout_keep_prob = 0.75
        self.batch_size = 64
        self.generate_num = 256
        self.start_token = 0
        self.dis_embedding_dim = 64
        self.goal_size = 16

        self.oracle_file = 'save/oracle.txt'
        self.generator_file = 'save/generator.txt'
        self.test_file = 'save/test_file.txt'

    def init_oracle_trainng(self, oracle=None):
        goal_out_size = sum(self.num_filters)

        if oracle is None:
            oracle = OracleLstm(num_vocabulary=self.vocab_size,
                                batch_size=self.batch_size,
                                emb_dim=self.emb_dim,
                                hidden_dim=self.hidden_dim,
                                sequence_length=self.sequence_length,
                                start_token=self.start_token)
        self.set_oracle(oracle)

        discriminator = Discriminator(sequence_length=self.sequence_length,
                                      num_classes=2,
                                      vocab_size=self.vocab_size,
                                      dis_emb_dim=self.dis_embedding_dim,
                                      filter_sizes=self.filter_size,
                                      num_filters=self.num_filters,
                                      batch_size=self.batch_size,
                                      hidden_dim=self.hidden_dim,
                                      start_token=self.start_token,
                                      goal_out_size=goal_out_size,
                                      step_size=4,
                                      l2_reg_lambda=self.l2_reg_lambda)
        self.set_discriminator(discriminator)

        generator = Generator(num_classes=2,
                              num_vocabulary=self.vocab_size,
                              batch_size=self.batch_size,
                              emb_dim=self.emb_dim,
                              dis_emb_dim=self.dis_embedding_dim,
                              goal_size=self.goal_size,
                              hidden_dim=self.hidden_dim,
                              sequence_length=self.sequence_length,
                              filter_sizes=self.filter_size,
                              start_token=self.start_token,
                              num_filters=self.num_filters,
                              goal_out_size=goal_out_size,
                              D_model=discriminator,
                              step_size=4)
        self.set_generator(generator)

        gen_dataloader = DataLoader(batch_size=self.batch_size,
                                    seq_length=self.sequence_length)
        oracle_dataloader = DataLoader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length)
        dis_dataloader = DisDataloader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.5
        self.sess = tf.Session(config=config)
        self.set_data_loader(gen_loader=gen_dataloader,
                             dis_loader=dis_dataloader,
                             oracle_loader=oracle_dataloader)

    def init_metric(self):
        nll = Nll(data_loader=self.oracle_data_loader,
                  rnn=self.oracle,
                  sess=self.sess)
        self.add_metric(nll)

        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

        from utils.metrics.DocEmbSim import DocEmbSim
        docsim = DocEmbSim(oracle_file=self.oracle_file,
                           generator_file=self.generator_file,
                           num_vocabulary=self.vocab_size)
        self.add_metric(docsim)

    def train_discriminator(self):
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        self.dis_data_loader.load_train_data(self.oracle_file,
                                             self.generator_file)
        for _ in range(3):
            #self.dis_data_loader.next_batch()
            x_batch, y_batch, conv_batch = self.dis_data_loader.next_data_batch(
            )
            feed = {
                self.discriminator.D_input_x: x_batch,
                self.discriminator.D_input_y: y_batch,
                self.discriminator.D_input_conv: conv_batch
            }
            _, _ = self.sess.run(
                [self.discriminator.D_loss, self.discriminator.D_train_op],
                feed)
            self.generator.update_feature_function(self.discriminator)

    def evaluate(self):
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        if self.oracle_data_loader is not None:
            self.oracle_data_loader.create_batches(self.generator_file)
        if self.log is not None:
            if self.epoch == 0 or self.epoch == 1:
                for metric in self.metrics:
                    self.log.write(metric.get_name() + ',')
                self.log.write('\n')
            scores = super().evaluate()
            for score in scores:
                self.log.write(str(score) + ',')
            self.log.write('\n')
            return scores
        return super().evaluate()

    def train_oracle(self):
        self.init_oracle_trainng()
        self.init_metric()
        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-leakgan.csv', 'w')
        generate_samples(self.sess, self.oracle, self.batch_size,
                         self.generate_num, self.oracle_file)
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)
        self.oracle_data_loader.create_batches(self.generator_file)

        for a in range(1):
            g = self.sess.run(self.generator.gen_x,
                              feed_dict={
                                  self.generator.drop_out: 1,
                                  self.generator.train: 1
                              })

        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch_gen(self.sess, self.generator,
                                       self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples_gen(self.sess, self.generator,
                                     self.batch_size, self.generate_num,
                                     self.generator_file)
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        print('adversarial training:')

        self.reward = Reward(model=self.generator,
                             dis=self.discriminator,
                             sess=self.sess,
                             rollout_num=4)
        for epoch in range(self.adversarial_epoch_num // 10):
            for epoch_ in range(10):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                for index in range(1):
                    start = time()
                    samples = self.generator.generate(self.sess, 1)
                    rewards = self.reward.get_reward(samples)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1
                    }
                    _, _, g_loss, w_loss = self.sess.run([
                        self.generator.manager_updates,
                        self.generator.worker_updates,
                        self.generator.goal_loss,
                        self.generator.worker_loss,
                    ],
                                                         feed_dict=feed)
                    print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss',
                          w_loss)
                    end = time()
                    print('epoch:' + str(epoch) + '--' + str(epoch_) +
                          '\t time:' + str(end - start))
                    if self.epoch % 5 == 0 or self.epoch == self.adversarial_epoch_num - 1:
                        generate_samples_gen(self.sess, self.generator,
                                             self.batch_size,
                                             self.generate_num,
                                             self.generator_file)
                        self.evaluate()
                    self.add_epoch()

                for _ in range(15):
                    self.train_discriminator()
            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator,
                                           self.gen_data_loader)
                end = time()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                self.add_epoch()
                if epoch % 5 == 0:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    # self.evaluate()
            for epoch_ in range(5):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                self.train_discriminator()

    def init_cfg_training(self, grammar=None):
        from utils.oracle.OracleCfg import OracleCfg
        oracle = OracleCfg(sequence_length=self.sequence_length,
                           cfg_grammar=grammar)
        self.set_oracle(oracle)
        self.oracle.generate_oracle()
        self.vocab_size = self.oracle.vocab_size + 1
        goal_out_size = sum(self.num_filters)

        discriminator = Discriminator(sequence_length=self.sequence_length,
                                      num_classes=2,
                                      vocab_size=self.vocab_size,
                                      dis_emb_dim=self.dis_embedding_dim,
                                      filter_sizes=self.filter_size,
                                      num_filters=self.num_filters,
                                      batch_size=self.batch_size,
                                      hidden_dim=self.hidden_dim,
                                      start_token=self.start_token,
                                      goal_out_size=goal_out_size,
                                      step_size=4,
                                      l2_reg_lambda=self.l2_reg_lambda)
        self.set_discriminator(discriminator)

        generator = Generator(num_classes=2,
                              num_vocabulary=self.vocab_size,
                              batch_size=self.batch_size,
                              emb_dim=self.emb_dim,
                              dis_emb_dim=self.dis_embedding_dim,
                              goal_size=self.goal_size,
                              hidden_dim=self.hidden_dim,
                              sequence_length=self.sequence_length,
                              filter_sizes=self.filter_size,
                              start_token=self.start_token,
                              num_filters=self.num_filters,
                              goal_out_size=goal_out_size,
                              D_model=discriminator,
                              step_size=4)
        self.set_generator(generator)

        gen_dataloader = DataLoader(batch_size=self.batch_size,
                                    seq_length=self.sequence_length)
        oracle_dataloader = DataLoader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length)
        dis_dataloader = DisDataloader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length)

        self.set_data_loader(gen_loader=gen_dataloader,
                             dis_loader=dis_dataloader,
                             oracle_loader=oracle_dataloader)
        return oracle.wi_dict, oracle.iw_dict

    def init_cfg_metric(self, grammar=None):
        from utils.metrics.Cfg import Cfg
        cfg = Cfg(test_file=self.test_file, cfg_grammar=grammar)
        self.add_metric(cfg)

    def train_cfg(self):
        import json
        from utils.text_process import get_tokenlized
        from utils.text_process import code_to_text
        cfg_grammar = """
          S -> S PLUS x | S SUB x |  S PROD x | S DIV x | x | '(' S ')'
          PLUS -> '+'
          SUB -> '-'
          PROD -> '*'
          DIV -> '/'
          x -> 'x' | 'y'
        """

        wi_dict_loc, iw_dict_loc = self.init_cfg_training(cfg_grammar)
        with open(iw_dict_loc, 'r') as file:
            iw_dict = json.load(file)

        def get_cfg_test_file(dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        self.init_cfg_metric(grammar=cfg_grammar)
        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-leakganbasic-cfg.csv', 'w')
        generate_samples_gen(self.sess, self.generator, self.batch_size,
                             self.generate_num, self.generator_file)
        self.gen_data_loader.create_batches(self.oracle_file)
        self.oracle_data_loader.create_batches(self.generator_file)
        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch_gen(self.sess, self.generator,
                                       self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples_gen(self.sess, self.generator,
                                     self.batch_size, self.generate_num,
                                     self.generator_file)
                get_cfg_test_file()
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num * 3):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        self.reset_epoch()
        print('adversarial training:')
        self.reward = Reward(model=self.generator,
                             dis=self.discriminator,
                             sess=self.sess,
                             rollout_num=4)
        for epoch in range(self.adversarial_epoch_num // 10):
            for epoch_ in range(10):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                start = time()
                for index in range(1):
                    samples = self.generator.generate(self.sess, 1)
                    rewards = self.reward.get_reward(samples)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1
                    }
                    _, _, g_loss, w_loss = self.sess.run([
                        self.generator.manager_updates,
                        self.generator.worker_updates,
                        self.generator.goal_loss,
                        self.generator.worker_loss,
                    ],
                                                         feed_dict=feed)
                    print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss',
                          w_loss)
                end = time()
                self.add_epoch()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    get_cfg_test_file()
                    self.evaluate()

                for _ in range(15):
                    self.train_discriminator()
            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator,
                                           self.gen_data_loader)
                end = time()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                self.add_epoch()
                if epoch % 5 == 0:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    get_cfg_test_file()
                    self.evaluate()
            for epoch_ in range(5):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                self.train_discriminator()

    def init_real_training(self, data_loc=None):
        #from utils.text_process import text_precess, text_to_code
        #from utils.text_process import get_tokenlized, get_word_list, get_dict
        from dataset import prepare_train_data

        if data_loc is None:
            data_loc = 'data/image_coco.txt'

        #self.sequence_length, self.vocab_size, word_index_dict, index_word_dict = text_precess(data_loc, self.oracle_file)

        dataset, self.sequence_length = prepare_train_data(
            self.config, data_loc)
        self.vocab_size = dataset.vocab_size()
        word_index_dict, idnex_work_dict = dataset.get_dict()

        goal_out_size = sum(self.num_filters)
        discriminator = Discriminator(sequence_length=self.sequence_length,
                                      num_classes=2,
                                      vocab_size=self.vocab_size,
                                      dis_emb_dim=self.dis_embedding_dim,
                                      filter_sizes=self.filter_size,
                                      num_filters=self.num_filters,
                                      batch_size=self.batch_size,
                                      hidden_dim=self.hidden_dim,
                                      start_token=self.start_token,
                                      goal_out_size=goal_out_size,
                                      step_size=4,
                                      l2_reg_lambda=self.l2_reg_lambda)
        self.set_discriminator(discriminator)

        generator = Generator(num_classes=2,
                              num_vocabulary=self.vocab_size,
                              batch_size=self.batch_size,
                              emb_dim=self.emb_dim,
                              dis_emb_dim=self.dis_embedding_dim,
                              goal_size=self.goal_size,
                              hidden_dim=self.hidden_dim,
                              sequence_length=self.sequence_length,
                              filter_sizes=self.filter_size,
                              start_token=self.start_token,
                              num_filters=self.num_filters,
                              goal_out_size=goal_out_size,
                              D_model=discriminator,
                              step_size=4)
        self.set_generator(generator)
        gen_dataloader = DataLoader(batch_size=self.batch_size,
                                    seq_length=self.sequence_length,
                                    dataset=dataset)
        oracle_dataloader = None
        dis_dataloader = DisDataloader(batch_size=self.batch_size,
                                       seq_length=self.sequence_length,
                                       dataset=dataset)

        self.set_data_loader(gen_loader=gen_dataloader,
                             dis_loader=dis_dataloader,
                             oracle_loader=oracle_dataloader)
        """
        tokens = get_tokenlized(data_loc)
        word_set = get_word_list(tokens)
        [word_index_dict, index_word_dict] = get_dict(word_set)
        
        with open(self.oracle_file, 'w') as outfile:
            outfile.write(text_to_code(tokens, word_index_dict, self.sequence_length))
        """

        return word_index_dict, index_word_dict

    def init_real_metric(self):
        from utils.metrics.DocEmbSim import DocEmbSim
        docsim = DocEmbSim(oracle_file=self.oracle_file,
                           generator_file=self.generator_file,
                           num_vocabulary=self.vocab_size)
        self.add_metric(docsim)

        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)

    def train_real(self, data_loc=None):
        from utils.text_process import code_to_text
        from utils.text_process import get_tokenlized
        wi_dict, iw_dict = self.init_real_training(data_loc)

        self.init_real_metric()

        def get_real_test_file(dict=iw_dict):
            with open(self.generator_file, 'r') as file:
                codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(code_to_text(codes=codes, dictionary=dict))

        self.sess.run(tf.global_variables_initializer())

        self.pre_epoch_num = 80
        self.adversarial_epoch_num = 100
        self.log = open('experiment-log-leakgan-real.csv', 'w')
        codes = generate_samples_gen(self.sess, self.generator,
                                     self.batch_size, self.generate_num,
                                     self.generator_file)
        print(codes)
        return
        #print(code_to_text(codes=codes, dictionary=iw_dict))

        #self.gen_data_loader.create_batches(self.oracle_file)

        for a in range(1):
            self.generator.generate(self.sess, 1)

        print('start pre-train generator:')
        for epoch in range(self.pre_epoch_num):
            start = time()
            loss = pre_train_epoch_gen(self.sess, self.generator,
                                       self.gen_data_loader)
            end = time()
            print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                generate_samples_gen(self.sess, self.generator,
                                     self.batch_size, self.generate_num,
                                     self.generator_file)
                get_real_test_file()
                self.evaluate()

        print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in range(self.pre_epoch_num):
            print('epoch:' + str(epoch))
            self.train_discriminator()

        self.reset_epoch()
        self.reward = Reward(model=self.generator,
                             dis=self.discriminator,
                             sess=self.sess,
                             rollout_num=4)
        for epoch in range(self.adversarial_epoch_num // 10):
            for epoch_ in range(10):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                start = time()
                for index in range(1):
                    samples = self.generator.generate(self.sess, 1)
                    rewards = self.reward.get_reward(samples)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1
                    }
                    _, _, g_loss, w_loss = self.sess.run([
                        self.generator.manager_updates,
                        self.generator.worker_updates,
                        self.generator.goal_loss,
                        self.generator.worker_loss,
                    ],
                                                         feed_dict=feed)
                    print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss',
                          w_loss)
                end = time()
                self.add_epoch()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    get_real_test_file()
                    self.evaluate()

                for _ in range(15):
                    self.train_discriminator()
            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator,
                                           self.gen_data_loader)
                end = time()
                print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' +
                      str(end - start))
                if epoch % 5 == 0:
                    generate_samples_gen(self.sess, self.generator,
                                         self.batch_size, self.generate_num,
                                         self.generator_file)
                    get_real_test_file()
                    # self.evaluate()
            for epoch_ in range(5):
                print('epoch:' + str(epoch) + '--' + str(epoch_))
                self.train_discriminator()
Esempio n. 5
0
class Leakgan(Gan):
    def __init__(self, config, oracle=None):
        super().__init__(config)
        # you can change parameters, generator here
        #self.vocab_size = 20
        self.emb_dim = config.dim_embedding
        self.hidden_dim = self.emb_dim
        flags = tf.app.flags
        FLAGS = flags.FLAGS
        flags.DEFINE_boolean('restore', False, 'Training or testing a model')
        flags.DEFINE_boolean('resD', False, 'Training or testing a D model')
        flags.DEFINE_integer('length', 70, 'The length of toy data')
        flags.DEFINE_string('model', "", 'Model NAME')
        self.sequence_length = config.max_caption_length  #FLAGS.length
        self.filter_size = config.filter_size
        self.num_filters = config.num_filters
        self.l2_reg_lambda = config.l2_reg_lambda
        self.dropout_keep_prob = config.dropout_keep_prob
        self.batch_size = config.batch_size
        self.generate_num = config.generate_num
        self.start_token = config._START_
        self.dis_embedding_dim = self.emb_dim
        self.goal_size = config.goal_size

        self.oracle_file = 'save/oracle.txt'
        self.generator_file = 'save/generator.txt'
        self.context_file = config.temp_generate_file
        self.test_file = 'save/test_file.txt'
        self.save_loc = 'save/checkpoints'
        self.global_step = tf.Variable(0, trainable=False)

    def train_discriminator(self):

        generate_samples_gen(self.sess, self.generator, self.gen_data_loader,
                             self.batch_size, self.generate_num,
                             self.generator_file, self.context_file)
        self.dis_data_loader.load_train_data(with_image=True)
        for epoch in range(3):
            #print("training discriminator...")
            x_batch, y_batch, conv_features = self.dis_data_loader.next_batch()
            feed = {
                self.discriminator.D_input_x: x_batch,
                self.discriminator.D_input_y: y_batch,
                self.discriminator.conv_features: conv_features
            }
            _, _, summary = self.sess.run([
                self.discriminator.D_loss, self.discriminator.D_train_op,
                self.discriminator.D_summary
            ], feed)
            self.writer.add_summary(summary, self.epoch)

            self.generator.update_feature_function(self.discriminator)

    def evaluate(self):
        return  #TODO: in-training evaluation metrics throw an out of memory error

    def init_real_training(self, data_loc=None, with_image=True):

        self.sequence_length, self.vocab_size, vocabulary = process_train_data(
            self.config, data_loc, has_image=with_image)
        ##self.sequence_length, self.vocab_size, index_word_dict = text_precess(data_loc, oracle_file=self.config.temp_oracle_file)
        print("sequence length:", self.sequence_length, " vocab size:",
              self.vocab_size)
        goal_out_size = sum(self.num_filters)

        discriminator = Discriminator(self.config)
        self.set_discriminator(discriminator)

        generator = Generator(self.config, D_model=discriminator)
        self.set_generator(generator)

        # data loader for generator and discriminator
        gen_dataloader = DataLoader(self.config,
                                    batch_size=self.batch_size,
                                    seq_length=self.sequence_length)
        gen_dataloader.create_shuffled_batches(with_image)
        #gen_dataloader.create_shuffled_batches()

        oracle_dataloader = None
        dis_dataloader = DisDataloader(self.config,
                                       batch_size=self.batch_size,
                                       seq_length=self.sequence_length)

        self.set_data_loader(gen_loader=gen_dataloader,
                             dis_loader=dis_dataloader,
                             oracle_loader=oracle_dataloader)

        #print("done initializing training")
        return vocabulary

    def init_real_metric(self):
        #from utils.metrics.DocEmbSim import DocEmbSim
        #docsim = DocEmbSim(oracle_file=self.oracle_file, generator_file=self.generator_file, num_vocabulary=self.vocab_size)
        #self.add_metric(docsim)

        inll = Nll(data_loader=self.gen_data_loader,
                   rnn=self.generator,
                   sess=self.sess)
        inll.set_name('nll-test')
        self.add_metric(inll)
        print("done initializing metric")

    def train_real(self, data_loc=None, with_image=True):
        from utils.text_process import get_tokenlized
        vocabulary = self.init_real_training(data_loc, with_image)

        #self.init_real_metric()

        def get_real_test_file(codes, vocab=vocabulary):
            return
            #with open(self.generator_file, 'r') as file:
            #    codes = get_tokenlized(self.generator_file)
            with open(self.test_file, 'w') as outfile:
                outfile.write(vocab.code_to_text(codes))

        self.sess.run(tf.global_variables_initializer())
        #self.restore_model(self.sess)

        if not os.path.exists(self.config.summary_dir):
            os.mkdir(self.config.summary_dir)
        self.writer = tf.summary.FileWriter(self.config.summary_dir,
                                            self.sess.graph)

        self.pre_epoch_num = 0
        self.adversarial_epoch_num = 40
        self.log = open('experiment-log-leakgan-real.csv', 'w')

        for a in range(0):
            g = self.sess.run(self.generator.gen_x,
                              feed_dict={
                                  self.generator.drop_out:
                                  1,
                                  self.generator.train:
                                  1,
                                  self.generator.conv_features:
                                  np.zeros((self.generator.batch_size,
                                            self.generator.image_feat_dim),
                                           dtype=np.float32)
                              })
            #print(g)
        #print('start pre-train generator:')
        for epoch in tqdm(list(range(self.pre_epoch_num)),
                          desc='Pretraining generator'):
            start = time()
            loss = pre_train_epoch_gen(self.sess, self.generator,
                                       self.gen_data_loader, self.writer,
                                       self.epoch)
            end = time()
            #print('epoch:' + str(self.epoch) + '\t time:' + str(end - start))
            self.add_epoch()
            if epoch % 5 == 0:
                codes = generate_samples_gen(self.sess, self.generator,
                                             self.gen_data_loader,
                                             self.batch_size,
                                             self.generate_num,
                                             self.generator_file,
                                             self.context_file)
                print(vocabulary.code_to_text(codes))
                get_real_test_file(codes)
                self.evaluate()
        self.save_model(self.sess, self.save_loc)

        #print('start pre-train discriminator:')
        self.reset_epoch()
        for epoch in tqdm(list(range(self.pre_epoch_num)),
                          desc='Pretraining discriminator'):

            self.train_discriminator()
            self.add_epoch()

        self.save_model(self.sess, self.save_loc)

        self.reset_epoch()
        self.reward = Reward(model=self.generator,
                             dis=self.discriminator,
                             sess=self.sess,
                             rollout_num=4)
        for epoch in tqdm(list(range(self.adversarial_epoch_num // 10)),
                          desc='Adversarial training'):
            for epoch_ in range(10):
                #print('epoch:' + str(epoch) + '--' + str(epoch_))
                start = time()
                for index in range(1):
                    _, _, _, conv_features = self.gen_data_loader.next_batch()
                    samples = self.generator.generate(self.sess, conv_features,
                                                      1)
                    rewards = self.reward.get_reward(samples, conv_features)
                    feed = {
                        self.generator.x: samples,
                        self.generator.reward: rewards,
                        self.generator.drop_out: 1.0,
                        self.generator.conv_features: conv_features
                    }
                    _, _, g_loss, w_loss = self.sess.run([
                        self.generator.manager_updates,
                        self.generator.worker_updates,
                        self.generator.goal_loss,
                        self.generator.worker_loss,
                    ],
                                                         feed_dict=feed)
                    print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss',
                          w_loss)
                end = time()
                self.add_epoch()
                #print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' + str(end - start))
                if epoch % 5 == 0 or epoch == self.adversarial_epoch_num - 1:
                    codes = generate_samples_gen(self.sess, self.generator,
                                                 self.gen_data_loader,
                                                 self.batch_size,
                                                 self.generate_num,
                                                 self.generator_file,
                                                 self.context_file)
                    print(vocabulary.code_to_text(codes))
                    get_real_test_file(codes)
                    self.evaluate()

                for _ in range(3):
                    self.train_discriminator()
            self.save_model(self.sess, self.save_loc)
            for epoch_ in range(5):
                start = time()
                loss = pre_train_epoch_gen(self.sess, self.generator,
                                           self.gen_data_loader, self.writer,
                                           self.epoch)
                end = time()
                #print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' + str(end - start))
                self.add_epoch()
                if epoch % 5 == 0:
                    codes = generate_samples_gen(self.sess, self.generator,
                                                 self.gen_data_loader,
                                                 self.batch_size,
                                                 self.generate_num,
                                                 self.generator_file,
                                                 self.context_file)
                    print(vocabulary.code_to_text(codes))
                    get_real_test_file(codes)
                    self.evaluate()
            for epoch_ in range(5):
                #print('epoch:' + str(epoch) + '--' + str(epoch_))
                self.train_discriminator()

        self.save_model(self.sess, self.save_loc)
        self.writer.close()

    def test(self, data_loc=None, with_image=True):

        goal_out_size = sum(self.num_filters)

        self.sequence_length, self.vocab_size, vocabulary = process_test_data(
            self.config)

        discriminator = Discriminator(self.config)
        self.set_discriminator(discriminator)

        generator = Generator(self.config, D_model=discriminator)
        self.set_generator(generator)

        # data loader for generator and discriminator
        gen_dataloader = DataEvalLoader(self.config,
                                        batch_size=self.batch_size)
        gen_dataloader.create_batches(with_image)
        #gen_dataloader.create_shuffled_batches()

        oracle_dataloader = None
        dis_dataloader = DisDataloader(self.config,
                                       batch_size=self.batch_size,
                                       seq_length=self.sequence_length)

        self.set_data_loader(gen_loader=gen_dataloader,
                             dis_loader=dis_dataloader,
                             oracle_loader=oracle_dataloader)

        self.restore_model(self.sess)
        #self.sess.run(tf.global_variables_initializer())
        self.context_file = self.config.temp_generate_eval_file
        codes = generate_samples_gen(self.sess,
                                     self.generator,
                                     self.gen_data_loader,
                                     self.batch_size,
                                     self.batch_size,
                                     self.generator_file,
                                     test=True)
        samples = vocabulary.code_to_text(codes)
        print(np.array(samples).shape)
        samples = self.remove_padding(samples)
        print(np.array(samples).shape)
        print(samples)

        results_writer = open(self.config.test_result_file, 'w')
        for samp in samples:
            results_writer.write(samp)
        results_writer.close()

    def val(self, data_loc=None, with_image=True):

        goal_out_size = sum(self.num_filters)

        self.sequence_length, self.vocab_size, vocabulary = process_val_data(
            self.config)

        discriminator = Discriminator(self.config)
        self.set_discriminator(discriminator)

        generator = Generator(self.config, D_model=discriminator)
        self.set_generator(generator)

        # data loader for generator and discriminator
        gen_dataloader = DataEvalLoader(self.config,
                                        batch_size=self.batch_size)
        gen_dataloader.create_batches(with_image)
        #gen_dataloader.create_shuffled_batches()

        oracle_dataloader = None
        dis_dataloader = DisDataloader(self.config,
                                       batch_size=self.batch_size,
                                       seq_length=self.sequence_length)

        self.set_data_loader(gen_loader=gen_dataloader,
                             dis_loader=dis_dataloader,
                             oracle_loader=oracle_dataloader)

        self.restore_model(self.sess)
        #self.sess.run(tf.global_variables_initializer())
        #self.context_file = self.config.temp_generate_eval_file
        image_files, codes = generate_samples_gen(self.sess,
                                                  self.generator,
                                                  self.gen_data_loader,
                                                  self.batch_size,
                                                  self.config.num_eval_samples,
                                                  eval=True,
                                                  test=True)

        generated_samples = []
        for code in codes:
            #print(code)
            code = vocabulary.code_to_text([code])
            code = self.remove_padding(code)
            generated_samples.append(code)

        np.save(self.config.temp_generate_eval_file, generated_samples)

        ids = []
        for img in image_files:
            #print(img)
            jpg_idx = img.find('.jpg')
            #print(str(jpg_idx))
            ids.append(int(img[45:jpg_idx]))
        np.save(self.config.temp_eval_id, ids)
        prepare_json(self.config)

    def remove_padding(self, samples):
        text = samples
        text = str(text)

        text = text[text.find(':') + 1:]
        ind1 = text.find('<')
        ind2 = text.find('>')

        while (ind1 >= 0 and ind2 >= 0):
            text = text[0:ind1] + text[ind2 + 1:]
            ind1 = text.find('<')
            ind2 = text.find('>')
        ret = text
        return ret

    def save_model(self, sess, checkpoint_dir):
        writer = tf.summary.FileWriter(checkpoint_dir, sess.graph)
        saver = tf.train.Saver(tf.global_variables())
        saver.save(sess,
                   checkpoint_dir + "model.ckpt",
                   global_step=self.global_step)

    def restore_model(self, sess):
        checkpoint_dir = self.config.save_dir
        saver = tf.train.Saver(tf.global_variables())
        saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))