def sisters_(x, y, equality=False): """ Check if x,y are sisters in T """ if not(isFunctionNode(x) and isFunctionNode(y)): return False return (x.parent is not None) and x.parent == y.parent
def sisters_(x, y, equality=False): """ Check if x,y are sisters in T """ if not (isFunctionNode(x) and isFunctionNode(y)): return False return (x.parent is not None) and x.parent == y.parent
def schemestring(x, d=0, bv_names=None): """Outputs a scheme string in (lambda (x) (+ x 3)) format. Arguments: x: We return the string for this FunctionNode. bv_names: A dictionary from the uuids to nicer names. """ if isinstance(x, str): return x elif isFunctionNode(x): if bv_names is None: bv_names = dict() name = x.name if isinstance(x, BVUseFunctionNode): name = bv_names.get(x.name, x.name) if x.args is None: return name else: if x.args is None: return name elif isinstance(x, BVAddFunctionNode): assert name is 'lambda' return "(%s (%s) %s)" % (name, x.added_rule.name, map(lambda a: schemestring(a, d+1, bv_names=bv_names), x.args)) else: return "(%s %s)" % (name, map(lambda a: schemestring(a,d+1, bv_names=bv_names), x.args))
def pystring(x, d=0, bv_names=None): """Output a string that can be evaluated by python; gives bound variables names based on their depth. Args: bv_names: A dictionary from the uuids to nicer names. """ if isinstance(x, str): return x elif isFunctionNode(x): if bv_names is None: bv_names = dict() if x.name == "if_": # this gets translated assert len(x.args) == 3, "if_ requires 3 arguments!" # This converts from scheme (if bool s t) to python (s if bool else t) b = pystring(x.args[0], d=d+1, bv_names=bv_names) s = pystring(x.args[1], d=d+1, bv_names=bv_names) t = pystring(x.args[2], d=d+1, bv_names=bv_names) return '( %s if %s else %s )' % (s, b, t) elif x.name == '': assert len(x.args) == 1, "Null names must have exactly 1 argument" return pystring(x.args[0], d=d, bv_names=bv_names) elif x.name == ',': # comma join return ', '.join(map(lambda a: pystring(a, d=d, bv_names=bv_names), x.args)) elif x.name == "apply_": assert x.args is not None and len(x.args)==2, "Apply requires exactly 2 arguments" #print ">>>>", self.args return '( %s )( %s )' % tuple(map(lambda a: pystring(a, d=d, bv_names=bv_names), x.args)) elif x.name == 'lambda': # On a lambda, we must add the introduced bv, and then remove it again afterwards bvn = '' if isinstance(x, BVAddFunctionNode) and x.added_rule is not None: bvn = x.added_rule.bv_prefix+str(d) bv_names[x.added_rule.name] = bvn assert len(x.args) == 1 ret = 'lambda %s: %s' % ( bvn, pystring(x.args[0], d=d+1, bv_names=bv_names) ) if isinstance(x, BVAddFunctionNode) and x.added_rule is not None: try: del bv_names[x.added_rule.name] except KeyError: x.fullprint() return ret elif percent_s_regex.search(x.name): # If we match the python string substitution character %s, then format return x.name % tuple(map(lambda a: pystring(a, d=d+1, bv_names=bv_names), x.args)) else: name = x.name if isinstance(x, BVUseFunctionNode): name = bv_names.get(x.name, x.name) if x.args is None: return name else: return name+'('+', '.join(map(lambda a: pystring(a, d=d+1, bv_names=bv_names), x.args))+')'
def iterate_subnodes(self, t, d=0, predicate=lambdaTrue, do_bv=True, yield_depth=False): """ Iterate through all subnodes of node *t*, while updating the added rules (bound variables) so that at each subnode, the grammar is accurate to what it was. if *do_bv*=False, we don't do bound variables (useful for things like counting nodes, instead of having to update the grammar) *yield_depth*: if True, we return (node, depth) instead of node *predicate*: filter only the ones that match this NOTE: if you DON'T iterate all the way through, you end up acculmulating bv rules so NEVER stop this iteration in the middle! TODO: Make this more elegant -- use BVCM """ if isFunctionNode(t): # print "iterate subnode: ", t, t.added_rule if predicate(t): yield (t,d) if yield_depth else t #Define a new context that is the grammar with the rule added. Then, when we exit, it's still right with BVRuleContextManager(self, t.added_rule): if t.args is not None: for g in self.iterate_subnodes(t.args, d=d+1, do_bv=do_bv, yield_depth=yield_depth, predicate=predicate): # pass up anything from below yield g elif isinstance(t, list): for a in t: for g in self.iterate_subnodes(a, d=d, do_bv=do_bv, yield_depth=yield_depth, predicate=predicate): yield g
def schemestring(x, d=0, bv_names=None): """Outputs a scheme string in (lambda (x) (+ x 3)) format. Arguments: x: We return the string for this FunctionNode. bv_names: A dictionary from the uuids to nicer names. """ if isinstance(x, str): return x elif isFunctionNode(x): if bv_names is None: bv_names = dict() name = x.name if isinstance(x, BVUseFunctionNode): name = bv_names.get(x.name, x.name) if x.args is None: return name else: if x.args is None: return name elif isinstance(x, BVAddFunctionNode): assert name is 'lambda' return "(%s (%s) %s)" % ( name, x.added_rule.name, map(lambda a: schemestring(a, d + 1, bv_names=bv_names), x.args)) else: return "(%s %s)" % ( name, map(lambda a: schemestring(a, d + 1, bv_names=bv_names), x.args))
def first_dominated_(x,t): # Returns the first thing dominating x of type t # And None otherwise if isFunctionNode(x): for sn in x: if is_nonterminal_type(sn, t): return sn return None
def first_dominated_(x, t): # Returns the first thing dominating x of type t # And None otherwise if isFunctionNode(x): for sn in x: if is_nonterminal_type(sn, t): return sn return None
def trim_leaves_(t): """ Take a tree t and replace terminal nodes (leaves) with their returntypes. next_(next_(((nine_ if True else four_) if equal_(ten_, ten_) else one_))) to: next_(next_(((WORD if BOOL else WORD) if equal_(WORD, WORD) else WORD))) NOTE: This modifies t! """ if not isFunctionNode(t): return t elif t.is_terminal(): return t.returntype if isFunctionNode(t) and t.args is not None: t.args = [ x.returntype if (isFunctionNode(x) and x.is_terminal()) else trim_leaves_(x) for x in t.args] return t
def recurse_down(y): #print "RD", y, "\t", x if isinstance(y, list): return any(map(recurse_down, filter(isFunctionNode, y))) elif isFunctionNode(y): if recurse_down(y.args) or immediately_dominates(y, x): anc.append(y) # put y on the end return True return False
def ancestors_(x): if not isFunctionNode(x): return [] out = [] while x.parent is not None: out.append(x.parent) x = x.parent return out
def recurse_down(y): #print "RD", y, "\t", x if isinstance(y,list): return any(map(recurse_down, filter(isFunctionNode, y))) elif isFunctionNode(y): if recurse_down(y.args) or immediately_dominates(y, x): anc.append(y) # put y on the end return True return False
def co_refers(x,y): if x is y: return False # By stipulation, nothing co-refers to itself # Weird corner cases if isinstance(x,list) or isinstance(y,list): return False if x is None or y is None: return False ## Check if two FunctionNodes or strings co-refer (e.g. are indexed with the same .i in their name) xx = x.name if isFunctionNode(x) else x yy = y.name if isFunctionNode(y) else y mx = coref_matcher.search(xx) my = coref_matcher.search(yy) if mx is None or my is None: return False else: return (mx.groups("X")[0] == my.groups("Y")[0]) # set the default in groups so that they won't be equal if there is no match
def first_dominating_(T, x, t): # Returns the first thing dominating x of type t # And None otherwise if isFunctionNode(x): up = tree_up(T, x) while up is not None: if is_nonterminal_type(up, t): return up up = tree_up(T, up) return None
def first_dominating_(T,x,t): # Returns the first thing dominating x of type t # And None otherwise if isFunctionNode(x): up = tree_up(T,x) while up is not None: if is_nonterminal_type(up,t): return up up = tree_up(T,up) return None
def co_refers(x, y): if x is y: return False # By stipulation, nothing co-refers to itself # Weird corner cases if isinstance(x, list) or isinstance(y, list): return False if x is None or y is None: return False ## Check if two FunctionNodes or strings co-refer (e.g. are indexed with the same .i in their name) xx = x.name if isFunctionNode(x) else x yy = y.name if isFunctionNode(y) else y mx = coref_matcher.search(xx) my = coref_matcher.search(yy) if mx is None or my is None: return False else: return ( mx.groups("X")[0] == my.groups("Y")[0] ) # set the default in groups so that they won't be equal if there is no match
def fullstring(x, d=0, bv_names=None): """ A string mapping function that is for equality checking. This is necessary because pystring silently ignores FunctionNode.names that are ''. Here, we print out everything, including returntypes :param x: :param d: :param bv_names: :return: """ if isinstance(x, str): return x elif isFunctionNode(x): if bv_names is None: bv_names = dict() if x.name == 'lambda': # On a lambda, we must add the introduced bv, and then remove it again afterwards bvn = '' if isinstance(x, BVAddFunctionNode) and x.added_rule is not None: bvn = x.added_rule.bv_prefix + str(d) bv_names[x.added_rule.name] = bvn assert len(x.args) == 1 ret = 'lambda<%s> %s: %s' % ( x.returntype, bvn, fullstring(x.args[0], d=d + 1, bv_names=bv_names)) if isinstance(x, BVAddFunctionNode) and x.added_rule is not None: try: del bv_names[x.added_rule.name] except KeyError: x.fullprint() return ret else: name = x.name if isinstance(x, BVUseFunctionNode): name = bv_names.get(x.name, x.name) if x.args is None: return "%s<%s>" % (name, x.returntype) else: return "%s<%s>(%s)" % (name, x.returntype, ', '.join( map(lambda a: fullstring(a, d=d + 1, bv_names=bv_names), x.args)))
def is_terminal(self, x): """ A terminal is not a nonterminal and either has no children or its children are terminals themselves """ if self.is_nonterminal(x): return False if isinstance(x, list): for k in x: if not self.is_terminal(k): return False if isFunctionNode(x): # if you are structured, you must not contain nonterminals below if self.args is not None: for k in x.args: if not self.is_terminal(k): return False # else we get here for strings, etc. return True
def to_regex(fn): """ Custom mapping from a function node to a regular expression string (like, e.g. "(ab)*(c|d)" ) """ assert isFunctionNode(fn) if fn.name == 'star_': return '(%s)*'% to_regex(fn.args[0]) elif fn.name == 'plus_': return '(%s)+'% to_regex(fn.args[0]) elif fn.name == 'question_': return '(%s)?'% to_regex(fn.args[0]) elif fn.name == 'or_': return '(%s|%s)'% tuple(map(to_regex, fn.args)) elif fn.name == 'str_append_': return '%s%s'% (fn.args[0], to_regex(fn.args[1])) elif fn.name == 'terminal_': return '%s'%fn.args[0] elif fn.name == '': return to_regex(fn.args[0]) else: assert False, fn
def fullstring(x, d=0, bv_names=None): """ A string mapping function that is for equality checking. This is necessary because pystring silently ignores FunctionNode.names that are ''. Here, we print out everything, including returntypes :param x: :param d: :param bv_names: :return: """ if isinstance(x, str): return x elif isFunctionNode(x): if bv_names is None: bv_names = dict() if x.name == 'lambda': # On a lambda, we must add the introduced bv, and then remove it again afterwards bvn = '' if isinstance(x, BVAddFunctionNode) and x.added_rule is not None: bvn = x.added_rule.bv_prefix+str(d) bv_names[x.added_rule.name] = bvn assert len(x.args) == 1 ret = 'lambda<%s> %s: %s' % ( x.returntype, bvn, fullstring(x.args[0], d=d+1, bv_names=bv_names) ) if isinstance(x, BVAddFunctionNode) and x.added_rule is not None: try: del bv_names[x.added_rule.name] except KeyError: x.fullprint() return ret else: name = x.name if isinstance(x, BVUseFunctionNode): name = bv_names.get(x.name, x.name) if x.args is None: return "%s<%s>"%(name, x.returntype) else: return "%s<%s>(%s)" % (name, x.returntype, ', '.join(map(lambda a: fullstring(a, d=d+1, bv_names=bv_names), x.args)))
def iterate_subnodes(self, t, d=0, predicate=lambdaTrue, do_bv=True, yield_depth=False): """ Iterate through all subnodes of node *t*, while updating the added rules (bound variables) so that at each subnode, the grammar is accurate to what it was. if *do_bv*=False, we don't do bound variables (useful for things like counting nodes, instead of having to update the grammar) *yield_depth*: if True, we return (node, depth) instead of node *predicate*: filter only the ones that match this NOTE: if you DON'T iterate all the way through, you end up acculmulating bv rules so NEVER stop this iteration in the middle! TODO: Make this more elegant -- use BVCM """ if isFunctionNode(t): # print "iterate subnode: ", t, t.added_rule if predicate(t): yield (t, d) if yield_depth else t #Define a new context that is the grammar with the rule added. Then, when we exit, it's still right with BVRuleContextManager(self, t.added_rule): if t.args is not None: for g in self.iterate_subnodes( t.args, d=d + 1, do_bv=do_bv, yield_depth=yield_depth, predicate=predicate ): # pass up anything from below yield g elif isinstance(t, list): for a in t: for g in self.iterate_subnodes(a, d=d, do_bv=do_bv, yield_depth=yield_depth, predicate=predicate): yield g
def to_regex(fn): """ Custom mapping from a function node to a regular expression string (like, e.g. "(ab)*(c|d)" ) """ assert isFunctionNode(fn) if fn.name == 'star_': return '(%s)*' % to_regex(fn.args[0]) elif fn.name == 'plus_': return '(%s)+' % to_regex(fn.args[0]) elif fn.name == 'question_': return '(%s)?' % to_regex(fn.args[0]) elif fn.name == 'or_': return '(%s|%s)' % tuple(map(to_regex, fn.args)) elif fn.name == 'str_append_': return '%s%s' % (fn.args[0], to_regex(fn.args[1])) elif fn.name == 'terminal_': return '%s' % fn.args[0] elif fn.name == '': return to_regex(fn.args[0]) else: assert False, fn
def iterate_subnodes(self, t, d=0, predicate=lambdaTrue, do_bv=True, yield_depth=False): """ Iterate through all subnodes of t, while updating my added rules (bound variables) so that at each subnode, the grammar is accurate to what it was if We set do_bu=False, we don't do bound variables (useful for things like counting nodes, instead of having to update the grammar) yield_depth -- if True, we return (node, depth) instead of node predicate -- filter only the ones that match this # NOTE: if you DON'T iterate all the way through, you end up acculmulating bv rules # so NEVER stop this iteration in the middle! """ if isFunctionNode(t): if predicate(t): yield (t,d) if yield_depth else t #print "iterate subnode: ", t.name, t.bv_type, t if do_bv and t.bv_type is not None: added = self.add_bv_rule( t.bv_type, t.bv_args, d) if t.args is not None: for g in self.iterate_subnodes(t.args, d=d+1, do_bv=do_bv, yield_depth=yield_depth, predicate=predicate): # pass up anything from below yield g # And remove them if do_bv and (t.bv_type is not None): self.remove_rule(added) elif isinstance(t, list): for a in t: for g in self.iterate_subnodes(a, d=d, do_bv=do_bv, yield_depth=yield_depth, predicate=predicate): yield g
def pystring(x, d=0, bv_names=None): """Output a string that can be evaluated by python; gives bound variables names based on their depth. Args: bv_names: A dictionary from the uuids to nicer names. """ if isinstance(x, str): return x elif isFunctionNode(x): if bv_names is None: bv_names = dict() if x.name == "if_": # this gets translated assert len(x.args) == 3, "if_ requires 3 arguments!" # This converts from scheme (if bool s t) to python (s if bool else t) b = pystring(x.args[0], d=d + 1, bv_names=bv_names) s = pystring(x.args[1], d=d + 1, bv_names=bv_names) t = pystring(x.args[2], d=d + 1, bv_names=bv_names) return '( %s if %s else %s )' % (s, b, t) elif x.name == '': assert len(x.args) == 1, "Null names must have exactly 1 argument" return pystring(x.args[0], d=d, bv_names=bv_names) elif x.name == ',': # comma join return ', '.join( map(lambda a: pystring(a, d=d, bv_names=bv_names), x.args)) elif x.name == "apply_": assert x.args is not None and len( x.args) == 2, "Apply requires exactly 2 arguments" #print ">>>>", self.args return '( %s )( %s )' % tuple( map(lambda a: pystring(a, d=d, bv_names=bv_names), x.args)) elif x.name == 'lambda': # On a lambda, we must add the introduced bv, and then remove it again afterwards bvn = '' if isinstance(x, BVAddFunctionNode) and x.added_rule is not None: bvn = x.added_rule.bv_prefix + str(d) bv_names[x.added_rule.name] = bvn assert len(x.args) == 1 ret = 'lambda %s: %s' % ( bvn, pystring(x.args[0], d=d + 1, bv_names=bv_names)) if isinstance(x, BVAddFunctionNode) and x.added_rule is not None: try: del bv_names[x.added_rule.name] except KeyError: x.fullprint() return ret elif percent_s_regex.search( x.name ): # If we match the python string substitution character %s, then format return x.name % tuple( map(lambda a: pystring(a, d=d + 1, bv_names=bv_names), x.args)) else: name = x.name if isinstance(x, BVUseFunctionNode): name = bv_names.get(x.name, x.name) if x.args is None: return name else: return name + '(' + ', '.join( map(lambda a: pystring(a, d=d + 1, bv_names=bv_names), x.args)) + ')'
def increment_tree(self, x, depth): """ A lazy version of tree enumeration. Here, we generate all trees, starting from a rule or a nonterminal symbol. This is constant memory """ assert_or_die( self.bv_count==0, "Error: increment_tree not yet implemented for bound variables." ) if LOTlib.SIG_INTERRUPTED: return # quit if interrupted if isFunctionNode(x) and depth >= 0 and x.args is not None: #print "FN:", x, depth # Short-circuit if we can # add the rules #addedrules = [ self.add_bv_rule(b,depth) for b in x.bv ] original_x = copy(x) # go all odometer on the kids below:: iters = [self.increment_tree(y,depth) if self.is_nonterminal(y) else None for y in x.args] if len(iters) == 0: yield copy(x) else: # First, initialize the arguments for i in xrange(len(iters)): if iters[i] is not None: x.args[i] = iters[i].next() # the index of the last terminal symbol (may not be len(iters)-1), last_terminal_idx = max( [i if iters[i] is not None else -1 for i in xrange(len(iters))] ) ## Now loop through the args, counting them up continue_counting = True while continue_counting: # while we continue incrementing yield copy(x) # yield the initial tree, and then each successive tree # and then process each carry: for carry_pos in xrange(len(iters)): # index into which tree we are incrementing if iters[carry_pos] is not None: # we are not a terminal symbol (mixed in) try: x.args[carry_pos] = iters[carry_pos].next() break # if we increment successfully, no carry, so break out of i loop except StopIteration: # if so, then "carry" if carry_pos == last_terminal_idx: continue_counting = False # done counting here elif iters[carry_pos] is not None: # reset the incrementer since we just carried iters[carry_pos] = self.increment_tree(original_x.args[carry_pos],depth) x.args[carry_pos] = iters[carry_pos].next() # reset this # and just continue your loop over i (which processes the carry) #print "REMOVING", addedrule #[ self.remove_rule(r) for r in addedrules ]# remove bv rule elif self.is_nonterminal(x): # just a single nonterminal ## TODO: somewhat inefficient since we do this each time: ## Here we change the order of rules to be terminals *first* ## else we don't enumerate small to large, which is clearly insane terminals = [] nonterminals = [] for k in self.rules[x]: if self.is_terminal(k.to): terminals.append(k) else: nonterminals.append(k) #print ">>", terminals #print ">>", nonterminals Z = logsumexp([ log(r.p) for r in self.rules[x]] ) # normalizer if depth >= 0: # yield each of the rules that lead to terminals for r in terminals: n = FunctionNode(returntype=r.nt, name=r.name, args=deepcopy(r.to), generation_probability=(log(r.p) - Z), bv_type=r.bv_type, bv_args=r.bv_args, ruleid=r.rid ) yield n if depth > 0: # and expand each nonterminals for r in nonterminals: n = FunctionNode(returntype=r.nt, name=r.name, args=deepcopy(r.to), generation_probability=(log(r.p) - Z), bv_type=r.bv_type, bv_args=r.bv_args, ruleid=r.rid ) for q in self.increment_tree(n, depth-1): yield q else: raise StopIteration
def tree_up_(x): if isFunctionNode(x): return x.parent else: return None
def generate(self, x='START', d=0): """ Generate from the PCFG -- default is to start from x - either a nonterminal or a FunctionNode TODO: We can make this limit the depth, if we want. Maybe that's dangerous? TODO: Add a check that we don't have any leftover bound variable rules, when d==0 """ if isinstance(x,list): # If we get a list, just map along it to generate. We don't count lists as depth--only FunctionNodes return map(lambda xi: self.generate(x=xi, d=d), x) elif x=='*gaussian*': ## TODO: HIGHLY EXPERIMENTAL!! Wow this is really terrible for mixing... v = np.random.normal() gp = normlogpdf(v, 0.0, 1.0) return FunctionNode(returntype=x, name=str(v), args=None, generation_probability=gp, ruleid=0, resample_p=CONSTANT_RESAMPLE_P ) ##TODO: FIX THE ruleid elif x=='*uniform*': v = np.random.rand() gp = 0.0 return FunctionNode(returntype=x, name=str(v), args=None, generation_probability=gp, ruleid=0, resample_p=CONSTANT_RESAMPLE_P ) ##TODO: FIX THE ruleid elif x is None: return None elif self.is_nonterminal(x): # if we generate a nonterminal, then sample a GrammarRule, convert it to a FunctionNode # via nt->returntype, name->name, to->args, # And recurse r, gp = weighted_sample(self.rules[x], probs=lambda x: x.p, return_probability=True, log=False) #print "SAMPLED:", r if r.bv_type is not None: # adding a rule added = self.add_bv_rule(r.bv_type, r.bv_args, d) #print "ADDING", added # expand the "to" part of our rule if r is None: args = None else: args = self.generate(r.to, d=d+1) #print "GENERATED ", args if r.bv_type is not None: #print "REMOVING ", added self.remove_rule(added) # create the new node if r.bv_type is not None: ## UGH, bv_type=r.bv_type -- here bv_type is really bv_returntype. THIS SHOULD BE FIXED return FunctionNode(returntype=r.nt, name=r.name, args=args, generation_probability=gp, bv_type=r.bv_type, bv_name=added.name, bv_args=r.bv_args, ruleid=r.rid ) else: return FunctionNode(returntype=r.nt, name=r.name, args=args, generation_probability=gp, ruleid=r.rid ) return fn elif isFunctionNode(x): #for function Nodes, we are able to generate by copying and expanding the children ret = copy(x) ret.to = self.generate(ret.to, d=d+1) # re-generate below -- importantly the "to" points are re-generated, not copied return ret else: # must be a terminal assert_or_die(isinstance(x, str), "Terminal must be a string! x="+x) return x
def increment_tree_(self, x=None, depth=0, max_depth=Infinity, depthdict=None): """ A lazy version of tree enumeration. Here, we generate all trees, starting from a rule or a nonterminal symbol and going up to max_depth This is constant memory and should produce each tree *once* (However: if a grammar has multiple derivations of the same str(tree), then you will see repeats!). TODO: CHANGE THIS TO ENUMERATE SHALLOW->DEEP *x*: A node in the tree *depth*: Depth of the tree *depthdict* : memoizes depth_to_terminal so that we can order rules in order to make enumeration small->large """ # wrap no specification for x if depth >= max_depth: raise StopIteration if isFunctionNode(x): # NOTE: WE don't need to handle BV here since they are handled below when we use the rule original_x = copy(x) # go all odometer on the kids below:: iters = [ self.increment_tree_( x=y, depth=depth, max_depth=max_depth, depthdict=depthdict) if self.is_nonterminal(y) else None for y in x.args ] if len(iters) == 0: yield copy(x) else: #print "HERE", iters for i in xrange(len(iters)): if iters[i] is not None: x.args[i] = iters[i].next() # the index of the last terminal symbol (may not be len(iters)-1), last_terminal_idx = max([ i if iters[i] is not None else -1 for i in xrange(len(iters)) ]) ## Now loop through the args, counting them up while True: yield copy( x ) # yield the initial tree, and then each successive tree # and then process each carry: for carry_pos in xrange( len(iters) ): # index into which tree we are incrementing if iters[ carry_pos] is not None: # we are not a terminal symbol (mixed in) ## NOTE: This *MUST* go here in order to prevent adding a rule and then not removing it when you carry (thus introducing a bv of a1 into a2) with BVRuleContextManager(self, x.added_rule): try: x.args[carry_pos] = iters[carry_pos].next() break # if we increment successfully, no carry, so break out of i loop except StopIteration: # if so, then "carry" if carry_pos == last_terminal_idx: raise StopIteration elif iters[carry_pos] is not None: # reset the incrementer since we just carried iters[ carry_pos] = self.increment_tree_( x=original_x.args[carry_pos], depth=depth, max_depth=max_depth, depthdict=depthdict) x.args[carry_pos] = iters[ carry_pos].next() # reset this # and just continue your loop over i (which processes the carry) elif self.is_nonterminal(x): # just a single nonterminal ## TODO: somewhat inefficient since we do this each time: ## Here we change the order of rules to be terminals *first* terminals = [] nonterminals = [] for k in self.rules[x]: if not self.is_terminal_rule( k ): #AAH this used to be called "x" and that ruined the scope of the outer "x" nonterminals.append(k) else: terminals.append(k) # sort by probability, so high probability trees *tend* to come first terminals = sorted( terminals, key=lambda r: self.depth_to_terminal(r, current_d=depthdict)) nonterminals = sorted( nonterminals, key=lambda r: self.depth_to_terminal(r, current_d=depthdict)) Z = logsumexp([log(r.p) for r in self.rules[x]]) # normalizer #print terminals #print nonterminals #print "---------------------------------------" # yield each of the rules that lead to terminals -- always do this since depth>=0 (above) for r in terminals: fn = r.make_FunctionNodeStub(self, (log(r.p) - Z)) # Do not need to set added_rule since they can't exist here yield fn if depth < max_depth: # if we can go deeper for r in nonterminals: #expand each nonterminals fn = r.make_FunctionNodeStub(self, (log(r.p) - Z)) for q in self.increment_tree_(x=fn, depth=depth + 1, max_depth=max_depth, depthdict=depthdict): yield q else: yield x
def increment_tree_(self, x=None, depth=0, max_depth=Infinity, depthdict=None): """ A lazy version of tree enumeration. Here, we generate all trees, starting from a rule or a nonterminal symbol and going up to max_depth This is constant memory and should produce each tree *once* (However: if a grammar has multiple derivations of the same str(tree), then you will see repeats!). TODO: CHANGE THIS TO ENUMERATE SHALLOW->DEEP *x*: A node in the tree *depth*: Depth of the tree *depthdict* : memoizes depth_to_terminal so that we can order rules in order to make enumeration small->large """ # wrap no specification for x if depth >= max_depth: raise StopIteration if isFunctionNode(x): # NOTE: WE don't need to handle BV here since they are handled below when we use the rule original_x = copy(x) # go all odometer on the kids below:: iters = [self.increment_tree_(x=y,depth=depth,max_depth=max_depth, depthdict=depthdict) if self.is_nonterminal(y) else None for y in x.args] if len(iters) == 0: yield copy(x) else: #print "HERE", iters for i in xrange(len(iters)): if iters[i] is not None: x.args[i] = iters[i].next() # the index of the last terminal symbol (may not be len(iters)-1), last_terminal_idx = max( [i if iters[i] is not None else -1 for i in xrange(len(iters))] ) ## Now loop through the args, counting them up while True: yield copy(x) # yield the initial tree, and then each successive tree # and then process each carry: for carry_pos in xrange(len(iters)): # index into which tree we are incrementing if iters[carry_pos] is not None: # we are not a terminal symbol (mixed in) ## NOTE: This *MUST* go here in order to prevent adding a rule and then not removing it when you carry (thus introducing a bv of a1 into a2) with BVRuleContextManager(self, x.added_rule): try: x.args[carry_pos] = iters[carry_pos].next() break # if we increment successfully, no carry, so break out of i loop except StopIteration: # if so, then "carry" if carry_pos == last_terminal_idx: raise StopIteration elif iters[carry_pos] is not None: # reset the incrementer since we just carried iters[carry_pos] = self.increment_tree_(x=original_x.args[carry_pos],depth=depth,max_depth=max_depth, depthdict=depthdict) x.args[carry_pos] = iters[carry_pos].next() # reset this # and just continue your loop over i (which processes the carry) elif self.is_nonterminal(x): # just a single nonterminal ## TODO: somewhat inefficient since we do this each time: ## Here we change the order of rules to be terminals *first* terminals = [] nonterminals = [] for k in self.rules[x]: if not self.is_terminal_rule(k): #AAH this used to be called "x" and that ruined the scope of the outer "x" nonterminals.append(k) else: terminals.append(k) # sort by probability, so high probability trees *tend* to come first terminals = sorted(terminals, key=lambda r: self.depth_to_terminal(r, current_d=depthdict) ) nonterminals = sorted(nonterminals, key=lambda r: self.depth_to_terminal(r, current_d=depthdict) ) Z = logsumexp([ log(r.p) for r in self.rules[x]] ) # normalizer #print terminals #print nonterminals #print "---------------------------------------" # yield each of the rules that lead to terminals -- always do this since depth>=0 (above) for r in terminals: fn = r.make_FunctionNodeStub(self, (log(r.p) - Z)) # Do not need to set added_rule since they can't exist here yield fn if depth < max_depth: # if we can go deeper for r in nonterminals:#expand each nonterminals fn = r.make_FunctionNodeStub(self, (log(r.p) - Z)) for q in self.increment_tree_(x=fn, depth=depth+1,max_depth=max_depth, depthdict=depthdict): yield q else: yield x