Beispiel #1
0
def create_many_oracle(from_a, to_b, num=1, save_path='../pretrain/'):
    for i in range(num):
        while True:
            oracle = Oracle(cfg.gen_embed_dim,
                            cfg.gen_hidden_dim,
                            cfg.vocab_size,
                            cfg.max_seq_len,
                            cfg.padding_idx,
                            gpu=cfg.CUDA)
            if cfg.CUDA:
                oracle = oracle.cuda()

            big_samples = oracle.sample(cfg.samples_num, 8 * cfg.batch_size)
            small_samples = oracle.sample(cfg.samples_num // 2,
                                          8 * cfg.batch_size)

            oracle_data = GenDataIter(big_samples)
            mle_criterion = nn.NLLLoss()
            groud_truth = NLL.cal_nll(oracle, oracle_data.loader,
                                      mle_criterion)

            if from_a <= groud_truth <= to_b:
                print('save ground truth: ', groud_truth)
                prefix = 'oracle_lstm'
                torch.save(oracle.state_dict(),
                           save_path + '{}.pt'.format(prefix))
                torch.save(
                    big_samples, save_path +
                    '{}_samples_{}.pt'.format(prefix, cfg.samples_num))
                torch.save(
                    small_samples, save_path +
                    '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2))
                break
Beispiel #2
0
def create_multi_oracle(number):
    for i in range(number):
        print('Creating Oracle %d...' % i)
        oracle = Oracle(cfg.gen_embed_dim,
                        cfg.gen_hidden_dim,
                        cfg.vocab_size,
                        cfg.max_seq_len,
                        cfg.padding_idx,
                        gpu=cfg.CUDA)
        if cfg.CUDA:
            oracle = oracle.cuda()
        large_samples = oracle.sample(cfg.samples_num, 4 * cfg.batch_size)
        small_samples = oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size)

        torch.save(oracle.state_dict(),
                   cfg.multi_oracle_state_dict_path.format(i))
        torch.save(large_samples,
                   cfg.multi_oracle_samples_path.format(i, cfg.samples_num))
        torch.save(
            small_samples,
            cfg.multi_oracle_samples_path.format(i, cfg.samples_num // 2))

        oracle_data = GenDataIter(large_samples)
        mle_criterion = nn.NLLLoss()
        groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion)
        print('Oracle %d Groud Truth: %.4f' % (i, groud_truth))
Beispiel #3
0
def create_oracle():
    """Create a new Oracle model and Oracle's samples"""
    import config as cfg
    from models.Oracle import Oracle

    print('Creating Oracle...')
    oracle = Oracle(cfg.gen_embed_dim,
                    cfg.gen_hidden_dim,
                    cfg.vocab_size,
                    cfg.max_seq_len,
                    cfg.padding_idx,
                    gpu=cfg.CUDA)
    if cfg.CUDA:
        oracle = oracle.cuda()

    torch.save(oracle.state_dict(), cfg.oracle_state_dict_path)

    big_samples = oracle.sample(cfg.samples_num, 4 * cfg.batch_size)
    # large
    torch.save(big_samples, cfg.oracle_samples_path.format(cfg.samples_num))
    # small
    torch.save(oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size),
               cfg.oracle_samples_path.format(cfg.samples_num // 2))

    oracle_data = GenDataIter(big_samples)
    mle_criterion = nn.NLLLoss()
    groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion)
    print('NLL_Oracle Groud Truth: %.4f' % groud_truth)
def create_oracle():
    """Create a new Oracle model and Oracle's samples"""
    from models.Oracle import Oracle
    print('Creating Oracle...')
    oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size,
                    cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA)
    oracle = oracle.cuda()

    torch.save(oracle.state_dict(), cfg.oracle_state_dict_path)

    # large
    torch.save(oracle.sample(cfg.samples_num, 4 * cfg.batch_size),
               cfg.oracle_samples_path.format(cfg.samples_num))
    # small
    torch.save(oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size),
               cfg.oracle_samples_path.format(cfg.samples_num // 2))
Beispiel #5
0
def create_oracle():
    oracle = Oracle(cfg.gen_embed_dim,
                    cfg.gen_hidden_dim,
                    cfg.vocab_size,
                    cfg.max_seq_len,
                    cfg.padding_idx,
                    gpu=cfg.CUDA)
    oracle = oracle.cuda()

    torch.save(oracle.state_dict(), cfg.oracle_state_dict_path)

    # large
    torch.save(oracle.sample(cfg.samples_num, 4 * cfg.batch_size),
               cfg.oracle_samples_path.format(cfg.samples_num))
    # small
    torch.save(oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size),
               cfg.oracle_samples_path.format(cfg.samples_num // 2))
Beispiel #6
0
def create_specific_oracle(from_a, to_b, num=1, save_path='../pretrain/'):
    for i in range(num):
        while True:
            oracle = Oracle(cfg.gen_embed_dim,
                            cfg.gen_hidden_dim,
                            cfg.vocab_size,
                            cfg.max_seq_len,
                            cfg.padding_idx,
                            gpu=cfg.CUDA)
            if cfg.CUDA:
                oracle = oracle.cuda()

            big_samples = oracle.sample(cfg.samples_num, 8 * cfg.batch_size)
            small_samples = oracle.sample(cfg.samples_num // 2,
                                          8 * cfg.batch_size)

            oracle_data = GenDataIter(big_samples)
            mle_criterion = nn.NLLLoss()
            groud_truth = NLL.cal_nll(oracle, oracle_data.loader,
                                      mle_criterion)

            if from_a <= groud_truth <= to_b:
                dir_path = save_path + 'oracle_data_gt{:.2f}_{}'.format(
                    groud_truth, strftime("%m%d_%H%M%S", localtime()))
                if not os.path.exists(dir_path):
                    os.mkdir(dir_path)
                print('save ground truth: ', groud_truth)
                # prefix = 'oracle{}_lstm_gt{:.2f}_{}'.format(i, groud_truth, strftime("%m%d", localtime()))
                prefix = dir_path + '/oracle_lstm'
                torch.save(oracle.state_dict(), '{}.pt'.format(prefix))
                torch.save(big_samples,
                           '{}_samples_{}.pt'.format(prefix, cfg.samples_num))
                torch.save(
                    small_samples,
                    '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2))
                break