def normalize_minus( expr ):
    ld = PatternDot("ld")
    md = PatternDot("md")
    l = PatternStar("l")
    r = PatternStar("r")
    return replace_all( expr, [RewriteRule( Times([ ld, l, Minus([md]), r ]),
                                            Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"], d["md"], d["r"]]) )),
                               RewriteRule( Times([ Minus([Minus([ld])]), l ]),
                                            Replacement( lambda d: Times([ d["ld"], d["l"]]) )),
                               RewriteRule( Minus([ Times([ ld, l ]) ]),
                                            Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"]]) )) ] )
Beispiel #2
0
 def expr_to_rule_rhs_lhs(self, predicates):
     rules = []
     t = PatternStar("t")
     l = PatternStar("l")
     ld = PatternDot("ld")
     r = PatternStar("r")
     for p in predicates:
         pr = []
         lhs, rhs = p.children
         if len(lhs.children) == 1:
             #lhs_sym = WrapOutBef( lhs.children[0] )
             lhs_sym = lhs.children[0]
             if isinstance(rhs, Plus):
                 # t___ + rhs -> t + lhs
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children +
                                                      [lhs]))(lhs_sym)
                 pr.append(
                     RewriteRule(Plus([t] + rhs.children),
                                 Replacement(repl_f)))
                 # t___ + l___ rhs_i r___ + ... -> t + l lhs r
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times(d["l"].children + [lhs] + d["r"].children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             Times([l] + [ch] + [r]) for ch in rhs.children
                         ]), Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([simplify(to_canonical(Minus([lhs])))] + d["r"].
                           children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             Times([simplify(to_canonical(Minus([ch])))] +
                                   [r]) for ch in rhs.children
                         ]), Replacement(repl_f)))
                 # A - B C in  L B C R + -L A R  (minus pushed all the way to the left, and whole thing negated)
                 repl_f = (lambda lhs: lambda d: normalize_minus(
                     Plus(d["t"].children + [
                         Times([d["ld"]] + d["l"].children + [Minus([lhs])]
                               + d["r"].children)
                     ])))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             normalize_minus(Times([ld, l,
                                                    Minus([ch]), r]))
                             for ch in rhs.children
                         ]), Replacement(repl_f)))
                 # A - B C in  -L B C R + L A R  (minus pushed all the way to the left)
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([d["ld"]] + d["l"].children + [lhs] + d["r"].
                           children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             normalize_minus(Times([ld, l, ch, r]))
                             for ch in rhs.children
                         ]), Replacement(repl_f)))
                 #repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [Times([simplify(to_canonical(Minus(lhs.children)))] + d["r"].children)]))(lhs_sym)
                 #pr.append( RewriteRule( Plus([t] + [
                 #Times([ Minus([ld]), l, ch, r]) if not isinstance(ch, Minus) \
                 #else Times([ l, ch.children[0], r]) \
                 #for ch in rhs.children ]),
                 #Replacement(repl_f) ) )
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([
                         simplify(to_canonical(Minus([Transpose([lhs])])))
                     ] + d["r"].children)
                 ]))(lhs_sym)
                 pr.append( RewriteRule( Plus([t] + [
                             Times([ Minus([ld]), l, simplify(to_canonical(Transpose([ch]))), r]) if not isinstance(ch, Minus) \
                                     else Times([ l, simplify(to_canonical(Transpose([ch]))), r]) \
                                     for ch in rhs.children ]),
                                 Replacement(repl_f) ) )
             elif isinstance(rhs, Times):
                 repl_f = (lambda lhs: lambda d: Times(d[
                     "l"].children + [lhs] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(Times([l] + rhs.children + [r]),
                                 Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Times(d[
                     "l"].children + [Transpose([lhs])] + d["r"].children)
                           )(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             l,
                             simplify(to_canonical(Transpose([rhs]))), r
                         ]), Replacement(repl_f)))
                 # [TODO] Minus is a b*tch. Should go for -1 and remove the operator internally?
                 repl_f = (lambda lhs: lambda d: Times([
                     simplify(to_canonical(Minus([Times([lhs])])))
                 ] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             simplify(
                                 to_canonical(
                                     Minus([Times(rhs.get_children())])))
                         ] + [r]), Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Times([
                     simplify(
                         to_canonical(Minus([Transpose([Times([lhs])])])))
                 ] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             simplify(
                                 to_canonical(
                                     Minus([
                                         Transpose(
                                             [Times(rhs.get_children())])
                                     ])))
                         ] + [r]), Replacement(repl_f)))
             else:
                 pr.append(RewriteRule(rhs, Replacement(lhs_sym)))
                 new_rhs = simplify(to_canonical(Transpose([rhs])))
                 if not isOperand(new_rhs):
                     pr.append(
                         RewriteRule(
                             simplify(to_canonical(Transpose([rhs]))),
                             Replacement(Transpose([lhs_sym]))))
         else:
             pr.append(RewriteRule(rhs, Replacement(lhs)))
         rules.append(pr)
     return rules
Beispiel #3
0
    def find_updates_v2(self, before, after):
        # If a part is (partially) computed in the before and
        #   does not appear in the after or
        # going from before to after requires undoing some computation
        # it is potentially unstable, and more expensive: ignore
        dict_bef = dict([(str(u.get_children()[0]), u) for u in before])
        dict_aft = dict([(str(u.get_children()[0]), u) for u in after])
        ignore = False
        quadrant = None
        for k, v in dict_bef.items():
            if k not in dict_aft:
                ignore = True
                break
            else:
                rules = self.expr_to_rule_rhs_lhs([v])
                rules = list(itertools.chain(*rules))
                expr_copy = copy.deepcopy(dict_aft[k])
                t = replace(expr_copy, rules)
                #if v == replace( expr_copy, rules ):
                if dict_aft[k] == t:
                    ignore = True
                    break
        if ignore:
            print("[INFO] Skipping invariant: %s" % reason)
            return None
        #
        # Wrap outputs for before and after
        WrapBefOut = WrapOutBef
        for u in before:
            u.children[0] = NList([WrapBefOut(l) for l in u.children[0]])
        #
        wrap_rules_after = []
        for u in after:
            u.children[0] = NList([WrapOutAft(l) for l in u.children[0]])
        # replace before in after
        wrap_rules_before = []
        for u in before:
            print(u)
            lhs, rhs = u.get_children()
            if len(lhs.children) > 1:
                continue
            rules = self.expr_to_rule_rhs_lhs([u])
            wrap_rules_before.append(list(itertools.chain(*rules)))
        #
        for i, rule in enumerate(reversed(wrap_rules_before)):
            idx = len(wrap_rules_before) - i - 1
            for j in range(idx - 1, -1, -1):
                for _rule in rule:
                    _rule.pattern = replace_all(_rule.pattern,
                                                wrap_rules_before[j])
        wrap_rules_before = list(itertools.chain(*wrap_rules_before))
        #
        for u in after:
            _, rhs = u.get_children()
            u.children[1] = simplify(
                to_canonical(replace_all(copy.deepcopy(rhs),
                                         wrap_rules_before)))
        # replace after in after
        done = False
        while not done:
            # replace after in after
            wrap_rules_after = []
            for u in after:
                lhs, rhs = u.get_children()
                if len(lhs.children) > 1:
                    wrap_rules_after.append([])
                    continue
                rules = self.expr_to_rule_rhs_lhs([u])
                wrap_rules_after.append(list(itertools.chain(*rules)))
            #
            after_top = [copy.deepcopy(u) for u in after]
            for i, u in enumerate(after):
                _, rhs = u.get_children()
                rules = list(
                    itertools.chain.from_iterable(wrap_rules_after[:i] +
                                                  wrap_rules_after[i + 1:]))
                u.children[1] = simplify(
                    to_canonical(replace_all(copy.deepcopy(rhs), rules)))
            done = True
            for top, bot in zip(after_top, after):
                if top != bot:
                    done = False
                    break
        # [TODO] Multiple lhss, won't work
        updates = []
        for u in after:
            lhs, rhs = u.get_children()
            lhs = lhs.children[0]  # NList[op] -> op
            if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \
                    matchq(lhs.children[0], rhs.children[0]):
                continue
            updates.append(u)
        #
        tiled_updates = []
        for u in updates:
            print("*   ", u)
            tilings = list(tile_expr(u))
            if len(tilings) > 1:
                print("[WARNING] Multiple (%d) tilings for expression %s" %
                      (len(tilings), u))
                print("          Discarding all but one")
            tiled_updates.extend(tilings[0])
        tiled_updates = sort(tiled_updates)
        print("* Tiled update")
        for t in tiled_updates:
            print("*   ", t)

        # Drop WrapOutBef's
        # Drop WrapOutAft's
        s = PatternDot("s")
        updates = []
        for u in tiled_updates:
            u = replace_all(
                u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))])
            u = replace_all(
                u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))])
            updates.append(u)

        return updates
Beispiel #4
0
    def find_updates(self, before, after):
        # If a part is (partially) computed in the before and
        #   does not appear in the after or
        # going from before to after requires undoing some computation
        # it is potentially unstable, and more expensive: ignore
        try:
            before_finputs = self.express_in_terms_of_input(before)
            after_finputs = self.express_in_terms_of_input(after)
        except:
            # [TODO] In LU's variant 5, parts of A appear as lhs's
            return None
        #
        dict_bef = dict([(str(u.get_children()[0]), u)
                         for u in before_finputs])
        dict_aft = dict([(str(u.get_children()[0]), u) for u in after_finputs])
        same = []
        ignore = False
        for k, v in dict_bef.items():
            if k in dict_aft and matchq(v, dict_aft[k]):
                same.extend(v.children[0].children)
            if k not in dict_aft:
                ignore = True
                reason = "%s not in %s" % (k, dict_aft.keys())
                break
            else:
                rules = self.expr_to_rule_rhs_lhs([v])
                rules = list(itertools.chain(*rules))
                expr_copy = copy.deepcopy(dict_aft[k])
                t = replace(expr_copy, rules)
                #if v == replace( expr_copy, rules ):
                if dict_aft[k] == t:
                    ignore = True
                    reason = "%s would require undoing job" % k
                    break
        if ignore:
            print("[INFO] Skipping invariant: %s" % reason)
            return None
        #
        # Wrap outputs for before and after
        WrapBefOut = WrapOutBef
        lhss = []
        for u in before:
            lhss.extend(u.children[0])
            u.children[0] = NList([WrapBefOut(l) for l in u.children[0]])
        for u in before:
            u.children[1] = replace(
                u.children[1],
                [RewriteRule(l, Replacement(WrapBefOut(l))) for l in lhss])
        #
        lhss = []
        for u in after:
            lhss.extend(u.children[0])
            u.children[0] = NList([WrapOutAft(l) for l in u.children[0]])
        wrap_rules_after = \
                [
                    RewriteRule(l, Replacement(WrapBefOut(l))) if l in same else
                    RewriteRule(l, Replacement(WrapOutAft(l))) for l in lhss
                ]
        for u in after:
            u.children[1] = replace(u.children[1], wrap_rules_after)
        # replace before in before
        wrap_rules_before = []
        for u in before:
            lhs, rhs = u.get_children()
            #if len(lhs.children) > 1:
            #wrap_rules_before.append([])
            #continue
            rules = self.expr_to_rule_rhs_lhs([u])
            wrap_rules_before.append(list(itertools.chain(*rules)))
        #
        new_rules = []
        for i, rules in enumerate(wrap_rules_before):
            new_rules.append([])
            for rule in rules:
                new_r = copy.deepcopy(rule)
                new_r.pattern = replace_all(
                    new_r.pattern,
                    list(
                        itertools.chain.from_iterable(wrap_rules_before[:i] +
                                                      wrap_rules_before[i +
                                                                        1:])))
                if new_r.pattern != rule.pattern:
                    new_rules[-1].append(new_r)
        for r1, r2 in zip(new_rules, wrap_rules_before):
            r2.extend(r1)
        #
        wrap_rules_before = list(itertools.chain(*wrap_rules_before))
        done = False
        while not done:
            after_top = [copy.deepcopy(u) for u in after]
            for i, u in enumerate(after):
                _, rhs = u.get_children()
                u.children[1] = simplify(
                    to_canonical(
                        replace_all(copy.deepcopy(rhs), wrap_rules_before)))
            done = True
            for top, bot in zip(after_top, after):
                if top != bot:
                    done = False
                    break
        # replace after in after
        done = False
        while not done:
            # replace after in after
            wrap_rules_after = []
            for u in after:
                lhs, rhs = u.get_children()
                #if len(lhs.children) > 1:
                #wrap_rules_after.append([])
                #continue
                rules = self.expr_to_rule_rhs_lhs([u])
                wrap_rules_after.append(list(itertools.chain(*rules)))
            #
            after_top = [copy.deepcopy(u) for u in after]
            for i, u in enumerate(after):
                _, rhs = u.get_children()
                rules = list(
                    itertools.chain.from_iterable(wrap_rules_after[:i] +
                                                  wrap_rules_after[i + 1:]))
                u.children[1] = simplify(
                    to_canonical(replace_all(copy.deepcopy(rhs), rules)))
            done = True
            for top, bot in zip(after_top, after):
                if top != bot:
                    done = False
                    break
        # [TODO] Multiple lhss, won't work
        updates = []
        for u in after:
            lhs, rhs = u.get_children()
            if len(lhs.children) == 1:
                lhs = lhs.children[0]  # NList[op] -> op
                if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \
                        matchq(lhs.children[0], rhs.children[0]):
                    continue
            elif not isinstance(rhs,
                                NList):  # multiple outputs/predicate in rhs,
                # but not complete (otherwise it would be NList)
                pass
            else:
                to_skip = True
                for l, r in zip(lhs.children, rhs.children):
                    if not( isinstance(r, WrapBefOut) and isinstance(l, WrapOutAft) and \
                            matchq(l.children[0], r.children[0]) ):
                        to_skip = False
                        break
                if to_skip:
                    continue
            updates.append(u)
        #
        tiled_updates = []
        for u in updates:
            print("*   ", u)
            tilings = list(tile_expr(u))
            if len(tilings) > 1:
                print("[WARNING] Multiple (%d) tilings for expression %s" %
                      (len(tilings), u))
                print("          Discarding all but one")
            tiled_updates.extend(tilings[0])
        tiled_updates = sort(tiled_updates)
        print("* Tiled update")
        for t in tiled_updates:
            print("*   ", t)

        # Drop WrapOutBef's
        # Drop WrapOutAft's
        s = PatternDot("s")
        updates = []
        for u in tiled_updates:
            u = replace_all(
                u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))])
            u = replace_all(
                u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))])
            updates.append(u)

        return updates
Beispiel #5
0
    def generate_loop_based_algorithms(self):
        print("* Generating Loop-based algorithms...")
        self.algs = []
        variant = 1
        for pme, linvs in zip(self.pmes, self.linvs):
            algs = []
            for linv in linvs:
                print("* ")
                print("* Loop invariant", variant)
                for expr in linv.expressions:
                    print("*     ", expr)
                print("* ")
                trav, init_state, _ = linv.traversals[
                    0]  # this would be another for loop
                init = self.algorithm_initialization(init_state)
                print("* Init")
                #print( init_state )
                print("*   ", init)
                s = PatternDot("s")
                init = [
                    replace_all(i, [
                        RewriteRule(WrapOutBef(s),
                                    Replacement(lambda d: d["s"]))
                    ]) for i in init
                ]
                print("* Before")
                repart, before = self.generate_predicate_before(
                    pme, trav, linv.expressions, linv)
                print("* After")
                reversed_trav = dict([(k, (r * -1, c * -1))
                                      for k, (r, c) in trav.items()])
                cont_with, after = self.generate_predicate_before(
                    pme, reversed_trav, linv.expressions, linv)
                # find updates
                print("* Updates")
                updates = self.find_updates(before, after)
                if updates is None:
                    #variant += 1
                    continue
                # Tile updates
                for u in updates:
                    infer_properties(u)
                final_updates = []
                # [DIEGO] Fixing some output pieces being labeled as input
                outputs = []
                for u in updates:
                    lhs, rhs = u.children
                    for l in lhs:
                        if not l.isTemporary():
                            outputs.append(l)
                            l.set_property(OUTPUT)
                #
                for u in updates:
                    #[DIEGO] u.children[0].children[0].set_property( OUTPUT )
                    for node in u.children[1].iterate_preorder():
                        #if isinstance(node, Symbol):
                        if isinstance(node, Symbol) and node not in outputs:
                            node.set_property(INPUT)
                    #
                    #copy_u = replace( copy.deepcopy(u), known_ops_single )
                    copy_u = copy.deepcopy(u)
                    copy_u = tile_expr(copy.deepcopy(copy_u))[0]  # One tiling
                    final_updates.extend(copy_u)
                if len(updates) == 0:
                    print("No updates!! Should only happen in copy")
                    continue

                algs.append(
                    Algorithm(linv, variant, init, repart, cont_with, before,
                              after, final_updates))
                algs[-1].prepare_for_code_generation()
                variant += 1
            self.algs.append(algs)
Beispiel #6
0
    def learn_pattern(self):
        inops = [op for op in self.operands if op.isInput()]
        outops = [op for op in self.operands if op.isOutput()]
        #
        single_assignment = len( self.equation.children ) == 1 and \
                            isinstance(self.equation.children[0].children[0], Symbol) # eq.lhs.single_entry_in_NL
        #
        op_to_pattern = [
            RewriteRule(op, Replacement("PatternDot(%s.name)" % op.name))
            for op in self.operands
        ]
        pattern = NList([
            replace_all(copy.deepcopy(eq), op_to_pattern)
            for eq in self.equation
        ])
        if single_assignment:
            props_str = [
                symbol_props_to_constraints_no_io(op) for op in self.operands
            ]
            constraint = Constraint(" and ".join(
                [prop for prop in props_str if prop]))
        else:
            constraint = Constraint(" and ".join(
                [symbol_props_to_constraints(op) for op in self.operands]))
        # [TODO] Tuple for get_size
        replacement = Replacement(
            "Equal([ NList([%s]), Predicate( \"%s\", [%s], [%s] ) ])" %
            (", ".join([op.name for op in outops]), self.name, ", ".join([
                op.name for op in inops
            ]), ", ".join(["%s.get_size()" % op.get_name() for op in outops])))
        # [TODO] This should be part of the verbose option
        print("* Learnt pattern")
        print("*   ", pattern, end="")
        if constraint.to_eval:
            print("with        ", constraint.to_eval)
        print(" --> ")
        print("*          ", replacement.to_eval)
        print("**********************************")
        # [TODO] Maybe sort known ops by specificity (a la mathematica)
        #known_ops.insert( 0, RewriteRule( (pattern, constraint), replacement ) )

        if single_assignment:
            expr = pattern.children[0]
            expr.children[0] = NList([expr.children[0]])
            known_ops_single.insert(
                0, RewriteRule((expr, constraint), replacement))
            # With minus
            replacement = Replacement(
                "Equal([ NList([%s]), Minus([ Predicate( \"%s\", [%s], [%s] ) ]) ])"
                % (", ".join([op.name for op in outops]), self.name, ", ".join(
                    [op.name for op in inops]), ", ".join(
                        ["%s.get_size()" % op.get_name() for op in outops])))
            expr = copy.deepcopy(expr)
            expr.children[1] = Minus([expr.children[1]])
            expr.children[1] = normalize_minus(copy.deepcopy(expr.children[1]))
            known_ops_single.insert(
                0, RewriteRule((expr, constraint), replacement))
            #with open(os.path.join("OUTPUT", self.name+"_patterns"), "wb") as patt_f:
            #pickle.dump( known_ops_single[1], patt_f )
            #pickle.dump( known_ops_single[0], patt_f )
        else:
            known_ops.insert(0, RewriteRule((pattern, constraint),
                                            replacement))
            with open(os.path.join("OUTPUT", self.name + "_patterns"),
                      "wb") as patt_f:
                pickle.dump(known_ops[0], patt_f)

        pattern = Equal([
            NList([PatternDot(op.get_name()) for op in outops]),
            Predicate(self.name, [PatternDot(op.get_name()) for op in inops],
                      [op.get_size() for op in outops])
        ])
        replacement = Replacement(equation2replacement(self.equation))
        op_to_implicit.append(RewriteRule(pattern, replacement))
Beispiel #7
0
from BindDimensions import bindDimensions
from graph import build_dep_graph, subgraphs, zero_out_lower
from Partitioning import partition_shape, repartition, repartition_shape, repart_group
from Tiling import tile_expr
from utils import sort
from RewritingExtension import *

from BackEnd.MatlabCode import generate_matlab_code, click2matlab
#from BackEnd.LatexCode import generate_latex_code
from BackEnd.LGenCode import generate_lgen_files
#from BackEnd.CCode import generate_c_code

known_ops = [
    # Straight assignment
    RewriteRule(
        (NList([Equal([PatternDot("LHS"), PatternDot("RHS")])]),
         Constraint(
             "isinstance(LHS, Symbol) and isOutput(LHS) and isInput(RHS)")),
        Replacement("Equal([ NList([LHS]), RHS ])"))
]
known_ops_single = []
known_pmes = []
op_to_implicit = []


# Add a verbose option
# For now it would help to illustrate the process
class Operation(object):
    def __init__(self, name, operands, equation, overwrite):
        self.name = name
        self.operands = operands
Beispiel #8
0
from core.expression import Symbol, Matrix, Vector, \
                            Equal, Plus, Minus, Times, Transpose, \
                            PatternDot, PatternStar
from core.properties import *
from core.InferenceOfProperties import *

from core.functional import Constraint, Replacement, RewriteRule, replace_all

from core.builtin_operands import Zero, Identity

LHS = PatternDot("LHS")
RHS = PatternDot("RHS")
subexpr = PatternDot("subexpr")
PD1 = PatternDot("PD1")
PD2 = PatternDot("PD2")
PS1 = PatternStar("PS1")
PS2 = PatternStar("PS2")
PS3 = PatternStar("PS3")
PSLeft = PatternStar("PSLeft")
PSRight = PatternStar("PSRight")

simplify_rules = [
    # Minus
    # --expr -> expr
    RewriteRule(Minus([Minus([subexpr])]),
                Replacement(lambda d: d["subexpr"])),
    # -0 -> 0
    RewriteRule((Minus([subexpr]), Constraint(lambda d: isZero(d["subexpr"]))),
                Replacement(lambda d: d["subexpr"])),
    # Plus
    # Plus(a) -> a
def flatten_blocked_operation_click( expr ):
    PD = PatternDot("PD")
    rule = RewriteRule( WrapOutBef( PD ), \
                        Replacement(lambda d: BlockedExpression( map_thread( WrapOutBef, [d["PD"]], 2 ), d["PD"].size, d["PD"].shape)) )
    expr = replace( copy.deepcopy(expr), [rule] )
    return flatten_blocked_operation( expr ) 
Beispiel #10
0
    def learn_pattern(self):
        inops = [op for op in self.operands if op.isInput()]
        outops = [op for op in self.operands if op.isOutput()]
        # pattern
        predicate_inops = []
        predicate_outops = []
        for op in self.operands:
            rewrite_predicate_ops = []
            basic_part = self.basic_partitionings[op.get_name()]
            for part_op in itertools.chain(*basic_part):
                rewrite_predicate_ops.append(
                    RewriteRule(part_op,
                                Replacement(PatternDot(part_op.get_name()))))
            new = BlockedExpression(
                copy.deepcopy(self.basic_partitionings[op.get_name()]),
                op.get_size(), basic_part.shape)
            if op.isInput():
                predicate_inops.append(replace(new, rewrite_predicate_ops))
            else:
                predicate_outops.append(replace(new, rewrite_predicate_ops))
        pattern = Equal([
            NList(predicate_outops),
            Predicate(self.name, predicate_inops,
                      [op.get_size() for op in outops])
        ])
        # replacement
        # [TODO] Tuple for get_size
        #basic_parts = self.basic_partitionings
        basic_parts = self.partitionings
        lhss = map_thread(NList, [basic_parts[op.get_name()] for op in outops],
                          2)

        # [TODO] This is a fix for lu (maybe coup sylv as well).
        #        Generalize and clean up
        for i, row in enumerate(lhss):
            for j, cell in enumerate(row):
                cell = replace(cell, [
                    RewriteRule(
                        (NList([PatternPlus("PP"),
                                PatternDot("PD")
                                ]), Constraint(lambda d: isZero(d["PD"]))),
                        Replacement(lambda d: NList(d["PP"].get_children()))),
                    RewriteRule(
                        (NList([PatternDot("PD"),
                                PatternPlus("PP")
                                ]), Constraint(lambda d: isZero(d["PD"]))),
                        Replacement(lambda d: NList(d["PP"].get_children())))
                ])
                lhss[i][j] = cell
        #

        # [CHECK] parts = self.partitionings
        #parts = self.basic_partitionings
        parts = self.partitionings
        eqs = map_thread(Equal, [lhss, parts[outops[0].get_name()]], 2)

        output_shape = self.basic_partitionings[outops[0].get_name()].shape
        r, c = output_shape
        for eq in self.solved_subequations:
            lhs, rhs = eq.get_children()
            for row in range(r):
                for col in range(c):
                    this_lhs, this_rhs = eqs[row][col].get_children()
                    if lhs == this_lhs:
                        eqs[row][col].set_children(1, rhs)

        #for row in range(r):
        #for col in range(c):
        #eqs[row][col] = equation2replacement( eqs[row][col].get_children()[1] )

        replacement_str = equation2replacement(
            BlockedExpression(eqs, (0, 0), output_shape))
        #", ".join([
        #"[ " + ", ".join( [ eq for eq in row ] ) + " ]"
        #for row in eqs ]) + " ]" + \
        #", (0,0), (%d, %d) )" % (r, c) + \
        #"]), " + \
        #"BlockedExpression([ " + \
        #", ".join([
        #"[ " + ", ".join( [ eq for eq in row ] ) + " ]"
        #for row in eqs ]) + " ]" + \
        #", (0,0), (%d, %d) )" % (r, c) + \
        #"])" # size does not matter
        print("* Learnt PME pattern")
        print("*     ", RewriteRule(pattern, Replacement(replacement_str)))
        self.known_pmes.append(
            RewriteRule(pattern, Replacement(replacement_str)))
        with open(os.path.join("OUTPUT", self.name + "_pmes"),
                  "ab+") as pmes_f:
            pickle.dump(self.known_pmes[-1], pmes_f)
Beispiel #11
0
                        tile, new = instr.tile( new, node, _m )
                        lhs, rhs = tile.get_children()
                        if rhs not in matched_this_node: ###
                            matched_this_node.append( rhs ) ###
                            ongoing.append( alg[:] + [tile, new] )
                            # Set size of new temporary
                            lhs.children[0].size = rhs.get_size()
                        else:
                            # Aestetic. Simply to avoid missing T? values.
                            TOS.push_back_temp( )
                            TOS._TOS.unset_operand( tile.children[0].children[0] )
            if matched_in_this_level: ###
                break


PD1 = PatternDot("PD1")
PD2 = PatternDot("PD2")
PS1 = PatternStar("PS1")
PS2 = PatternStar("PS2")

grouping_rules = [
    # A B + A C D -> A (B + C D)
    RewriteRule(
        Plus([ Times([ PD1, PS1 ]), Times([ PD1, PS2 ]) ]),
        Replacement(lambda d: Times([ d["PD1"], Plus([Times([d["PS1"]]), Times([d["PS2"]])]) ]))
    ),
    # A B - A C D -> A (B - C D)
    RewriteRule(
        Plus([ Times([ PD1, PS1 ]), Times([ Minus([PD1]), PD2, PS2 ]) ]),
        Replacement(lambda d: Times([ d["PD1"], Plus([ Times([d["PS1"]]), Times([Minus([d["PD2"]]), Times([d["PS2"]])])]) ]))
    ),
Beispiel #12
0

class tril(Operator):
    def __init__(self, arg):
        Operator.__init__(self, [arg], [], UNARY)
        self.size = arg.get_size()


class triu(Operator):
    def __init__(self, arg):
        Operator.__init__(self, [arg], [], UNARY)
        self.size = arg.get_size()


# Patterns for inv to trsm
A = PatternDot("A")
B = PatternDot("B")
C = PatternDot("C")
# [TODO] Complete the set of patterns
trsm_patterns = [
    #RewriteRule( (Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), Constraint("A.st_info[0] == ST_LOWER")), \
    RewriteRule( Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), \
            Replacement( lambda d: Equal([ d["C"], mrdiv([ Transpose([tril(d["A"])]), d["B"] ]) ]) ) ),
]


#
# Produces Matlab code for:
# - Loop-based code
#
def generate_matlab_code(operation, matlab_dir):
Beispiel #13
0
import copy

from core.expression import Equal, Plus, Times, Minus, Transpose, Inverse, \
                            Symbol, PatternDot, PatternStar, NList, Predicate
from core.functional import Constraint, Replacement, RewriteRule, match, replace
import core.properties as properties
from core.InferenceOfProperties import isUpperTriangular
from core.prop_to_queryfunc import propagate_properties

from CoreExtension import isOperand
import storage

import core.TOS as TOS

alpha = PatternDot( "alpha" )
beta = PatternDot( "beta" )
A = PatternDot( "A" )
B = PatternDot( "B" )
C = PatternDot( "C" )
D_PS = PatternStar( "D_PS")
left = PatternStar( "left" )
middle = PatternStar( "middle" )
right = PatternStar( "right" )


class Instruction( object ):
    def __init__( self, pattern, create_rewrite_rule, create_tile ):
        self.pattern = pattern
        self.create_rewrite_rule = create_rewrite_rule
        self.create_tile = create_tile
Beispiel #14
0
pm.DB["rdiv_unu_ow"] = pm.PredicateMetadata("rdiv_unu_ow", tuple())
pm.DB["rdiv_unu_ow"].overwrite = [(1, 0)]
pm.DB["rdiv_uti"] = pm.PredicateMetadata("rdiv_uti", tuple())
pm.DB["rdiv_uti"].overwrite = []
pm.DB["rdiv_uti_ow"] = pm.PredicateMetadata("rdiv_uti_ow", tuple())
pm.DB["rdiv_uti_ow"].overwrite = [(1, 0)]
pm.DB["rdiv_utn"] = pm.PredicateMetadata("rdiv_utn", tuple())
pm.DB["rdiv_utn"].overwrite = []
pm.DB["rdiv_utn_ow"] = pm.PredicateMetadata("rdiv_utn_ow", tuple())
pm.DB["rdiv_utn_ow"].overwrite = [(1, 0)]
pm.DB["rdiv_utu"] = pm.PredicateMetadata("rdiv_utu", tuple())
pm.DB["rdiv_utu"].overwrite = []
pm.DB["rdiv_utu_ow"] = pm.PredicateMetadata("rdiv_utu_ow", tuple())
pm.DB["rdiv_utu_ow"].overwrite = [(1, 0)]

A = PatternDot("A")
B = PatternDot("B")
X = PatternDot("X")

trsm2lgen_rules = [
    # X = i(t(A)) B -> ldiv_lni
    RewriteRule((
        Equal([NList([X]), Times([Inverse([A]), B])]),
        Constraint(
            "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name == X.name"
        )),
                Replacement(lambda d: Equal([
                    NList([d["X"]]),
                    Predicate("ldiv_lni", [d["A"], d["B"]],
                              [d["A"].get_size(), d["B"].get_size()])
                ]))),