class DefineLanguage_IdRewriter(pattern.PatternTransformer): def __init__(self, definelanguage): assert isinstance(definelanguage, tlform.DefineLanguage) super().__init__() self.definelanguage = definelanguage self.symgen = SymGen() def run(self): ntdefs = [] for nt, ntdef in self.definelanguage.nts.items(): npats = [] for pat in ntdef.patterns: npat = self.transform(pat) npat.copyattributesfrom(pat) npats.append(npat) ntdefs.append(tlform.DefineLanguage.NtDefinition(ntdef.nt, npats)) return tlform.DefineLanguage(self.definelanguage.name, ntdefs) def transformBuiltInPat(self, node): assert isinstance(node, pattern.BuiltInPat) nsym = self.symgen.get(node.prefix+'_') return pattern.BuiltInPat(node.kind, node.prefix, nsym).copyattributesfrom(node) def transformNt(self, node): assert isinstance(node, pattern.Nt) nsym = self.symgen.get(node.prefix+'_') return pattern.Nt(node.prefix, nsym).copyattributesfrom(node)
class Pattern_ConstraintCheckInserter(pattern.PatternTransformer): def __init__(self, pattern): self.pattern = pattern self.symgen = SymGen() self.symstoremove = [] def run(self): pat, _ = self.transform(self.pattern) pat.addattribute(pattern.PatternAttribute.PatternVariablesToRemove, self.symstoremove) return pat def _merge_variable_maps(self, m1, m2): m1k = set(list(m1.keys())) m2k = set(list(m2.keys())) intersection = m1k.intersection(m2k) nmap, syms2check = {}, [] for k in intersection: syms2check.append((m1[k], m2[k])) nmap[k] = m2[k] for k in m1k: if k not in intersection: nmap[k] = m1[k] for k in m2k: nmap[k] = m2[k] return nmap, syms2check def transformInHole(self, node): assert isinstance(node, pattern.InHole) npat1, syms1 = self.transform(node.pat1) npat2, syms2 = self.transform(node.pat2) constraintchecks = [] syms, syms2check = self._merge_variable_maps(syms1, syms2) for sym1, sym2 in syms2check: constraintchecks.append(pattern.CheckConstraint(sym1, sym2)) return pattern.InHole(npat1, npat2, constraintchecks).copyattributesfrom(node), syms def transformPatSequence(self, node): assert isinstance(node, pattern.PatSequence) if len(node.seq) == 0: return node, {} nseq = [] npat, syms = self.transform(node.seq[0]) nseq.append(npat) for pat in node.seq[1:]: npat, nsyms = self.transform(pat) nseq.append(npat) syms, syms2check = self._merge_variable_maps(syms, nsyms) for sym1, sym2 in syms2check: nseq.append(pattern.CheckConstraint(sym1, sym2)) return pattern.PatSequence(nseq).copyattributesfrom(node), syms def transformRepeat(self, node): assert isinstance(node, pattern.Repeat) pat, syms = self.transform(node.pat) return pattern.Repeat(pat, node.matchmode).copyattributesfrom(node), syms def transformNt(self, node): assert isinstance(node, pattern.Nt) # First time we see desired symbol we do not rename it - we will keep it in the end. nsym = self.symgen.get('{}#'.format(node.sym)) if nsym == '{}#0'.format(node.sym): nsym = node.sym else: self.symstoremove.append(nsym) return pattern.Nt(node.prefix, nsym).copyattributesfrom(node), { node.sym: nsym } def transformBuiltInPat(self, node): assert isinstance(node, pattern.BuiltInPat) # First time we see desired symbol we do not rename it - we will keep it in the end. # Also we never bind holes. if node.kind != pattern.BuiltInPatKind.Hole: nsym = self.symgen.get('{}#'.format(node.sym)) if nsym == '{}#0'.format(node.sym): nsym = node.sym else: self.symstoremove.append(nsym) return pattern.BuiltInPat(node.kind, node.prefix, nsym).copyattributesfrom(node), { node.sym: nsym } return node, {} def transformLit(self, pat): return pat, {} def transformCheckConstraint(self, node): return node, {}
class Term_EllipsisDepthChecker(term.TermTransformer): def __init__(self, variables, idof, context): self.idof = idof self.path = [] self.context = context self.variables = variables self.symgen = SymGen() # stores annotations that will be injected into term-template after # visiting all the children. self.annotations = { term.TermAttribute.MatchRead: {}, term.TermAttribute.InArg: {}, term.TermAttribute.ForEach: {}, } def add_annotation_to(self, node, attribute, value): attributedict = self.annotations[attribute] if node not in attributedict: attributedict[node] = [] attributedict[node].append(value) def complete_annotation(self, oldnode, newnode): if isinstance(oldnode, term.Repeat): attributedict = self.annotations[term.TermAttribute.ForEach] contents = attributedict.get(oldnode, []) return newnode.addattribute(term.TermAttribute.ForEach, contents) attributedict = self.annotations[term.TermAttribute.MatchRead] contents = attributedict.get(oldnode, []) newnode.addattribute(term.TermAttribute.MatchRead, contents) attributedict = self.annotations[term.TermAttribute.InArg] contents = attributedict.get(oldnode, []) return newnode.addattribute(term.TermAttribute.InArg, contents) def contains_nonzero_foreach_annotations(self, node): assert isinstance(node, term.Repeat) attributedict = self.annotations[term.TermAttribute.ForEach] contents = attributedict.get(node, []) return len(contents) != 0 def transform(self, element): assert isinstance(element, term.Term) method_name = 'transform' + element.__class__.__name__ method_ref = getattr(self, method_name) self.path.append(element) result = method_ref(element) assert isinstance(result, term.Term) self.path.pop() return result def transformTermLiteral(self, literal): assert isinstance(literal, term.TermLiteral) #self.context.add_lit_term(literal) return literal def transformPyCall(self, pycall): assert isinstance(pycall, term.PyCall) terms = [] for t in pycall.termargs: transformer = Term_EllipsisDepthChecker(self.variables, '', self.context) terms.append(transformer.transform(t)) return self.complete_annotation( pycall, term.PyCall(pycall.mode, pycall.functionname, terms)) def transformRepeat(self, repeat): assert isinstance(repeat, term.Repeat) nrepeat = term.Repeat(self.transform( repeat.term)).copyattributesfrom(repeat) if not self.contains_nonzero_foreach_annotations(repeat): raise Exception('too many ellipses in template {}'.format( repr(nrepeat))) return self.complete_annotation(repeat, nrepeat) def transformTermSequence(self, termsequence): ntermsequence = super().transformTermSequence(termsequence) return self.complete_annotation(termsequence, ntermsequence) def transformInHole(self, inhole): ninhole = super().transformInHole(inhole) return self.complete_annotation(inhole, ninhole) def transformUnresolvedSym(self, node): assert isinstance(node, term.UnresolvedSym) if node.sym not in self.variables: t = term.TermLiteral(term.TermLiteralKind.Variable, node.sym) #self.context.add_lit_term(t) return t expecteddepth = self.variables[node.sym] actualdepth = 0 param = self.symgen.get(node.sym) # definitely a pattern variable now, topmost entry on the stack is this node. for t in reversed(self.path): if isinstance(t, term.UnresolvedSym): if expecteddepth == 0: self.add_annotation_to(t, term.TermAttribute.MatchRead, (node.sym, param)) break self.add_annotation_to(t, term.TermAttribute.InArg, param) if isinstance(t, term.TermSequence) or isinstance(t, term.InHole): if expecteddepth == actualdepth: self.add_annotation_to(t, term.TermAttribute.MatchRead, (node.sym, param)) break else: self.add_annotation_to(t, term.TermAttribute.InArg, param) if isinstance(t, term.Repeat): actualdepth += 1 self.add_annotation_to(t, term.TermAttribute.ForEach, (param, actualdepth)) if actualdepth != expecteddepth: raise Exception( 'inconsistent ellipsis depth for pattern variable {}: expected {} actual {}' .format(node.sym, expecteddepth, actualdepth)) return self.complete_annotation(node, term.PatternVariable(node.sym))
class CompilationContext: def __init__(self): self.__variables_mentioned = {} self.__isa_functions = {} self.__pattern_code = {} self.__term_template_funcs = {} self._litterms = {} self.__toplevel_patterns = {} self.__reductionrelations = {} self.__metafuctions = {} self.__redexmatches = {} self.symgen = SymGen() def add_variables_mentioned(self, languagename, variables): assert languagename not in self.__variables_mentioned self.__variables_mentioned[languagename] = ( '{}_variables_mentioned'.format(languagename), variables) def get_variables_mentioned(self, languagename): assert languagename in self.__variables_mentioned return self.__variables_mentioned[languagename] def get_variables_mentioned_all(self): return self.__variables_mentioned.values() def add_lit_term(self, term): self._litterms[term] = self.symgen.get('literal_term_') def add_isa_function_name(self, languagename, patrepr, functionname): k = (languagename, patrepr) assert k not in self.__isa_functions self.__isa_functions[k] = functionname def get_isa_function_name(self, languagename, patrepr): k = (languagename, patrepr) if k in self.__isa_functions: return self.__isa_functions[k] return None def get_sym_for_lit_term(self, term): return self._litterms[term] def add_function_for_pattern(self, languagename, patrepr, functionname): k = (languagename, patrepr) assert k not in self.__pattern_code, 'function for {}-{} is present'.format( languagename, patrepr) self.__pattern_code[k] = functionname def get_function_for_pattern(self, languagename, patrepr): k = (languagename, patrepr) if k in self.__pattern_code: return self.__pattern_code[k] return None # generated and returns procedurename, boolean for given pattern. If boolean is True - the code for # pattern has already been generated. def get_function_for_pattern_2(self, languagename, patrepr): k = (languagename, patrepr) if k in self.__pattern_code: return self.__pattern_code[k], True self.__pattern_code[k] = self.symgen.get('pattern_match') return self.__pattern_code[k], False def add_toplevel_function_for_pattern(self, languagename, patrepr, functionname): k = (languagename, patrepr) assert k not in self.__toplevel_patterns, 'function for {}-{} is present'.format( languagename, patrepr) self.__toplevel_patterns[k] = functionname def get_toplevel_function_for_pattern(self, languagename, patrepr): k = (languagename, patrepr) if k in self.__toplevel_patterns: return self.__toplevel_patterns[k] return None def get_function_for_term_template(self, term_template): assert isinstance(term_template, term.Term) if term_template not in self.__term_template_funcs: name = self.symgen.get('gen_term') self.__term_template_funcs[term_template] = name return self.__term_template_funcs[term_template] def add_reduction_relation(self, reductionrelationname, function): k = reductionrelationname assert k not in self.__reductionrelations self.__reductionrelations[k] = function def get_reduction_relation(self, reductionrelationname): k = reductionrelationname if k in self.__reductionrelations: return self.__reductionrelations[k] return None def add_metafunction(self, mfname, function): k = mfname assert k not in self.__metafuctions self.__metafuctions[k] = function def get_metafunction(self, mfname): k = mfname if k in self.__metafuctions: return self.__metafuctions[k] return None def get_redexmatch_for(self, form): return self.__redexmatches[form] def add_redexmatch_for(self, form, name): self.__redexmatches[form] = name
class TopLevelFormCodegen(tlform.TopLevelFormVisitor): def __init__(self, module, context): assert isinstance(module, tlform.Module) assert isinstance(context, CompilationContext) self.module = module self.context = context self.symgen = SymGen() self.modulebuilder = rpy.BlockBuilder() self.main_procedurecalls = [] def run(self): self.modulebuilder.IncludeFromPythonSource('runtime/term.py') self.modulebuilder.IncludeFromPythonSource('runtime/parser.py') self.modulebuilder.IncludeFromPythonSource('runtime/fresh.py') self.modulebuilder.IncludeFromPythonSource('runtime/match.py') # parse all term literals. # ~~ 26.07.2020 disable lit terms for now, need to implement # nt caching acceleration technique first. """ tmp0, tmp1 = rpy.gen_pyid_temporaries(2, self.symgen) for trm, sym1 in self.context._litterms.items(): sym1 = rpy.gen_pyid_for(sym1) self.modulebuilder.AssignTo(tmp0).New('Parser', rpy.PyString(repr(trm))) self.modulebuilder.AssignTo(sym1).MethodCall(tmp0, 'parse') """ # variable-not-otherwise-mentioned of given define language for ident, variables in self.context.get_variables_mentioned_all(): ident = rpy.gen_pyid_for(ident) variables = map(lambda v: rpy.PyString(v), variables) self.modulebuilder.AssignTo(ident).PySet(*variables) for form in self.module.tlforms: self._visit(form) # generate main fb = rpy.BlockBuilder() symgen = SymGen() ## emit some dummy Terms to aid RPython with type inference. tmp0, tmp1, tmp2, tmp3, tmp4 = rpy.gen_pyid_temporaries(5, symgen) fb.AssignTo(tmp0).New('Integer', rpy.PyInt(0)) fb.AssignTo(tmp1).New('Float', rpy.PyFloat(0.0)) fb.AssignTo(tmp2).New('String', rpy.PyString("\"hello world!\"")) fb.AssignTo(tmp3).New('Boolean', rpy.PyString("#f")) fb.AssignTo(tmp4).New('Variable', rpy.PyString("x")) for procedure in self.main_procedurecalls: tmpi = rpy.gen_pyid_temporaries(1, symgen) fb.AssignTo(tmpi).FunctionCall(procedure) fb.Return(rpy.PyInt(0)) self.modulebuilder.Function('entrypoint').WithParameters( rpy.PyId('argv')).Block(fb) #required entry procedure for Rpython. fb = rpy.BlockBuilder() fb.Return(rpy.PyTuple(rpy.PyId('entrypoint'), rpy.PyNone())) self.modulebuilder.Function('target').WithParameters( rpy.PyVarArg('args')).Block(fb) # if __name__ == '__main__': entrypoint() # for python2.7 compatibility. ifb = rpy.BlockBuilder() tmp = rpy.gen_pyid_temporaries(1, self.symgen) ifb.AssignTo(tmp).FunctionCall('entrypoint', rpy.PyList()) self.modulebuilder.If.Equal(rpy.PyId('__name__'), rpy.PyString('__main__')).ThenBlock(ifb) return rpy.Module(self.modulebuilder.build()) def _codegenNtDefinition(self, languagename, ntdef): assert isinstance(ntdef, tlform.DefineLanguage.NtDefinition) for pat in ntdef.patterns: if self.context.get_toplevel_function_for_pattern( languagename, repr(pat)) is None: PatternCodegen(self.modulebuilder, pat, self.context, languagename, self.symgen).run() nameof_this_func = 'lang_{}_isa_nt_{}'.format(languagename, ntdef.nt.prefix) term, match, matches = rpy.gen_pyid_for('term', 'match', 'matches') # for each pattern in ntdefinition # match = Match(...) # matches = matchpat(term, match, 0, 1) # if len(matches) != 0: # return True fb = rpy.BlockBuilder() for pat in ntdef.patterns: func2call = self.context.get_toplevel_function_for_pattern( languagename, repr(pat)) ifb = rpy.BlockBuilder() ifb.Return(rpy.PyBoolean(True)) fb.AssignTo(matches).FunctionCall(func2call, term) fb.If.LengthOf(matches).NotEqual(rpy.PyInt(0)).ThenBlock(ifb) fb.Return(rpy.PyBoolean(False)) self.modulebuilder.Function(nameof_this_func).WithParameters( term).Block(fb) def _visitDefineLanguage(self, form): assert isinstance(form, tlform.DefineLanguage) # generate hole for each language. Need this for term annotation. hole = rpy.gen_pyid_for('{}_hole'.format(form.name)) self.modulebuilder.AssignTo(hole).New('Hole') # first insert isa_nt functions intocontext for ntsym, ntdef in form.nts.items(): nameof_this_func = 'lang_{}_isa_nt_{}'.format(form.name, ntsym) self.context.add_isa_function_name(form.name, ntdef.nt.prefix, nameof_this_func) for nt in form.nts.values(): self._codegenNtDefinition(form.name, nt) def _visitRequirePythonSource(self, form): assert isinstance(form, tlform.RequirePythonSource) self.modulebuilder.IncludeFromPythonSource(form.filename) def _visitRedexMatch(self, form, callself=True): assert isinstance(form, tlform.RedexMatch) assert False if self.context.get_toplevel_function_for_pattern( form.languagename, repr(form.pat)) is None: PatternCodegen(self.modulebuilder, form.pat, self.context, form.languagename, self.symgen).run() TermCodegen(self.modulebuilder, self.context).transform(form.termstr) termfunc = self.context.get_function_for_term_template(form.termstr) matchfunc = self.context.get_toplevel_function_for_pattern( form.languagename, repr(form.pat)) symgen = SymGen() matches, match, term = rpy.gen_pyid_for('matches', 'match', 'term') tmp0 = rpy.gen_pyid_temporaries(1, symgen) fb = rpy.BlockBuilder() fb.AssignTo(tmp0).New('Match') fb.AssignTo(term).FunctionCall(termfunc, tmp0) fb.AssignTo(matches).FunctionCall(matchfunc, term) fb.Print(matches) fb.Return(matches) # call redex-match itself. nameof_this_func = self.symgen.get('redexmatch') self.context.add_redexmatch_for(form, nameof_this_func) self.modulebuilder.Function(nameof_this_func).Block(fb) if callself: tmp0 = rpy.gen_pyid_temporaries(1, self.symgen) self.modulebuilder.AssignTo(tmp0).FunctionCall(nameof_this_func) def _visitRedexMatchAssertEqual(self, form): def gen_matches(expectedmatches, fb, symgen): processedmatches = [] for m in expectedmatches: tmp0 = rpy.gen_pyid_temporaries(1, symgen) fb.AssignTo(tmp0).New('Match') processedmatches.append(tmp0) for sym, termx in m.bindings: tmp1, tmp2, tmp3, tmp4 = rpy.gen_pyid_temporaries( 4, symgen) TermCodegen(self.modulebuilder, self.context).transform(termx) termfunc = self.context.get_function_for_term_template( termx) fb.AssignTo(tmp1).New('Match') fb.AssignTo(tmp2).FunctionCall(termfunc, tmp1) fb.AssignTo(tmp3).MethodCall(tmp0, MatchMethodTable.AddKey, rpy.PyString(sym)) fb.AssignTo(tmp4).MethodCall(tmp0, MatchMethodTable.AddToBinding, rpy.PyString(sym), tmp2) tmpi = rpy.gen_pyid_temporaries(1, symgen) fb.AssignTo(tmpi).PyList(*processedmatches) return tmpi assert isinstance(form, tlform.RedexMatchAssertEqual) if self.context.get_toplevel_function_for_pattern( form.languagename, repr(form.pat)) is None: PatternCodegen(self.modulebuilder, form.pat, self.context, form.languagename, self.symgen).run() TermCodegen(self.modulebuilder, self.context).transform(form.termtemplate) matchfunc = self.context.get_toplevel_function_for_pattern( form.languagename, repr(form.pat)) termfunc = self.context.get_function_for_term_template( form.termtemplate) symgen = SymGen() matches, match, term = rpy.gen_pyid_for('matches', 'match', 'term') fb = rpy.BlockBuilder() expectedmatches = gen_matches(form.expectedmatches, fb, symgen) tmp0, tmp1, tmp2 = rpy.gen_pyid_temporaries(3, symgen) fb.AssignTo(tmp0).New('Match') fb.AssignTo(term).FunctionCall(termfunc, tmp0) fb.AssignTo(matches).FunctionCall(matchfunc, term) fb.AssignTo(tmp1).FunctionCall('assert_compare_match_lists', matches, expectedmatches) fb.AssignTo(tmp2).FunctionCall(MatchHelperFuncs.PrintMatchList, matches) fb.Return(matches) nameof_this_func = self.symgen.get('redexmatchassertequal') self.context.add_redexmatch_for(form, nameof_this_func) self.modulebuilder.Function(nameof_this_func).Block(fb) self.main_procedurecalls.append(nameof_this_func) def _visitTermLetAssertEqual(self, form): assert isinstance(form, tlform.TermLetAssertEqual) template = form.template TermCodegen(self.modulebuilder, self.context).transform(template) templatetermfunc = self.context.get_function_for_term_template( template) TermCodegen(self.modulebuilder, self.context).transform(form.expected) expectedtermfunc = self.context.get_function_for_term_template( form.expected) fb = rpy.BlockBuilder() symgen = SymGen() expected, match = rpy.gen_pyid_for('expected', 'match') tmp0 = rpy.gen_pyid_temporaries(1, symgen) fb.AssignTo(tmp0).New('Match') fb.AssignTo(expected).FunctionCall(expectedtermfunc, tmp0) fb.AssignTo(match).New('Match') for variable, term in form.variableassignments.items(): tmp1, tmp2, tmp3, tmp4 = rpy.gen_pyid_temporaries(4, symgen) TermCodegen(self.modulebuilder, self.context).transform(term) termfunc = self.context.get_function_for_term_template(term) fb.AssignTo(tmp1).New('Match') fb.AssignTo(tmp2).FunctionCall(termfunc, tmp1) fb.AssignTo(tmp3).MethodCall(match, MatchMethodTable.AddKey, rpy.PyString(variable)) fb.AssignTo(tmp4).MethodCall(match, MatchMethodTable.AddToBinding, rpy.PyString(variable), tmp2) tmp0, tmp1, tmp2 = rpy.gen_pyid_temporaries(3, symgen) fb.AssignTo(tmp0).FunctionCall(templatetermfunc, match) fb.AssignTo(tmp1).FunctionCall('asserttermsequal', tmp0, expected) fb.AssignTo(tmp2).FunctionCall(TermHelperFuncs.PrintTerm, tmp0) nameof_this_func = self.symgen.get('asserttermequal') self.modulebuilder.Function(nameof_this_func).Block(fb) self.main_procedurecalls.append(nameof_this_func) def _codegenReductionCase(self, rc, languagename, reductionrelationname, nameof_domaincheck=None): assert isinstance(rc, tlform.DefineReductionRelation.ReductionCase) if self.context.get_toplevel_function_for_pattern( languagename, repr(rc.pattern)) is None: PatternCodegen(self.modulebuilder, rc.pattern, self.context, languagename, self.symgen).run() TermCodegen(self.modulebuilder, self.context).transform(rc.termtemplate) nameof_matchfn = self.context.get_toplevel_function_for_pattern( languagename, repr(rc.pattern)) nameof_termfn = self.context.get_function_for_term_template( rc.termtemplate) nameof_rc = self.symgen.get('{}_{}_case'.format( languagename, reductionrelationname)) symgen = SymGen() # terms = [] # matches = match(term) # if len(matches) != 0: # for match in matches: # tmp0 = gen_term(match) # tmp2 = match_domain(tmp0) # if len(tmp2) == 0: # raise Exception('reduction-relation {}: term reduced from {} to {} via rule {} and is outside domain') # tmp1 = terms.append(tmp0) # return terms terms, term, matches, match = rpy.gen_pyid_for('terms', 'term', 'matches', 'match') tmp0, tmp1, tmp2 = rpy.gen_pyid_temporaries(3, symgen) forb = rpy.BlockBuilder() forb.AssignTo(tmp0).FunctionCall(nameof_termfn, match) if nameof_domaincheck is not None: ifb = rpy.BlockBuilder() tmpa, tmpb = rpy.gen_pyid_temporaries(2, symgen) ifb.AssignTo(tmpa).MethodCall(term, TermMethodTable.ToString) ifb.AssignTo(tmpb).MethodCall(tmp0, TermMethodTable.ToString) ifb.RaiseException('reduction-relation \\"{}\\": term reduced from %s to %s via rule \\"{}\\" is outside domain' \ .format(reductionrelationname, rc.name), tmpa, tmpb) forb.AssignTo(tmp2).FunctionCall(nameof_domaincheck, tmp0) forb.If.LengthOf(tmp2).Equal(rpy.PyInt(0)).ThenBlock(ifb) forb.AssignTo(tmp1).MethodCall(terms, 'append', tmp0) ifb = rpy.BlockBuilder() ifb.For(match).In(matches).Block(forb) fb = rpy.BlockBuilder() fb.AssignTo(terms).PyList() fb.AssignTo(matches).FunctionCall(nameof_matchfn, term) fb.If.LengthOf(matches).NotEqual(rpy.PyInt(0)).ThenBlock(ifb) fb.Return(terms) self.modulebuilder.Function(nameof_rc).WithParameters(term).Block(fb) return nameof_rc def _visitDefineReductionRelation(self, form): assert isinstance(form, tlform.DefineReductionRelation) # def reduction_relation_name(term): # outterms = [] # {for each case} # tmpi = rc(term) # outterms = outterms + tmp{i} # return outterms if form.domain != None: if self.context.get_toplevel_function_for_pattern( form.languagename, repr(form.domain)) is None: PatternCodegen(self.modulebuilder, form.domain, self.context, form.languagename, self.symgen).run() nameof_domaincheck = None if form.domain != None: nameof_domaincheck = self.context.get_toplevel_function_for_pattern( form.languagename, repr(form.domain)) rcfuncs = [] for rc in form.reductioncases: rcfunc = self._codegenReductionCase(rc, form.languagename, form.name, nameof_domaincheck) rcfuncs.append(rcfunc) terms, term = rpy.gen_pyid_for('terms', 'term') symgen = SymGen() fb = rpy.BlockBuilder() if nameof_domaincheck != None: tmp0 = rpy.gen_pyid_temporaries(1, symgen) ifb = rpy.BlockBuilder() tmpa = rpy.gen_pyid_temporaries(1, symgen) ifb.AssignTo(tmpa).MethodCall(term, TermMethodTable.ToString) ifb.RaiseException('reduction-relation not defined for %s', tmpa) fb.AssignTo(tmp0).FunctionCall(nameof_domaincheck, term) fb.If.LengthOf(tmp0).Equal(rpy.PyInt(0)).ThenBlock(ifb) fb.AssignTo(terms).PyList() for rcfunc in rcfuncs: tmpi = rpy.gen_pyid_temporaries(1, symgen) fb.AssignTo(tmpi).FunctionCall(rcfunc, term) fb.AssignTo(terms).Add(terms, tmpi) fb.Return(terms) nameof_function = '{}_{}'.format(form.languagename, form.name) self.context.add_reduction_relation(form.name, nameof_function) self.modulebuilder.Function(nameof_function).WithParameters( term).Block(fb) return form # This generates call to reduction relation. Used by multiple other tlforms. def _genreductionrelation(self, fb, symgen, nameof_reductionrelation, term): TermCodegen(self.modulebuilder, self.context).transform(term) termfunc = self.context.get_function_for_term_template(term) term, terms = rpy.gen_pyid_for('term', 'terms') tmp0, tmp1 = rpy.gen_pyid_temporaries(2, symgen) fb.AssignTo(tmp0).New('Match') fb.AssignTo(term).FunctionCall(termfunc, tmp0) fb.AssignTo(terms).FunctionCall(nameof_reductionrelation, term) fb.AssignTo(tmp1).FunctionCall(TermHelperFuncs.PrintTermList, terms) return terms def _visitApplyReductionRelationAssertEqual(self, form): assert isinstance(form, tlform.ApplyReductionRelationAssertEqual) def gen_terms(termtemplates, fb, symgen): processed = [] for expectedterm in termtemplates: TermCodegen(self.modulebuilder, self.context).transform(expectedterm) expectedtermfunc = self.context.get_function_for_term_template( expectedterm) tmpi, tmpj = rpy.gen_pyid_temporaries(2, symgen) fb.AssignTo(tmpi).New('Match') fb.AssignTo(tmpj).FunctionCall(expectedtermfunc, tmpi) processed.append(tmpj) tmpi = rpy.gen_pyid_temporaries(1, symgen) fb.AssignTo(tmpi).PyList(*processed) return tmpi nameof_reductionrelation = self.context.get_reduction_relation( form.reductionrelationname) fb = rpy.BlockBuilder() symgen = SymGen() tmp0 = rpy.gen_pyid_temporaries(1, symgen) expectedterms = gen_terms(form.expected_termtemplates, fb, symgen) terms = self._genreductionrelation(fb, symgen, nameof_reductionrelation, form.term) fb.AssignTo(tmp0).FunctionCall(TermHelperFuncs.AssertTermListsEqual, terms, expectedterms) nameof_function = self.symgen.get('applyreductionrelationassertequal') self.modulebuilder.Function(nameof_function).Block(fb) self.main_procedurecalls.append(nameof_function) def _visitApplyReductionRelation(self, form): assert isinstance(form, tlform.ApplyReductionRelation) nameof_reductionrelation = self.context.get_reduction_relation( form.reductionrelationname) assert nameof_reductionrelation != None fb = rpy.BlockBuilder() symgen = SymGen() self._genreductionrelation(fb, symgen, nameof_reductionrelation, form.term) tmp1 = rpy.gen_pyid_temporaries(1, symgen) nameof_function = self.symgen.get('applyreductionrelation') self.modulebuilder.Function(nameof_function).Block(fb) self.modulebuilder.AssignTo(tmp1).FunctionCall(nameof_function) # metafunction case may produce multiple matches but after term plugging all terms # must be the same. def _codegenMetafunctionCase(self, metafunction, case, caseid, mfname): assert isinstance(metafunction, tlform.DefineMetafunction) assert isinstance(case, tlform.DefineMetafunction.MetafunctionCase) #def mfcase(argterm): # tmp0 = matchfunc(argterm) # tmp1 = [] # if len(tmp0) == 0: # return tmp1 # for tmp2 in tmp0: # tmp3 = termfunc(tmp2) # if tmp3 == None: continue # tmp4 = tmp1.append(tmp3) # tmp5 = aretermsequalpairwise(tmp1) # if tmp5 != True: # raise Exception('mfcase 1 matched (term) in len(tmp{i}) ways, single match is expected') # tmp6 = tmp1[0] # return tmp6 if self.context.get_toplevel_function_for_pattern( metafunction.languagename, repr(case.patternsequence)) is None: PatternCodegen(self.modulebuilder, case.patternsequence, self.context, metafunction.languagename, self.symgen).run()
class TopLevelProcessor(tlform.TopLevelFormVisitor): def __init__(self, module, context, debug_dump_ntgraph=False): assert isinstance(module, tlform.Module) assert isinstance(context, CompilationContext) self.module = module self.context = context self.symgen = SymGen() self.debug_dump_ntgraph = debug_dump_ntgraph # store reference to definelanguage structure for use by redex-match form self.definelanguages = {} self.definelanguageclosures = {} self.reductionrelations = {} self.metafunctions = {} def run(self): forms = [] for form in self.module.tlforms: forms.append( self._visit(form) ) return tlform.Module(forms), self.context def _visitDefineLanguage(self, form): assert isinstance(form, tlform.DefineLanguage) form, variables = DefineLanguage_NtRewriter(form, form.ntsyms()).run() self.context.add_variables_mentioned(form.name, variables) form = DefineLanguage_IdRewriter(form).run() successors, closures = DefineLanguage_NtClosureSolver(form).run() DefineLanguage_NtCycleChecker(form, successors).run() graph = NtGraphBuilder(form).run() DefineLanguage_HoleReachabilitySolver(form, graph).run() if self.debug_dump_ntgraph: graph.dump(form.name) print('------ Debug Nt hole counts for language {}: ------'.format(form.name)) for nt, ntdef in form.nts.items(): print('{}: {}'.format(nt, ntdef.nt.getattribute(pattern.PatternAttribute.NumberOfHoles))) print('\n') #form = DefineLanguage_EllipsisMatchModeRewriter(form, closures).run() form = DefineLanguage_AssignableSymbolExtractor(form).run() self.definelanguages[form.name] = form self.definelanguageclosures[form.name] = closures return form def __processpattern(self, pat, languagename): lang = self.definelanguages[languagename] closure = self.definelanguageclosures[languagename] ntsyms = lang.ntsyms() pat = Pattern_NtRewriter(pat, ntsyms).run() pat = Pattern_EllipsisDepthChecker(pat).run() Pattern_InHoleChecker(lang, pat).run() #pat = Pattern_EllipsisMatchModeRewriter(lang, pat, closure).run() pat = Pattern_ConstraintCheckInserter(pat).run() pat = Pattern_AssignableSymbolExtractor(pat).run() return pat def __processdomaincheck(self, pat, languagename): lang = self.definelanguages[languagename] ntsyms = lang.ntsyms() closure = self.definelanguageclosures[languagename] pat = Pattern_NtRewriter(pat, ntsyms).run() pat = Pattern_IdRewriter(pat).run() Pattern_InHoleChecker(lang, pat).run() #pat = Pattern_EllipsisMatchModeRewriter(lang, pat, closure).run() pat = Pattern_AssignableSymbolExtractor(pat).run() return pat def __processtermtemplate(self, termtemplate, assignments={}): idof = self.symgen.get('termtemplate') termtemplate = Term_EllipsisDepthChecker(assignments, idof, self.context).transform(termtemplate) termtemplate = Term_MetafunctionApplicationRewriter(termtemplate, self.metafunctions, self.symgen).run() return termtemplate def _visitRedexMatch(self, form): assert isinstance(form, tlform.RedexMatch) form.pat = self.__processpattern(form.pat, form.languagename) form.termstr = self.__processtermtemplate(form.termstr) return form def _visitRedexMatchAssertEqual(self, form): assert isinstance(form, tlform.RedexMatchAssertEqual) form.pat = self.__processpattern(form.pat, form.languagename) form.termtemplate = self.__processtermtemplate(form.termtemplate) for i, match in enumerate(form.expectedmatches): assert isinstance(match, tlform.RedexMatchAssertEqual.Match) nbindings = [] for (ident, term) in match.bindings: nterm = self.__processtermtemplate(term) nbindings.append((ident, nterm)) match.bindings = nbindings return form def _visitTermLetAssertEqual(self, form): assert isinstance(form, tlform.TermLetAssertEqual) variable_assignments = {} for ident, term in form.variableassignments.items(): form.variableassignments[ident] = self.__processtermtemplate(term) form.template = self.__processtermtemplate(form.template, assignments=form.variabledepths) form.expected = self.__processtermtemplate(form.expected) return form def processReductionCase(self, reductioncase, languagename): assert isinstance(reductioncase, tlform.DefineReductionRelation.ReductionCase) reductioncase.pattern = self.__processpattern(reductioncase.pattern, languagename) assignablesymsdepths = reductioncase.pattern.getattribute(pattern.PatternAttribute.PatternVariableEllipsisDepths) reductioncase.termtemplate = self.__processtermtemplate(reductioncase.termtemplate, assignments=assignablesymsdepths) def _visitDefineReductionRelation(self, form): assert isinstance(form, tlform.DefineReductionRelation) self.reductionrelations[form.name] = form for rc in form.reductioncases: self.processReductionCase(rc, form.languagename) if form.domain != None: form.domain = self.__processdomaincheck(form.domain, form.languagename) return form def _visitApplyReductionRelation(self, form): assert isinstance(form, tlform.ApplyReductionRelation) reductionrelation = self.reductionrelations[form.reductionrelationname] form.term = self.__processtermtemplate(form.term) return form def _visitDefineMetafunction(self, form): assert isinstance(form, tlform.DefineMetafunction) self.metafunctions[form.contract.name] = form form.contract.domain = self.__processdomaincheck(form.contract.domain, form.languagename) form.contract.codomain = self.__processdomaincheck(form.contract.codomain, form.languagename) for i, case in enumerate(form.cases): form.cases[i].patternsequence = self.__processpattern(case.patternsequence, form.languagename) assignablesymsdepths = form.cases[i].patternsequence.getattribute(pattern.PatternAttribute.PatternVariableEllipsisDepths) form.cases[i].termtemplate = self.__processtermtemplate(form.cases[i].termtemplate, assignments=assignablesymsdepths) return form def _visitApplyReductionRelationAssertEqual(self, form): assert isinstance(form, tlform.ApplyReductionRelationAssertEqual) reductionrelation = self.reductionrelations[form.reductionrelationname] form.term = self.__processtermtemplate(form.term) for i, termtemplate in enumerate(form.expected_termtemplates): idof = self.symgen.get('term_assert_term_lists_equal') form.expected_termtemplates[i] = Term_EllipsisDepthChecker({}, idof, self.context).transform(termtemplate) return form def _visitParseAssertEqual(self, form): assert isinstance(form, tlform.ParseAssertEqual) form.expected_termtemplate = self.__processtermtemplate(form.expected_termtemplate) return form def _visitReadFromStdinAndApplyReductionRelation(self, form): assert isinstance(form, tlform.ReadFromStdinAndApplyReductionRelation) if form.metafunctionname != None: try: self.metafunctions[form.metafunctionname] except KeyError: raise CompilationError('read-from-stdin-and-apply-reduction-relation* : unknown metafunction {}'.format(form.metafunctionname)) try: self.reductionrelations[form.reductionrelationname] except KeyError: raise CompilationError('read-from-stdin-and-apply-reduction-relation* : unknown reduction relation {}'.format(form.reductionrelationname)) return form