示例#1
0
def create_many_oracle(from_a, to_b, num=1, save_path='../pretrain/'):
    for i in range(num):
        while True:
            oracle = Oracle(cfg.gen_embed_dim,
                            cfg.gen_hidden_dim,
                            cfg.vocab_size,
                            cfg.max_seq_len,
                            cfg.padding_idx,
                            gpu=cfg.CUDA)
            if cfg.CUDA:
                oracle = oracle.cuda()

            big_samples = oracle.sample(cfg.samples_num, 8 * cfg.batch_size)
            small_samples = oracle.sample(cfg.samples_num // 2,
                                          8 * cfg.batch_size)

            oracle_data = GenDataIter(big_samples)
            mle_criterion = nn.NLLLoss()
            groud_truth = NLL.cal_nll(oracle, oracle_data.loader,
                                      mle_criterion)

            if from_a <= groud_truth <= to_b:
                print('save ground truth: ', groud_truth)
                prefix = 'oracle_lstm'
                torch.save(oracle.state_dict(),
                           save_path + '{}.pt'.format(prefix))
                torch.save(
                    big_samples, save_path +
                    '{}_samples_{}.pt'.format(prefix, cfg.samples_num))
                torch.save(
                    small_samples, save_path +
                    '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2))
                break
示例#2
0
def create_multi_oracle(number):
    for i in range(number):
        print('Creating Oracle %d...' % i)
        oracle = Oracle(cfg.gen_embed_dim,
                        cfg.gen_hidden_dim,
                        cfg.vocab_size,
                        cfg.max_seq_len,
                        cfg.padding_idx,
                        gpu=cfg.CUDA)
        if cfg.CUDA:
            oracle = oracle.cuda()
        large_samples = oracle.sample(cfg.samples_num, 4 * cfg.batch_size)
        small_samples = oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size)

        torch.save(oracle.state_dict(),
                   cfg.multi_oracle_state_dict_path.format(i))
        torch.save(large_samples,
                   cfg.multi_oracle_samples_path.format(i, cfg.samples_num))
        torch.save(
            small_samples,
            cfg.multi_oracle_samples_path.format(i, cfg.samples_num // 2))

        oracle_data = GenDataIter(large_samples)
        mle_criterion = nn.NLLLoss()
        groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion)
        print('Oracle %d Groud Truth: %.4f' % (i, groud_truth))
示例#3
0
def create_oracle():
    """Create a new Oracle model and Oracle's samples"""
    import config as cfg
    from models.Oracle import Oracle

    print('Creating Oracle...')
    oracle = Oracle(cfg.gen_embed_dim,
                    cfg.gen_hidden_dim,
                    cfg.vocab_size,
                    cfg.max_seq_len,
                    cfg.padding_idx,
                    gpu=cfg.CUDA)
    if cfg.CUDA:
        oracle = oracle.cuda()

    torch.save(oracle.state_dict(), cfg.oracle_state_dict_path)

    big_samples = oracle.sample(cfg.samples_num, 4 * cfg.batch_size)
    # large
    torch.save(big_samples, cfg.oracle_samples_path.format(cfg.samples_num))
    # small
    torch.save(oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size),
               cfg.oracle_samples_path.format(cfg.samples_num // 2))

    oracle_data = GenDataIter(big_samples)
    mle_criterion = nn.NLLLoss()
    groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion)
    print('NLL_Oracle Groud Truth: %.4f' % groud_truth)
示例#4
0
    def evaluation(self, eval_type):
        """Evaluation all children, update child score. Note that the eval data should be the same"""

        eval_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size,
                                       cfg.max_bn * cfg.batch_size)
        gen_data = GenDataIter(eval_samples)

        # Fd
        if cfg.lambda_fd != 0:
            Fd = NLL.cal_nll(self.gen, gen_data.loader,
                             self.mle_criterion)  # NLL_div
        else:
            Fd = 0

        if eval_type == 'standard':
            Fq = self.eval_d_out_fake.mean().cpu().item()
        elif eval_type == 'rsgan':
            g_loss, d_loss = get_losses(self.eval_d_out_real,
                                        self.eval_d_out_fake, 'rsgan')
            Fq = d_loss.item()
        elif eval_type == 'nll':
            if cfg.lambda_fq != 0:
                Fq = -NLL.cal_nll(self.oracle, gen_data.loader,
                                  self.mle_criterion)  # NLL_Oracle
            else:
                Fq = 0
        elif eval_type == 'Ra':
            g_loss = torch.sigmoid(self.eval_d_out_fake -
                                   torch.mean(self.eval_d_out_real)).sum()
            Fq = g_loss.item()
        else:
            raise NotImplementedError("Evaluation '%s' is not implemented" %
                                      eval_type)

        score = cfg.lambda_fq * Fq + cfg.lambda_fd * Fd
        return Fq, Fd, score
示例#5
0
    def evaluation(self, eval_type):
        """Evaluation all children, update child score. Note that the eval data should be the same"""
        eval_samples = self.gen.sample(cfg.eval_b_num * cfg.batch_size,
                                       cfg.max_bn * cfg.batch_size)
        gen_data = GenDataIter(eval_samples)

        # Fd
        if cfg.lambda_fd != 0:
            Fd = NLL.cal_nll(self.gen, gen_data.loader,
                             self.mle_criterion)  # NLL_div
        else:
            Fd = 0

        # Fq
        if eval_type == 'standard':
            Fq = self.eval_d_out_fake.mean().cpu().item()
        elif eval_type == 'rsgan':
            g_loss, d_loss = get_losses(self.eval_d_out_real,
                                        self.eval_d_out_fake, 'rsgan')
            Fq = d_loss.item()
        elif 'bleu' in eval_type:
            self.bleu.reset(
                test_text=tensor_to_tokens(eval_samples, self.idx2word_dict))

            if cfg.lambda_fq != 0:
                Fq = self.bleu.get_score(given_gram=int(eval_type[-1]))
            else:
                Fq = 0
        elif 'Ra' in eval_type:
            g_loss = torch.sigmoid(self.eval_d_out_fake -
                                   torch.mean(self.eval_d_out_real)).sum()
            Fq = g_loss.item()
        else:
            raise NotImplementedError("Evaluation '%s' is not implemented" %
                                      eval_type)

        score = cfg.lambda_fq * Fq + cfg.lambda_fd * Fd
        return Fq, Fd, score
示例#6
0
def create_specific_oracle(from_a, to_b, num=1, save_path='../pretrain/'):
    for i in range(num):
        while True:
            oracle = Oracle(cfg.gen_embed_dim,
                            cfg.gen_hidden_dim,
                            cfg.vocab_size,
                            cfg.max_seq_len,
                            cfg.padding_idx,
                            gpu=cfg.CUDA)
            if cfg.CUDA:
                oracle = oracle.cuda()

            big_samples = oracle.sample(cfg.samples_num, 8 * cfg.batch_size)
            small_samples = oracle.sample(cfg.samples_num // 2,
                                          8 * cfg.batch_size)

            oracle_data = GenDataIter(big_samples)
            mle_criterion = nn.NLLLoss()
            groud_truth = NLL.cal_nll(oracle, oracle_data.loader,
                                      mle_criterion)

            if from_a <= groud_truth <= to_b:
                dir_path = save_path + 'oracle_data_gt{:.2f}_{}'.format(
                    groud_truth, strftime("%m%d_%H%M%S", localtime()))
                if not os.path.exists(dir_path):
                    os.mkdir(dir_path)
                print('save ground truth: ', groud_truth)
                # prefix = 'oracle{}_lstm_gt{:.2f}_{}'.format(i, groud_truth, strftime("%m%d", localtime()))
                prefix = dir_path + '/oracle_lstm'
                torch.save(oracle.state_dict(), '{}.pt'.format(prefix))
                torch.save(big_samples,
                           '{}_samples_{}.pt'.format(prefix, cfg.samples_num))
                torch.save(
                    small_samples,
                    '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2))
                break