def visit_Call(self, node): if isinstance(node.func, ast.Name): # Catch bool, and interpret it as a set emptyness check. if node.func.id == 'bool': if len(node.args) != 1: raise PythonToDARhiError('Invalid call') iter = self.visit(node.args[0]) if self.objectdomain: return dha.Match(dha.PatMatch(dha.genUnboundVar(st.getFreshNSymbol()), iter), dka.Symtab()) else: dka.assertnodetype(iter, dha.Name) return dha.Match(dha.PatMatchName(dha.genUnboundVar(st.getFreshNSymbol()), iter.id), dka.Symtab()) # Catch hasattr. elif node.func.id == 'hasattr': if len(node.args) != 2: raise PythonToDARhiError('Invalid call') if not isinstance(node.args[1], ast.Str): raise PythonToDARhiError('Reflection not allowed in hasattr') return dha.HasAttr(self.visit(node.args[0]), node.args[1].s) elif isinstance(node.func, ast.Attribute): # Catch any. if node.func.attr == 'any': if len(node.args) != 0: raise PythonToDARhiError('Invalid call') return dha.UnaryOp(dha.Any(), self.visit(node.func.value)) return self.call_helper(node, isstmt=False)
def visit_SelAttr(self, node): self.generic_visit(node) dka.assertnodetype(node.path, dha.SelName) v = self.fieldvar_helper(node.path.id, node.attr) newnode = dha.SelName(v) newnode._t = v.t return newnode
def isBoundPatVar(node): dka.assertnodetype(node, PatVar) if node.bnd is P_BOUND: return True elif node.bnd is P_UNBOUND: return False else: assert()
def isAddUpdate(node): """Test update type.""" dka.assertnodetype(node, SetUpdate) if dka.hasnodetype(node.op, UpAdd): return True elif dka.hasnodetype(node.op, UpRemove): return False else: assert()
def visit_RefUpdate(self, node): dka.assertnodetype(node.op, {dha.UpAdd, dha.UpRemove}) func = 'incref' if dka.hasnodetype(node.op, dha.UpAdd) \ else 'decref' return ast.Call(ast.Attribute(ast.Name(node.target.name, ast.Load()), func, ast.Load()), [self.visit(node.value)], [], None, None)
def visit_SetUpdate(self, node): dka.assertnodetype(node.op, {dha.UpAdd, dha.UpRemove}) op = 'add' if dka.hasnodetype(node.op, dha.UpAdd) \ else 'remove' return ast.Call(ast.Attribute(ast.Name(node.target.name, ast.Load()), op, ast.Load()), [self.visit(node.value)], [], None, None)
def eliminatedecomp(decomp, code): dka.assertnodetype(decomp, dha.Assign) bindings = {} if tuplematcher(decomp.target, decomp.value, bindings): st.replaceSymbols(code, bindings) return code else: return [decomp] + code
def hasWildcards(pattern): """Return whether a pattern contains any wildcards.""" dka.assertnodetype(pattern, dha.pattern) if dka.hasnodetype(pattern, dha.PatIgnore): return True elif dka.hasnodetype(pattern, dha.PatTuple): return any(hasWildcards(e) for e in pattern.elts) else: return False
def visit_For(self, node): target = self.pattern_helper(node.target) iter = self.visit(node.iter) body = self.stmtlist_helper(node.body) orelse = self.stmtlist_helper(node.orelse) if self.objectdomain: return dha.For(dha.PatMatch(target, iter), body, orelse) else: dka.assertnodetype(iter, dha.Name) return dha.For(dha.PatMatchName(target, iter.id), body, orelse)
def visit_SetUpdate(self, node): self.generic_visit(node) ssym = node.target if ismember(ssym): cont, val = matchpair(dka.copy(node.value)) dka.assertnodetype(cont, {dha.Name, dha.PatVar}) newnode = dha.SetUpdate(cont.id, dka.copy(node.op), val) return [newnode, node] return node
def visit_AttrEnum(self, node): self.fieldenums = [] self.generic_visit(node) dka.assertnodetype(node.iter, dha.SelName) et = dha.AttrEnum if dha.isPosEnum(node) else dha.NegAttrEnum e = et(node.target, dha.SelName(node.iter.id)) newenums = self.fieldenums_helper() return newenums + [e]
def visit_AttrEnum(self, node): dka.assertnodetype(node.iter, dha.SelName) ssym = node.iter.id et = dha.Enum if dha.isPosEnum(node) else dha.NegEnum if (isinstance(ssym, MSymbol) or dha.isfixset(ssym)): return et(node.target, ssym) iter = dha.genBoundVar(ssym) target = self.visit(node.target) return et(dha.PatTuple([iter, target]), self.getMemSym(target._t))
def visit_If(self, node): conds = [] bodies = [] for case in node.cases: dka.assertnodetype(case, dha.CondCase) conds.append(self.visit(case.cond)) bodies.append(self.visit(case.body)) # Construct the tree of nested if/elifs inside out. tree = self.visit(node.orelse) if node.orelse else [] for cond, body in zip(reversed(conds), reversed(bodies)): tree = [ast.If(cond, body, tree)] return tree[0]
def visit_Attribute(self, node): # Under normal circumstances, replace with a lookup into # the appropriate field metaset. if not self.incomp: self.generic_visit(node) fs = self.getFieldSym(node.attr, node.value._t) code = dha.Pick2nd(fs, node.value) return code # In comprehension result or condition expressions, replace with # the appropriate pair-domain variable. else: # ... in a post-traversal manner. self.generic_visit(node) dka.assertnodetype(node.value, dha.Name) v = self.fieldvar_helper(node.value.id, node.attr) return dha.Name(v)
def visit_SetUpdate(self, node): self.generic_visit(node) ssym = node.target if isfield(ssym): attr = ssym.attr cont, val = matchpair(dka.copy(node.value)) dka.assertnodetype(cont, {dha.Name, dha.PatVar}) if dha.isAddUpdate(node): newnode = dha.AttrUpdate(cont.id, attr, val) else: newnode = dha.DelAttr(cont.id, attr) return [newnode, node] return node
def visit_InvDef(self, node): dka.assertnodetype(node.value, dha.RelSetComp) self.generic_visit(node) compnode = node.value info = node.id.info params = info.enumparams repldict = {sym: dha.VSymbol(sym.name + '_local') for sym in params} if len(params) > 0: paramtup = du.genDTuple([dha.Name(sym) for sym in params]) newelt = dha.Tuple([paramtup, dka.copy(compnode.elt)]) compnode.elt = newelt else: pass st.cleanPatterns(compnode) st.replaceSymbols(compnode, repldict)
def visit_While(self, node): body = self.stmtlist_helper(node.body) orelse = self.stmtlist_helper(node.orelse) # Membership tests become pattern matches. if (isinstance(node.test, ast.Compare) and len(node.test.ops) == len(node.test.comparators) == 1 and isinstance(node.test.ops[0], ast.In)): left = self.visit(node.test.left) right = self.visit(node.test.comparators[0]) if self.objectdomain: node = dha.PatWhile(dha.PatMatch(dha.valueToPattern(left), right), body, orelse) else: dka.assertnodetype(iter, dha.Name) return dha.PatWhile(dha.PatMatchName(dha.valueToPattern(left), right.id), body, orelse) else: test = self.visit(node.test) node = dha.While(test, body, orelse) return node
def visit_If(self, node): body = self.stmtlist_helper(node.body) orelse = self.stmtlist_helper(node.orelse) # Membership tests become pattern matches. if (isinstance(node.test, ast.Compare) and len(node.test.ops) == len(node.test.comparators) == 1 and isinstance(node.test.ops[0], ast.In)): left = self.visit(node.test.left) right = self.visit(node.test.comparators[0]) if self.objectdomain: case = dha.PatCase(dha.PatMatch(dha.valueToPattern(left), right), body) else: dka.assertnodetype(right, dha.Name) case = dha.PatCase(dha.PatMatchName(dha.valueToPattern(left, right.od)), body) else: test = self.visit(node.test) case = dha.CondCase(test, body) return dha.If([case], orelse)
def visit_Compare(self, node): # Make sure only simple comparisons are used. if len(node.comparators) > 1: raise PythonToDARhiError('Complex comparison used') # Turn In, Not In into Match nodes. op = self.visit(node.ops[0]) if dka.hasnodetype(op, {dha.In, dha.NotIn}): target = self.pattern_helper(node.left) iter = self.visit(node.comparators[0]) if self.objectdomain: code = dha.Match(dha.PatMatch(target, iter), dka.Symtab()) else: dka.assertnodetype(iter, dha.Name) code = dha.Match(dha.PatMatchName(target, iter.id), dka.Symtab()) if dka.hasnodetype(op, {dha.NotIn}): code = dha.UnaryOp(dha.Not(), code) return code else: left = self.visit(node.left) right = self.visit(node.comparators[0]) return dha.BinOp(left, op, right)
def isPosEnum(node): """Test enumerator type.""" dka.assertnodetype(node, {RelEnum, RelNegEnum}) return dka.hasnodetype(node, RelEnum)
def genTagInfo(comp): info = comp.info compnode = comp.getCompNode() uset = info.uset if uset is None: return enums = compnode.enums assert(enums[0].iter is uset) rootsyms = set(st.gatherSymbols(enums[0].target)) unconsparams = info.unconsparams assert(rootsyms == set(unconsparams)) taginfo = TagInfo() info.taginfo = taginfo taginfo.compinfo = info pairenums = enums[1:] taginfo.As = As = [] taginfo.Bs = Bs = [] taginfo.Rs = Rs = [] taginfo.DRs = DRs = [] for i, enum in enumerate(pairenums): dka.assertnodetype(enum.target, dha.PatTuple) elts = enum.target.elts assert(len(elts) == 2) dka.assertnodetype(elts[0], dha.PatVar) dka.assertnodetype(elts[1], dha.PatVar) cont, elem = elts[0].id, elts[1].id rel = enum.iter As.append(cont) Bs.append(elem) Rs.append(rel) demrel = TagSym('Dm_' + info.compsym.id + '_' + str(i+1)) DRs.append(demrel) taginfo.tagcomps = tagcomps = [] # Generate auxiliary demand comprehensions. for i, (a, b, r, dr) in enumerate(zip(As, Bs, Rs, DRs)): tagenums = [] # Add U pred. if a in unconsparams: k = unconsparams.index(a) ig1 = [dha.PatIgnore() for z in range(0, i)] ig2 = [dha.PatIgnore() for z in range(i+1, len(unconsparams))] tup = du.genDPatTuple(ig1 + [dha.genUnboundVar(a)] + ig2) e = dha.Enum(tup, uset) tagenums.append(e) # Add other preds. for a2, b2, r2, dr2 in list(zip(As, Bs, Rs, DRs))[:i]: if b2 is a: tup = dha.PatTuple([dha.PatIgnore(), dha.genUnboundVar(b2)]) e = dha.Enum(tup, dr2) tagenums.append(e) # Add orig relation. assert(len(tagenums) > 0) tup = dha.PatTuple([dha.genUnboundVar(a), dha.genUnboundVar(b)]) e = dha.Enum(tup, r) tagenums.append(e) comp = dha.SetComp(dha.Tuple([dha.Name(a), dha.Name(b)]), tagenums, None, dka.Symtab()) instmapping = {s: dha.VSymbol(s.name + '_t' + str(i)) for s in set(As) | set(Bs)} st.replaceSymbols(comp, instmapping) compdef = dha.SetCompDef(dr, comp) tagcomps.append(compdef) return taginfo
def matchpair(pair): dka.assertnodetype(pair, {dha.Tuple, dha.PatTuple}) assert(len(pair.elts) == 2) return pair.elts[0], pair.elts[1]
def desingletonize(node): dka.assertnodetype(node, {dha.Tuple, dha.PatTuple}) return node if len(node.elts) != 1 else node.elts[0]
def run(self, node): dka.assertnodetype(node, pattern) return super().run(node)
def generic_visit(self, node): dka.assertnodetype(node, pattern) super().generic_visit(node)