def create_many_oracle(from_a, to_b, num=1, save_path='../pretrain/'): for i in range(num): while True: oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) if cfg.CUDA: oracle = oracle.cuda() big_samples = oracle.sample(cfg.samples_num, 8 * cfg.batch_size) small_samples = oracle.sample(cfg.samples_num // 2, 8 * cfg.batch_size) oracle_data = GenDataIter(big_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) if from_a <= groud_truth <= to_b: print('save ground truth: ', groud_truth) prefix = 'oracle_lstm' torch.save(oracle.state_dict(), save_path + '{}.pt'.format(prefix)) torch.save( big_samples, save_path + '{}_samples_{}.pt'.format(prefix, cfg.samples_num)) torch.save( small_samples, save_path + '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2)) break
def create_multi_oracle(number): for i in range(number): print('Creating Oracle %d...' % i) oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) if cfg.CUDA: oracle = oracle.cuda() large_samples = oracle.sample(cfg.samples_num, 4 * cfg.batch_size) small_samples = oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size) torch.save(oracle.state_dict(), cfg.multi_oracle_state_dict_path.format(i)) torch.save(large_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num)) torch.save( small_samples, cfg.multi_oracle_samples_path.format(i, cfg.samples_num // 2)) oracle_data = GenDataIter(large_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) print('Oracle %d Groud Truth: %.4f' % (i, groud_truth))
def create_oracle(): """Create a new Oracle model and Oracle's samples""" import config as cfg from models.Oracle import Oracle print('Creating Oracle...') oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) if cfg.CUDA: oracle = oracle.cuda() torch.save(oracle.state_dict(), cfg.oracle_state_dict_path) big_samples = oracle.sample(cfg.samples_num, 4 * cfg.batch_size) # large torch.save(big_samples, cfg.oracle_samples_path.format(cfg.samples_num)) # small torch.save(oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size), cfg.oracle_samples_path.format(cfg.samples_num // 2)) oracle_data = GenDataIter(big_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) print('NLL_Oracle Groud Truth: %.4f' % groud_truth)
def create_oracle(): """Create a new Oracle model and Oracle's samples""" from models.Oracle import Oracle print('Creating Oracle...') oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) oracle = oracle.cuda() torch.save(oracle.state_dict(), cfg.oracle_state_dict_path) # large torch.save(oracle.sample(cfg.samples_num, 4 * cfg.batch_size), cfg.oracle_samples_path.format(cfg.samples_num)) # small torch.save(oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size), cfg.oracle_samples_path.format(cfg.samples_num // 2))
def create_oracle(): oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) oracle = oracle.cuda() torch.save(oracle.state_dict(), cfg.oracle_state_dict_path) # large torch.save(oracle.sample(cfg.samples_num, 4 * cfg.batch_size), cfg.oracle_samples_path.format(cfg.samples_num)) # small torch.save(oracle.sample(cfg.samples_num // 2, 4 * cfg.batch_size), cfg.oracle_samples_path.format(cfg.samples_num // 2))
def create_specific_oracle(from_a, to_b, num=1, save_path='../pretrain/'): for i in range(num): while True: oracle = Oracle(cfg.gen_embed_dim, cfg.gen_hidden_dim, cfg.vocab_size, cfg.max_seq_len, cfg.padding_idx, gpu=cfg.CUDA) if cfg.CUDA: oracle = oracle.cuda() big_samples = oracle.sample(cfg.samples_num, 8 * cfg.batch_size) small_samples = oracle.sample(cfg.samples_num // 2, 8 * cfg.batch_size) oracle_data = GenDataIter(big_samples) mle_criterion = nn.NLLLoss() groud_truth = NLL.cal_nll(oracle, oracle_data.loader, mle_criterion) if from_a <= groud_truth <= to_b: dir_path = save_path + 'oracle_data_gt{:.2f}_{}'.format( groud_truth, strftime("%m%d_%H%M%S", localtime())) if not os.path.exists(dir_path): os.mkdir(dir_path) print('save ground truth: ', groud_truth) # prefix = 'oracle{}_lstm_gt{:.2f}_{}'.format(i, groud_truth, strftime("%m%d", localtime())) prefix = dir_path + '/oracle_lstm' torch.save(oracle.state_dict(), '{}.pt'.format(prefix)) torch.save(big_samples, '{}_samples_{}.pt'.format(prefix, cfg.samples_num)) torch.save( small_samples, '{}_samples_{}.pt'.format(prefix, cfg.samples_num // 2)) break