forked from NLPQA/WikiQA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ask_pipeline.py
157 lines (139 loc) · 5.12 KB
/
ask_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import math
import doc_parser
import nltk
import tree_parser
import ask
import stanford_utils
import ginger_python2 as grammar_checker
import sys
tagger = stanford_utils.new_NERtagger()
why_keywords = ["because"]
def contains_reason(sent):
for why_keyword in why_keywords:
if why_keyword in sent:
return True
return False
def contains_time(tagged_sent):
for tup in tagged_sent:
if tup[1] == "DATE" or tup[1] == "TIME":
return True
return False
def contains_loc(tagged_sent):
for tup in tagged_sent:
if tup[1] == "LOCATION" or tup[1] == "ORGANIZATION":
return True
return False
def contains_name(tagged_sent):
for tup in tagged_sent:
if tup[1] == "PERSON":
return True
elif tup[0].lower() == "he" or tup[0].lower() == "she":
return True
return False
def contains_quant(sent, tagged_sent):
tokens = nltk.tokenize.word_tokenize(sent)
for i in xrange(0, len(tokens)):
if str.isdigit(str(tokens[i])):
if i + 1 < len(tokens) and tagged_sent[i+1][1].endswith('s'):
return True
return False
def preprocess_sents(sents):
preds = []
for sent in sents:
tree = tree_parser.sent_to_tree(sent)
if tree_parser.contains_appos(tree):
preds += tree_parser.appps_to_sents(tree)
else:
pred = tree_parser.sent_to_predicate(tree)
preds.append(pred)
return preds
def main(wiki_path, n):
title, sents = doc_parser.doc_to_sents(wiki_path)
questions = []
sents = [sent for sent in sents if 10 <= sent.count(" ") <= 30]
sents = sents[:3*n]
# preds = []
# for sent in sents:
# tree = tree_parser.sent_to_tree(sent)
# if tree_parser.contains_appos(tree):
# preds += tree_parser.appps_to_sents(tree)
# else:
# pred = tree_parser.sent_to_predicate(tree)
# if 10 <= pred.count(" ") <= 30:
# preds.append(pred)
# if len(preds) > 2*n:
# break
# for pred in preds:
# print pred
for sent in sents:
parsed_sent = tree_parser.sent_to_tree(sent)
pps = tree_parser.get_phrases(parsed_sent, "PP", False, False)
tagged_sent = tagger.tag(nltk.tokenize.word_tokenize(sent))
# bonus for average len
score = (20 - math.fabs(sent.count(" ")-10))*0.5
# bonus for more pps
score += len(pps)-1
# bonus for question difficulties
# distribute sents to generators
# why
if contains_reason(tagged_sent):
question = ask.get_why(sent).capitalize()
# correct grammar and find errors
question, errs = grammar_checker.correct_sent(question)
# deductions for errors
questions.append((question, score-errs+5))
# how-many
elif contains_quant(sent, tagged_sent):
question = ask.get_howmany(sent).capitalize()
# correct grammar and find errors
question, errs = grammar_checker.correct_sent(question)
# deductions for errors
questions.append((question, score-errs+5))
# when
if contains_time(tagged_sent):
question = ask.get_when(sent).capitalize()
# correct grammar and find errors
question, errs = grammar_checker.correct_sent(question)
# deductions for errors
if (len(question) > 29):
questions.append((question, score-errs+4))
# where
if contains_loc(tagged_sent):
question = ask.get_where(sent).capitalize()
# correct grammar and find errors
question, errs = grammar_checker.correct_sent(question)
# deductions for errors
questions.append((question, score-errs+4))
# who/what
if contains_name(tagged_sent):
question = ask.get_who(parsed_sent).capitalize()
# correct grammar and find errors
question, errs = grammar_checker.correct_sent(question)
# deductions for errors
questions.append((question, score-errs+3))
else:
question = ask.get_what(parsed_sent).capitalize()
# correct grammar and find errors
question, errs = grammar_checker.correct_sent(question)
# deductions for errors
questions.append((question, score-errs+2))
# binary question
binary_q = ask.get_binary(sent, twist=False).capitalize()
binary_q, errs = grammar_checker.correct_sent(binary_q)
# deductions for errors
questions.append((binary_q, score-errs+2))
ranked_questions = sorted(questions, key=lambda x:(-x[1],x[0]))
ranked_questions = [q for q in ranked_questions if len(q[0]) > 0][:n]
for question in ranked_questions:
sys.stdout.write(question[0]+" "+"\n")
# import time
# for i in xrange(1, 9):
# start = time.time()
# if i == 4:
# continue
# print i
# wiki_path = "test/a"+str(i)+".htm"
# main(wiki_path, i)
# print time.time() - start
main("test/a6.htm", 10)
# main(sys.argv[1], int(sys.argv[2]))