예제 #1
0
def normalize_minus( expr ):
    ld = PatternDot("ld")
    md = PatternDot("md")
    l = PatternStar("l")
    r = PatternStar("r")
    return replace_all( expr, [RewriteRule( Times([ ld, l, Minus([md]), r ]),
                                            Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"], d["md"], d["r"]]) )),
                               RewriteRule( Times([ Minus([Minus([ld])]), l ]),
                                            Replacement( lambda d: Times([ d["ld"], d["l"]]) )),
                               RewriteRule( Minus([ Times([ ld, l ]) ]),
                                            Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"]]) )) ] )
예제 #2
0
 def before_to_rule(self, before):
     lhs, rhs = before.get_children()
     if isinstance(rhs, Plus):
         rules = [
             RewriteRule(
                 Plus([PatternStar("rest")] + rhs.get_children()),
                 Replacement(lambda d: Plus(d["rest"].get_children() + lhs.
                                            get_children()))),
             #Replacement("Plus(rest+lhs.get_children())") )
             RewriteRule(
                 Plus([
                     Times([PatternStar("l")] + [ch] + [PatternStar("r")])
                     for ch in rhs.get_children()
                 ]),
                 Replacement(lambda d: Times(d["l"].get_children(
                 ) + lhs.get_children() + d["r"].get_children()))),
             RewriteRule(
                 to_canonical(
                     Plus([
                         Times([PatternStar("l")] + [ch] +
                               [PatternStar("r")])
                         for ch in rhs.get_children()
                     ])),
                 Replacement(lambda d: Times(d["l"].get_children(
                 ) + lhs.get_children() + d["r"].get_children())))
         ]
     elif isinstance(rhs, Times):
         rules = [
             RewriteRule(
                 Times([PatternStar("l")] + rhs.get_children() +
                       [PatternStar("r")]),
                 Replacement(lambda d: Times(d["l"].get_children(
                 ) + lhs.get_children() + d["r"].get_children()))),
             # [TODO] For chol loop invariants 2 and 3, should fix elsewhere (maybe -1 instead of minus operator)
             RewriteRule(
                 Times([PatternStar("l")] +
                       [to_canonical(Minus([Times(rhs.get_children())]))] +
                       [PatternStar("r")]),
                 Replacement(lambda d: Times(d["l"].get_children() + [
                     to_canonical(Minus([Times(lhs.get_children())]))
                 ] + d["r"].get_children())))
         ]
     else:  # [FIX] what if multiple outputs?
         #rules = [ RewriteRule( rhs, Replacement( lhs ) ) ]
         rules = [RewriteRule(rhs, Replacement(lhs.children[0]))]
     return rules
예제 #3
0
 def expr_to_rule_rhs_lhs(self, predicates):
     rules = []
     t = PatternStar("t")
     l = PatternStar("l")
     ld = PatternDot("ld")
     r = PatternStar("r")
     for p in predicates:
         pr = []
         lhs, rhs = p.children
         if len(lhs.children) == 1:
             #lhs_sym = WrapOutBef( lhs.children[0] )
             lhs_sym = lhs.children[0]
             if isinstance(rhs, Plus):
                 # t___ + rhs -> t + lhs
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children +
                                                      [lhs]))(lhs_sym)
                 pr.append(
                     RewriteRule(Plus([t] + rhs.children),
                                 Replacement(repl_f)))
                 # t___ + l___ rhs_i r___ + ... -> t + l lhs r
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times(d["l"].children + [lhs] + d["r"].children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             Times([l] + [ch] + [r]) for ch in rhs.children
                         ]), Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([simplify(to_canonical(Minus([lhs])))] + d["r"].
                           children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             Times([simplify(to_canonical(Minus([ch])))] +
                                   [r]) for ch in rhs.children
                         ]), Replacement(repl_f)))
                 # A - B C in  L B C R + -L A R  (minus pushed all the way to the left, and whole thing negated)
                 repl_f = (lambda lhs: lambda d: normalize_minus(
                     Plus(d["t"].children + [
                         Times([d["ld"]] + d["l"].children + [Minus([lhs])]
                               + d["r"].children)
                     ])))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             normalize_minus(Times([ld, l,
                                                    Minus([ch]), r]))
                             for ch in rhs.children
                         ]), Replacement(repl_f)))
                 # A - B C in  -L B C R + L A R  (minus pushed all the way to the left)
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([d["ld"]] + d["l"].children + [lhs] + d["r"].
                           children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             normalize_minus(Times([ld, l, ch, r]))
                             for ch in rhs.children
                         ]), Replacement(repl_f)))
                 #repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [Times([simplify(to_canonical(Minus(lhs.children)))] + d["r"].children)]))(lhs_sym)
                 #pr.append( RewriteRule( Plus([t] + [
                 #Times([ Minus([ld]), l, ch, r]) if not isinstance(ch, Minus) \
                 #else Times([ l, ch.children[0], r]) \
                 #for ch in rhs.children ]),
                 #Replacement(repl_f) ) )
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([
                         simplify(to_canonical(Minus([Transpose([lhs])])))
                     ] + d["r"].children)
                 ]))(lhs_sym)
                 pr.append( RewriteRule( Plus([t] + [
                             Times([ Minus([ld]), l, simplify(to_canonical(Transpose([ch]))), r]) if not isinstance(ch, Minus) \
                                     else Times([ l, simplify(to_canonical(Transpose([ch]))), r]) \
                                     for ch in rhs.children ]),
                                 Replacement(repl_f) ) )
             elif isinstance(rhs, Times):
                 repl_f = (lambda lhs: lambda d: Times(d[
                     "l"].children + [lhs] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(Times([l] + rhs.children + [r]),
                                 Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Times(d[
                     "l"].children + [Transpose([lhs])] + d["r"].children)
                           )(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             l,
                             simplify(to_canonical(Transpose([rhs]))), r
                         ]), Replacement(repl_f)))
                 # [TODO] Minus is a b*tch. Should go for -1 and remove the operator internally?
                 repl_f = (lambda lhs: lambda d: Times([
                     simplify(to_canonical(Minus([Times([lhs])])))
                 ] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             simplify(
                                 to_canonical(
                                     Minus([Times(rhs.get_children())])))
                         ] + [r]), Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Times([
                     simplify(
                         to_canonical(Minus([Transpose([Times([lhs])])])))
                 ] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             simplify(
                                 to_canonical(
                                     Minus([
                                         Transpose(
                                             [Times(rhs.get_children())])
                                     ])))
                         ] + [r]), Replacement(repl_f)))
             else:
                 pr.append(RewriteRule(rhs, Replacement(lhs_sym)))
                 new_rhs = simplify(to_canonical(Transpose([rhs])))
                 if not isOperand(new_rhs):
                     pr.append(
                         RewriteRule(
                             simplify(to_canonical(Transpose([rhs]))),
                             Replacement(Transpose([lhs_sym]))))
         else:
             pr.append(RewriteRule(rhs, Replacement(lhs)))
         rules.append(pr)
     return rules
예제 #4
0
from core.expression import Symbol, Matrix, Vector, \
                            Equal, Plus, Minus, Times, Transpose, \
                            PatternDot, PatternStar
from core.properties import *
from core.InferenceOfProperties import *

from core.functional import Constraint, Replacement, RewriteRule, replace_all

from core.builtin_operands import Zero, Identity

LHS = PatternDot("LHS")
RHS = PatternDot("RHS")
subexpr = PatternDot("subexpr")
PD1 = PatternDot("PD1")
PD2 = PatternDot("PD2")
PS1 = PatternStar("PS1")
PS2 = PatternStar("PS2")
PS3 = PatternStar("PS3")
PSLeft = PatternStar("PSLeft")
PSRight = PatternStar("PSRight")

simplify_rules = [
    # Minus
    # --expr -> expr
    RewriteRule(Minus([Minus([subexpr])]),
                Replacement(lambda d: d["subexpr"])),
    # -0 -> 0
    RewriteRule((Minus([subexpr]), Constraint(lambda d: isZero(d["subexpr"]))),
                Replacement(lambda d: d["subexpr"])),
    # Plus
    # Plus(a) -> a
예제 #5
0
    def solve_equations(self):
        dist = self.distributed_partitioned_postcondition
        # Update TOE
        for eq in dist.flatten_children():
            for subeq in eq:
                lhs, rhs = subeq.get_children()
                update_TOE(RewriteRule(lhs, Replacement(rhs)))
                update_TOE(RewriteRule(rhs, Replacement(lhs)))
        # Collect all output parts to be computed
        basic_part = self.basic_partitionings
        basic_part = self.partitionings  #[TODO] Ok?
        all_outputs = []
        for op in self.operands:
            if op.isOutput():
                all_outputs.extend([
                    node
                    for node in basic_part[op.get_name()].iterate_preorder()
                    if isinstance(node, Symbol)
                ])

        # Iterate until PME is solved
        subeqs_to_solve = dist.flatten_children()
        subeqs_to_solve = filter_zero_zero(subeqs_to_solve)
        subeqs_solved = []
        solved_outputs = set([
            ""
        ])  # any initialization that cannot render the equality below true
        solved = False
        #
        while not solved:
            for eq in subeqs_to_solve:
                new = replace(copy.deepcopy(eq), self.known_ops)
                if isinstance(new, Equal):
                    lhs, rhs = new.get_children()
                    # Set input property
                    if all([isinstance(elem, Symbol) for elem in lhs]):
                        if isinstance(rhs, Predicate):
                            rhs.set_property(INPUT)  # SOLVED?
                        for op in lhs:
                            op.set_property(INPUT)  # SOLVED?
                        subeqs_solved.append(new)
                    else:
                        #print( "Can it be?" )
                        # Yes, e.g., trans(X_BL) in lyapunov
                        #raise Exception
                        pass
            cur_solved_outputs = set(
                [op.get_name() for op in all_outputs if isInput(op)])
            #print( "SOLVED:", cur_solved_outputs )
            if solved_outputs == cur_solved_outputs:
                print("")
                print(
                    "[WARNING] PME Generation is stuck - Trying recursive instance"
                )
                print("")
                minimum = (100000, 100000)
                subeq = None
                for eq in subeqs_to_solve:
                    unks = []
                    cnt = 0
                    for node in eq.iterate_preorder():
                        cnt += 1
                        if isinstance(node, Symbol) and node.isOutput(
                        ) and node not in unks:
                            unks.append(node)
                    if (len(unks), cnt) < minimum:
                        subeq = eq
                        minimum = (len(unks), cnt)
                from NestedInstance import rec_instance
                ((md_name, md_data), new_patts,
                 new_pmes) = rec_instance(subeq.children[0])
                self.known_pmes.extend(new_pmes)
                print("NPMES:", len(new_pmes))
                for patt in new_patts:
                    self.known_ops.append(patt)
                pm.DB[md_name] = md_data
                #raise Exception
            solved_outputs = cur_solved_outputs
            if all([isInput(op) or op.isZero() for op in all_outputs]):
                solved = True
            else:
                new_to_solve = []
                for eq in subeqs_to_solve:
                    if not isinstance( eq, Equal ) and \
                            len([op for op in eq.iterate_preorder() if isinstance(op, Symbol) and isOutput(op)]) != 0:
                        new_to_solve.append(eq)
                subeqs_to_solve = [
                    simplify(to_canonicalIO(eq))._cleanup()
                    for eq in new_to_solve
                ]
        self.solved_subequations = subeqs_solved
        # Replace rhs with lhs if they appear as subexpressions of other rhss
        #
        # Create replacements when suited
        reuse_rules = []
        for i, eq in enumerate(subeqs_solved):
            eq_rules = []
            lhs, rhs = eq.children
            if not isinstance(rhs, Predicate) and rhs.get_head() not in (
                    Minus, Transpose, Inverse):
                lhs_sym = lhs.children[0]
                t = PatternStar("t")
                l = PatternStar("l")
                r = PatternStar("r")
                repl_f = (lambda lhs_sym: lambda d: Plus(
                    [d["t"], Times([d["l"], lhs_sym, d["r"]])]))(lhs_sym)
                eq_rules.append(
                    RewriteRule(Plus([t, Times([l, rhs, r])]),
                                Replacement(repl_f)))
                # A - B C in  -L B C R + L A R  (minus pushed all the way to the left)
                if len(rhs.children) > 1:
                    repl_f = (lambda lhs_sym: lambda d: Times(
                        [Minus([lhs_sym])] + d["r"].children))(lhs_sym)
                    eq_rules.append(
                        RewriteRule(
                            Times([
                                Minus([rhs.children[0]]), *rhs.children[1:], r
                            ]), Replacement(repl_f)))
            reuse_rules.append((i, eq_rules))

        # Replace
        self.solved_subequations = []
        for i, eq in enumerate(subeqs_solved):
            rules = [r for j, r in reuse_rules if i != j]
            self.solved_subequations.append(
                replace_all(eq, list(itertools.chain(*rules))))

        # [TODO] Add T to operands and bind dimensions
        #self.bind_temporaries( )
        # Reset outputs
        for op in solved_outputs:
            TOS[op][0].set_property(OUTPUT)

        print("* PME ")
        for eq in self.solved_subequations:
            print("*    ", eq)
예제 #6
0
파일: Tiling.py 프로젝트: shreyas42/slingen
                        if rhs not in matched_this_node: ###
                            matched_this_node.append( rhs ) ###
                            ongoing.append( alg[:] + [tile, new] )
                            # Set size of new temporary
                            lhs.children[0].size = rhs.get_size()
                        else:
                            # Aestetic. Simply to avoid missing T? values.
                            TOS.push_back_temp( )
                            TOS._TOS.unset_operand( tile.children[0].children[0] )
            if matched_in_this_level: ###
                break


PD1 = PatternDot("PD1")
PD2 = PatternDot("PD2")
PS1 = PatternStar("PS1")
PS2 = PatternStar("PS2")

grouping_rules = [
    # A B + A C D -> A (B + C D)
    RewriteRule(
        Plus([ Times([ PD1, PS1 ]), Times([ PD1, PS2 ]) ]),
        Replacement(lambda d: Times([ d["PD1"], Plus([Times([d["PS1"]]), Times([d["PS2"]])]) ]))
    ),
    # A B - A C D -> A (B - C D)
    RewriteRule(
        Plus([ Times([ PD1, PS1 ]), Times([ Minus([PD1]), PD2, PS2 ]) ]),
        Replacement(lambda d: Times([ d["PD1"], Plus([ Times([d["PS1"]]), Times([Minus([d["PD2"]]), Times([d["PS2"]])])]) ]))
    ),
    # B A + C D A -> (B + C D) A
    RewriteRule(
예제 #7
0
from core.functional import Constraint, Replacement, RewriteRule, match, replace
import core.properties as properties
from core.InferenceOfProperties import isUpperTriangular
from core.prop_to_queryfunc import propagate_properties

from CoreExtension import isOperand
import storage

import core.TOS as TOS

alpha = PatternDot( "alpha" )
beta = PatternDot( "beta" )
A = PatternDot( "A" )
B = PatternDot( "B" )
C = PatternDot( "C" )
D_PS = PatternStar( "D_PS")
left = PatternStar( "left" )
middle = PatternStar( "middle" )
right = PatternStar( "right" )


class Instruction( object ):
    def __init__( self, pattern, create_rewrite_rule, create_tile ):
        self.pattern = pattern
        self.create_rewrite_rule = create_rewrite_rule
        self.create_tile = create_tile

    def match( self, expr ):
        yield from match( expr, self.pattern )

    def tile( self, tree, node, match_dict ):