コード例 #1
0
class DefineLanguage_IdRewriter(pattern.PatternTransformer):
    def __init__(self, definelanguage):
        assert isinstance(definelanguage, tlform.DefineLanguage)
        super().__init__()
        self.definelanguage = definelanguage
        self.symgen = SymGen()

    def run(self):
        ntdefs = []
        for nt, ntdef in self.definelanguage.nts.items():
            npats = []
            for pat in ntdef.patterns:
                npat = self.transform(pat)
                npat.copyattributesfrom(pat)
                npats.append(npat)
            ntdefs.append(tlform.DefineLanguage.NtDefinition(ntdef.nt, npats))
        return tlform.DefineLanguage(self.definelanguage.name, ntdefs)

    def transformBuiltInPat(self, node):
        assert isinstance(node, pattern.BuiltInPat)
        nsym = self.symgen.get(node.prefix+'_')
        return pattern.BuiltInPat(node.kind, node.prefix, nsym).copyattributesfrom(node)

    def transformNt(self, node):
        assert isinstance(node, pattern.Nt)
        nsym = self.symgen.get(node.prefix+'_')
        return pattern.Nt(node.prefix, nsym).copyattributesfrom(node)
コード例 #2
0
class Pattern_ConstraintCheckInserter(pattern.PatternTransformer):
    def __init__(self, pattern):
        self.pattern = pattern
        self.symgen = SymGen()
        self.symstoremove = []

    def run(self):
        pat, _ = self.transform(self.pattern)
        pat.addattribute(pattern.PatternAttribute.PatternVariablesToRemove,
                         self.symstoremove)
        return pat

    def _merge_variable_maps(self, m1, m2):
        m1k = set(list(m1.keys()))
        m2k = set(list(m2.keys()))
        intersection = m1k.intersection(m2k)
        nmap, syms2check = {}, []
        for k in intersection:
            syms2check.append((m1[k], m2[k]))
            nmap[k] = m2[k]
        for k in m1k:
            if k not in intersection:
                nmap[k] = m1[k]
        for k in m2k:
            nmap[k] = m2[k]
        return nmap, syms2check

    def transformInHole(self, node):
        assert isinstance(node, pattern.InHole)
        npat1, syms1 = self.transform(node.pat1)
        npat2, syms2 = self.transform(node.pat2)
        constraintchecks = []
        syms, syms2check = self._merge_variable_maps(syms1, syms2)
        for sym1, sym2 in syms2check:
            constraintchecks.append(pattern.CheckConstraint(sym1, sym2))
        return pattern.InHole(npat1, npat2,
                              constraintchecks).copyattributesfrom(node), syms

    def transformPatSequence(self, node):
        assert isinstance(node, pattern.PatSequence)
        if len(node.seq) == 0:
            return node, {}
        nseq = []
        npat, syms = self.transform(node.seq[0])
        nseq.append(npat)
        for pat in node.seq[1:]:
            npat, nsyms = self.transform(pat)
            nseq.append(npat)
            syms, syms2check = self._merge_variable_maps(syms, nsyms)
            for sym1, sym2 in syms2check:
                nseq.append(pattern.CheckConstraint(sym1, sym2))
        return pattern.PatSequence(nseq).copyattributesfrom(node), syms

    def transformRepeat(self, node):
        assert isinstance(node, pattern.Repeat)
        pat, syms = self.transform(node.pat)
        return pattern.Repeat(pat,
                              node.matchmode).copyattributesfrom(node), syms

    def transformNt(self, node):
        assert isinstance(node, pattern.Nt)
        # First time we see desired symbol we do not rename it - we will keep it in the end.
        nsym = self.symgen.get('{}#'.format(node.sym))
        if nsym == '{}#0'.format(node.sym):
            nsym = node.sym
        else:
            self.symstoremove.append(nsym)
        return pattern.Nt(node.prefix, nsym).copyattributesfrom(node), {
            node.sym: nsym
        }

    def transformBuiltInPat(self, node):
        assert isinstance(node, pattern.BuiltInPat)
        # First time we see desired symbol we do not rename it - we will keep it in the end.
        # Also we never bind holes.
        if node.kind != pattern.BuiltInPatKind.Hole:
            nsym = self.symgen.get('{}#'.format(node.sym))
            if nsym == '{}#0'.format(node.sym):
                nsym = node.sym
            else:
                self.symstoremove.append(nsym)
            return pattern.BuiltInPat(node.kind, node.prefix,
                                      nsym).copyattributesfrom(node), {
                                          node.sym: nsym
                                      }
        return node, {}

    def transformLit(self, pat):
        return pat, {}

    def transformCheckConstraint(self, node):
        return node, {}
コード例 #3
0
class Term_EllipsisDepthChecker(term.TermTransformer):
    def __init__(self, variables, idof, context):
        self.idof = idof
        self.path = []
        self.context = context
        self.variables = variables
        self.symgen = SymGen()

        # stores annotations that will be injected into term-template after
        # visiting all the children.
        self.annotations = {
            term.TermAttribute.MatchRead: {},
            term.TermAttribute.InArg: {},
            term.TermAttribute.ForEach: {},
        }

    def add_annotation_to(self, node, attribute, value):
        attributedict = self.annotations[attribute]
        if node not in attributedict:
            attributedict[node] = []
        attributedict[node].append(value)

    def complete_annotation(self, oldnode, newnode):
        if isinstance(oldnode, term.Repeat):
            attributedict = self.annotations[term.TermAttribute.ForEach]
            contents = attributedict.get(oldnode, [])
            return newnode.addattribute(term.TermAttribute.ForEach, contents)

        attributedict = self.annotations[term.TermAttribute.MatchRead]
        contents = attributedict.get(oldnode, [])
        newnode.addattribute(term.TermAttribute.MatchRead, contents)
        attributedict = self.annotations[term.TermAttribute.InArg]
        contents = attributedict.get(oldnode, [])
        return newnode.addattribute(term.TermAttribute.InArg, contents)

    def contains_nonzero_foreach_annotations(self, node):
        assert isinstance(node, term.Repeat)
        attributedict = self.annotations[term.TermAttribute.ForEach]
        contents = attributedict.get(node, [])
        return len(contents) != 0

    def transform(self, element):
        assert isinstance(element, term.Term)
        method_name = 'transform' + element.__class__.__name__
        method_ref = getattr(self, method_name)
        self.path.append(element)
        result = method_ref(element)
        assert isinstance(result, term.Term)
        self.path.pop()
        return result

    def transformTermLiteral(self, literal):
        assert isinstance(literal, term.TermLiteral)
        #self.context.add_lit_term(literal)
        return literal

    def transformPyCall(self, pycall):
        assert isinstance(pycall, term.PyCall)
        terms = []
        for t in pycall.termargs:
            transformer = Term_EllipsisDepthChecker(self.variables, '',
                                                    self.context)
            terms.append(transformer.transform(t))
        return self.complete_annotation(
            pycall, term.PyCall(pycall.mode, pycall.functionname, terms))

    def transformRepeat(self, repeat):
        assert isinstance(repeat, term.Repeat)
        nrepeat = term.Repeat(self.transform(
            repeat.term)).copyattributesfrom(repeat)
        if not self.contains_nonzero_foreach_annotations(repeat):
            raise Exception('too many ellipses in template {}'.format(
                repr(nrepeat)))
        return self.complete_annotation(repeat, nrepeat)

    def transformTermSequence(self, termsequence):
        ntermsequence = super().transformTermSequence(termsequence)
        return self.complete_annotation(termsequence, ntermsequence)

    def transformInHole(self, inhole):
        ninhole = super().transformInHole(inhole)
        return self.complete_annotation(inhole, ninhole)

    def transformUnresolvedSym(self, node):
        assert isinstance(node, term.UnresolvedSym)
        if node.sym not in self.variables:
            t = term.TermLiteral(term.TermLiteralKind.Variable, node.sym)
            #self.context.add_lit_term(t)
            return t
        expecteddepth = self.variables[node.sym]
        actualdepth = 0

        param = self.symgen.get(node.sym)
        # definitely a pattern variable now, topmost entry on the stack is this node.
        for t in reversed(self.path):
            if isinstance(t, term.UnresolvedSym):
                if expecteddepth == 0:
                    self.add_annotation_to(t, term.TermAttribute.MatchRead,
                                           (node.sym, param))
                    break
                self.add_annotation_to(t, term.TermAttribute.InArg, param)
            if isinstance(t, term.TermSequence) or isinstance(t, term.InHole):
                if expecteddepth == actualdepth:
                    self.add_annotation_to(t, term.TermAttribute.MatchRead,
                                           (node.sym, param))
                    break
                else:
                    self.add_annotation_to(t, term.TermAttribute.InArg, param)
            if isinstance(t, term.Repeat):
                actualdepth += 1
                self.add_annotation_to(t, term.TermAttribute.ForEach,
                                       (param, actualdepth))

        if actualdepth != expecteddepth:
            raise Exception(
                'inconsistent ellipsis depth for pattern variable {}: expected {} actual {}'
                .format(node.sym, expecteddepth, actualdepth))

        return self.complete_annotation(node, term.PatternVariable(node.sym))
コード例 #4
0
class CompilationContext:
    def __init__(self):
        self.__variables_mentioned = {}
        self.__isa_functions = {}
        self.__pattern_code = {}
        self.__term_template_funcs = {}

        self._litterms = {}

        self.__toplevel_patterns = {}

        self.__reductionrelations = {}
        self.__metafuctions = {}

        self.__redexmatches = {}

        self.symgen = SymGen()

    def add_variables_mentioned(self, languagename, variables):
        assert languagename not in self.__variables_mentioned
        self.__variables_mentioned[languagename] = (
            '{}_variables_mentioned'.format(languagename), variables)

    def get_variables_mentioned(self, languagename):
        assert languagename in self.__variables_mentioned
        return self.__variables_mentioned[languagename]

    def get_variables_mentioned_all(self):
        return self.__variables_mentioned.values()

    def add_lit_term(self, term):
        self._litterms[term] = self.symgen.get('literal_term_')

    def add_isa_function_name(self, languagename, patrepr, functionname):
        k = (languagename, patrepr)
        assert k not in self.__isa_functions
        self.__isa_functions[k] = functionname

    def get_isa_function_name(self, languagename, patrepr):
        k = (languagename, patrepr)
        if k in self.__isa_functions:
            return self.__isa_functions[k]
        return None

    def get_sym_for_lit_term(self, term):
        return self._litterms[term]

    def add_function_for_pattern(self, languagename, patrepr, functionname):
        k = (languagename, patrepr)
        assert k not in self.__pattern_code, 'function for {}-{}  is present'.format(
            languagename, patrepr)
        self.__pattern_code[k] = functionname

    def get_function_for_pattern(self, languagename, patrepr):
        k = (languagename, patrepr)
        if k in self.__pattern_code:
            return self.__pattern_code[k]
        return None

    # generated and returns procedurename, boolean for given pattern. If boolean is True - the code for
    # pattern has already been generated.
    def get_function_for_pattern_2(self, languagename, patrepr):
        k = (languagename, patrepr)
        if k in self.__pattern_code:
            return self.__pattern_code[k], True
        self.__pattern_code[k] = self.symgen.get('pattern_match')
        return self.__pattern_code[k], False

    def add_toplevel_function_for_pattern(self, languagename, patrepr,
                                          functionname):
        k = (languagename, patrepr)
        assert k not in self.__toplevel_patterns, 'function for {}-{} is present'.format(
            languagename, patrepr)
        self.__toplevel_patterns[k] = functionname

    def get_toplevel_function_for_pattern(self, languagename, patrepr):
        k = (languagename, patrepr)
        if k in self.__toplevel_patterns:
            return self.__toplevel_patterns[k]
        return None

    def get_function_for_term_template(self, term_template):
        assert isinstance(term_template, term.Term)
        if term_template not in self.__term_template_funcs:
            name = self.symgen.get('gen_term')
            self.__term_template_funcs[term_template] = name
        return self.__term_template_funcs[term_template]

    def add_reduction_relation(self, reductionrelationname, function):
        k = reductionrelationname
        assert k not in self.__reductionrelations
        self.__reductionrelations[k] = function

    def get_reduction_relation(self, reductionrelationname):
        k = reductionrelationname
        if k in self.__reductionrelations:
            return self.__reductionrelations[k]
        return None

    def add_metafunction(self, mfname, function):
        k = mfname
        assert k not in self.__metafuctions
        self.__metafuctions[k] = function

    def get_metafunction(self, mfname):
        k = mfname
        if k in self.__metafuctions:
            return self.__metafuctions[k]
        return None

    def get_redexmatch_for(self, form):
        return self.__redexmatches[form]

    def add_redexmatch_for(self, form, name):
        self.__redexmatches[form] = name
コード例 #5
0
ファイル: tlform.py プロジェクト: mamysa/PyPltRedex
class TopLevelFormCodegen(tlform.TopLevelFormVisitor):
    def __init__(self, module, context):
        assert isinstance(module, tlform.Module)
        assert isinstance(context, CompilationContext)
        self.module = module
        self.context = context
        self.symgen = SymGen()
        self.modulebuilder = rpy.BlockBuilder()

        self.main_procedurecalls = []

    def run(self):
        self.modulebuilder.IncludeFromPythonSource('runtime/term.py')
        self.modulebuilder.IncludeFromPythonSource('runtime/parser.py')
        self.modulebuilder.IncludeFromPythonSource('runtime/fresh.py')
        self.modulebuilder.IncludeFromPythonSource('runtime/match.py')

        # parse all term literals.
        # ~~ 26.07.2020 disable lit terms for now, need to implement
        # nt caching acceleration technique first.
        """
        tmp0, tmp1 = rpy.gen_pyid_temporaries(2, self.symgen)
        for trm, sym1 in self.context._litterms.items():
            sym1 = rpy.gen_pyid_for(sym1)
            self.modulebuilder.AssignTo(tmp0).New('Parser', rpy.PyString(repr(trm)))
            self.modulebuilder.AssignTo(sym1).MethodCall(tmp0, 'parse')
        """

        # variable-not-otherwise-mentioned of given define language
        for ident, variables in self.context.get_variables_mentioned_all():
            ident = rpy.gen_pyid_for(ident)
            variables = map(lambda v: rpy.PyString(v), variables)
            self.modulebuilder.AssignTo(ident).PySet(*variables)

        for form in self.module.tlforms:
            self._visit(form)

        # generate main
        fb = rpy.BlockBuilder()
        symgen = SymGen()
        ## emit some dummy Terms to aid RPython with type inference.
        tmp0, tmp1, tmp2, tmp3, tmp4 = rpy.gen_pyid_temporaries(5, symgen)
        fb.AssignTo(tmp0).New('Integer', rpy.PyInt(0))
        fb.AssignTo(tmp1).New('Float', rpy.PyFloat(0.0))
        fb.AssignTo(tmp2).New('String', rpy.PyString("\"hello world!\""))
        fb.AssignTo(tmp3).New('Boolean', rpy.PyString("#f"))
        fb.AssignTo(tmp4).New('Variable', rpy.PyString("x"))
        for procedure in self.main_procedurecalls:
            tmpi = rpy.gen_pyid_temporaries(1, symgen)
            fb.AssignTo(tmpi).FunctionCall(procedure)
        fb.Return(rpy.PyInt(0))
        self.modulebuilder.Function('entrypoint').WithParameters(
            rpy.PyId('argv')).Block(fb)

        #required entry procedure for Rpython.

        fb = rpy.BlockBuilder()
        fb.Return(rpy.PyTuple(rpy.PyId('entrypoint'), rpy.PyNone()))
        self.modulebuilder.Function('target').WithParameters(
            rpy.PyVarArg('args')).Block(fb)

        # if __name__ == '__main__': entrypoint()
        # for python2.7 compatibility.
        ifb = rpy.BlockBuilder()
        tmp = rpy.gen_pyid_temporaries(1, self.symgen)
        ifb.AssignTo(tmp).FunctionCall('entrypoint', rpy.PyList())
        self.modulebuilder.If.Equal(rpy.PyId('__name__'),
                                    rpy.PyString('__main__')).ThenBlock(ifb)

        return rpy.Module(self.modulebuilder.build())

    def _codegenNtDefinition(self, languagename, ntdef):
        assert isinstance(ntdef, tlform.DefineLanguage.NtDefinition)
        for pat in ntdef.patterns:
            if self.context.get_toplevel_function_for_pattern(
                    languagename, repr(pat)) is None:
                PatternCodegen(self.modulebuilder, pat, self.context,
                               languagename, self.symgen).run()

        nameof_this_func = 'lang_{}_isa_nt_{}'.format(languagename,
                                                      ntdef.nt.prefix)
        term, match, matches = rpy.gen_pyid_for('term', 'match', 'matches')
        # for each pattern in ntdefinition
        # match = Match(...)
        # matches = matchpat(term, match, 0, 1)
        # if len(matches) != 0:
        #   return True
        fb = rpy.BlockBuilder()

        for pat in ntdef.patterns:
            func2call = self.context.get_toplevel_function_for_pattern(
                languagename, repr(pat))

            ifb = rpy.BlockBuilder()
            ifb.Return(rpy.PyBoolean(True))

            fb.AssignTo(matches).FunctionCall(func2call, term)
            fb.If.LengthOf(matches).NotEqual(rpy.PyInt(0)).ThenBlock(ifb)
        fb.Return(rpy.PyBoolean(False))

        self.modulebuilder.Function(nameof_this_func).WithParameters(
            term).Block(fb)

    def _visitDefineLanguage(self, form):
        assert isinstance(form, tlform.DefineLanguage)
        # generate hole for each language.  Need this for term annotation.
        hole = rpy.gen_pyid_for('{}_hole'.format(form.name))
        self.modulebuilder.AssignTo(hole).New('Hole')

        # first insert isa_nt functions intocontext
        for ntsym, ntdef in form.nts.items():
            nameof_this_func = 'lang_{}_isa_nt_{}'.format(form.name, ntsym)
            self.context.add_isa_function_name(form.name, ntdef.nt.prefix,
                                               nameof_this_func)

        for nt in form.nts.values():
            self._codegenNtDefinition(form.name, nt)

    def _visitRequirePythonSource(self, form):
        assert isinstance(form, tlform.RequirePythonSource)
        self.modulebuilder.IncludeFromPythonSource(form.filename)

    def _visitRedexMatch(self, form, callself=True):
        assert isinstance(form, tlform.RedexMatch)
        assert False
        if self.context.get_toplevel_function_for_pattern(
                form.languagename, repr(form.pat)) is None:
            PatternCodegen(self.modulebuilder, form.pat, self.context,
                           form.languagename, self.symgen).run()

        TermCodegen(self.modulebuilder, self.context).transform(form.termstr)
        termfunc = self.context.get_function_for_term_template(form.termstr)

        matchfunc = self.context.get_toplevel_function_for_pattern(
            form.languagename, repr(form.pat))
        symgen = SymGen()

        matches, match, term = rpy.gen_pyid_for('matches', 'match', 'term')
        tmp0 = rpy.gen_pyid_temporaries(1, symgen)

        fb = rpy.BlockBuilder()
        fb.AssignTo(tmp0).New('Match')
        fb.AssignTo(term).FunctionCall(termfunc, tmp0)
        fb.AssignTo(matches).FunctionCall(matchfunc, term)
        fb.Print(matches)
        fb.Return(matches)

        # call redex-match itself.
        nameof_this_func = self.symgen.get('redexmatch')
        self.context.add_redexmatch_for(form, nameof_this_func)
        self.modulebuilder.Function(nameof_this_func).Block(fb)

        if callself:
            tmp0 = rpy.gen_pyid_temporaries(1, self.symgen)
            self.modulebuilder.AssignTo(tmp0).FunctionCall(nameof_this_func)

    def _visitRedexMatchAssertEqual(self, form):
        def gen_matches(expectedmatches, fb, symgen):
            processedmatches = []
            for m in expectedmatches:
                tmp0 = rpy.gen_pyid_temporaries(1, symgen)
                fb.AssignTo(tmp0).New('Match')
                processedmatches.append(tmp0)
                for sym, termx in m.bindings:
                    tmp1, tmp2, tmp3, tmp4 = rpy.gen_pyid_temporaries(
                        4, symgen)
                    TermCodegen(self.modulebuilder,
                                self.context).transform(termx)
                    termfunc = self.context.get_function_for_term_template(
                        termx)
                    fb.AssignTo(tmp1).New('Match')
                    fb.AssignTo(tmp2).FunctionCall(termfunc, tmp1)
                    fb.AssignTo(tmp3).MethodCall(tmp0, MatchMethodTable.AddKey,
                                                 rpy.PyString(sym))
                    fb.AssignTo(tmp4).MethodCall(tmp0,
                                                 MatchMethodTable.AddToBinding,
                                                 rpy.PyString(sym), tmp2)
            tmpi = rpy.gen_pyid_temporaries(1, symgen)
            fb.AssignTo(tmpi).PyList(*processedmatches)
            return tmpi

        assert isinstance(form, tlform.RedexMatchAssertEqual)
        if self.context.get_toplevel_function_for_pattern(
                form.languagename, repr(form.pat)) is None:
            PatternCodegen(self.modulebuilder, form.pat, self.context,
                           form.languagename, self.symgen).run()
        TermCodegen(self.modulebuilder,
                    self.context).transform(form.termtemplate)
        matchfunc = self.context.get_toplevel_function_for_pattern(
            form.languagename, repr(form.pat))
        termfunc = self.context.get_function_for_term_template(
            form.termtemplate)
        symgen = SymGen()

        matches, match, term = rpy.gen_pyid_for('matches', 'match', 'term')
        fb = rpy.BlockBuilder()
        expectedmatches = gen_matches(form.expectedmatches, fb, symgen)
        tmp0, tmp1, tmp2 = rpy.gen_pyid_temporaries(3, symgen)
        fb.AssignTo(tmp0).New('Match')
        fb.AssignTo(term).FunctionCall(termfunc, tmp0)
        fb.AssignTo(matches).FunctionCall(matchfunc, term)
        fb.AssignTo(tmp1).FunctionCall('assert_compare_match_lists', matches,
                                       expectedmatches)
        fb.AssignTo(tmp2).FunctionCall(MatchHelperFuncs.PrintMatchList,
                                       matches)
        fb.Return(matches)

        nameof_this_func = self.symgen.get('redexmatchassertequal')
        self.context.add_redexmatch_for(form, nameof_this_func)
        self.modulebuilder.Function(nameof_this_func).Block(fb)

        self.main_procedurecalls.append(nameof_this_func)

    def _visitTermLetAssertEqual(self, form):
        assert isinstance(form, tlform.TermLetAssertEqual)

        template = form.template
        TermCodegen(self.modulebuilder, self.context).transform(template)
        templatetermfunc = self.context.get_function_for_term_template(
            template)

        TermCodegen(self.modulebuilder, self.context).transform(form.expected)
        expectedtermfunc = self.context.get_function_for_term_template(
            form.expected)

        fb = rpy.BlockBuilder()
        symgen = SymGen()

        expected, match = rpy.gen_pyid_for('expected', 'match')
        tmp0 = rpy.gen_pyid_temporaries(1, symgen)

        fb.AssignTo(tmp0).New('Match')
        fb.AssignTo(expected).FunctionCall(expectedtermfunc, tmp0)

        fb.AssignTo(match).New('Match')
        for variable, term in form.variableassignments.items():
            tmp1, tmp2, tmp3, tmp4 = rpy.gen_pyid_temporaries(4, symgen)

            TermCodegen(self.modulebuilder, self.context).transform(term)
            termfunc = self.context.get_function_for_term_template(term)

            fb.AssignTo(tmp1).New('Match')
            fb.AssignTo(tmp2).FunctionCall(termfunc, tmp1)
            fb.AssignTo(tmp3).MethodCall(match, MatchMethodTable.AddKey,
                                         rpy.PyString(variable))
            fb.AssignTo(tmp4).MethodCall(match, MatchMethodTable.AddToBinding,
                                         rpy.PyString(variable), tmp2)

        tmp0, tmp1, tmp2 = rpy.gen_pyid_temporaries(3, symgen)
        fb.AssignTo(tmp0).FunctionCall(templatetermfunc, match)
        fb.AssignTo(tmp1).FunctionCall('asserttermsequal', tmp0, expected)
        fb.AssignTo(tmp2).FunctionCall(TermHelperFuncs.PrintTerm, tmp0)

        nameof_this_func = self.symgen.get('asserttermequal')
        self.modulebuilder.Function(nameof_this_func).Block(fb)
        self.main_procedurecalls.append(nameof_this_func)

    def _codegenReductionCase(self,
                              rc,
                              languagename,
                              reductionrelationname,
                              nameof_domaincheck=None):
        assert isinstance(rc, tlform.DefineReductionRelation.ReductionCase)

        if self.context.get_toplevel_function_for_pattern(
                languagename, repr(rc.pattern)) is None:
            PatternCodegen(self.modulebuilder, rc.pattern, self.context,
                           languagename, self.symgen).run()
        TermCodegen(self.modulebuilder,
                    self.context).transform(rc.termtemplate)

        nameof_matchfn = self.context.get_toplevel_function_for_pattern(
            languagename, repr(rc.pattern))
        nameof_termfn = self.context.get_function_for_term_template(
            rc.termtemplate)

        nameof_rc = self.symgen.get('{}_{}_case'.format(
            languagename, reductionrelationname))

        symgen = SymGen()
        # terms = []
        # matches = match(term)
        # if len(matches) != 0:
        #   for match in matches:
        #     tmp0 = gen_term(match)
        #     tmp2 = match_domain(tmp0)
        #     if len(tmp2) == 0:
        #       raise Exception('reduction-relation {}: term reduced from {} to {} via rule {} and is outside domain')
        #     tmp1 = terms.append(tmp0)
        # return terms

        terms, term, matches, match = rpy.gen_pyid_for('terms', 'term',
                                                       'matches', 'match')
        tmp0, tmp1, tmp2 = rpy.gen_pyid_temporaries(3, symgen)

        forb = rpy.BlockBuilder()
        forb.AssignTo(tmp0).FunctionCall(nameof_termfn, match)
        if nameof_domaincheck is not None:
            ifb = rpy.BlockBuilder()
            tmpa, tmpb = rpy.gen_pyid_temporaries(2, symgen)
            ifb.AssignTo(tmpa).MethodCall(term, TermMethodTable.ToString)
            ifb.AssignTo(tmpb).MethodCall(tmp0, TermMethodTable.ToString)
            ifb.RaiseException('reduction-relation \\"{}\\": term reduced from %s to %s via rule \\"{}\\" is outside domain' \
                    .format(reductionrelationname, rc.name),
                    tmpa, tmpb)
            forb.AssignTo(tmp2).FunctionCall(nameof_domaincheck, tmp0)
            forb.If.LengthOf(tmp2).Equal(rpy.PyInt(0)).ThenBlock(ifb)
        forb.AssignTo(tmp1).MethodCall(terms, 'append', tmp0)

        ifb = rpy.BlockBuilder()
        ifb.For(match).In(matches).Block(forb)

        fb = rpy.BlockBuilder()
        fb.AssignTo(terms).PyList()
        fb.AssignTo(matches).FunctionCall(nameof_matchfn, term)
        fb.If.LengthOf(matches).NotEqual(rpy.PyInt(0)).ThenBlock(ifb)
        fb.Return(terms)

        self.modulebuilder.Function(nameof_rc).WithParameters(term).Block(fb)
        return nameof_rc

    def _visitDefineReductionRelation(self, form):
        assert isinstance(form, tlform.DefineReductionRelation)
        # def reduction_relation_name(term):
        #   outterms = []
        # {for each case}
        # tmpi = rc(term)
        # outterms = outterms + tmp{i}
        # return outterms
        if form.domain != None:
            if self.context.get_toplevel_function_for_pattern(
                    form.languagename, repr(form.domain)) is None:
                PatternCodegen(self.modulebuilder, form.domain, self.context,
                               form.languagename, self.symgen).run()

        nameof_domaincheck = None
        if form.domain != None:
            nameof_domaincheck = self.context.get_toplevel_function_for_pattern(
                form.languagename, repr(form.domain))

        rcfuncs = []
        for rc in form.reductioncases:
            rcfunc = self._codegenReductionCase(rc, form.languagename,
                                                form.name, nameof_domaincheck)
            rcfuncs.append(rcfunc)

        terms, term = rpy.gen_pyid_for('terms', 'term')
        symgen = SymGen()

        fb = rpy.BlockBuilder()

        if nameof_domaincheck != None:
            tmp0 = rpy.gen_pyid_temporaries(1, symgen)
            ifb = rpy.BlockBuilder()
            tmpa = rpy.gen_pyid_temporaries(1, symgen)
            ifb.AssignTo(tmpa).MethodCall(term, TermMethodTable.ToString)
            ifb.RaiseException('reduction-relation not defined for %s', tmpa)

            fb.AssignTo(tmp0).FunctionCall(nameof_domaincheck, term)
            fb.If.LengthOf(tmp0).Equal(rpy.PyInt(0)).ThenBlock(ifb)

        fb.AssignTo(terms).PyList()
        for rcfunc in rcfuncs:
            tmpi = rpy.gen_pyid_temporaries(1, symgen)
            fb.AssignTo(tmpi).FunctionCall(rcfunc, term)
            fb.AssignTo(terms).Add(terms, tmpi)
        fb.Return(terms)

        nameof_function = '{}_{}'.format(form.languagename, form.name)
        self.context.add_reduction_relation(form.name, nameof_function)
        self.modulebuilder.Function(nameof_function).WithParameters(
            term).Block(fb)
        return form

    # This generates call to reduction relation. Used by multiple other tlforms.
    def _genreductionrelation(self, fb, symgen, nameof_reductionrelation,
                              term):
        TermCodegen(self.modulebuilder, self.context).transform(term)
        termfunc = self.context.get_function_for_term_template(term)

        term, terms = rpy.gen_pyid_for('term', 'terms')

        tmp0, tmp1 = rpy.gen_pyid_temporaries(2, symgen)
        fb.AssignTo(tmp0).New('Match')
        fb.AssignTo(term).FunctionCall(termfunc, tmp0)
        fb.AssignTo(terms).FunctionCall(nameof_reductionrelation, term)
        fb.AssignTo(tmp1).FunctionCall(TermHelperFuncs.PrintTermList, terms)
        return terms

    def _visitApplyReductionRelationAssertEqual(self, form):
        assert isinstance(form, tlform.ApplyReductionRelationAssertEqual)

        def gen_terms(termtemplates, fb, symgen):
            processed = []
            for expectedterm in termtemplates:
                TermCodegen(self.modulebuilder,
                            self.context).transform(expectedterm)
                expectedtermfunc = self.context.get_function_for_term_template(
                    expectedterm)
                tmpi, tmpj = rpy.gen_pyid_temporaries(2, symgen)
                fb.AssignTo(tmpi).New('Match')
                fb.AssignTo(tmpj).FunctionCall(expectedtermfunc, tmpi)
                processed.append(tmpj)
            tmpi = rpy.gen_pyid_temporaries(1, symgen)
            fb.AssignTo(tmpi).PyList(*processed)
            return tmpi

        nameof_reductionrelation = self.context.get_reduction_relation(
            form.reductionrelationname)

        fb = rpy.BlockBuilder()
        symgen = SymGen()
        tmp0 = rpy.gen_pyid_temporaries(1, symgen)
        expectedterms = gen_terms(form.expected_termtemplates, fb, symgen)
        terms = self._genreductionrelation(fb, symgen,
                                           nameof_reductionrelation, form.term)
        fb.AssignTo(tmp0).FunctionCall(TermHelperFuncs.AssertTermListsEqual,
                                       terms, expectedterms)

        nameof_function = self.symgen.get('applyreductionrelationassertequal')
        self.modulebuilder.Function(nameof_function).Block(fb)
        self.main_procedurecalls.append(nameof_function)

    def _visitApplyReductionRelation(self, form):
        assert isinstance(form, tlform.ApplyReductionRelation)
        nameof_reductionrelation = self.context.get_reduction_relation(
            form.reductionrelationname)
        assert nameof_reductionrelation != None

        fb = rpy.BlockBuilder()
        symgen = SymGen()
        self._genreductionrelation(fb, symgen, nameof_reductionrelation,
                                   form.term)
        tmp1 = rpy.gen_pyid_temporaries(1, symgen)
        nameof_function = self.symgen.get('applyreductionrelation')
        self.modulebuilder.Function(nameof_function).Block(fb)
        self.modulebuilder.AssignTo(tmp1).FunctionCall(nameof_function)

    # metafunction case may produce multiple matches but after term plugging all terms
    # must be the same.
    def _codegenMetafunctionCase(self, metafunction, case, caseid, mfname):
        assert isinstance(metafunction, tlform.DefineMetafunction)
        assert isinstance(case, tlform.DefineMetafunction.MetafunctionCase)
        #def mfcase(argterm):
        #  tmp0 = matchfunc(argterm)
        #  tmp1 = []
        #  if len(tmp0) == 0:
        #    return tmp1
        #  for tmp2 in tmp0:
        #    tmp3 = termfunc(tmp2)
        #    if tmp3 == None: continue
        #    tmp4 = tmp1.append(tmp3)
        #  tmp5 = aretermsequalpairwise(tmp1)
        #  if tmp5 != True:
        #    raise Exception('mfcase 1 matched (term) in len(tmp{i}) ways, single match is expected')
        #  tmp6 = tmp1[0]
        #  return tmp6
        if self.context.get_toplevel_function_for_pattern(
                metafunction.languagename, repr(case.patternsequence)) is None:
            PatternCodegen(self.modulebuilder, case.patternsequence,
                           self.context, metafunction.languagename,
                           self.symgen).run()
コード例 #6
0
ファイル: tlform.py プロジェクト: mamysa/PyPltRedex
class TopLevelProcessor(tlform.TopLevelFormVisitor):
    def __init__(self, module, context, debug_dump_ntgraph=False):
        assert isinstance(module, tlform.Module)
        assert isinstance(context, CompilationContext)
        self.module = module
        self.context = context
        self.symgen = SymGen() 

        self.debug_dump_ntgraph = debug_dump_ntgraph

        # store reference to definelanguage structure for use by redex-match form
        self.definelanguages = {}
        self.definelanguageclosures = {}
        self.reductionrelations = {}
        self.metafunctions = {}

    def run(self):
        forms = []
        for form in self.module.tlforms:
            forms.append( self._visit(form) )
        return tlform.Module(forms), self.context

    def _visitDefineLanguage(self, form):
        assert isinstance(form, tlform.DefineLanguage)
        form, variables = DefineLanguage_NtRewriter(form, form.ntsyms()).run()
        self.context.add_variables_mentioned(form.name, variables)
        form = DefineLanguage_IdRewriter(form).run()
        successors, closures = DefineLanguage_NtClosureSolver(form).run()
        DefineLanguage_NtCycleChecker(form, successors).run()

        graph = NtGraphBuilder(form).run()
        DefineLanguage_HoleReachabilitySolver(form, graph).run()
        if self.debug_dump_ntgraph:
            graph.dump(form.name)
            print('------ Debug Nt hole counts for language {}: ------'.format(form.name))
            for nt, ntdef in form.nts.items():
                print('{}: {}'.format(nt, ntdef.nt.getattribute(pattern.PatternAttribute.NumberOfHoles)))
            print('\n') 
        #form = DefineLanguage_EllipsisMatchModeRewriter(form, closures).run()
        form = DefineLanguage_AssignableSymbolExtractor(form).run()
        self.definelanguages[form.name] = form 
        self.definelanguageclosures[form.name] = closures
        return form

    def __processpattern(self, pat, languagename):
        lang = self.definelanguages[languagename]
        closure = self.definelanguageclosures[languagename]
        ntsyms = lang.ntsyms()
        pat = Pattern_NtRewriter(pat, ntsyms).run()
        pat = Pattern_EllipsisDepthChecker(pat).run()
        Pattern_InHoleChecker(lang, pat).run()
        #pat = Pattern_EllipsisMatchModeRewriter(lang, pat, closure).run()
        pat = Pattern_ConstraintCheckInserter(pat).run()
        pat = Pattern_AssignableSymbolExtractor(pat).run()
        return pat

    def __processdomaincheck(self, pat, languagename):
        lang = self.definelanguages[languagename]
        ntsyms = lang.ntsyms()
        closure = self.definelanguageclosures[languagename]
        pat = Pattern_NtRewriter(pat, ntsyms).run()
        pat = Pattern_IdRewriter(pat).run()
        Pattern_InHoleChecker(lang, pat).run()
        #pat = Pattern_EllipsisMatchModeRewriter(lang, pat, closure).run()
        pat = Pattern_AssignableSymbolExtractor(pat).run()
        return pat

    def __processtermtemplate(self, termtemplate, assignments={}):
        idof = self.symgen.get('termtemplate')
        termtemplate = Term_EllipsisDepthChecker(assignments, idof, self.context).transform(termtemplate)
        termtemplate = Term_MetafunctionApplicationRewriter(termtemplate, self.metafunctions, self.symgen).run()
        return termtemplate
    
    def _visitRedexMatch(self, form):
        assert isinstance(form, tlform.RedexMatch)
        form.pat = self.__processpattern(form.pat, form.languagename)
        form.termstr = self.__processtermtemplate(form.termstr)
        return form

    def _visitRedexMatchAssertEqual(self, form):
        assert isinstance(form, tlform.RedexMatchAssertEqual)
        form.pat = self.__processpattern(form.pat, form.languagename)
        form.termtemplate = self.__processtermtemplate(form.termtemplate)
        for i, match in enumerate(form.expectedmatches):
            assert isinstance(match, tlform.RedexMatchAssertEqual.Match)
            nbindings = []
            for (ident, term) in match.bindings:
                nterm = self.__processtermtemplate(term)
                nbindings.append((ident, nterm))
            match.bindings = nbindings
        return form

    def _visitTermLetAssertEqual(self, form):
        assert isinstance(form, tlform.TermLetAssertEqual)
        variable_assignments = {}
        for ident, term in form.variableassignments.items(): 
            form.variableassignments[ident] = self.__processtermtemplate(term)
            
        form.template = self.__processtermtemplate(form.template, assignments=form.variabledepths)
        form.expected = self.__processtermtemplate(form.expected)
        return form

    def processReductionCase(self, reductioncase, languagename):
        assert isinstance(reductioncase, tlform.DefineReductionRelation.ReductionCase)
        reductioncase.pattern = self.__processpattern(reductioncase.pattern, languagename)
        assignablesymsdepths = reductioncase.pattern.getattribute(pattern.PatternAttribute.PatternVariableEllipsisDepths)
        reductioncase.termtemplate = self.__processtermtemplate(reductioncase.termtemplate, assignments=assignablesymsdepths)

    def _visitDefineReductionRelation(self, form):
        assert isinstance(form, tlform.DefineReductionRelation)
        self.reductionrelations[form.name] = form
        for rc in form.reductioncases:
            self.processReductionCase(rc, form.languagename)
        if form.domain != None:
            form.domain = self.__processdomaincheck(form.domain, form.languagename)
        return form

    def _visitApplyReductionRelation(self, form):
        assert isinstance(form, tlform.ApplyReductionRelation)
        reductionrelation = self.reductionrelations[form.reductionrelationname]
        form.term = self.__processtermtemplate(form.term)
        return form

    def _visitDefineMetafunction(self, form):
        assert isinstance(form, tlform.DefineMetafunction)
        self.metafunctions[form.contract.name] = form

        form.contract.domain = self.__processdomaincheck(form.contract.domain, form.languagename)
        form.contract.codomain = self.__processdomaincheck(form.contract.codomain, form.languagename)

        for i, case in enumerate(form.cases):
            form.cases[i].patternsequence = self.__processpattern(case.patternsequence, form.languagename)
            assignablesymsdepths = form.cases[i].patternsequence.getattribute(pattern.PatternAttribute.PatternVariableEllipsisDepths)
            form.cases[i].termtemplate = self.__processtermtemplate(form.cases[i].termtemplate, assignments=assignablesymsdepths)
        return form

    def _visitApplyReductionRelationAssertEqual(self, form):
        assert isinstance(form, tlform.ApplyReductionRelationAssertEqual)
        reductionrelation = self.reductionrelations[form.reductionrelationname]
        form.term = self.__processtermtemplate(form.term)
        for i, termtemplate in enumerate(form.expected_termtemplates):
            idof = self.symgen.get('term_assert_term_lists_equal')
            form.expected_termtemplates[i] = Term_EllipsisDepthChecker({}, idof, self.context).transform(termtemplate) 
        return form

    def _visitParseAssertEqual(self, form):
        assert isinstance(form, tlform.ParseAssertEqual)
        form.expected_termtemplate = self.__processtermtemplate(form.expected_termtemplate)
        return form

    def _visitReadFromStdinAndApplyReductionRelation(self, form):
        assert isinstance(form, tlform.ReadFromStdinAndApplyReductionRelation)
        if form.metafunctionname != None:
            try:
                self.metafunctions[form.metafunctionname]
            except KeyError:
                raise CompilationError('read-from-stdin-and-apply-reduction-relation* : unknown metafunction {}'.format(form.metafunctionname))
        try:
            self.reductionrelations[form.reductionrelationname]
        except KeyError:
            raise CompilationError('read-from-stdin-and-apply-reduction-relation* : unknown reduction relation {}'.format(form.reductionrelationname))
        return form