Ejemplo n.º 1
0
class RelbarGANInstructor(BasicInstructor):
    def __init__(self, opt):
        super(RelbarGANInstructor, self).__init__(opt)

        # generator, discriminator
        self.gen = RelbarGAN_G(cfg.mem_slots,
                               cfg.num_heads,
                               cfg.head_size,
                               cfg.gen_embed_dim,
                               cfg.gen_hidden_dim,
                               cfg.vocab_size,
                               cfg.max_seq_len,
                               cfg.padding_idx,
                               cfg.temperature,
                               cfg.eta,
                               gpu=cfg.CUDA)
        self.dis = RelbarGAN_D(cfg.dis_embed_dim,
                               cfg.max_seq_len,
                               cfg.num_rep,
                               cfg.vocab_size,
                               cfg.padding_idx,
                               gpu=cfg.CUDA)
        self.init_model()

        # Optimizer
        self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)
        self.gen_adv_opt = optim.Adam(itertools.chain(
            self.gen.parameters(), [self.gen.temperature, self.gen.eta]),
                                      lr=cfg.gen_adv_lr)
        self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)

        # Criterion
        self.mle_criterion = nn.NLLLoss()
        self.dis_pretrain_criterion = nn.BCEWithLogitsLoss()

        # DataLoader
        self.gen_data = GenDataIter(
            self.gen.sample(cfg.batch_size, cfg.batch_size))
        self.dis_data = DisDataIter(self.train_data.random_batch()['target'],
                                    self.gen_data.random_batch()['target'])

        # Metrics
        bleu_gram = list(range(2, cfg.max_seq_len +
                               1)) if cfg.max_seq_len < 5 else [2, 3, 4, 5]
        self.bleu = BLEU(test_text=tensor_to_tokens(self.gen_data.target,
                                                    self.index_word_dict),
                         real_text=tensor_to_tokens(
                             self.test_data.target,
                             self.test_data.index_word_dict),
                         gram=bleu_gram)
        self.self_bleu = BLEU(
            test_text=tensor_to_tokens(self.gen_data.target,
                                       self.index_word_dict),
            real_text=tensor_to_tokens(self.gen_data.target,
                                       self.index_word_dict),
            gram=3)

    def _run(self):
        # =====PRE-TRAINING (GENERATOR)=====
        if not cfg.gen_pretrain:
            self.log.info('Starting Generator MLE Training...')
            self.pretrain_generator(cfg.MLE_train_epoch)
            if cfg.if_save and not cfg.if_test:
                torch.save(self.gen.state_dict(), cfg.pretrained_gen_path)
                print('Save pretrain_generator: {}'.format(
                    cfg.pretrained_gen_path))

        # =====PRE-TRAINING (DISCRIMINATOR)=====
        if not cfg.dis_pretrain:
            self.log.info('Starting Discriminator Training...')
            self.pretrain_discriminator(cfg.d_step, cfg.d_epoch)
            if cfg.if_save and not cfg.if_test:
                torch.save(self.dis.state_dict(), cfg.pretrained_dis_path)
                print('Save pretrain_generator discriminator: {}'.format(
                    cfg.pretrained_dis_path))

        # # =====ADVERSARIAL TRAINING=====
        self.log.info('Starting Adversarial Training...')
        self.log.info('Initial generator: %s' %
                      (self.cal_metrics(fmt_str=True)))

        for adv_epoch in range(cfg.ADV_train_epoch):
            if adv_epoch % cfg.adv_log_step == 0:
                self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch)
            self.sig.update()
            if self.sig.adv_sig:
                self.adv_train_generator(cfg.ADV_g_step,
                                         adv_epoch)  # Generator
                self.adv_train_discriminator(cfg.ADV_d_step,
                                             adv_epoch)  # Discriminator

                if adv_epoch % cfg.adv_log_step == 0:
                    if cfg.if_save and not cfg.if_test:
                        self._save('ADV', adv_epoch)
            else:
                self.log.info(
                    '>>> Stop by adv_signal! Finishing adversarial training...'
                )
                break

    def _test(self):
        print('>>> Begin test...')

        self._run()
        pass

    def pretrain_generator(self, epochs):
        """
        Max Likelihood Pre-training for the generator
        """
        for epoch in range(epochs):
            self.sig.update()
            if self.sig.pre_sig:
                # =====Train=====
                pre_loss = self.train_gen_epoch(self.gen,
                                                self.train_data.loader,
                                                self.mle_criterion,
                                                self.gen_opt)

                # =====Test=====
                if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1:
                    self.log.info(
                        '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' %
                        (epoch, pre_loss, self.cal_metrics(fmt_str=True)))

                    if cfg.if_save and not cfg.if_test:
                        self._save('MLE', epoch)
            else:
                self.log.info(
                    '>>> Stop by pre signal, skip to adversarial training...')
                break
        if cfg.if_save and not cfg.if_test:
            self._save('MLE', epoch)

    def adv_train_generator(self, g_step, adv_epoch):
        rebar_ge = RebarGradientEstimator(
            discriminator=self.dis,
            batch_size=cfg.batch_size,
            real_samples=self.train_data.random_batch()['target'],
            gpu=cfg.CUDA,
            num_rep=cfg.num_rep)
        total_rebar_loss = 0
        old_temperature = self.gen.temperature.item()
        old_eta = self.gen.eta.item()
        for step in range(g_step):
            theta, z = self.gen.sample_theta(cfg.batch_size)
            estimated_gradient, temperature_grad, eta_grad = rebar_ge.estimate_gradient(
                theta, z,
                self.gen.temperature.clone().detach().requires_grad_(),
                self.gen.eta.clone().detach().requires_grad_())

            adv_loss = self.gen.computeRebarLoss(estimated_gradient)
            temperature_grad = temperature_grad if cfg.learn_temperature else torch.zeros_like(
                temperature_grad)
            eta_grad = eta_grad if cfg.learn_eta else torch.zeros_like(
                eta_grad)
            self.optimize(self.gen_adv_opt,
                          adv_loss,
                          callback=functools.partial(
                              self.gen.set_variance_loss_gradients,
                              temperature_grad, eta_grad))
            total_rebar_loss += adv_loss.item()

        # =====Test=====
        avg_rebar_loss = total_rebar_loss / g_step if g_step != 0 else 0
        if adv_epoch % cfg.adv_log_step == 0:
            self.log.info(
                '[ADV-GEN] rebar_loss = %.4f, temperature = %.4f, eta = %.4f, %s'
                % (avg_rebar_loss, old_temperature, old_eta,
                   self.cal_metrics(fmt_str=True)))

    def adv_train_discriminator(self, d_step, adv_epoch):
        total_loss = 0
        total_acc = 0
        for step in range(d_step):
            # TODO(ethanjiang) we may want to train a full epoch instead of a random batch
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(
                ), gen_samples.cuda()
            real_samples = F.one_hot(real_samples, cfg.vocab_size).float()
            gen_samples = F.one_hot(gen_samples, cfg.vocab_size).float()

            # =====Train=====
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.dis_opt.zero_grad()
            d_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.dis.parameters(),
                                           cfg.clip_norm)
            self.dis_opt.step()

            total_loss += d_loss.item()
            if cfg.loss_type != 'rsgan':
                predictions = torch.cat((d_out_real, d_out_fake))
                labels = torch.cat((torch.ones_like(d_out_real),
                                    torch.zeros_like(d_out_fake)))
                total_acc += torch.sum(
                    ((predictions > 0).float() == labels)).item()

        # =====Test=====
        avg_loss = total_loss / d_step if d_step != 0 else 0
        if cfg.loss_type != 'rsgan':
            avg_acc = total_acc / (d_step * cfg.batch_size *
                                   2) if d_step != 0 else 0
            if adv_epoch % cfg.adv_log_step == 0:
                self.log.info('[ADV-DIS] d_loss = %.4f, train_acc = %.4f,' %
                              (avg_loss, avg_acc))
        else:
            if adv_epoch % cfg.adv_log_step == 0:
                self.log.info('[ADV-DIS] d_loss = %.4f,' % (avg_loss))

    def pretrain_discriminator(self, d_step, d_epoch, phrase='MLE'):
        """
        Training the discriminator on real_data_samples (positive) and generated samples from gen (negative).
        Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch.
        """
        # prepare loader for validate
        for step in range(d_step):
            # prepare loader for training
            pos_samples = self.train_data.target
            neg_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size)
            self.dis_data.reset(pos_samples, neg_samples)

            for epoch in range(d_epoch):
                # =====Train=====
                d_loss, train_acc = self.train_dis_epoch(
                    self.dis,
                    self.dis_data.loader,
                    self.dis_pretrain_criterion,
                    self.dis_opt,
                    one_hot=True,
                    num_rep=cfg.num_rep)

            # =====Test=====
            self.log.info(
                '[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' %
                (phrase, step, d_loss, train_acc))
Ejemplo n.º 2
0
class SeqGANInstructor(BasicInstructor):
    def __init__(self, opt):
        super(SeqGANInstructor, self).__init__(opt)

        # generator, discriminator
        self.gen = SeqGAN_G(cfg.gen_embed_dim,
                            cfg.gen_hidden_dim,
                            cfg.vocab_size,
                            cfg.max_seq_len,
                            cfg.padding_idx,
                            cfg.temperature,
                            gpu=cfg.CUDA)
        self.dis = SeqGAN_D(cfg.dis_embed_dim,
                            cfg.vocab_size,
                            cfg.padding_idx,
                            gpu=cfg.CUDA)
        self.init_model()

        # Optimizer
        self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)
        self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)
        self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)

        # Criterion
        self.mle_criterion = nn.NLLLoss()
        self.dis_criterion = nn.BCEWithLogitsLoss()

        # DataLoader
        self.gen_data = GenDataIter(
            self.gen.sample(cfg.batch_size, cfg.batch_size))
        self.dis_data = DisDataIter(self.train_data.random_batch()['target'],
                                    self.gen_data.random_batch()['target'])

        # Metrics
        self.bleu = BLEU(test_text=tensor_to_tokens(self.gen_data.target,
                                                    self.index_word_dict),
                         real_text=tensor_to_tokens(
                             self.test_data.target,
                             self.test_data.index_word_dict),
                         gram=3)
        self.self_bleu = BLEU(
            test_text=tensor_to_tokens(self.gen_data.target,
                                       self.index_word_dict),
            real_text=tensor_to_tokens(self.gen_data.target,
                                       self.index_word_dict),
            gram=3)

    def _run(self):
        # =====PRE-TRAINING=====
        # TRAIN GENERATOR
        if not cfg.gen_pretrain:
            self.log.info('Starting Generator MLE Training...')
            self.pretrain_generator(cfg.MLE_train_epoch)
            if cfg.if_save and not cfg.if_test:
                torch.save(self.gen.state_dict(), cfg.pretrained_gen_path)
                print('Save pre-trained generator: {}'.format(
                    cfg.pretrained_gen_path))

        # =====TRAIN DISCRIMINATOR======
        if not cfg.dis_pretrain:
            self.log.info('Starting Discriminator Training...')
            self.train_discriminator(cfg.d_step, cfg.d_epoch)
            if cfg.if_save and not cfg.if_test:
                torch.save(self.dis.state_dict(), cfg.pretrained_dis_path)
                print('Save pretrain_generator discriminator: {}'.format(
                    cfg.pretrained_dis_path))

        # =====ADVERSARIAL TRAINING=====
        self.log.info('Starting Adversarial Training...')
        self.log.info('Initial generator: %s' %
                      (self.cal_metrics(fmt_str=True)))

        for adv_epoch in range(cfg.ADV_train_epoch):
            self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch)
            self.sig.update()
            if self.sig.adv_sig:
                self.adv_train_generator(cfg.ADV_g_step)  # Generator
                self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch,
                                         'ADV')  # Discriminator

                if adv_epoch % cfg.adv_log_step == 0:
                    if cfg.if_save and not cfg.if_test:
                        self._save('ADV', adv_epoch)
            else:
                self.log.info(
                    '>>> Stop by adv_signal! Finishing adversarial training...'
                )
                break

    def _test(self):
        print('>>> Begin test...')

        self._run()
        pass

    def pretrain_generator(self, epochs):
        """
        Max Likelihood Pre-training for the generator
        """
        for epoch in range(epochs):
            self.sig.update()
            if self.sig.pre_sig:
                pre_loss = self.train_gen_epoch(self.gen,
                                                self.train_data.loader,
                                                self.mle_criterion,
                                                self.gen_opt)

                # =====Test=====
                if epoch % cfg.pre_log_step == 0:
                    self.log.info(
                        '[MLE-GEN] epoch %d : pre_loss = %.4f, %s' %
                        (epoch, pre_loss, self.cal_metrics(fmt_str=True)))
                    if cfg.if_save and not cfg.if_test:
                        self._save('MLE', epoch)
            else:
                self.log.info(
                    '>>> Stop by pre signal, skip to adversarial training...')
                break
        if cfg.if_save and not cfg.if_test:
            self._save('MLE', epoch)

    def adv_train_generator(self, g_step):
        """
        The gen is trained using policy gradients, using the reward from the discriminator.
        Training is done for num_batches batches.
        """
        rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA)
        total_g_loss = 0
        for step in range(g_step):
            inp, target = self.gen_data.prepare(self.gen.sample(
                cfg.batch_size, cfg.batch_size),
                                                gpu=cfg.CUDA)

            # =====Train=====
            rewards = rollout_func.get_reward(target, cfg.rollout_num,
                                              self.dis)
            adv_loss = self.gen.batchPGLoss(inp, target, rewards)
            self.optimize(self.gen_adv_opt, adv_loss)
            total_g_loss += adv_loss.item()

        # =====Test=====
        self.log.info('[ADV-GEN]: g_loss = %.4f, %s' %
                      (total_g_loss, self.cal_metrics(fmt_str=True)))

    def train_discriminator(self, d_step, d_epoch, phrase='MLE'):
        """
        Training the discriminator on real_data_samples (positive) and generated samples from gen (negative).
        Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch.
        """
        # prepare loader for validate
        for step in range(d_step):
            # prepare loader for training
            pos_samples = self.train_data.target
            neg_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size)
            self.dis_data.reset(pos_samples, neg_samples)

            for epoch in range(d_epoch):
                # =====Train=====
                d_loss, train_acc = self.train_dis_epoch(
                    self.dis, self.dis_data.loader, self.dis_criterion,
                    self.dis_opt)

            # =====Test=====
            self.log.info(
                '[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' %
                (phrase, step, d_loss, train_acc))

            if cfg.if_save and not cfg.if_test:
                torch.save(self.dis.state_dict(), cfg.pretrained_dis_path)
Ejemplo n.º 3
0
class LeakGANInstructor(BasicInstructor):
    def __init__(self, opt):
        super(LeakGANInstructor, self).__init__(opt)

        # generator, discriminator
        self.gen = LeakGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim,
                             cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx,
                             cfg.goal_size, cfg.step_size, cfg.CUDA)
        self.dis = LeakGAN_D(cfg.dis_embed_dim,
                             cfg.vocab_size,
                             cfg.padding_idx,
                             gpu=cfg.CUDA)

        #LSTM
        self.corpus = dataa.Corpus('dataset/emnlp_news/')
        self.lstm = LSTM.RNNModel('LSTM', len(self.corpus.dictionary), 200,
                                  600, 3, 0.2, False)
        if (cfg.CUDA):
            self.dis.cuda()
            self.gen.cuda()
        self.init_model()

        # optimizer
        mana_params, work_params = self.gen.split_params()
        mana_opt = optim.Adam(mana_params, lr=cfg.gen_lr)
        work_opt = optim.Adam(work_params, lr=cfg.gen_lr)

        self.gen_opt = [mana_opt, work_opt]
        self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)

        # Criterion
        self.mle_criterion = nn.NLLLoss()
        self.dis_criterion = nn.CrossEntropyLoss()

        # DataLoader
        self.gen_data = GenDataIter(
            self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis))
        self.dis_data = DisDataIter(self.gen_data.random_batch()['target'],
                                    self.oracle_data.random_batch()['target'])

        # Metrics
        self.bleu3 = BLEU(test_text=tensor_to_tokens(self.gen_data.target,
                                                     self.index_word_dict),
                          real_text=tensor_to_tokens(self.test_data.target,
                                                     self.index_word_dict),
                          gram=3)

    def _run(self):
        for inter_num in range(cfg.inter_epoch):
            self.log.info('>>> Interleaved Round %d...' % inter_num)
            self.sig.update()  # update signal
            if self.sig.pre_sig:
                # =====DISCRIMINATOR PRE-TRAINING=====
                if not cfg.dis_pretrain:
                    self.log.info('Starting Discriminator Training...')
                    self.train_discriminator(cfg.d_step, cfg.d_epoch)
                    if cfg.if_save and not cfg.if_test:
                        torch.save(self.dis.state_dict(),
                                   cfg.pretrained_dis_path)
                        print('Save pre-trained discriminator: {}'.format(
                            cfg.pretrained_dis_path))

                # =====GENERATOR MLE TRAINING=====
                if not cfg.gen_pretrain:
                    self.log.info('Starting Generator MLE Training...')
                    self.pretrain_generator(cfg.MLE_train_epoch)
                    if cfg.if_save and not cfg.if_test:
                        torch.save(self.gen.state_dict(),
                                   cfg.pretrained_gen_path)
                        print('Save pre-trained generator: {}'.format(
                            cfg.pretrained_gen_path))
            else:
                self.log.info(
                    '>>> Stop by pre_signal! Skip to adversarial training...')
                break

        # =====ADVERSARIAL TRAINING=====
        self.log.info('Starting Adversarial Training...')
        self.log.info('Initial generator: %s' %
                      (str(self.cal_metrics(fmt_str=True))))

        for adv_epoch in range(cfg.ADV_train_epoch):
            self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch)
            self.sig.update()
            if self.sig.adv_sig:
                self.adv_train_generator(cfg.ADV_g_step)  # Generator
                self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch,
                                         'ADV')  # Discriminator

                if adv_epoch % cfg.adv_log_step == 0:
                    if cfg.if_save and not cfg.if_test:
                        self._save('ADV', adv_epoch)
            else:
                self.log.info(
                    '>>> Stop by adv_signal! Finishing adversarial training...'
                )
                break

    def string2bins(self, bit_string, n_bins):
        n_bits = int(math.log(n_bins, 2))
        return [
            bit_string[i:i + n_bits] for i in range(0, len(bit_string), n_bits)
        ]

    def LSTM_layer_1(self, intermediate_file, bins_num):
        print('>>> Begin test...')
        print('Begin with LSTM Layer')
        #First layer- LSTM layer
        epoch_start_time = time.time()
        seed = 1111
        data_root = './decode/'
        #Reproducibility
        torch.manual_seed(seed)
        if cfg.CUDA:
            torch.cuda.manual_seed(seed)
        with open("leakGAN_instructor/real_data/emnlp_news.pt", 'rb') as f:
            self.lstm = torch.load(f)
        if cfg.CUDA:
            self.lstm.cuda()
        emnlp_data = 'dataset/emnlp_news/'
        corpus = dataa.Corpus(emnlp_data)
        ntokens = len(corpus.dictionary)
        idx2word_file = data_root + "idx2word_1.txt"
        word2idx_file = data_root + "word2idx_1.txt"
        with open(idx2word_file, "wb") as fp:  #Pickling
            pickle.dump(corpus.dictionary.idx2word, fp)
        with open(word2idx_file, "wb") as fp:  #Pickling
            pickle.dump(corpus.dictionary.word2idx, fp)
        hidden = self.lstm.init_hidden(1)
        input = torch.randint(ntokens, (1, 1), dtype=torch.long)

        if cfg.CUDA:
            input.data = input.data.cuda()
        print("Finished Initializing LSTM Model")
        #Step 1: Get secret data
        secret_file = open("leakGAN_instructor/real_data/secret_file.txt", 'r')
        secret_data = secret_file.read()
        #Step 2: Compress string into binary string
        bit_string = ''.join(
            bin(ord(letter))[2:].zfill(8) for letter in secret_data)
        #print(bit_string)
        bit_string = '111011100101000111000011110111101111110111000110011010110110'
        #In the first step we will use 256 bins (8 bit representation each) to convert so that we can convert 64 bits into 8 word
        #Step 3: Divide into bins
        secret_text = [
            int(i, 2) for i in self.string2bins(bit_string, bins_num)
        ]  #convert to bins
        #Step 4: Divide vocabulary into bins, zero words not in the bin
        if bins_num >= 2:
            tokens = list(range(ntokens))  #indecies of words
            random.shuffle(tokens)  #randomize

            #Words in each bin
            words_in_bin = int(ntokens / bins_num)
            #leftovers should be also included in the
            leftover = int(ntokens % bins_num)
            bins = [
                tokens[i:i + words_in_bin]
                for i in range(0, ntokens - leftover, words_in_bin)
            ]  # words to keep in each bin
            for i in range(len(bins)):
                if (i == leftover):
                    break
                bins[i].append(tokens[i + words_in_bin * bins_num])
            print("Len of bins in 1st layer: {}".format(len(bins)))
            #save bins into key 1
            key1 = data_root + "lstm_key1.txt"
            with open(key1, "wb") as fp:  #Pickling
                pickle.dump(bins, fp)
            zero = [list(set(tokens) - set(bin_)) for bin_ in bins]

        print('Finished Initializing First LSTM Layer')
        print('time: {:5.2f}s'.format(time.time() - epoch_start_time))
        print('-' * 89)

        intermediate_file = data_root + intermediate_file
        with open(intermediate_file, 'w') as outf:
            w = 0
            i = 1
            temperature = 1.5
            bin_sequence_length = len(secret_text[:])  # 85
            print("bin sequence length", bin_sequence_length)  #32
            while i <= bin_sequence_length:
                epoch_start_time = time.time()
                output, hidden = self.lstm(input, hidden)

                zero_index = zero[secret_text[:][i - 1]]
                zero_index = torch.LongTensor(zero_index)
                word_weights = output.squeeze().data.div(
                    temperature).exp().cpu()
                word_weights.index_fill_(0, zero_index, 0)
                word_idx = torch.multinomial(word_weights, 1)[0]

                input.data.fill_(word_idx)
                word = corpus.dictionary.idx2word[word_idx]
                i += 1
                w += 1
                word = word.encode('ascii', 'ignore').decode('ascii')
                outf.write(word + ' ')
        print("Generated intermediate short steganographic text")
        print("Intermediate text saved in following file: {}".format(
            intermediate_file))

    def LSTM_layer_2(self, secret_file, final_file, bins_num):
        print('Final LSTM Layer')
        #First layer- LSTM layer
        data_root = './decode/'
        epoch_start_time = time.time()
        seed = 1111
        #Reproducibility
        torch.manual_seed(seed)
        if cfg.CUDA:
            torch.cuda.manual_seed(seed)
        with open("leakGAN_instructor/real_data/emnlp_news.pt", 'rb') as f:
            self.lstm = torch.load(f)
        if cfg.CUDA:
            self.lstm.cuda()
        emnlp_data = 'dataset/emnlp_news/'
        corpus = dataa.Corpus(emnlp_data)
        #save dictionary
        idx2word_file = data_root + "idx2word_2.txt"
        word2idx_file = data_root + "word2idx_2.txt"
        with open(idx2word_file, "wb") as fp:  #Pickling
            pickle.dump(corpus.dictionary.idx2word, fp)
        with open(word2idx_file, "wb") as fp:  #Pickling
            pickle.dump(corpus.dictionary.word2idx, fp)
        ntokens = len(corpus.dictionary)
        hidden = self.lstm.init_hidden(1)
        input = torch.randint(ntokens, (1, 1), dtype=torch.long)

        if cfg.CUDA:
            input.data = input.data.cuda()
        print("Finished Initializing LSTM Model")
        #Step 1: Get secret data
        secret_file = open(data_root + secret_file, 'r')
        secret_data = secret_file.read().split()
        #Step 2: Compress string into binary string
        bit_string = ''
        for data in secret_data:
            print("Data: {}".format(data))
            idWord = corpus.dictionary.word2idx[data]
            bit_string += '{0:{fill}13b}'.format(int(idWord), fill='0')
        #print(ntokens)
        print("Bit String: {}".format(bit_string))
        print("Length of Bit String: {}".format(len(bit_string)))
        #print(bit_string)
        #bit_string = '111011100101000111000011110111101111110111000110011010110110'
        #In the first step we will use 256 bins (8 bit representation each) to convert so that we can convert 64 bits into 8 word
        #Step 3: Divide into bins
        secret_text = [
            int(i, 2) for i in self.string2bins(bit_string, bins_num)
        ]  #convert to bins
        #Step 4: Divide vocabulary into bins, zero words not in the bin
        if bins_num >= 2:
            tokens = list(range(ntokens))  #indecies of words
            random.shuffle(tokens)  #randomize

            #Words in each bin
            words_in_bin = int(ntokens / bins_num)
            #leftovers should be also included in the
            leftover = int(ntokens % bins_num)
            bins = [
                tokens[i:i + words_in_bin]
                for i in range(0, ntokens - leftover, words_in_bin)
            ]  # words to keep in each bin

            for i in range(0, leftover):
                bins[i].append(tokens[i + words_in_bin * bins_num])

            #save bins into key 1
            key1 = data_root + "lstm_key2.txt"
            with open(key1, "wb") as fp:  #Pickling
                pickle.dump(bins, fp)
            zero = [list(set(tokens) - set(bin_)) for bin_ in bins]

        print('Finished Initializing Second LSTM Layer')
        print('time: {:5.2f}s'.format(time.time() - epoch_start_time))
        print('-' * 89)

        final_file = data_root + final_file
        with open(final_file, 'w') as outf:
            w = 0
            i = 1
            temperature = 1.5
            bin_sequence_length = len(secret_text[:])  # 85
            print("bin sequence length", bin_sequence_length)  #32
            while i <= bin_sequence_length:
                epoch_start_time = time.time()
                output, hidden = self.lstm(input, hidden)

                zero_index = zero[secret_text[:][i - 1]]
                zero_index = torch.LongTensor(zero_index)
                word_weights = output.squeeze().data.div(
                    temperature).exp().cpu()
                word_weights.index_fill_(0, zero_index, 0)
                word_idx = torch.multinomial(word_weights, 1)[0]

                input.data.fill_(word_idx)
                word = corpus.dictionary.idx2word[word_idx]
                i += 1
                w += 1
                word = word.encode('ascii', 'ignore').decode('ascii')
                outf.write(word + ' ')
        print("Generated final steganographic text")
        print("Final text saved in following file: {}".format(
            str(data_root + final_file)))

    def leakGAN_layer(self, secret_file, final_file, bins_num):
        #Second Layer = LeakGAN layer
        print('>>> Begin Second Layer...')
        data_root = './decode/'
        torch.nn.Module.dump_patches = True
        epoch_start_time = time.time()
        # Set the random seed manually for reproducibility.
        seed = 1111
        #Step 1: load the most accurate model
        with open("leakGAN_instructor/real_data/gen_ADV_00028.pt", 'rb') as f:
            self.gen.load_state_dict(torch.load(f))
        print("Finish Loading")
        self.gen.eval()

        #Step 1: Get Intermediate text
        secret_file = data_root + secret_file
        secret_file = open(secret_file, 'r')
        secret_data = secret_file.read().split()
        #Step 2: Compress string into binary string
        bit_string = ''
        #You need LSTM Corpus for that
        emnlp_data = 'dataset/emnlp_news/'
        corpus = dataa.Corpus(emnlp_data)
        for data in secret_data:
            print("Data: {}".format(data))
            idWord = corpus.dictionary.word2idx[data]
            bit_string += '{0:{fill}13b}'.format(int(idWord), fill='0')

        secret_text = [
            int(i, 2) for i in self.string2bins(bit_string, bins_num)
        ]  #convert to bins
        corpus_leak = self.index_word_dict
        if bins_num >= 2:
            ntokens = len(corpus_leak)
            tokens = list(range(ntokens))  # * args.replication_factor
            #print(ntokens)
            random.shuffle(tokens)
            #Words in each bin
            words_in_bin = int(ntokens / bins_num)
            #leftovers should be also included in the
            leftover = int(ntokens % bins_num)
            bins = [
                tokens[i:i + words_in_bin]
                for i in range(0, ntokens - leftover, words_in_bin)
            ]  # words to keep in each bin
            for i in range(0, leftover):
                bins[i].append(tokens[i + words_in_bin * bins_num])
            #save bins into leakGAN key
            key2 = data_root + 'leakGAN_key.txt'
            with open(key2, "wb") as fp:  #Pickling
                pickle.dump(bins, fp)
            zero = [list(set(tokens) - set(bin_)) for bin_ in bins]
        print('Finished Initializing Second LeakGAN Layer')
        print('time: {:5.2f}s'.format(time.time() - epoch_start_time))
        print('-' * 89)
        out_file = data_root + final_file
        w = 0
        i = 1
        bin_sequence_length = len(secret_text[:])
        print("bin sequence length", bin_sequence_length)
        batch_size = cfg.batch_size
        seq_len = cfg.max_seq_len

        feature_array = torch.zeros(
            (batch_size, seq_len + 1, self.gen.goal_out_size))
        goal_array = torch.zeros(
            (batch_size, seq_len + 1, self.gen.goal_out_size))
        leak_out_array = torch.zeros((batch_size, seq_len + 1, cfg.vocab_size))
        samples = torch.zeros(batch_size, seq_len + 1).long()
        work_hidden = self.gen.init_hidden(batch_size)
        mana_hidden = self.gen.init_hidden(batch_size)
        leak_inp = torch.LongTensor([cfg.start_letter] * batch_size)
        real_goal = self.gen.goal_init[:batch_size, :]

        if cfg.CUDA:
            feature_array = feature_array.cuda()
            goal_array = goal_array.cuda()
            leak_out_array = leak_out_array.cuda()

        goal_array[:, 0, :] = real_goal  # g0 = goal_init
        if_sample = True
        no_log = False
        index = cfg.start_letter
        while i <= seq_len:

            dis_inp = torch.zeros(batch_size, bin_sequence_length).long()
            if i > 1:
                dis_inp[:, :i - 1] = samples[:, :i - 1]  # cut sentences
                leak_inp = samples[:, i - 2]

            if torch.cuda.is_available():
                dis_inp = dis_inp.cuda()
                leak_inp = leak_inp.cuda()
            feature = self.dis.get_feature(dis_inp).unsqueeze(0)
            #print(feature)
            feature_array[:, i - 1, :] = feature.squeeze(0)
            out, cur_goal, work_hidden, mana_hidden = self.gen(index,
                                                               leak_inp,
                                                               work_hidden,
                                                               mana_hidden,
                                                               feature,
                                                               real_goal,
                                                               no_log=no_log,
                                                               train=False)
            leak_out_array[:, i - 1, :] = out

            goal_array[:, i, :] = cur_goal.squeeze(1)
            if i > 0 and i % self.gen.step_size == 0:
                real_goal = torch.sum(goal_array[:, i - 3:i + 1, :], dim=1)
                if i / self.gen.step_size == 1:
                    real_goal += self.gen.goal_init[:batch_size, :]
            # Sample one token
            if not no_log:
                out = torch.exp(out)
            zero_index = zero[secret_text[:][
                i -
                1]]  #indecies that has to be zeroed, as they are not in the current bin
            #zero_index.append(0)
            zero_index = torch.LongTensor(zero_index)
            if cfg.CUDA:
                zero_index = zero_index.cuda()
            temperature = 1.5
            word_weights = out
            word_weights = word_weights.index_fill_(
                1, zero_index,
                0)  #make all the indecies zero if they are not in the bin
            word_weights = torch.multinomial(word_weights, 1).view(
                -1)  #choose one word with highest probability for each sample
            #print("Out after: {}".format(word_weights))
            samples[:, i] = word_weights
            leak_inp = word_weights
            i += 1
            w += 1
        leak_out_array = leak_out_array[:, :seq_len, :]
        tokens = []
        write_tokens(out_file, tensor_to_tokens(samples, self.index_word_dict))
        print("Generated final steganographic text")
        print("Final steganographic text saved in following file: {}".format(
            out_file))

    def _test_2_layers(self):
        self.LSTM_layer_1("intermediate.txt", 4096)
        if cfg.leakGAN:
            self.leakGAN_layer("intermediate.txt", "final_leakgan.txt", 4)
        else:
            self.LSTM_layer_2("intermediate.txt", "final_lstm.txt", 4)

    def _test(self):
        print('>>> Begin test...')

    def pretrain_generator(self, epochs):
        """
        Max Likelihood Pretraining for the gen

        - gen_opt: [mana_opt, work_opt]
        """
        for epoch in range(epochs):
            self.sig.update()
            if self.sig.pre_sig:
                pre_mana_loss = 0
                pre_work_loss = 0

                # =====Train=====
                for i, data in enumerate(self.oracle_data.loader):
                    inp, target = data['input'], data['target']
                    if cfg.CUDA:
                        inp, target = inp.cuda(), target.cuda()

                    mana_loss, work_loss = self.gen.pretrain_loss(
                        target, self.dis)
                    self.optimize_multi(self.gen_opt, [mana_loss, work_loss])
                    pre_mana_loss += mana_loss.data.item()
                    pre_work_loss += work_loss.data.item()
                pre_mana_loss = pre_mana_loss / len(self.oracle_data.loader)
                pre_work_loss = pre_work_loss / len(self.oracle_data.loader)

                # =====Test=====
                if epoch % cfg.pre_log_step == 0:
                    self.log.info(
                        '[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s'
                        % (epoch, pre_mana_loss, pre_work_loss,
                           self.cal_metrics(fmt_str=True)))

                    if cfg.if_save and not cfg.if_test:
                        self._save('MLE', epoch)
            else:
                self.log.info(
                    '>>> Stop by pre signal, skip to adversarial training...')
                break

    def adv_train_generator(self, g_step, current_k=0):
        """
        The gen is trained using policy gradients, using the reward from the discriminator.
        Training is done for num_batches batches.
        """

        rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA)
        adv_mana_loss = 0
        adv_work_loss = 0
        for step in range(g_step):
            with torch.no_grad():
                gen_samples = self.gen.sample(
                    cfg.batch_size, cfg.batch_size, self.dis,
                    train=True)  # !!! train=True, the only place
                inp, target = self.gen_data.prepare(gen_samples, gpu=cfg.CUDA)

            # =====Train=====
            rewards = rollout_func.get_reward_leakgan(
                target, cfg.rollout_num, self.dis,
                current_k).cpu()  # reward with MC search
            mana_loss, work_loss = self.gen.adversarial_loss(
                target, rewards, self.dis)

            # update parameters
            self.optimize_multi(self.gen_opt, [mana_loss, work_loss])
            adv_mana_loss += mana_loss.data.item()
            adv_work_loss += work_loss.data.item()
        # =====Test=====
        self.log.info(
            '[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s' %
            (adv_mana_loss / g_step, adv_work_loss / g_step,
             self.cal_metrics(fmt_str=True)))

    def train_discriminator(self, d_step, d_epoch, phrase='MLE'):
        """
        Training the discriminator on real_data_samples (positive) and generated samples from gen (negative).
        Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch.
        """
        for step in range(d_step):
            # prepare loader for training
            pos_samples = self.oracle_data.target
            neg_samples = self.gen.sample(cfg.samples_num, cfg.batch_size,
                                          self.dis)
            self.dis_data.reset(pos_samples, neg_samples)

            for epoch in range(d_epoch):
                # =====Train=====
                d_loss, train_acc = self.train_dis_epoch(
                    self.dis, self.dis_data.loader, self.dis_criterion,
                    self.dis_opt)

            # =====Test=====
            self.log.info(
                '[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' %
                (phrase, step, d_loss, train_acc))

    def cal_metrics(self, fmt_str=False):
        self.gen_data.reset(
            self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis))
        self.bleu3.test_text = tensor_to_tokens(self.gen_data.target,
                                                self.index_word_dict)
        bleu3_score = self.bleu3.get_score(ignore=False)

        with torch.no_grad():
            gen_nll = 0
            for data in self.oracle_data.loader:
                inp, target = data['input'], data['target']
                if cfg.CUDA:
                    inp, target = inp.cuda(), target.cuda()
                loss = self.gen.batchNLLLoss(target, self.dis)
                gen_nll += loss.item()
            gen_nll /= len(self.oracle_data.loader)

        if fmt_str:
            return 'BLEU-3 = %.4f, gen_NLL = %.4f,' % (bleu3_score, gen_nll)
        return bleu3_score, gen_nll

    def _save(self, phrase, epoch):
        torch.save(
            self.gen.state_dict(),
            cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phrase, epoch))
        save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format(
            phrase, epoch)
        samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis)
        write_tokens(save_sample_path,
                     tensor_to_tokens(samples, self.index_word_dict))
Ejemplo n.º 4
0
class LeakGANInstructor(BasicInstructor):
    def __init__(self, opt):
        super(LeakGANInstructor, self).__init__(opt)

        # generator, discriminator
        self.gen = LeakGAN_G(cfg.gen_embed_dim, cfg.gen_hidden_dim,
                             cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx,
                             cfg.goal_size, cfg.step_size, cfg.CUDA)
        self.dis = LeakGAN_D(cfg.dis_embed_dim,
                             cfg.vocab_size,
                             cfg.padding_idx,
                             gpu=cfg.CUDA)
        self.init_model()

        # optimizer
        mana_params, work_params = self.gen.split_params()
        mana_opt = optim.Adam(mana_params, lr=cfg.gen_lr)
        work_opt = optim.Adam(work_params, lr=cfg.gen_lr)

        self.gen_opt = [mana_opt, work_opt]
        self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)

        # Criterion
        self.mle_criterion = nn.NLLLoss()
        self.dis_criterion = nn.CrossEntropyLoss()

        # DataLoader
        self.gen_data = GenDataIter(
            self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis))
        self.dis_data = DisDataIter(self.gen_data.random_batch()['target'],
                                    self.oracle_data.random_batch()['target'])

        # Metrics
        self.bleu3 = BLEU(test_text=tensor_to_tokens(self.gen_data.target,
                                                     self.index_word_dict),
                          real_text=tensor_to_tokens(self.test_data.target,
                                                     self.index_word_dict),
                          gram=3)

    def _run(self):
        for inter_num in range(cfg.inter_epoch):
            self.log.info('>>> Interleaved Round %d...' % inter_num)
            self.sig.update()  # update signal
            if self.sig.pre_sig:
                # =====DISCRIMINATOR PRE-TRAINING=====
                if not cfg.dis_pretrain:
                    self.log.info('Starting Discriminator Training...')
                    self.train_discriminator(cfg.d_step, cfg.d_epoch)
                    if cfg.if_save and not cfg.if_test:
                        torch.save(self.dis.state_dict(),
                                   cfg.pretrained_dis_path)
                        print('Save pre-trained discriminator: {}'.format(
                            cfg.pretrained_dis_path))

                # =====GENERATOR MLE TRAINING=====
                if not cfg.gen_pretrain:
                    self.log.info('Starting Generator MLE Training...')
                    self.pretrain_generator(cfg.MLE_train_epoch)
                    if cfg.if_save and not cfg.if_test:
                        torch.save(self.gen.state_dict(),
                                   cfg.pretrained_gen_path)
                        print('Save pre-trained generator: {}'.format(
                            cfg.pretrained_gen_path))
            else:
                self.log.info(
                    '>>> Stop by pre_signal! Skip to adversarial training...')
                break

        # =====ADVERSARIAL TRAINING=====
        self.log.info('Starting Adversarial Training...')
        self.log.info('Initial generator: %s' %
                      (str(self.cal_metrics(fmt_str=True))))

        for adv_epoch in range(cfg.ADV_train_epoch):
            self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch)
            self.sig.update()
            if self.sig.adv_sig:
                self.adv_train_generator(cfg.ADV_g_step)  # Generator
                self.train_discriminator(cfg.ADV_d_step, cfg.ADV_d_epoch,
                                         'ADV')  # Discriminator

                if adv_epoch % cfg.adv_log_step == 0:
                    if cfg.if_save and not cfg.if_test:
                        self._save('ADV', adv_epoch)
            else:
                self.log.info(
                    '>>> Stop by adv_signal! Finishing adversarial training...'
                )
                break

    def _test(self):
        print('>>> Begin test...')
        self._run()
        pass

    def pretrain_generator(self, epochs):
        """
        Max Likelihood Pretraining for the gen

        - gen_opt: [mana_opt, work_opt]
        """
        for epoch in range(epochs):
            self.sig.update()
            if self.sig.pre_sig:
                pre_mana_loss = 0
                pre_work_loss = 0

                # =====Train=====
                for i, data in enumerate(self.oracle_data.loader):
                    inp, target = data['input'], data['target']
                    if cfg.CUDA:
                        inp, target = inp.cuda(), target.cuda()

                    mana_loss, work_loss = self.gen.pretrain_loss(
                        target, self.dis)
                    self.optimize_multi(self.gen_opt, [mana_loss, work_loss])
                    pre_mana_loss += mana_loss.data.item()
                    pre_work_loss += work_loss.data.item()
                pre_mana_loss = pre_mana_loss / len(self.oracle_data.loader)
                pre_work_loss = pre_work_loss / len(self.oracle_data.loader)

                # =====Test=====
                if epoch % cfg.pre_log_step == 0:
                    self.log.info(
                        '[MLE-GEN] epoch %d : pre_mana_loss = %.4f, pre_work_loss = %.4f, %s'
                        % (epoch, pre_mana_loss, pre_work_loss,
                           self.cal_metrics(fmt_str=True)))

                    if cfg.if_save and not cfg.if_test:
                        self._save('MLE', epoch)
            else:
                self.log.info(
                    '>>> Stop by pre signal, skip to adversarial training...')
                break

    def adv_train_generator(self, g_step, current_k=0):
        """
        The gen is trained using policy gradients, using the reward from the discriminator.
        Training is done for num_batches batches.
        """

        rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA)
        adv_mana_loss = 0
        adv_work_loss = 0
        for step in range(g_step):
            with torch.no_grad():
                gen_samples = self.gen.sample(
                    cfg.batch_size, cfg.batch_size, self.dis,
                    train=True)  # !!! train=True, the only place
                inp, target = self.gen_data.prepare(gen_samples, gpu=cfg.CUDA)

            # =====Train=====
            rewards = rollout_func.get_reward_leakgan(
                target, cfg.rollout_num, self.dis,
                current_k).cpu()  # reward with MC search
            mana_loss, work_loss = self.gen.adversarial_loss(
                target, rewards, self.dis)

            # update parameters
            self.optimize_multi(self.gen_opt, [mana_loss, work_loss])
            adv_mana_loss += mana_loss.data.item()
            adv_work_loss += work_loss.data.item()
        # =====Test=====
        self.log.info(
            '[ADV-GEN] adv_mana_loss = %.4f, adv_work_loss = %.4f, %s' %
            (adv_mana_loss / g_step, adv_work_loss / g_step,
             self.cal_metrics(fmt_str=True)))

    def train_discriminator(self, d_step, d_epoch, phrase='MLE'):
        """
        Training the discriminator on real_data_samples (positive) and generated samples from gen (negative).
        Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch.
        """
        for step in range(d_step):
            # prepare loader for training
            pos_samples = self.oracle_data.target
            neg_samples = self.gen.sample(cfg.samples_num, cfg.batch_size,
                                          self.dis)
            self.dis_data.reset(pos_samples, neg_samples)

            for epoch in range(d_epoch):
                # =====Train=====
                d_loss, train_acc = self.train_dis_epoch(
                    self.dis, self.dis_data.loader, self.dis_criterion,
                    self.dis_opt)

            # =====Test=====
            self.log.info(
                '[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' %
                (phrase, step, d_loss, train_acc))

    def cal_metrics(self, fmt_str=False):
        self.gen_data.reset(
            self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis))
        self.bleu3.test_text = tensor_to_tokens(self.gen_data.target,
                                                self.index_word_dict)
        bleu3_score = self.bleu3.get_score(ignore=False)

        with torch.no_grad():
            gen_nll = 0
            for data in self.oracle_data.loader:
                inp, target = data['input'], data['target']
                if cfg.CUDA:
                    inp, target = inp.cuda(), target.cuda()
                loss = self.gen.batchNLLLoss(target, self.dis)
                gen_nll += loss.item()
            gen_nll /= len(self.oracle_data.loader)

        if fmt_str:
            return 'BLEU-3 = %.4f, gen_NLL = %.4f,' % (bleu3_score, gen_nll)
        return bleu3_score, gen_nll

    def _save(self, phrase, epoch):
        torch.save(
            self.gen.state_dict(),
            cfg.save_model_root + 'gen_{}_{:05d}.pt'.format(phrase, epoch))
        save_sample_path = cfg.save_samples_root + 'samples_{}_{:05d}.txt'.format(
            phrase, epoch)
        samples = self.gen.sample(cfg.batch_size, cfg.batch_size, self.dis)
        write_tokens(save_sample_path,
                     tensor_to_tokens(samples, self.index_word_dict))
Ejemplo n.º 5
0
class GumbelGANInstructor(BasicInstructor):
    def __init__(self, opt):
        super(GumbelGANInstructor, self).__init__(opt)

        # generator, discriminator
        self.gen = GumbelGAN_G(cfg.gen_embed_dim,
                               cfg.gen_hidden_dim,
                               cfg.vocab_size,
                               cfg.max_seq_len,
                               cfg.padding_idx,
                               gpu=cfg.CUDA)
        self.dis = GumbelGAN_D2(cfg.dis_embed_dim,
                                cfg.max_seq_len,
                                cfg.vocab_size,
                                cfg.padding_idx,
                                gpu=cfg.CUDA)
        self.init_model()

        # Optimizer
        self.gen_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_lr)
        self.gen_adv_opt = optim.Adam(self.gen.parameters(), lr=cfg.gen_adv_lr)
        self.dis_opt = optim.Adam(self.dis.parameters(), lr=cfg.dis_lr)

        # Criterion
        self.mle_criterion = nn.NLLLoss()
        self.dis_pretrain_criterion = nn.BCEWithLogitsLoss()

        # DataLoader
        self.gen_data = GenDataIter(
            self.gen.sample(cfg.batch_size, cfg.batch_size))
        self.dis_data = DisDataIter(self.train_data.random_batch()['target'],
                                    self.gen_data.random_batch()['target'])

        # Metrics
        bleu_gram = list(range(2, cfg.max_seq_len +
                               1)) if cfg.max_seq_len < 5 else [2, 3, 4, 5]
        self.bleu = BLEU(test_text=tensor_to_tokens(self.gen_data.target,
                                                    self.index_word_dict),
                         real_text=tensor_to_tokens(
                             self.test_data.target,
                             self.test_data.index_word_dict),
                         gram=bleu_gram)
        self.self_bleu = BLEU(
            test_text=tensor_to_tokens(self.gen_data.target,
                                       self.index_word_dict),
            real_text=tensor_to_tokens(self.gen_data.target,
                                       self.index_word_dict),
            gram=3)

    def _run(self):
        # =====PRE-TRAINING (GENERATOR)=====
        if not cfg.gen_pretrain:
            self.log.info('Starting Generator MLE Training...')
            self.pretrain_generator(cfg.MLE_train_epoch)
            if cfg.if_save and not cfg.if_test:
                torch.save(self.gen.state_dict(), cfg.pretrained_gen_path)
                print('Save pretrain_generator: {}'.format(
                    cfg.pretrained_gen_path))

        # =====PRE-TRAINING (DISCRIMINATOR)=====
        if not cfg.dis_pretrain:
            self.log.info('Starting Discriminator Training...')
            self.pretrain_discriminator(cfg.d_step, cfg.d_epoch)
            if cfg.if_save and not cfg.if_test:
                torch.save(self.dis.state_dict(), cfg.pretrained_dis_path)
                print('Save pretrain_generator discriminator: {}'.format(
                    cfg.pretrained_dis_path))

        # =====ADVERSARIAL TRAINING=====
        self.log.info('Starting Adversarial Training...')

        for adv_epoch in range(cfg.ADV_train_epoch):
            if adv_epoch % cfg.adv_log_step == 0:
                self.log.info('-----\nADV EPOCH %d\n-----' % adv_epoch)
            self.sig.update()
            if self.sig.adv_sig:
                self.adv_train_generator(cfg.ADV_g_step,
                                         adv_epoch)  # Generator
                self.adv_train_discriminator(cfg.ADV_d_step,
                                             adv_epoch)  # Discriminator
                self.update_temperature(
                    adv_epoch, cfg.ADV_train_epoch)  # update temperature

                if adv_epoch % cfg.adv_log_step == 0:
                    if cfg.if_save and not cfg.if_test:
                        self._save('ADV', adv_epoch)
            else:
                self.log.info(
                    '>>> Stop by adv_signal! Finishing adversarial training...'
                )
                break

    def _test(self):
        print('>>> Begin test...')

        self._run()
        pass

    def pretrain_generator(self, epochs):
        """
        Max Likelihood Pre-training for the generator
        """
        for epoch in range(epochs):
            self.sig.update()
            if self.sig.pre_sig:
                # =====Train=====
                pre_loss = self.train_gen_epoch(self.gen,
                                                self.train_data.loader,
                                                self.mle_criterion,
                                                self.gen_opt)

                # =====Test=====
                if epoch % cfg.pre_log_step == 0 or epoch == epochs - 1:
                    self.log.info('[MLE-GEN] epoch %d : pre_loss = %.4f' %
                                  (epoch, pre_loss))
                    if cfg.if_save and not cfg.if_test:
                        self._save('MLE', epoch)
            else:
                self.log.info(
                    '>>> Stop by pre signal, skip to adversarial training...')
                break
        if cfg.if_save and not cfg.if_test:
            self._save('MLE', epoch)

    def adv_train_generator(self, g_step, adv_epoch):
        # true_ge = TrueGradientEstimator()  TODO

        total_loss = 0
        for step in range(g_step):
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size,
                                          cfg.batch_size,
                                          one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(
                ), gen_samples.cuda()
            real_samples = F.one_hot(real_samples, cfg.vocab_size).float()

            # =====Train=====
            # vanilla_theta = self.gen.sample_vanilla_theta()
            # true_ge = true_ge.estimate_gradient(vanilla_theta...)  TODO

            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            theta_gradient = self.optimize(
                self.gen_adv_opt,
                g_loss,
                self.gen,
                theta_gradient_fetcher=self.gen.get_theta_gradient)
            theta_gradient_log_var = get_gradient_variance(theta_gradient)
            total_loss += g_loss.item()

        # =====Test=====
        avg_loss = total_loss / g_step if g_step != 0 else 0
        if adv_epoch % cfg.adv_log_step == 0:
            self.log.info(
                '[ADV-GEN] g_loss = %.4f, temperature = %.4f, theta_gradient_log_var = %.4f'
                % (avg_loss, self.gen.temperature, theta_gradient_log_var))

    def adv_train_discriminator(self, d_step, adv_epoch):
        total_loss = 0
        total_acc = 0
        for step in range(d_step):
            # TODO(ethanjiang) we may want to train a full epoch instead of a random batch
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size,
                                          cfg.batch_size,
                                          one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(
                ), gen_samples.cuda()
            real_samples = F.one_hot(real_samples, cfg.vocab_size).float()

            # =====Train=====
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.dis_opt, d_loss, self.dis)

            total_loss += d_loss.item()
            predictions = torch.cat((d_out_real, d_out_fake))
            labels = torch.cat(
                (torch.ones_like(d_out_real), torch.zeros_like(d_out_fake)))
            total_acc += torch.sum(
                ((predictions > 0).float() == labels)).item()

        # =====Test=====
        avg_loss = total_loss / d_step if d_step != 0 else 0
        avg_acc = total_acc / (d_step * cfg.batch_size *
                               2) if d_step != 0 else 0
        if adv_epoch % cfg.adv_log_step == 0:
            self.log.info('[ADV-DIS] d_loss = %.4f, train_acc = %.4f,' %
                          (avg_loss, avg_acc))

    def update_temperature(self, i, N):
        self.gen.temperature = get_fixed_temperature(cfg.temperature, i, N,
                                                     cfg.temp_adpt)

    @staticmethod
    def optimize(opt,
                 loss,
                 model=None,
                 retain_graph=False,
                 theta_gradient_fetcher=None):
        opt.zero_grad()
        loss.backward(retain_graph=retain_graph)
        theta_gradient = None if theta_gradient_fetcher is None else theta_gradient_fetcher(
        )
        if model is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_norm)
        opt.step()
        return theta_gradient

    def pretrain_discriminator(self, d_step, d_epoch, phrase='MLE'):
        """
        Training the discriminator on real_data_samples (positive) and generated samples from gen (negative).
        Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch.
        """
        # prepare loader for validate
        for step in range(d_step):
            # prepare loader for training
            pos_samples = self.train_data.target
            neg_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size)
            self.dis_data.reset(pos_samples, neg_samples)

            for epoch in range(d_epoch):
                # =====Train=====
                d_loss, train_acc = self.train_dis_epoch(
                    self.dis,
                    self.dis_data.loader,
                    self.dis_pretrain_criterion,
                    self.dis_opt,
                    one_hot=True)

            # =====Test=====
            self.log.info(
                '[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' %
                (phrase, step, d_loss, train_acc))