コード例 #1
0
ファイル: oracle.py プロジェクト: rupenp/transforest
def forest_oracle(forest, goldtree, del_puncs=False, prune_results=False):
	''' returns best_score, best_parseval, best_tree
   	    now non-recursive topol-sort-style
	'''

	## modifies forest also!!
	if del_puncs:
		idx_mapping, newforest = check_puncs(forest, goldtree.tag_seq)
	else:
		idx_mapping, newforest = lambda x:x, forest

	goldspans = merge_labels(goldtree.all_label_spans(), idx_mapping)
	goldbrs = set(goldspans) ## including TOP

	for node in newforest:
		if node.is_terminal():
			results = Oracles.unit("(%s %s)" % (node.label, node.word))  ## multiplication unit
			
		else:
			a, b = (0, 0) if node.is_spurious() \
				   else ((1, 1) if (merge_label((node.label, node.span), idx_mapping) in goldbrs) \
						 else (1, 0))

			label = "" if node.is_spurious() else node.label
			results = Oracles()	 ## addition unit
			for edge in node.edges:
				edgeres = Oracles.unit()  ## multiplication unit

				for sub in edge.subs:				
					edgeres = edgeres * sub.oracles

##				nodehead = (a, RES((b, -edge.fvector[0], label, [edge])))   ## originally there is label
				nodehead = (a, RES((b, -edge.fvector[0], [edge])))   
				results += nodehead * edgeres   ## mul

		if prune_results:
			prune(results)
		node.oracles = results
		if debug:
			print >> logs, node.labelspan(), "\n", results, "----------"

	res = (-1, RES((-1, 0, []))) * newforest.root.oracles   ## scale, remove TOP match

	num_gold = len(goldspans) - 1 ## omit TOP.  N.B. goldspans, not brackets! (NP (NP ...))

	best_parseval = None
	for num_test in res:
##		num_matched, score, tree_str, edgelist = res[num_test]
		num_matched, score, edgelist = res[num_test]
		this = Parseval.get_parseval(num_matched, num_test, num_gold)
		if best_parseval is None or this < best_parseval:
			best_parseval = this
			best_score = score
##			best_tree = tree_str
			best_edgelist = edgelist

	best_tree = Hyperedge.deriv2tree(best_edgelist)
	
	## very careful here: desymbol !
##	return -best_score, best_parseval, Tree.parse(desymbol(best_tree)), best_edgelist
	return -best_score, best_parseval, best_tree, best_edgelist
コード例 #2
0
ファイル: remove_sp.py プロジェクト: rupenp/transforest
						 help="test pruning (e.g., 5:15:25)", metavar="RANGE", default=None)

	(opts, args) = optparser.parse_args()

	prange = None
	if opts.range is not None:
		prange = eval("[%s]" % opts.range.replace(":", ","))
		prange.sort(reverse=True)

	if opts.quiet and opts.suffix is not None:
		optparser.error("-q and -s can not be present at the same time.")		

	for i, forest in enumerate(Forest.load("-")):
##		print >> logs, "%.4lf\n%s" % f.bestparse()[:2]
		remove(forest)
##		print >> logs, "%.4lf\n%s" % f.bestparse()[:2]
		if not opts.quiet:
			if opts.suffix is not None:
				forest.dump(open("%d.%s" % (i+opts.startid, opts.suffix), "wt"))
			else:
				forest.dump()

		if prange is not None:
			for p in prange:
				prune(forest, p)				
				forest.dump("%d.p%d" % (i+opts.startid, p))

 		if i % 10 == 9:
 			mymonitor.gc_collect()
		
コード例 #3
0
ファイル: oracle.py プロジェクト: rupenp/transforest
				  best_parseval, "%.4lf" % (best_score - base_score)

		
		onebest_parseval += Parseval(bres[1], f.goldtree)
		real_parseval = Parseval(best_tree, f.goldtree)
		all_real_parseval += real_parseval
		##assert real_parseval == best_parseval
		## N.B.: can't make this comparison work, so keep it separate.

		all_parseval += best_parseval

		if prange is not None:

			for p in prange:

				prune(f, p)
				sc, parseval, tr = forest_oracle(f, f.goldtree)
				pruned_parseval[p] += parseval
				pruned_real_parseval[p] += Parseval(tr, f.goldtree)

				if opts.suffix is not None:
					f.dump("%d.%s%d" % (i+1, opts.suffix, p))
				

	print "1-best (real)", onebest_parseval
	print "forest (punc)", all_parseval
	print "forest (real)", all_real_parseval

	total_time = time.time() - start_time
	print >> logs, "%d forests oracles computed in %.2lf secs (avg %.2lf secs per sent)" % \
		  (i+1, total_time, total_time/(i+1))