def _toforest(self, lmforest, sent, cache, level=0): if self in cache: return cache[self] this_node = Node("", "X [0-0]", 1, Vector(), sent) # no node fv, temporary iden="" is_root = level == 0 # don't include final </s> </s> for prev_states, extra_fv, extra_words in self.backptrs: prev_nodes = [p._toforest(lmforest, sent, cache, level+1) for p in prev_states if p is not None] if is_root: extra_words = extra_words[:-LMState.lm.order+1] edge = Hyperedge(this_node, prev_nodes, extra_fv, prev_nodes + LMState.lm.ppqstr(extra_words).split()) edge.rule = Rule("a(\"a\")", "b", "") edge.rule.ruleid = 1 this_node.add_edge(edge) cache[self] = this_node this_node.iden = str(len(cache)) # post-order id lmforest.add_node(this_node) if is_root: lmforest.root = this_node return this_node
def recover_oracle(self): '''oracle is already stored implicitly in the forest returns best_score, best_parseval, best_tree, edgelist ''' edgelist = self.root.get_oracle_edgelist() fv = Hyperedge.deriv2fvector(edgelist) tr = Hyperedge.deriv2tree(edgelist) bleu_p1 = self.bleu.rescore(tr) return bleu_p1, tr, fv, edgelist
def extract_oracle(forest): """oracle is already stored implicitly in the forest returns best_score, best_parseval, best_tree, edgelist """ global implicit_oracle implicit_oracle = True edgelist = get_edgelist(forest.root) fv = Hyperedge.deriv2fvector(edgelist) tr = Hyperedge.deriv2tree(edgelist) return fv[0], Parseval(), tr, edgelist
def start_state(root): ''' None -> <s>^{g-1} . TOP </s>^{g-1} ''' ## LMState.cache = {} lmstr = LMState.lm.raw_startsyms() lhsstr = lmstr + [root] + LMState.lm.raw_stopsyms() edge = Hyperedge(None, [root], Vector(), lhsstr) edge.lmlhsstr = LMState.lm.startsyms() + [root] + LMState.lm.stopsyms() edge.rule = Rule.parse("ROOT(TOP) -> x0 ### ") sc = root.bestres[0] if FLAGS.futurecost else 0 return LMState(None, [DottedRule(edge, dot=len(lmstr))], LMState.lm.startsyms(), step=0, score=sc) # future cost
def _oracle(self, forest): sc, parseval, tr, edgelist = forest_oracle(forest, forest.goldtree) forest.oracle_tree = tr forest.oracle_fvector = Hyperedge.deriv2fvector(edgelist) if Decoder.MAX_NUM_BRACKETS < 0: forest.oracle_size_ratio = 1 else: forest.oracle_size_ratio = len(tr.all_label_spans()) / Decoder.MAX_NUM_BRACKETS
def add_lex_th(self, lhs, node, isBP): ''' add lexical translation rules ''' ruleset = self.ruleset if lhs in ruleset: rules = ruleset[lhs] # add all translation hyperedges for rule in rules: newrhs = [x[1:-1] for x in rule.rhs] tfedge = Hyperedge(node, [], Vector(rule.fields), newrhs) tfedge.rule = rule node.edges.append(tfedge) elif not isBP: # add a default translation hyperedge (monotonic) defword = '"%s"' % node.word rule = Rule(lhs, [defword], self.deffields) tfedge = Hyperedge(node, [], Vector(self.deffields), [defword[1:-1]]) tfedge.rule = rule ruleset.add_rule(rule) node.edges.append(tfedge)
def load(filename, lower=True, sentid=0): '''now return a generator! use load().next() for singleton. and read the last line as the gold tree -- TODO: optional! and there is an empty line at the end ''' file = getfile(filename) line = None total_time = 0 num_sents = 0 while True: start_time = time.time() ##'\tThe complicated language in ...\n" ## tag is often missing try: if line is None or line == "\n": line = "\n" while line == "\n": line = file.readline() # emulate seek tag, sent = line.split("\t") except: ## no more forests break num_sents += 1 sent = sent.split() cased_sent = sent [:] if lower: sent = [w.lower() for w in sent] # mark johnson: lowercase all words num = int(file.readline()) forest = Forest(num, sent, cased_sent, tag) forest.labelspans = {} forest.short_edges = {} delta = num_spu = 0 for i in xrange(1, num+1): ## '2\tDT* [0-1]\t1 ||| 1232=2 ...\n' ## node-based features here: wordedges, greedyheavy, word(1), [word(2)], ... line = file.readline() try: keys, fields = line.split(" ||| ") except: keys = line fields = "" iden, labelspan, size = keys.split("\t") ## iden can be non-ints size = int(size) fvector = FVector.parse(fields) node = Node(iden, labelspan, size, fvector, sent) forest.add_node(node) if cache_same: if labelspan in forest.labelspans: node.same = forest.labelspans[labelspan] node.fvector = node.same.fvector else: forest.labelspans[labelspan] = node for j in xrange(size): is_oracle = False ## '\t1 ||| 0=8.86276 1=2 3\n' tails, fields = file.readline().strip().split(" ||| ") if tails[0] == "*": #oracle edge is_oracle = True tails = tails[1:] tails = tails.split() ## could be non-integers tailnodes = [] for x in tails: assert x in forest.nodes, "BAD TOPOL ORDER: node #%s is referred to " % x + \ "(in a hyperedge of node #%s) before being defined" % iden ## topological ordering tail = forest.nodes[x] tailnodes.append(tail) use_same = False if fields[-1] == "~": use_same = True fields = fields[:-1] fvector = FVector.parse(fields) edge = Hyperedge(node, tailnodes, fvector) if cache_same: short_edge = edge.shorter() if short_edge in forest.short_edges: edge.same = forest.short_edges[short_edge] if use_same: edge.fvector += edge.same.fvector else: forest.short_edges[short_edge] = edge node.add_edge(edge) if is_oracle: node.oracle_edge = edge if node.sp_terminal(): node.word = node.edges[0].subs[0].word ## splitted nodes 12-3-4 => (12, 3, 4) tmp = sorted([(map(int, x.iden.split("-")), x) for x in forest.nodeorder]) forest.nodeorder = [x for (_, x) in tmp] forest.rehash() sentid += 1 ## print >> logs, "sent #%d %s, %d words, %d nodes, %d edges, loaded in %.2lf secs" \ ## % (sentid, forest.tag, forest.len, num, forest.num_edges, time.time() - basetime) forest.root = node node.set_root(True) line = file.readline() if line is not None and line.strip() != "": if line[0] == "(": forest.goldtree = Tree.parse(line.strip(), trunc=True, lower=True) line = file.readline() else: line = None total_time += time.time() - start_time if num_sents % 100 == 0: print >> logs, "... %d sents loaded (%.2lf secs per sent) ..." \ % (num_sents, total_time/num_sents) yield forest Forest.load_time = total_time print >> logs, "%d forests loaded in %.2lf secs (avg %.2lf per sent)" \ % (num_sents, total_time, total_time/num_sents)
def load(filename, lower=False, sentid=0): '''now return a generator! use load().next() for singleton. and read the last line as the gold tree -- TODO: optional! and there is an empty line at the end ''' file = getfile(filename) line = None total_time = 0 num_sents = 0 while True: start_time = time.time() ##'\tThe complicated language in ...\n" ## tag is often missing try: if line is None or line == "\n": line = "\n" while line == "\n": line = file.readline() # emulate seek tag, sent = line.split("\t") except: ## no more forests break num_sents += 1 sent = sent.split() cased_sent = sent[:] if lower: sent = [w.lower() for w in sent] # mark johnson: lowercase all words num = int(file.readline()) forest = Forest(num, sent, cased_sent, tag) forest.labelspans = {} forest.short_edges = {} delta = num_spu = 0 for i in xrange(1, num + 1): ## '2\tDT* [0-1]\t1 ||| 1232=2 ...\n' ## node-based features here: wordedges, greedyheavy, word(1), [word(2)], ... line = file.readline() try: keys, fields = line.split(" ||| ") except: keys = line fields = "" iden, labelspan, size = keys.split( "\t") ## iden can be non-ints size = int(size) fvector = FVector(fields) # TODO: myvector node = Node(iden, labelspan, size, fvector, sent) forest.add_node(node) if cache_same: if labelspan in forest.labelspans: node.same = forest.labelspans[labelspan] node.fvector = node.same.fvector else: forest.labelspans[labelspan] = node for j in xrange(size): is_oracle = False ## '\t1 ||| 0=8.86276 1=2 3\n' tails, fields = file.readline().strip().split(" ||| ") if tails[0] == "*": #oracle edge is_oracle = True tails = tails[1:] tails = tails.split() ## could be non-integers tailnodes = [] for x in tails: assert x in forest.nodes, "BAD TOPOL ORDER: node #%s is referred to " % x + \ "(in a hyperedge of node #%s) before being defined" % iden ## topological ordering tail = forest.nodes[x] tailnodes.append(tail) use_same = False if fields[-1] == "~": use_same = True fields = fields[:-1] fvector = FVector(fields) edge = Hyperedge(node, tailnodes, fvector) if cache_same: short_edge = edge.shorter() if short_edge in forest.short_edges: edge.same = forest.short_edges[short_edge] if use_same: edge.fvector += edge.same.fvector else: forest.short_edges[short_edge] = edge node.add_edge(edge) if is_oracle: node.oracle_edge = edge if node.sp_terminal(): node.word = node.edges[0].subs[0].word ## splitted nodes 12-3-4 => (12, 3, 4) tmp = sorted([(map(int, x.iden.split("-")), x) for x in forest.nodeorder]) forest.nodeorder = [x for (_, x) in tmp] forest.rehash() sentid += 1 ## print >> logs, "sent #%d %s, %d words, %d nodes, %d edges, loaded in %.2lf secs" \ ## % (sentid, forest.tag, forest.len, num, forest.num_edges, time.time() - basetime) forest.root = node node.set_root(True) line = file.readline() if line is not None and line.strip() != "": if line[0] == "(": ## forest.goldtree = Tree.parse(line.strip(), trunc=True, lower=False) line = file.readline() else: line = None total_time += time.time() - start_time if num_sents % 100 == 0: print >> logs, "... %d sents loaded (%.2lf secs per sent) ..." \ % (num_sents, total_time/num_sents) yield forest Forest.load_time = total_time if num_sents > 0: print >> logs, "%d forests loaded in %.2lf secs (avg %.2lf per sent)" \ % (num_sents, total_time, total_time/num_sents)
def forest_oracle(forest, goldtree, del_puncs=False, prune_results=False): """ returns best_score, best_parseval, best_tree, edgelist now non-recursive topol-sort-style """ if hasattr(forest.root, "oracle_edge"): return extract_oracle(forest) ## modifies forest also!! if del_puncs: idx_mapping, newforest = check_puncs(forest, goldtree.tag_seq) else: idx_mapping, newforest = lambda x: x, forest goldspans = merge_labels(goldtree.all_label_spans(), idx_mapping) goldbrs = set(goldspans) ## including TOP for node in newforest: if node.is_terminal(): results = Oracles.unit("(%s %s)" % (node.label, node.word)) ## multiplication unit else: a, b = ( (0, 0) if node.is_spurious() else ((1, 1) if (merge_label((node.label, node.span), idx_mapping) in goldbrs) else (1, 0)) ) label = "" if node.is_spurious() else node.label results = Oracles() ## addition unit for edge in node.edges: edgeres = Oracles.unit() ## multiplication unit for sub in edge.subs: assert hasattr(sub, "oracles"), "%s ; %s ; %s" % (node, sub, edge) edgeres = edgeres * sub.oracles ## nodehead = (a, RES((b, -edge.fvector[0], label, [edge]))) ## originally there is label assert 0 in edge.fvector, edge nodehead = (a, RES((b, -edge.fvector[0], [edge]))) results += nodehead * edgeres ## mul if prune_results: prune(results) node.oracles = results if debug: print >> logs, node.labelspan(), "\n", results, "----------" res = (-1, RES((-1, 0, []))) * newforest.root.oracles ## scale, remove TOP match num_gold = len(goldspans) - 1 ## omit TOP. N.B. goldspans, not brackets! (NP (NP ...)) best_parseval = None for num_test in res: ## num_matched, score, tree_str, edgelist = res[num_test] num_matched, score, edgelist = res[num_test] this = Parseval.get_parseval(num_matched, num_test, num_gold) if best_parseval is None or this < best_parseval: best_parseval = this best_score = score ## best_tree = tree_str best_edgelist = edgelist best_tree = Hyperedge.deriv2tree(best_edgelist) ## annotate the forest for oracle so that next-time you can preload oracle for edge in best_edgelist: edge.head.oracle_edge = edge ## very careful here: desymbol ! ## return -best_score, best_parseval, Tree.parse(desymbol(best_tree)), best_edgelist return -best_score, best_parseval, best_tree, best_edgelist
def load(filename, is_tforest=False, lower=False, sentid=0, first=None, lm=None): '''now returns a generator! use load().next() for singleton. and read the last line as the gold tree -- TODO: optional! and there is an empty line at the end ''' if first is None: # N.B.: must be here, not in the param line (after program initializes) first = FLAGS.first file = getfile(filename) line = None total_time = 0 num_sents = 0 while True: start_time = time.time() ##'\tThe complicated language in ...\n" ## tag is often missing line = file.readline() # emulate seek if len(line) == 0: break try: ## strict format, no consecutive breaks # if line is None or line == "\n": # line = "\n" # while line == "\n": # line = file.readline() # emulate seek tag, sent = line.split("\t") # foreign sentence except: ## no more forests yield None continue num_sents += 1 # caching the original, word-based, true-case sentence sent = sent.split() ## no splitting with " " cased_sent = sent [:] if lower: sent = [w.lower() for w in sent] # mark johnson: lowercase all words #sent = words_to_chars(sent, encode_back=True) # split to chars ## read in references refnum = int(file.readline().strip()) refs = [] for i in xrange(refnum): refs.append(file.readline().strip()) ## sizes: number of nodes, number of edges (optional) num, nedges = map(int, file.readline().split("\t")) forest = Forest(sent, cased_sent, tag, is_tforest) forest.tag = tag forest.refs = refs forest.bleu = Bleu(refs=refs) ## initial (empty test) bleu; used repeatedly later forest.labelspans = {} forest.short_edges = {} forest.rules = {} for i in xrange(1, num+1): ## '2\tDT* [0-1]\t1 ||| 1232=2 ...\n' ## node-based features here: wordedges, greedyheavy, word(1), [word(2)], ... line = file.readline() try: keys, fields = line.split(" ||| ") except: keys = line fields = "" iden, labelspan, size = keys.split("\t") ## iden can be non-ints size = int(size) fvector = Vector(fields) # ## remove_blacklist(fvector) node = Node(iden, labelspan, size, fvector, sent) forest.add_node(node) if cache_same: if labelspan in forest.labelspans: node.same = forest.labelspans[labelspan] node.fvector = node.same.fvector else: forest.labelspans[labelspan] = node for j in xrange(size): is_oracle = False ## '\t1 ||| 0=8.86276 1=2 3\n' ## N.B.: can't just strip! "\t... ||| ... ||| \n" => 2 fields instead of 3 tails, rule, fields = file.readline().strip("\t\n").split(" ||| ") if tails != "" and tails[0] == "*": #oracle edge is_oracle = True tails = tails[1:] tails = tails.split() ## N.B.: don't split by " "! tailnodes = [] lhsstr = [] # 123 "thank" 456 lmstr = [] lmscore = 0 lmlhsstr = [] for x in tails: if x[0]=='"': # word word = desymbol(x[1:-1]) lhsstr.append(word) ## desymbol here and only here; ump will call quoteattr if lm is not None: this = lm.word2index(word) lmscore += lm.ngram.wordprob(this, lmstr) lmlhsstr.append(this) lmstr += [this,] else: # variable assert x in forest.nodes, "BAD TOPOL ORDER: node #%s is referred to " % x + \ "(in a hyperedge of node #%s) before being defined" % iden tail = forest.nodes[x] tailnodes.append(tail) lhsstr.append(tail) if lm is not None: lmstr = [] # "..." "..." x0 "..." lmlhsstr.append(tail) # sync with lhsstr fvector = Vector(fields) if lm is not None: fvector["lm1"] = lmscore # hack edge = Hyperedge(node, tailnodes, fvector, lhsstr) edge.lmlhsstr = lmlhsstr ## new x = rule.split() edge.ruleid = int(x[0]) if len(x) > 1: edge.rule = Rule.parse(" ".join(x[1:]) + " ### " + fields) forest.rules[edge.ruleid] = edge.rule #" ".join(x[1:]) #, None) else: edge.rule = forest.rules[edge.ruleid] # cahced rule node.add_edge(edge) if is_oracle: node.oracle_edge = edge if node.sp_terminal(): node.word = node.edges[0].subs[0].word ## splitted nodes 12-3-4 => (12, 3, 4) tmp = sorted([(map(int, x.iden.split("-")), x) for x in forest.nodeorder]) forest.nodeorder = [x for (_, x) in tmp] forest.rehash() sentid += 1 ## print >> logs, "sent #%d %s, %d words, %d nodes, %d edges, loaded in %.2lf secs" \ ## % (sentid, forest.tag, forest.len, num, forest.num_edges, time.time() - basetime) forest.root = node node.set_root(True) line = file.readline() if line is not None and line.strip() != "": if line[0] == "(": forest.goldtree = Tree.parse(line.strip(), trunc=True, lower=False) line = file.readline() else: line = None forest.number_nodes() #print forest.root.position_id total_time += time.time() - start_time if num_sents % 100 == 0: print >> logs, "... %d sents loaded (%.2lf secs per sent) ..." \ % (num_sents, total_time/num_sents) forest.subtree() #compute the subtree string for each node yield forest if first is not None and num_sents >= first: break # better check here instead of zero-division exception if num_sents == 0: print >> logs, "NO FORESTS FOUND!!! (empty input file?)" sys.exit(1) # yield None # new: don't halt -- WHY? Forest.load_time = total_time print >> logs, "%d forests loaded in %.2lf secs (avg %.2lf per sent)" \ % (num_sents, total_time, total_time/(num_sents+0.001))
def add_nonter_th(self, node): ''' add translation hyperedges to non-terminal node ''' ruleset = self.ruleset tfedges = [] for edge in node.edges: # enumerate all the possible frags basefrags = [("%s(" % node.label, [], 1)] lastchild = len(edge.subs) - 1 if len(edge.subs) >= 5: # this guy has too many children! it cannot be matched! deflhs = "%s(%s)" % (node.label, " ".join(sub.label for sub in edge.subs)) defrhs = edge.subs defheight = 1 basefrags = [(deflhs, defrhs, defheight)] else: for (id, sub) in enumerate(edge.subs): oldfrags = basefrags # cross-product basefrags = [PatternMatching.combinetwofrags(oldfrag, frag, id, lastchild, self.max_height) \ for oldfrag in oldfrags for frag in sub.frags] # for each frag add translation hyperedges for extfrag in basefrags: (extlhs, extrhs, extheight) = extfrag # add frags if extheight <= self.max_height - 1: node.frags.append(extfrag) if self.filter: self.all_lhss.add(extlhs) else: # add translation hyperedges if extlhs in ruleset: for des in extrhs: self.descendants[node.iden].add(des.iden) #unit(set(extrhs)) #print self.descendants[node.iden] rules = ruleset[extlhs] # add all translation hyperedges for rule in rules: rhsstr = [x[1:-1] if x[0]=='"' \ else extrhs[int(x.split('x')[1])] \ for x in rule.rhs] tfedge = Hyperedge(node, extrhs,\ Vector(rule.fields), rhsstr) tfedge.rule = rule tfedges.append(tfedge) if (not self.filter) and (len(tfedges) == 0): # no translation hyperedge for des in edge.subs: self.descendants[node.iden].add(des.iden) #unit(set(edge.subs)) # add a default translation hyperedge deflhs = "%s(%s)" % (node.label, " ".join(sub.label for sub in edge.subs)) defrhs = ["x%d" % i for i, _ in enumerate(edge.subs)] # N.B.: do not supply str defrule = Rule(deflhs, defrhs, self.deffields) tfedge = Hyperedge(node, edge.subs,\ Vector(self.deffields), edge.subs) tfedge.rule = defrule ruleset.add_rule(defrule) tfedges.append(tfedge) if not self.filter: # inside replace node.edges = tfedges