def greedy_search(gold, test, classify): # Initialise with the test tree cur = (test.clone(), {'type': 'init'}, 0) # Search while there is still something in the fringe iters = 0 path = [] while True: path.append(cur) if iters > 100: return (0, iters), None # Check for victory ctree = cur[0] cerrors = parse_errors.ParseErrorSet(gold, ctree) if len(cerrors) == 0: final = cur break best = None for fixes, ntree, info in successors(ctree, cerrors, gold): if not ntree.check_consistency(): raise Exception("Inconsistent tree! {}".format(ntree)) nerrors = parse_errors.get_errors(ntree, gold) change = len(cerrors) - len(nerrors) if change < 0: continue if best is None or change > best[2]: best = (ntree, info, change) cur = best iters += 1 for step in path: classify(step[1], gold, test) return (0, iters), path
def gen_move_successor(source_span, left, right, new_parent, cerrors, gold): success, response = tree_transform.move_nodes( source_span.subtrees[left:right + 1], new_parent, False) assert success, response ntree, nodes, new_parent = response new_left = new_parent.subtrees.index(nodes[0]) new_right = new_parent.subtrees.index(nodes[-1]) # Find Lowest Common Ancestor of the new and old parents full_span = (min(source_span.span[0], new_parent.span[0]), max(source_span.span[1], new_parent.span[1])) lca = new_parent while not (lca.span[0] <= full_span[0] and full_span[1] <= lca.span[1]): lca = lca.parent info = { 'type': 'move', 'old_parent': get_label(source_span), 'new_parent': get_label(new_parent), 'movers': [get_label(node) for node in nodes], 'mover info': [(get_label(node), node.span) for node in nodes], 'new_family': [get_label(subtree) for subtree in new_parent.subtrees], 'old_family': [get_label(subtree) for subtree in source_span.subtrees], 'start left siblings': [get_label(node) for node in source_span.subtrees[:left]], 'start right siblings': [get_label(node) for node in source_span.subtrees[right + 1:]], 'end left siblings': [get_label(node) for node in new_parent.subtrees[:new_left]], 'end right siblings': [get_label(node) for node in new_parent.subtrees[new_right + 1:]], 'auto preterminals': get_preterminals(lca), 'auto preterminal span': lca.span } if left == right and nodes[-1].span[1] - nodes[-1].span[0] == 1: preterminal = nodes[-1] while preterminal.word is None: preterminal = preterminal.subtrees[0] gold_eq = gold.get_nodes('lowest', preterminal.span[0], preterminal.span[1]) if gold_eq is not None: info['POS confusion'] = (get_label(preterminal), get_label(gold_eq)) # Consider fixing a missing node in the new location as well nerrors = parse_errors.ParseErrorSet(gold, ntree) to_fix = None for error in nerrors.missing: if error[1][0] <= nodes[0].span[0] and nodes[-1].span[1] <= error[1][1]: if error[1] == (nodes[0].span[0], nodes[-1].span[1]): continue if error[1][0] < new_parent.span[0] or error[1][ 1] > new_parent.span[1]: continue if to_fix is None or to_fix[1][0] < error[1][0] or error[1][ 1] < to_fix[1][1]: to_fix = error if to_fix is not None: info['added and moved'] = True info['added label'] = error[2] unmoved = [] for node in new_parent.subtrees: if to_fix[1][0] < node.span[0] and node.span[1] < to_fix[1][1]: if node not in nodes: unmoved.append(node) info['adding node already present'] = False if len(unmoved) == 1 and unmoved[0].label == to_fix[2]: info['adding node already present'] = True success, response = tree_transform.add_node(ntree, to_fix[1], to_fix[2], in_place=False) assert success, response ntree, nnode = response return (False, ntree, info)