Пример #1
0
def normalize_minus( expr ):
    ld = PatternDot("ld")
    md = PatternDot("md")
    l = PatternStar("l")
    r = PatternStar("r")
    return replace_all( expr, [RewriteRule( Times([ ld, l, Minus([md]), r ]),
                                            Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"], d["md"], d["r"]]) )),
                               RewriteRule( Times([ Minus([Minus([ld])]), l ]),
                                            Replacement( lambda d: Times([ d["ld"], d["l"]]) )),
                               RewriteRule( Minus([ Times([ ld, l ]) ]),
                                            Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"]]) )) ] )
Пример #2
0
def initial_rewrite(operand, part, trav_shape):
    part_ch = part.get_children()
    part_shape = (len(part_ch), len(part_ch[0]))
    zero = Zero((sZERO, sZERO))
    if part_shape == (1, 1):
        #rewrite = [[operand]]
        return []
    elif part_shape == (1, 2):  # [L|R]
        if trav_shape == (
                0, 1
        ):  # Left to Right, initially Left is empty, Right is the full operand
            rewrite = [[zero, operand]]
        elif trav_shape == (0, -1):  # Right to Left
            rewrite = [[operand, zero]]
    elif part_shape == (2, 1):  # [T;B]
        if trav_shape == (1, 0):  # Top to Bottom
            rewrite = [[zero], [operand]]
        elif trav_shape == (-1, 0):  # Bottom to Top
            rewrite = [[operand], [zero]]
    elif part_shape == (2, 2):  # [TL TR; BL BR]
        if trav_shape == (1, 1):  # Top Left to Bottom Right
            rewrite = [[zero, zero], [zero, operand]]
        elif trav_shape == (1, -1):  # Top Right to Bottom Left
            rewrite = [[zero, zero], [operand, zero]]
        elif trav_shape == (-1, 1):  # Bottom Left to Top Right
            rewrite = [[zero, operand], [zero, zero]]
        elif trav_shape == (-1, -1):  # Bottom Right to Top Left
            rewrite = [[operand, zero], [zero, zero]]

    return list(
        itertools.chain(*[[
            RewriteRule(p, Replacement(rew)) for p, rew in zip(p_row, rew_row)
        ] for p_row, rew_row in zip(part, rewrite)]))
Пример #3
0
 def expr_to_rule_lhs_rhs(self, predicates):
     rules = []
     for p in predicates:
         lhs, rhs = p.children
         if len(lhs.children) == 1:
             rules.append([RewriteRule(lhs.children[0], Replacement(rhs))])
     return rules
Пример #4
0
def compress_tiles(tiled_subpme):
    compressed = []
    for tiles in tiled_subpme:
        compressed.append([])
        i = 0
        while i < len(tiles):
            #for i, tile in enumerate( tiles ):
            tile = tiles[i]
            lhs, rhs = tile.get_children()
            if isinstance(rhs, Predicate):
                compressed[-1].append(copy.deepcopy(tile))
                i += 1
                continue
            replaced = False
            for j, other_tile in enumerate(tiles[i + 1:], start=i + 1):
                olhs, orhs = other_tile.get_children()
                lhs_ch = lhs.get_children()[0]
                if contains(orhs, lhs_ch):
                    new_other = copy.deepcopy(other_tile)
                    #new_other = replace( new_other, [RewriteRule(lhs_ch, Replacement(rhs))] )
                    new_other.children[1] = replace(
                        new_other.rhs(),
                        [RewriteRule(lhs_ch, Replacement(rhs))])
                    tiles[j] = new_other
                    replaced = True
                # if tile lhs = f(rhs) and other_tile overwrites lhs (lhs = f(lhs))
                # stop replacing with "old" lhs
                if lhs_ch in olhs.children:
                    break
            if not replaced:
                compressed[-1].append(copy.deepcopy(tile))
            i += 1
    return compressed
Пример #5
0
 def tile( self, tree, node, match_dict ):
     _T = TOS.new_temp()
     temp = Symbol("T" + str(_T))
     temp.set_property( properties.TEMPORARY )
     #
     if isinstance( node, Predicate ):
         if node == tree:
             print( "[Warning] If predicate has multiple outputs, may break if not careful from caller" )
         return Equal([ NList([ temp ]), copy.deepcopy( node ) ]), \
                 replace( tree, [RewriteRule( copy.deepcopy( node ), Replacement(temp) )] )
     else:
         tile_expr = self.create_tile( match_dict )
         propagate_properties( tile_expr, temp )
         propagate_st_info( tile_expr, temp )
         ## [FIXME] Quick and dirty test
         #if isUpperTriangular( tile_expr ):
             #print( temp, "is upper triangular" )
             #temp.set_property( properties.UPPER_TRIANGULAR )
         return Equal([ NList([ temp ]), self.create_tile( match_dict ) ]), \
                 replace( tree, [self.create_rewrite_rule( match_dict, temp )] )
Пример #6
0
 def opt_copy_propagation(self):
     for i, st in enumerate(self.statements):
         if isinstance(st, Equal):
             lhs, rhs = st.children
             lhs = lhs.children
             if len(lhs) == 1 and lhs[0].isTemporary() and isinstance(
                     st.rhs(), Symbol):
                 for oth_st in self.statements[i + 1:]:
                     if not isinstance(
                             oth_st, Equal
                     ):  #[CHECK] Careful. Not the case, but generally speaking could be some part/repart
                         continue
                     oth_lhs, oth_rhs = oth_st.lhs(), oth_st.rhs()
                     if contains(oth_rhs, lhs[0]):
                         #\ and not (isinstance( oth_rhs, Predicate ) and pm.DB[oth_rhs.name].overwrite): #[FIXME] Can be less restrictive
                         new_rhs = replace(
                             oth_rhs,
                             [RewriteRule(lhs[0], Replacement(rhs))])
                         oth_st.children[1] = new_rhs
                     if oth_lhs.children[0] == lhs[0] or oth_lhs == rhs:
                         break
Пример #7
0
 def _PassStorage(st):
     if isinstance(st, lpla._while):
         for s in st.body:
             _PassStorage(s)
     elif isinstance(st, Equal):
         symbols = []
         for n in st.iterate_preorder():
             # Zero has no st_info...
             if isinstance(n, Zero):
                 pass
             elif isinstance(n, Symbol) and n.st_info[1] != n:
                 n_st = n.st_info[1]
                 n_st.st_info = n.st_info
                 # E.g., in Cholesky, n is L_11, and n_st is A_11
                 #if n.isLowerTriangular() and not n_st.isLowerTriangular():
                 #symbols.append( RewriteRule( n, Replacement(tril(n_st)) ) )
                 #elif n.isUpperTriangular() and not n_st.isUpperTriangular():
                 #symbols.append( RewriteRule( n, Replacement(triu(n_st)) ) )
                 #else:
                 symbols.append(RewriteRule(n, Replacement(n_st)))
         replace(st, symbols)
Пример #8
0
 def before_to_rule(self, before):
     lhs, rhs = before.get_children()
     if isinstance(rhs, Plus):
         rules = [
             RewriteRule(
                 Plus([PatternStar("rest")] + rhs.get_children()),
                 Replacement(lambda d: Plus(d["rest"].get_children() + lhs.
                                            get_children()))),
             #Replacement("Plus(rest+lhs.get_children())") )
             RewriteRule(
                 Plus([
                     Times([PatternStar("l")] + [ch] + [PatternStar("r")])
                     for ch in rhs.get_children()
                 ]),
                 Replacement(lambda d: Times(d["l"].get_children(
                 ) + lhs.get_children() + d["r"].get_children()))),
             RewriteRule(
                 to_canonical(
                     Plus([
                         Times([PatternStar("l")] + [ch] +
                               [PatternStar("r")])
                         for ch in rhs.get_children()
                     ])),
                 Replacement(lambda d: Times(d["l"].get_children(
                 ) + lhs.get_children() + d["r"].get_children())))
         ]
     elif isinstance(rhs, Times):
         rules = [
             RewriteRule(
                 Times([PatternStar("l")] + rhs.get_children() +
                       [PatternStar("r")]),
                 Replacement(lambda d: Times(d["l"].get_children(
                 ) + lhs.get_children() + d["r"].get_children()))),
             # [TODO] For chol loop invariants 2 and 3, should fix elsewhere (maybe -1 instead of minus operator)
             RewriteRule(
                 Times([PatternStar("l")] +
                       [to_canonical(Minus([Times(rhs.get_children())]))] +
                       [PatternStar("r")]),
                 Replacement(lambda d: Times(d["l"].get_children() + [
                     to_canonical(Minus([Times(lhs.get_children())]))
                 ] + d["r"].get_children())))
         ]
     else:  # [FIX] what if multiple outputs?
         #rules = [ RewriteRule( rhs, Replacement( lhs ) ) ]
         rules = [RewriteRule(rhs, Replacement(lhs.children[0]))]
     return rules
Пример #9
0
    def partition(self):
        self.basic_partitionings = dict()
        self.partitionings = dict()
        rewrite_rules = []
        for operand in self.operands:
            op_name = operand.get_name()
            # For code generation
            #part = partition_shape( operand, self.part_shape[op_name] )
            part = partition_shape_with_storage(operand,
                                                self.part_shape[op_name])
            self.basic_partitionings[op_name] = part
            #
            part = partition(operand, self.part_shape[op_name],
                             operand.get_properties())
            self.partitionings[op_name] = part
            #
            rewrite_rules.append(RewriteRule(operand, Replacement(part)))

        self.partitioned_postcondition = \
            [replace( copy.deepcopy(eq), rewrite_rules ) for eq in self.equation]

        print("* Partitioned postcondition")
        for eq in self.partitioned_postcondition:
            print("*    ", eq)
Пример #10
0
    def solve_equations(self):
        dist = self.distributed_partitioned_postcondition
        # Update TOE
        for eq in dist.flatten_children():
            for subeq in eq:
                lhs, rhs = subeq.get_children()
                update_TOE(RewriteRule(lhs, Replacement(rhs)))
                update_TOE(RewriteRule(rhs, Replacement(lhs)))
        # Collect all output parts to be computed
        basic_part = self.basic_partitionings
        basic_part = self.partitionings  #[TODO] Ok?
        all_outputs = []
        for op in self.operands:
            if op.isOutput():
                all_outputs.extend([
                    node
                    for node in basic_part[op.get_name()].iterate_preorder()
                    if isinstance(node, Symbol)
                ])

        # Iterate until PME is solved
        subeqs_to_solve = dist.flatten_children()
        subeqs_to_solve = filter_zero_zero(subeqs_to_solve)
        subeqs_solved = []
        solved_outputs = set([
            ""
        ])  # any initialization that cannot render the equality below true
        solved = False
        #
        while not solved:
            for eq in subeqs_to_solve:
                new = replace(copy.deepcopy(eq), self.known_ops)
                if isinstance(new, Equal):
                    lhs, rhs = new.get_children()
                    # Set input property
                    if all([isinstance(elem, Symbol) for elem in lhs]):
                        if isinstance(rhs, Predicate):
                            rhs.set_property(INPUT)  # SOLVED?
                        for op in lhs:
                            op.set_property(INPUT)  # SOLVED?
                        subeqs_solved.append(new)
                    else:
                        #print( "Can it be?" )
                        # Yes, e.g., trans(X_BL) in lyapunov
                        #raise Exception
                        pass
            cur_solved_outputs = set(
                [op.get_name() for op in all_outputs if isInput(op)])
            #print( "SOLVED:", cur_solved_outputs )
            if solved_outputs == cur_solved_outputs:
                print("")
                print(
                    "[WARNING] PME Generation is stuck - Trying recursive instance"
                )
                print("")
                minimum = (100000, 100000)
                subeq = None
                for eq in subeqs_to_solve:
                    unks = []
                    cnt = 0
                    for node in eq.iterate_preorder():
                        cnt += 1
                        if isinstance(node, Symbol) and node.isOutput(
                        ) and node not in unks:
                            unks.append(node)
                    if (len(unks), cnt) < minimum:
                        subeq = eq
                        minimum = (len(unks), cnt)
                from NestedInstance import rec_instance
                ((md_name, md_data), new_patts,
                 new_pmes) = rec_instance(subeq.children[0])
                self.known_pmes.extend(new_pmes)
                print("NPMES:", len(new_pmes))
                for patt in new_patts:
                    self.known_ops.append(patt)
                pm.DB[md_name] = md_data
                #raise Exception
            solved_outputs = cur_solved_outputs
            if all([isInput(op) or op.isZero() for op in all_outputs]):
                solved = True
            else:
                new_to_solve = []
                for eq in subeqs_to_solve:
                    if not isinstance( eq, Equal ) and \
                            len([op for op in eq.iterate_preorder() if isinstance(op, Symbol) and isOutput(op)]) != 0:
                        new_to_solve.append(eq)
                subeqs_to_solve = [
                    simplify(to_canonicalIO(eq))._cleanup()
                    for eq in new_to_solve
                ]
        self.solved_subequations = subeqs_solved
        # Replace rhs with lhs if they appear as subexpressions of other rhss
        #
        # Create replacements when suited
        reuse_rules = []
        for i, eq in enumerate(subeqs_solved):
            eq_rules = []
            lhs, rhs = eq.children
            if not isinstance(rhs, Predicate) and rhs.get_head() not in (
                    Minus, Transpose, Inverse):
                lhs_sym = lhs.children[0]
                t = PatternStar("t")
                l = PatternStar("l")
                r = PatternStar("r")
                repl_f = (lambda lhs_sym: lambda d: Plus(
                    [d["t"], Times([d["l"], lhs_sym, d["r"]])]))(lhs_sym)
                eq_rules.append(
                    RewriteRule(Plus([t, Times([l, rhs, r])]),
                                Replacement(repl_f)))
                # A - B C in  -L B C R + L A R  (minus pushed all the way to the left)
                if len(rhs.children) > 1:
                    repl_f = (lambda lhs_sym: lambda d: Times(
                        [Minus([lhs_sym])] + d["r"].children))(lhs_sym)
                    eq_rules.append(
                        RewriteRule(
                            Times([
                                Minus([rhs.children[0]]), *rhs.children[1:], r
                            ]), Replacement(repl_f)))
            reuse_rules.append((i, eq_rules))

        # Replace
        self.solved_subequations = []
        for i, eq in enumerate(subeqs_solved):
            rules = [r for j, r in reuse_rules if i != j]
            self.solved_subequations.append(
                replace_all(eq, list(itertools.chain(*rules))))

        # [TODO] Add T to operands and bind dimensions
        #self.bind_temporaries( )
        # Reset outputs
        for op in solved_outputs:
            TOS[op][0].set_property(OUTPUT)

        print("* PME ")
        for eq in self.solved_subequations:
            print("*    ", eq)
Пример #11
0
pm.DB["rdiv_utu_ow"].overwrite = [(1, 0)]

A = PatternDot("A")
B = PatternDot("B")
X = PatternDot("X")

trsm2lgen_rules = [
    # X = i(t(A)) B -> ldiv_lni
    RewriteRule((
        Equal([NList([X]), Times([Inverse([A]), B])]),
        Constraint(
            "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name == X.name"
        )),
                Replacement(lambda d: Equal([
                    NList([d["X"]]),
                    Predicate("ldiv_lni", [d["A"], d["B"]],
                              [d["A"].get_size(), d["B"].get_size()])
                ]))),
    # X = i(t(A)) B -> ldiv_lni_ow
    RewriteRule(
        (Equal([NList([X]), Times([Inverse([A]), B])]),
         Constraint(
             "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name != X.name"
         )),
        Replacement(lambda d: Equal([
            NList([d["X"]]),
            Predicate("ldiv_lni_ow", [d["A"], d["B"]],
                      [d["A"].get_size(), d["B"].get_size()])
        ]))),
    # X = i(t(A)) B -> ldiv_lnu
    RewriteRule(
Пример #12
0
 def expr_to_rule_rhs_lhs(self, predicates):
     rules = []
     t = PatternStar("t")
     l = PatternStar("l")
     ld = PatternDot("ld")
     r = PatternStar("r")
     for p in predicates:
         pr = []
         lhs, rhs = p.children
         if len(lhs.children) == 1:
             #lhs_sym = WrapOutBef( lhs.children[0] )
             lhs_sym = lhs.children[0]
             if isinstance(rhs, Plus):
                 # t___ + rhs -> t + lhs
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children +
                                                      [lhs]))(lhs_sym)
                 pr.append(
                     RewriteRule(Plus([t] + rhs.children),
                                 Replacement(repl_f)))
                 # t___ + l___ rhs_i r___ + ... -> t + l lhs r
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times(d["l"].children + [lhs] + d["r"].children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             Times([l] + [ch] + [r]) for ch in rhs.children
                         ]), Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([simplify(to_canonical(Minus([lhs])))] + d["r"].
                           children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             Times([simplify(to_canonical(Minus([ch])))] +
                                   [r]) for ch in rhs.children
                         ]), Replacement(repl_f)))
                 # A - B C in  L B C R + -L A R  (minus pushed all the way to the left, and whole thing negated)
                 repl_f = (lambda lhs: lambda d: normalize_minus(
                     Plus(d["t"].children + [
                         Times([d["ld"]] + d["l"].children + [Minus([lhs])]
                               + d["r"].children)
                     ])))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             normalize_minus(Times([ld, l,
                                                    Minus([ch]), r]))
                             for ch in rhs.children
                         ]), Replacement(repl_f)))
                 # A - B C in  -L B C R + L A R  (minus pushed all the way to the left)
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([d["ld"]] + d["l"].children + [lhs] + d["r"].
                           children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             normalize_minus(Times([ld, l, ch, r]))
                             for ch in rhs.children
                         ]), Replacement(repl_f)))
                 #repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [Times([simplify(to_canonical(Minus(lhs.children)))] + d["r"].children)]))(lhs_sym)
                 #pr.append( RewriteRule( Plus([t] + [
                 #Times([ Minus([ld]), l, ch, r]) if not isinstance(ch, Minus) \
                 #else Times([ l, ch.children[0], r]) \
                 #for ch in rhs.children ]),
                 #Replacement(repl_f) ) )
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([
                         simplify(to_canonical(Minus([Transpose([lhs])])))
                     ] + d["r"].children)
                 ]))(lhs_sym)
                 pr.append( RewriteRule( Plus([t] + [
                             Times([ Minus([ld]), l, simplify(to_canonical(Transpose([ch]))), r]) if not isinstance(ch, Minus) \
                                     else Times([ l, simplify(to_canonical(Transpose([ch]))), r]) \
                                     for ch in rhs.children ]),
                                 Replacement(repl_f) ) )
             elif isinstance(rhs, Times):
                 repl_f = (lambda lhs: lambda d: Times(d[
                     "l"].children + [lhs] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(Times([l] + rhs.children + [r]),
                                 Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Times(d[
                     "l"].children + [Transpose([lhs])] + d["r"].children)
                           )(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             l,
                             simplify(to_canonical(Transpose([rhs]))), r
                         ]), Replacement(repl_f)))
                 # [TODO] Minus is a b*tch. Should go for -1 and remove the operator internally?
                 repl_f = (lambda lhs: lambda d: Times([
                     simplify(to_canonical(Minus([Times([lhs])])))
                 ] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             simplify(
                                 to_canonical(
                                     Minus([Times(rhs.get_children())])))
                         ] + [r]), Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Times([
                     simplify(
                         to_canonical(Minus([Transpose([Times([lhs])])])))
                 ] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             simplify(
                                 to_canonical(
                                     Minus([
                                         Transpose(
                                             [Times(rhs.get_children())])
                                     ])))
                         ] + [r]), Replacement(repl_f)))
             else:
                 pr.append(RewriteRule(rhs, Replacement(lhs_sym)))
                 new_rhs = simplify(to_canonical(Transpose([rhs])))
                 if not isOperand(new_rhs):
                     pr.append(
                         RewriteRule(
                             simplify(to_canonical(Transpose([rhs]))),
                             Replacement(Transpose([lhs_sym]))))
         else:
             pr.append(RewriteRule(rhs, Replacement(lhs)))
         rules.append(pr)
     return rules
Пример #13
0
    def find_updates(self, before, after):
        # If a part is (partially) computed in the before and
        #   does not appear in the after or
        # going from before to after requires undoing some computation
        # it is potentially unstable, and more expensive: ignore
        try:
            before_finputs = self.express_in_terms_of_input(before)
            after_finputs = self.express_in_terms_of_input(after)
        except:
            # [TODO] In LU's variant 5, parts of A appear as lhs's
            return None
        #
        dict_bef = dict([(str(u.get_children()[0]), u)
                         for u in before_finputs])
        dict_aft = dict([(str(u.get_children()[0]), u) for u in after_finputs])
        same = []
        ignore = False
        for k, v in dict_bef.items():
            if k in dict_aft and matchq(v, dict_aft[k]):
                same.extend(v.children[0].children)
            if k not in dict_aft:
                ignore = True
                reason = "%s not in %s" % (k, dict_aft.keys())
                break
            else:
                rules = self.expr_to_rule_rhs_lhs([v])
                rules = list(itertools.chain(*rules))
                expr_copy = copy.deepcopy(dict_aft[k])
                t = replace(expr_copy, rules)
                #if v == replace( expr_copy, rules ):
                if dict_aft[k] == t:
                    ignore = True
                    reason = "%s would require undoing job" % k
                    break
        if ignore:
            print("[INFO] Skipping invariant: %s" % reason)
            return None
        #
        # Wrap outputs for before and after
        WrapBefOut = WrapOutBef
        lhss = []
        for u in before:
            lhss.extend(u.children[0])
            u.children[0] = NList([WrapBefOut(l) for l in u.children[0]])
        for u in before:
            u.children[1] = replace(
                u.children[1],
                [RewriteRule(l, Replacement(WrapBefOut(l))) for l in lhss])
        #
        lhss = []
        for u in after:
            lhss.extend(u.children[0])
            u.children[0] = NList([WrapOutAft(l) for l in u.children[0]])
        wrap_rules_after = \
                [
                    RewriteRule(l, Replacement(WrapBefOut(l))) if l in same else
                    RewriteRule(l, Replacement(WrapOutAft(l))) for l in lhss
                ]
        for u in after:
            u.children[1] = replace(u.children[1], wrap_rules_after)
        # replace before in before
        wrap_rules_before = []
        for u in before:
            lhs, rhs = u.get_children()
            #if len(lhs.children) > 1:
            #wrap_rules_before.append([])
            #continue
            rules = self.expr_to_rule_rhs_lhs([u])
            wrap_rules_before.append(list(itertools.chain(*rules)))
        #
        new_rules = []
        for i, rules in enumerate(wrap_rules_before):
            new_rules.append([])
            for rule in rules:
                new_r = copy.deepcopy(rule)
                new_r.pattern = replace_all(
                    new_r.pattern,
                    list(
                        itertools.chain.from_iterable(wrap_rules_before[:i] +
                                                      wrap_rules_before[i +
                                                                        1:])))
                if new_r.pattern != rule.pattern:
                    new_rules[-1].append(new_r)
        for r1, r2 in zip(new_rules, wrap_rules_before):
            r2.extend(r1)
        #
        wrap_rules_before = list(itertools.chain(*wrap_rules_before))
        done = False
        while not done:
            after_top = [copy.deepcopy(u) for u in after]
            for i, u in enumerate(after):
                _, rhs = u.get_children()
                u.children[1] = simplify(
                    to_canonical(
                        replace_all(copy.deepcopy(rhs), wrap_rules_before)))
            done = True
            for top, bot in zip(after_top, after):
                if top != bot:
                    done = False
                    break
        # replace after in after
        done = False
        while not done:
            # replace after in after
            wrap_rules_after = []
            for u in after:
                lhs, rhs = u.get_children()
                #if len(lhs.children) > 1:
                #wrap_rules_after.append([])
                #continue
                rules = self.expr_to_rule_rhs_lhs([u])
                wrap_rules_after.append(list(itertools.chain(*rules)))
            #
            after_top = [copy.deepcopy(u) for u in after]
            for i, u in enumerate(after):
                _, rhs = u.get_children()
                rules = list(
                    itertools.chain.from_iterable(wrap_rules_after[:i] +
                                                  wrap_rules_after[i + 1:]))
                u.children[1] = simplify(
                    to_canonical(replace_all(copy.deepcopy(rhs), rules)))
            done = True
            for top, bot in zip(after_top, after):
                if top != bot:
                    done = False
                    break
        # [TODO] Multiple lhss, won't work
        updates = []
        for u in after:
            lhs, rhs = u.get_children()
            if len(lhs.children) == 1:
                lhs = lhs.children[0]  # NList[op] -> op
                if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \
                        matchq(lhs.children[0], rhs.children[0]):
                    continue
            elif not isinstance(rhs,
                                NList):  # multiple outputs/predicate in rhs,
                # but not complete (otherwise it would be NList)
                pass
            else:
                to_skip = True
                for l, r in zip(lhs.children, rhs.children):
                    if not( isinstance(r, WrapBefOut) and isinstance(l, WrapOutAft) and \
                            matchq(l.children[0], r.children[0]) ):
                        to_skip = False
                        break
                if to_skip:
                    continue
            updates.append(u)
        #
        tiled_updates = []
        for u in updates:
            print("*   ", u)
            tilings = list(tile_expr(u))
            if len(tilings) > 1:
                print("[WARNING] Multiple (%d) tilings for expression %s" %
                      (len(tilings), u))
                print("          Discarding all but one")
            tiled_updates.extend(tilings[0])
        tiled_updates = sort(tiled_updates)
        print("* Tiled update")
        for t in tiled_updates:
            print("*   ", t)

        # Drop WrapOutBef's
        # Drop WrapOutAft's
        s = PatternDot("s")
        updates = []
        for u in tiled_updates:
            u = replace_all(
                u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))])
            u = replace_all(
                u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))])
            updates.append(u)

        return updates
Пример #14
0
    def find_updates_v2(self, before, after):
        # If a part is (partially) computed in the before and
        #   does not appear in the after or
        # going from before to after requires undoing some computation
        # it is potentially unstable, and more expensive: ignore
        dict_bef = dict([(str(u.get_children()[0]), u) for u in before])
        dict_aft = dict([(str(u.get_children()[0]), u) for u in after])
        ignore = False
        quadrant = None
        for k, v in dict_bef.items():
            if k not in dict_aft:
                ignore = True
                break
            else:
                rules = self.expr_to_rule_rhs_lhs([v])
                rules = list(itertools.chain(*rules))
                expr_copy = copy.deepcopy(dict_aft[k])
                t = replace(expr_copy, rules)
                #if v == replace( expr_copy, rules ):
                if dict_aft[k] == t:
                    ignore = True
                    break
        if ignore:
            print("[INFO] Skipping invariant: %s" % reason)
            return None
        #
        # Wrap outputs for before and after
        WrapBefOut = WrapOutBef
        for u in before:
            u.children[0] = NList([WrapBefOut(l) for l in u.children[0]])
        #
        wrap_rules_after = []
        for u in after:
            u.children[0] = NList([WrapOutAft(l) for l in u.children[0]])
        # replace before in after
        wrap_rules_before = []
        for u in before:
            print(u)
            lhs, rhs = u.get_children()
            if len(lhs.children) > 1:
                continue
            rules = self.expr_to_rule_rhs_lhs([u])
            wrap_rules_before.append(list(itertools.chain(*rules)))
        #
        for i, rule in enumerate(reversed(wrap_rules_before)):
            idx = len(wrap_rules_before) - i - 1
            for j in range(idx - 1, -1, -1):
                for _rule in rule:
                    _rule.pattern = replace_all(_rule.pattern,
                                                wrap_rules_before[j])
        wrap_rules_before = list(itertools.chain(*wrap_rules_before))
        #
        for u in after:
            _, rhs = u.get_children()
            u.children[1] = simplify(
                to_canonical(replace_all(copy.deepcopy(rhs),
                                         wrap_rules_before)))
        # replace after in after
        done = False
        while not done:
            # replace after in after
            wrap_rules_after = []
            for u in after:
                lhs, rhs = u.get_children()
                if len(lhs.children) > 1:
                    wrap_rules_after.append([])
                    continue
                rules = self.expr_to_rule_rhs_lhs([u])
                wrap_rules_after.append(list(itertools.chain(*rules)))
            #
            after_top = [copy.deepcopy(u) for u in after]
            for i, u in enumerate(after):
                _, rhs = u.get_children()
                rules = list(
                    itertools.chain.from_iterable(wrap_rules_after[:i] +
                                                  wrap_rules_after[i + 1:]))
                u.children[1] = simplify(
                    to_canonical(replace_all(copy.deepcopy(rhs), rules)))
            done = True
            for top, bot in zip(after_top, after):
                if top != bot:
                    done = False
                    break
        # [TODO] Multiple lhss, won't work
        updates = []
        for u in after:
            lhs, rhs = u.get_children()
            lhs = lhs.children[0]  # NList[op] -> op
            if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \
                    matchq(lhs.children[0], rhs.children[0]):
                continue
            updates.append(u)
        #
        tiled_updates = []
        for u in updates:
            print("*   ", u)
            tilings = list(tile_expr(u))
            if len(tilings) > 1:
                print("[WARNING] Multiple (%d) tilings for expression %s" %
                      (len(tilings), u))
                print("          Discarding all but one")
            tiled_updates.extend(tilings[0])
        tiled_updates = sort(tiled_updates)
        print("* Tiled update")
        for t in tiled_updates:
            print("*   ", t)

        # Drop WrapOutBef's
        # Drop WrapOutAft's
        s = PatternDot("s")
        updates = []
        for u in tiled_updates:
            u = replace_all(
                u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))])
            u = replace_all(
                u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))])
            updates.append(u)

        return updates
Пример #15
0
    def learn_pattern(self):
        inops = [op for op in self.operands if op.isInput()]
        outops = [op for op in self.operands if op.isOutput()]
        #
        single_assignment = len( self.equation.children ) == 1 and \
                            isinstance(self.equation.children[0].children[0], Symbol) # eq.lhs.single_entry_in_NL
        #
        op_to_pattern = [
            RewriteRule(op, Replacement("PatternDot(%s.name)" % op.name))
            for op in self.operands
        ]
        pattern = NList([
            replace_all(copy.deepcopy(eq), op_to_pattern)
            for eq in self.equation
        ])
        if single_assignment:
            props_str = [
                symbol_props_to_constraints_no_io(op) for op in self.operands
            ]
            constraint = Constraint(" and ".join(
                [prop for prop in props_str if prop]))
        else:
            constraint = Constraint(" and ".join(
                [symbol_props_to_constraints(op) for op in self.operands]))
        # [TODO] Tuple for get_size
        replacement = Replacement(
            "Equal([ NList([%s]), Predicate( \"%s\", [%s], [%s] ) ])" %
            (", ".join([op.name for op in outops]), self.name, ", ".join([
                op.name for op in inops
            ]), ", ".join(["%s.get_size()" % op.get_name() for op in outops])))
        # [TODO] This should be part of the verbose option
        print("* Learnt pattern")
        print("*   ", pattern, end="")
        if constraint.to_eval:
            print("with        ", constraint.to_eval)
        print(" --> ")
        print("*          ", replacement.to_eval)
        print("**********************************")
        # [TODO] Maybe sort known ops by specificity (a la mathematica)
        #known_ops.insert( 0, RewriteRule( (pattern, constraint), replacement ) )

        if single_assignment:
            expr = pattern.children[0]
            expr.children[0] = NList([expr.children[0]])
            known_ops_single.insert(
                0, RewriteRule((expr, constraint), replacement))
            # With minus
            replacement = Replacement(
                "Equal([ NList([%s]), Minus([ Predicate( \"%s\", [%s], [%s] ) ]) ])"
                % (", ".join([op.name for op in outops]), self.name, ", ".join(
                    [op.name for op in inops]), ", ".join(
                        ["%s.get_size()" % op.get_name() for op in outops])))
            expr = copy.deepcopy(expr)
            expr.children[1] = Minus([expr.children[1]])
            expr.children[1] = normalize_minus(copy.deepcopy(expr.children[1]))
            known_ops_single.insert(
                0, RewriteRule((expr, constraint), replacement))
            #with open(os.path.join("OUTPUT", self.name+"_patterns"), "wb") as patt_f:
            #pickle.dump( known_ops_single[1], patt_f )
            #pickle.dump( known_ops_single[0], patt_f )
        else:
            known_ops.insert(0, RewriteRule((pattern, constraint),
                                            replacement))
            with open(os.path.join("OUTPUT", self.name + "_patterns"),
                      "wb") as patt_f:
                pickle.dump(known_ops[0], patt_f)

        pattern = Equal([
            NList([PatternDot(op.get_name()) for op in outops]),
            Predicate(self.name, [PatternDot(op.get_name()) for op in inops],
                      [op.get_size() for op in outops])
        ])
        replacement = Replacement(equation2replacement(self.equation))
        op_to_implicit.append(RewriteRule(pattern, replacement))
Пример #16
0
    def generate_loop_based_algorithms(self):
        print("* Generating Loop-based algorithms...")
        self.algs = []
        variant = 1
        for pme, linvs in zip(self.pmes, self.linvs):
            algs = []
            for linv in linvs:
                print("* ")
                print("* Loop invariant", variant)
                for expr in linv.expressions:
                    print("*     ", expr)
                print("* ")
                trav, init_state, _ = linv.traversals[
                    0]  # this would be another for loop
                init = self.algorithm_initialization(init_state)
                print("* Init")
                #print( init_state )
                print("*   ", init)
                s = PatternDot("s")
                init = [
                    replace_all(i, [
                        RewriteRule(WrapOutBef(s),
                                    Replacement(lambda d: d["s"]))
                    ]) for i in init
                ]
                print("* Before")
                repart, before = self.generate_predicate_before(
                    pme, trav, linv.expressions, linv)
                print("* After")
                reversed_trav = dict([(k, (r * -1, c * -1))
                                      for k, (r, c) in trav.items()])
                cont_with, after = self.generate_predicate_before(
                    pme, reversed_trav, linv.expressions, linv)
                # find updates
                print("* Updates")
                updates = self.find_updates(before, after)
                if updates is None:
                    #variant += 1
                    continue
                # Tile updates
                for u in updates:
                    infer_properties(u)
                final_updates = []
                # [DIEGO] Fixing some output pieces being labeled as input
                outputs = []
                for u in updates:
                    lhs, rhs = u.children
                    for l in lhs:
                        if not l.isTemporary():
                            outputs.append(l)
                            l.set_property(OUTPUT)
                #
                for u in updates:
                    #[DIEGO] u.children[0].children[0].set_property( OUTPUT )
                    for node in u.children[1].iterate_preorder():
                        #if isinstance(node, Symbol):
                        if isinstance(node, Symbol) and node not in outputs:
                            node.set_property(INPUT)
                    #
                    #copy_u = replace( copy.deepcopy(u), known_ops_single )
                    copy_u = copy.deepcopy(u)
                    copy_u = tile_expr(copy.deepcopy(copy_u))[0]  # One tiling
                    final_updates.extend(copy_u)
                if len(updates) == 0:
                    print("No updates!! Should only happen in copy")
                    continue

                algs.append(
                    Algorithm(linv, variant, init, repart, cont_with, before,
                              after, final_updates))
                algs[-1].prepare_for_code_generation()
                variant += 1
            self.algs.append(algs)
Пример #17
0
LHS = PatternDot("LHS")
RHS = PatternDot("RHS")
subexpr = PatternDot("subexpr")
PD1 = PatternDot("PD1")
PD2 = PatternDot("PD2")
PS1 = PatternStar("PS1")
PS2 = PatternStar("PS2")
PS3 = PatternStar("PS3")
PSLeft = PatternStar("PSLeft")
PSRight = PatternStar("PSRight")

simplify_rules = [
    # Minus
    # --expr -> expr
    RewriteRule(Minus([Minus([subexpr])]),
                Replacement(lambda d: d["subexpr"])),
    # -0 -> 0
    RewriteRule((Minus([subexpr]), Constraint(lambda d: isZero(d["subexpr"]))),
                Replacement(lambda d: d["subexpr"])),
    # Plus
    # Plus(a) -> a
    RewriteRule(Plus([subexpr]), Replacement(lambda d: d["subexpr"])),
    # Plus( a___, 0, b___ ) -> Plus( a, b )
    RewriteRule((Plus([PS1, subexpr, PS2
                       ]), Constraint(lambda d: isZero(d["subexpr"]))),
                Replacement(lambda d: Plus([d["PS1"], d["PS2"]]))),
    # a - a -> 0
    RewriteRule(
        Plus([PS1, subexpr, PS2, Minus([subexpr]), PS3]),
        Replacement(lambda d: Plus(
            [d["PS1"],
Пример #18
0
from Tiling import tile_expr
from utils import sort
from RewritingExtension import *

from BackEnd.MatlabCode import generate_matlab_code, click2matlab
#from BackEnd.LatexCode import generate_latex_code
from BackEnd.LGenCode import generate_lgen_files
#from BackEnd.CCode import generate_c_code

known_ops = [
    # Straight assignment
    RewriteRule(
        (NList([Equal([PatternDot("LHS"), PatternDot("RHS")])]),
         Constraint(
             "isinstance(LHS, Symbol) and isOutput(LHS) and isInput(RHS)")),
        Replacement("Equal([ NList([LHS]), RHS ])"))
]
known_ops_single = []
known_pmes = []
op_to_implicit = []


# Add a verbose option
# For now it would help to illustrate the process
class Operation(object):
    def __init__(self, name, operands, equation, overwrite):
        self.name = name
        self.operands = operands
        self.equation = equation  # list because it may be a coupled equation
        self.overwrite = overwrite
Пример #19
0
            pass
            # [FIXME] Default?


# [TODO] Complete
instruction_set = [
    # Basic Ops - Completeness of the instruction set
    #
    [
        # Inverse
        Instruction(
            ( Inverse([ A ]), Constraint("isOperand(A)") ),
            lambda d, t: \
                    RewriteRule(
                        Inverse([ d["A"] ]),
                        Replacement( "%s" % t )
                    ),
            lambda d: Inverse([ d["A"] ])
        ),
        # Transpose
        Instruction(
            ( Transpose([ A ]), Constraint("isOperand(A)") ),
            lambda d, t: \
                    RewriteRule(
                        Transpose([ d["A"] ]),
                        Replacement( "%s" % t )
                    ),
            lambda d: Transpose([ d["A"] ])
        ),
        # Minus
        #( Minus([ A ]), Constraint("isOperand(A, Symbol)") ),
Пример #20
0
def flatten_blocked_operation_click( expr ):
    PD = PatternDot("PD")
    rule = RewriteRule( WrapOutBef( PD ), \
                        Replacement(lambda d: BlockedExpression( map_thread( WrapOutBef, [d["PD"]], 2 ), d["PD"].size, d["PD"].shape)) )
    expr = replace( copy.deepcopy(expr), [rule] )
    return flatten_blocked_operation( expr ) 
Пример #21
0
class triu(Operator):
    def __init__(self, arg):
        Operator.__init__(self, [arg], [], UNARY)
        self.size = arg.get_size()


# Patterns for inv to trsm
A = PatternDot("A")
B = PatternDot("B")
C = PatternDot("C")
# [TODO] Complete the set of patterns
trsm_patterns = [
    #RewriteRule( (Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), Constraint("A.st_info[0] == ST_LOWER")), \
    RewriteRule( Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), \
            Replacement( lambda d: Equal([ d["C"], mrdiv([ Transpose([tril(d["A"])]), d["B"] ]) ]) ) ),
]


#
# Produces Matlab code for:
# - Loop-based code
#
def generate_matlab_code(operation, matlab_dir):
    # Recursive code
    out_path = os.path.join(matlab_dir, operation.name +
                            ".m")  # [FIX] At some this should be opname_rec.m
    with open(out_path, "w") as out:
        generate_matlab_recursive_code(operation, operation.pmes[-1],
                                       out)  # pmes[-1] should be the 2x2 one
    # Loop based code
Пример #22
0
    def generate_predicate_before(
            self, pme, trav, linv,
            linv_obj):  # [TODO] Cleanup, no need for linv AND linv_obj
        new = [copy.deepcopy(expr) for expr in itertools.chain(*linv)]
        # Repartition
        reparts = dict()
        repart_rules = []
        for op in linv_obj.linv_operands:
            part = linv_obj.linv_operands_basic_part[op.get_name()]
            # [CHECK] _shape or not? Regular one needed for inheritance
            repart = repartition(op, part.shape, trav[op.get_name()])
            #
            #repart_shape = {(1,1):(1,1), (1,2):(1,3), (2,1):(3,1), (2,2):(3,3)}[part.shape]
            #repart = repartition_shape( op, repart_shape )
            #repart = repart_group( repart, repart_shape, trav[op.get_name()] )
            #
            for part_op, repart_op in zip(itertools.chain(*part),
                                          itertools.chain(*repart)):
                repart_rules.append(
                    RewriteRule(part_op, Replacement(repart_op)))
            reparts[op.get_name()] = repart
        # Apply repartitionings
        new = [replace(expr, repart_rules) for expr in new]

        # Explicit functions to BlockedExpression
        #  First flatten args, then replace
        for expr in new:
            lhs, rhs = expr.get_children()
            if isinstance(rhs, Predicate):
                for i, arg in enumerate(rhs.get_children()):
                    #rhs.set_children( i, flatten_blocked_operation(arg) )
                    rhs.set_children(i, flatten_blocked_operation_click(arg))
        new = [replace(expr, known_pmes) for expr in new]

        # Operators applied to BlockedExpression, into BlockedExpressions
        for expr in new:
            if isinstance(expr, BlockedExpression):
                continue
            _, rhs = expr.get_children()
            #print( rhs ) # [TODO] Maybe "Sylv(...)"!!!
            #rhs = flatten_blocked_operation( rhs )
            rhs = flatten_blocked_operation_click(rhs)
            expr.set_children(1, rhs)

        # Flatten left-hand sides of the previous type of expressions
        for expr in new:
            if isinstance(expr, BlockedExpression):
                continue
            lhs, rhs = expr.get_children()
            new_lhs = []
            for out in lhs:
                if isinstance(out, Symbol):  # this is a temporary one
                    out.size = rhs.get_size()
                    #part = partition_shape( out, rhs.shape )
                    part = partition_shape(out, tuple(rhs.shape))
                    new_lhs.append(part)
                else:
                    new_lhs.append(out)
            lhs = BlockedExpression(map_thread(NList, new_lhs, 2), (0, 0),
                                    rhs.shape)
            expr.set_children(0, lhs)
        # Flatten the last type of expressions
        final = []
        for expr in new:
            if isinstance(expr, BlockedExpression):
                final.extend([
                    simplify(to_canonical(eq)) for eq in itertools.chain(*expr)
                ])
            else:
                lhs, rhs = expr.get_children()
                final.extend([
                    simplify(to_canonical(eq))
                    for eq in itertools.chain.from_iterable(
                        map_thread(Equal, [lhs, rhs], 2))
                ])
        final = filter_zero_zero(final)
        # remove expressions of the type " B_10^T = ..." (e.g., in symv)"
        _final = final
        final = []
        for expr in _final:
            lhs, rhs = expr.children
            lhs = lhs.children
            # [FIXME] == 1 only to make sure it does not affect other cases.
            # Just want to let them break and study them in the future
            if len(lhs) == 1 and (not isOperand(lhs[0]) or lhs[0].isZero()):
                continue
            final.append(expr)
        #
        # expand in terms of input parts
        #
        #expand_rules = list(itertools.chain(*self.expr_to_rule_lhs_rhs( final )))
        #for expr in final:
        #expr.children[1] = simplify(to_canonical(replace_all(copy.deepcopy(expr.children[1]), expand_rules)))
        #
        # Print and return
        #
        for expr in final:
            print("*   ", expr)

        return (reparts, final)
Пример #23
0
    def learn_pattern(self):
        inops = [op for op in self.operands if op.isInput()]
        outops = [op for op in self.operands if op.isOutput()]
        # pattern
        predicate_inops = []
        predicate_outops = []
        for op in self.operands:
            rewrite_predicate_ops = []
            basic_part = self.basic_partitionings[op.get_name()]
            for part_op in itertools.chain(*basic_part):
                rewrite_predicate_ops.append(
                    RewriteRule(part_op,
                                Replacement(PatternDot(part_op.get_name()))))
            new = BlockedExpression(
                copy.deepcopy(self.basic_partitionings[op.get_name()]),
                op.get_size(), basic_part.shape)
            if op.isInput():
                predicate_inops.append(replace(new, rewrite_predicate_ops))
            else:
                predicate_outops.append(replace(new, rewrite_predicate_ops))
        pattern = Equal([
            NList(predicate_outops),
            Predicate(self.name, predicate_inops,
                      [op.get_size() for op in outops])
        ])
        # replacement
        # [TODO] Tuple for get_size
        #basic_parts = self.basic_partitionings
        basic_parts = self.partitionings
        lhss = map_thread(NList, [basic_parts[op.get_name()] for op in outops],
                          2)

        # [TODO] This is a fix for lu (maybe coup sylv as well).
        #        Generalize and clean up
        for i, row in enumerate(lhss):
            for j, cell in enumerate(row):
                cell = replace(cell, [
                    RewriteRule(
                        (NList([PatternPlus("PP"),
                                PatternDot("PD")
                                ]), Constraint(lambda d: isZero(d["PD"]))),
                        Replacement(lambda d: NList(d["PP"].get_children()))),
                    RewriteRule(
                        (NList([PatternDot("PD"),
                                PatternPlus("PP")
                                ]), Constraint(lambda d: isZero(d["PD"]))),
                        Replacement(lambda d: NList(d["PP"].get_children())))
                ])
                lhss[i][j] = cell
        #

        # [CHECK] parts = self.partitionings
        #parts = self.basic_partitionings
        parts = self.partitionings
        eqs = map_thread(Equal, [lhss, parts[outops[0].get_name()]], 2)

        output_shape = self.basic_partitionings[outops[0].get_name()].shape
        r, c = output_shape
        for eq in self.solved_subequations:
            lhs, rhs = eq.get_children()
            for row in range(r):
                for col in range(c):
                    this_lhs, this_rhs = eqs[row][col].get_children()
                    if lhs == this_lhs:
                        eqs[row][col].set_children(1, rhs)

        #for row in range(r):
        #for col in range(c):
        #eqs[row][col] = equation2replacement( eqs[row][col].get_children()[1] )

        replacement_str = equation2replacement(
            BlockedExpression(eqs, (0, 0), output_shape))
        #", ".join([
        #"[ " + ", ".join( [ eq for eq in row ] ) + " ]"
        #for row in eqs ]) + " ]" + \
        #", (0,0), (%d, %d) )" % (r, c) + \
        #"]), " + \
        #"BlockedExpression([ " + \
        #", ".join([
        #"[ " + ", ".join( [ eq for eq in row ] ) + " ]"
        #for row in eqs ]) + " ]" + \
        #", (0,0), (%d, %d) )" % (r, c) + \
        #"])" # size does not matter
        print("* Learnt PME pattern")
        print("*     ", RewriteRule(pattern, Replacement(replacement_str)))
        self.known_pmes.append(
            RewriteRule(pattern, Replacement(replacement_str)))
        with open(os.path.join("OUTPUT", self.name + "_pmes"),
                  "ab+") as pmes_f:
            pickle.dump(self.known_pmes[-1], pmes_f)
Пример #24
0
    def fix_temporaries(self):
        temp_exprs = []
        for expr in itertools.chain(*self.expressions):
            lhs, rhs = expr.get_children()
            for out in lhs:
                if not out.isInput() and not out.isOutput(
                ):  # [FIXME] Temporaries are not labeled. Now they are, fix.
                    self.linv_operands.append(out)
                    temp_exprs.append(expr)
        for expr in temp_exprs:
            #print( expr )
            lhs, rhs = expr.get_children()
            lhs = lhs.get_children()[0]
            # Determine to which operands in the temporary bound
            # Also, which quadrant of the full temporary should be used
            (rows_op, r_dim), (cols_op,
                               c_dim) = utils.size_as_func_of_operands(rhs)

            # set in bound_dimensions
            r = rows_op.parent_op.get_name() + "_" + r_dim.lower()
            for s in self.linv_bound_dimensions:
                if r in s:
                    s.append(lhs.get_name() + "_r")
            c = cols_op.parent_op.get_name() + "_" + c_dim.lower()
            for s in self.linv_bound_dimensions:
                if c in s:
                    s.append(lhs.get_name() + "_c")

            # partition temporary
            rows_parent = rows_op.parent_op.get_name()
            cols_parent = cols_op.parent_op.get_name()
            shape_rows = self.pme.part_shape[rows_parent][{
                "R": 0,
                "C": 1
            }[r_dim]]
            shape_cols = self.pme.part_shape[cols_parent][{
                "R": 0,
                "C": 1
            }[c_dim]]

            size_rows = rows_op.parent_op.get_size()[{"R": 0, "C": 1}[r_dim]]
            size_cols = cols_op.parent_op.get_size()[{"R": 0, "C": 1}[c_dim]]
            lhs.size = (size_rows, size_cols)

            part = partition_shape(lhs, (shape_rows, shape_cols))
            self.linv_operands_basic_part[lhs.get_name()] = part
            part_shape = (len(part.children), len(part.children[0]))
            self.linv_operands_part_shape[lhs.get_name()] = part_shape

            part_rows = rows_op.quadrant[{"R": 0, "C": 1}[r_dim]]
            part_cols = cols_op.quadrant[{"R": 0, "C": 1}[c_dim]]
            quadrant = part[part_rows][part_cols]

            # replace the temporary with the proper quadrant in every expression of the linv
            # [TODO] why is loop invariant a list of lists?
            expressions = []
            for expr_l in self.expressions:
                expressions.append([
                    replace(copy.deepcopy(expr),
                            [RewriteRule(lhs, Replacement(quadrant))])
                    for expr in expr_l
                ])
            self.expressions = expressions
Пример #25
0
                            TOS.push_back_temp( )
                            TOS._TOS.unset_operand( tile.children[0].children[0] )
            if matched_in_this_level: ###
                break


PD1 = PatternDot("PD1")
PD2 = PatternDot("PD2")
PS1 = PatternStar("PS1")
PS2 = PatternStar("PS2")

grouping_rules = [
    # A B + A C D -> A (B + C D)
    RewriteRule(
        Plus([ Times([ PD1, PS1 ]), Times([ PD1, PS2 ]) ]),
        Replacement(lambda d: Times([ d["PD1"], Plus([Times([d["PS1"]]), Times([d["PS2"]])]) ]))
    ),
    # A B - A C D -> A (B - C D)
    RewriteRule(
        Plus([ Times([ PD1, PS1 ]), Times([ Minus([PD1]), PD2, PS2 ]) ]),
        Replacement(lambda d: Times([ d["PD1"], Plus([ Times([d["PS1"]]), Times([Minus([d["PD2"]]), Times([d["PS2"]])])]) ]))
    ),
    # B A + C D A -> (B + C D) A
    RewriteRule(
        Plus([ Times([ PS1, PD1 ]), Times([ PS2, PD1 ]) ]),
        Replacement(lambda d: Times([ Plus([ Times([d["PS1"]]), Times([d["PS2"]]) ]), d["PD1"] ]))
    ),
    # B A - C D A -> (B - C D) A
    RewriteRule(
        Plus([ Times([ PS1, PD1 ]), Times([ Minus([PD2]), PS2, PD1 ]) ]),
        Replacement(lambda d: Times([ Plus([ Times([d["PS1"]]), Times([Minus([d["PD2"]]), d["PS2"]])]), d["PD1"] ]))
Пример #26
0
def _checkPredicateOverwrite(statement):
    lhs = statement.lhs()
    rhs = statement.rhs()
    if not isinstance(rhs, Predicate):
        rhs_ops = [
            node for node in rhs.iterate_preorder()
            if isinstance(node, Symbol)
        ]
        tmp_ops = [op for op in rhs_ops if op.isTemporary()]
        if len(tmp_ops
               ) == 1:  # [FIXME] Quick and dirty to play with temporaries
            tmp = tmp_ops[0]
            if lhs.children[0].size == tmp.size:
                if not lhs.children[0].isTemporary():
                    overwrites = False
                    for op in rhs_ops:
                        try:
                            overwrites = lhs.children[0].st_info[
                                1] == op.st_info[1]
                        except AttributeError:
                            pass
                        if overwrites: break
                    if not overwrites:
                        statements = []
                        statements.append(Equal([NList(lhs.children), tmp]))
                        statement.children[1] = replace(
                            copy.deepcopy(rhs),
                            [RewriteRule(tmp, Replacement(lhs.children[0]))])
                        statements.append(statement)
                        return statements
        else:
            # TRSM 2x2 ...
            pass
        return [statement]
    if not pm.DB[rhs.name].overwrite:  # []
        return [statement]

    statements = []
    # [FIXME] Assumes one single operands get overwritten. Will break in the future
    already_copied = []
    for inp, out in pm.DB[rhs.name].overwrite:
        if inp in already_copied:
            continue
        already_copied.append(inp)
        #
        if rhs.children[inp] != lhs.children[
                out]:  # [FIXME] All should have st_into
            try:
                overwrites = lhs.children[out].st_info[1] == rhs.children[
                    inp].st_info[1]
            except AttributeError:
                overwrites = False
            if overwrites:
                statements.append(statement)  #[FIXME] Gosh...
                continue
            inpop = rhs.children[inp]
            outop = lhs.children[out]
            if inpop.isTemporary() or inpop.isInput():
                # if multiple outputs overwrite input (e.g., LU)
                if len([o for i, o in pm.DB[rhs.name].overwrite if i == inp
                        ]) > 1:
                    try:
                        outop = TOS._TOS[outop.st_info[1].name][
                            0]  # LU  (ABR = T3; [LBR,UBR] = LU(ABR))
                    except:
                        pass
                    outop.st_info = (None, outop)
                #
                statements.append(Equal([NList([outop]), inpop]))
                rhs.children[inp] = outop
                statements.append(statement)
            else:
                lhs.children[out] = rhs.children[inp]
                statements.append(statement)
                statements.append(Equal([NList([inpop]), outop]))
        else:
            statements.append(statement)
    return statements
Пример #27
0
subexpr = PatternDot("subexpr")
PD1 = PatternDot("PD1")
PD2 = PatternDot("PD2")
PS1 = PatternStar("PS1")
PS2 = PatternStar("PS2")
PS3 = PatternStar("PS3")
PSLeft = PatternStar("PSLeft")
PSRight = PatternStar("PSRight")

# TODO: where do we place A*inv(A) -> I?

canonical_rules = [
    # a * ( b + c ) * e -> a*b*e + a*c*e
    RewriteRule(
        Times([PSLeft, Plus([PS1]), PSRight]),
        Replacement(lambda d: Plus(
            [Times([d["PSLeft"], term, d["PSRight"]]) for term in d["PS1"]]))),
    # -( a + b) -> (-a)+(-b)
    RewriteRule(
        Minus([Plus([PS1])]),
        Replacement(lambda d: Plus([Minus([term]) for term in d["PS1"]]))),
    # -( a * b) -> (-a) * b
    RewriteRule(Minus([Times([PD1, PS1])]),
                Replacement(lambda d: Times([Minus([d["PD1"]]), d["PS1"]]))),
    # a * b * (-c) -> (-a) * b * c
    RewriteRule(
        Times([PD1, PS1, Minus([PD2]), PS2]),
        Replacement(lambda d: Times(
            [Minus([d["PD1"]]), d["PS1"], d["PD2"], d["PS2"]]))),
    # Transpose( A + B + C ) -> A^T + B^T + C^T
    RewriteRule(
        Transpose([Plus([PS1])]),