コード例 #1
0
ファイル: test.py プロジェクト: dkloz/query_completion
def test_model(dataset_name, context, testdata):
    tf.reset_default_graph()
    exp_dir = os.path.join(expdir, dataset_name, context)

    metamodel = MetaModel(exp_dir)
    model_loaded = metamodel.model
    metamodel.MakeSessionAndRestore(threads)

    total_word_count = 0
    total_log_prob = 0
    results = []

    for idx in range(len(testdata.df) / testdata.batch_size):
        feed_dict = testdata.GetFeedDict(model_loaded)
        c, words_in_batch, sentence_costs = metamodel.session.run([
            model_loaded.avg_loss, model_loaded.words_in_batch,
            model_loaded.per_sentence_loss
        ], feed_dict)

        total_word_count += words_in_batch
        total_log_prob += float(c * words_in_batch)
        print '{0}\t{1:.3f}'.format(idx,
                                    np.exp(total_log_prob / total_word_count))

        lens = feed_dict[model_loaded.query_lengths]
        for length, sentence_cost in zip(lens, sentence_costs):
            data_row = {'length': length, 'cost': sentence_cost}
            results.append(data_row)

    results = pandas.DataFrame(results)
    results.to_csv(os.path.join(exp_dir, 'pplstats.csv'))

    idx = len(testdata.df) / testdata.batch_size
    print '{0}\t{1:.3f}'.format(idx, np.exp(total_log_prob / total_word_count))
コード例 #2
0
parser = argparse.ArgumentParser()
parser.add_argument('expdir', help='experiment directory')
parser.add_argument('--threads',
                    type=int,
                    default=12,
                    help='how many threads to use in tensorflow')
args = parser.parse_args()

df = pandas.read_csv('/g/ssli/data/LowResourceLM/aol/queries01.dev.txt.gz',
                     sep='\t',
                     header=None)
df.columns = ['user', 'query_', 'date']
df['user'] = df.user.apply(lambda x: 's' + str(x))

m = MetaModel(args.expdir)  # Load the model
m.MakeSessionAndRestore(args.threads)

for i in range(23000):
    row = df.iloc[i]
    query_len = len(row.query_)

    if query_len <= 3:
        continue

    prefix_len = GetPrefixLen(row.user, row.query_)
    prefix = row.query_[:prefix_len]
    b = GetCompletions(['<S>'] + list(prefix),
                       m.user_vocab[row.user],
                       m,
                       branching_factor=4)
    qlist = [''.join(q.words[1:-1]) for q in reversed(list(b))]
コード例 #3
0
                    type=str,
                    action='append',
                    dest='data',
                    default=[data_dir + "queries07.test.txt.gz"],
                    help='where to load the data')
parser.add_argument('--threads',
                    type=int,
                    default=12,
                    help='how many threads to use in tensorflow')
args = parser.parse_args()
expdir = args.expdir

# 模型加载
metamodel = MetaModel(expdir)
model = metamodel.model
metamodel.MakeSessionAndRestore(args.threads)
# 数据加载
df = LoadData(args.data)
dataset = Dataset(df,
                  metamodel.char_vocab,
                  metamodel.user_vocab,
                  max_len=metamodel.params.max_len)

total_word_count = 0
total_log_prob = 0
print(len(dataset.df), dataset.batch_size)  # 20999    24
for idx in range(0, int(len(dataset.df) / dataset.batch_size)):
    feed_dict = dataset.GetFeedDict(model)
    # 这里的session 是 获取的是 保存后的模型
    c, words_in_batch = metamodel.session.run(
        [model.avg_loss, model.words_in_batch], feed_dict)