Esempio n. 1
0
 def before_to_rule(self, before):
     lhs, rhs = before.get_children()
     if isinstance(rhs, Plus):
         rules = [
             RewriteRule(
                 Plus([PatternStar("rest")] + rhs.get_children()),
                 Replacement(lambda d: Plus(d["rest"].get_children() + lhs.
                                            get_children()))),
             #Replacement("Plus(rest+lhs.get_children())") )
             RewriteRule(
                 Plus([
                     Times([PatternStar("l")] + [ch] + [PatternStar("r")])
                     for ch in rhs.get_children()
                 ]),
                 Replacement(lambda d: Times(d["l"].get_children(
                 ) + lhs.get_children() + d["r"].get_children()))),
             RewriteRule(
                 to_canonical(
                     Plus([
                         Times([PatternStar("l")] + [ch] +
                               [PatternStar("r")])
                         for ch in rhs.get_children()
                     ])),
                 Replacement(lambda d: Times(d["l"].get_children(
                 ) + lhs.get_children() + d["r"].get_children())))
         ]
     elif isinstance(rhs, Times):
         rules = [
             RewriteRule(
                 Times([PatternStar("l")] + rhs.get_children() +
                       [PatternStar("r")]),
                 Replacement(lambda d: Times(d["l"].get_children(
                 ) + lhs.get_children() + d["r"].get_children()))),
             # [TODO] For chol loop invariants 2 and 3, should fix elsewhere (maybe -1 instead of minus operator)
             RewriteRule(
                 Times([PatternStar("l")] +
                       [to_canonical(Minus([Times(rhs.get_children())]))] +
                       [PatternStar("r")]),
                 Replacement(lambda d: Times(d["l"].get_children() + [
                     to_canonical(Minus([Times(lhs.get_children())]))
                 ] + d["r"].get_children())))
         ]
     else:  # [FIX] what if multiple outputs?
         #rules = [ RewriteRule( rhs, Replacement( lhs ) ) ]
         rules = [RewriteRule(rhs, Replacement(lhs.children[0]))]
     return rules
Esempio n. 2
0
 def term(self, ast):
     if isinstance(ast, BaseExpression):
         return ast
     elif isinstance(ast, AST):
         if ast.ops[0] == "-":
             return Minus(ast.args)
         elif ast.ops[0] == "*":
             return Times(ast.args)
     else:
         print(ast.__class__)
Esempio n. 3
0
def inherit_spd(operand, part, shape):
    m, n = shape
    if m == n:
        for i in range(m):
            part[i][i].set_property(properties.SPD)
        if shape == (2, 2):  # For 3,3 repart not really needed
            # Schur complements of A_TL and A_BR are also SPD
            TL, TR, BL, BR = part.flatten_children()
            if operand.st_info[0] == storage.ST_LOWER:
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        TL,
                        Times([Minus([Transpose([BL])]),
                               Inverse([BR]), BL])
                    ]))
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        BR,
                        Times([Minus([BL]),
                               Inverse([TL]),
                               Transpose([BL])])
                    ]))
            else:
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        TL,
                        Times([Minus([TR]),
                               Inverse([BR]),
                               Transpose([TR])])
                    ]))
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        BR,
                        Times([Minus([Transpose([TR])]),
                               Inverse([TL]), TR])
                    ]))
    else:
        raise WrongPartitioningShape
Esempio n. 4
0
def normalize_minus( expr ):
    ld = PatternDot("ld")
    md = PatternDot("md")
    l = PatternStar("l")
    r = PatternStar("r")
    return replace_all( expr, [RewriteRule( Times([ ld, l, Minus([md]), r ]),
                                            Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"], d["md"], d["r"]]) )),
                               RewriteRule( Times([ Minus([Minus([ld])]), l ]),
                                            Replacement( lambda d: Times([ d["ld"], d["l"]]) )),
                               RewriteRule( Minus([ Times([ ld, l ]) ]),
                                            Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"]]) )) ] )
Esempio n. 5
0
 def expr_to_rule_rhs_lhs(self, predicates):
     rules = []
     t = PatternStar("t")
     l = PatternStar("l")
     ld = PatternDot("ld")
     r = PatternStar("r")
     for p in predicates:
         pr = []
         lhs, rhs = p.children
         if len(lhs.children) == 1:
             #lhs_sym = WrapOutBef( lhs.children[0] )
             lhs_sym = lhs.children[0]
             if isinstance(rhs, Plus):
                 # t___ + rhs -> t + lhs
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children +
                                                      [lhs]))(lhs_sym)
                 pr.append(
                     RewriteRule(Plus([t] + rhs.children),
                                 Replacement(repl_f)))
                 # t___ + l___ rhs_i r___ + ... -> t + l lhs r
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times(d["l"].children + [lhs] + d["r"].children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             Times([l] + [ch] + [r]) for ch in rhs.children
                         ]), Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([simplify(to_canonical(Minus([lhs])))] + d["r"].
                           children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             Times([simplify(to_canonical(Minus([ch])))] +
                                   [r]) for ch in rhs.children
                         ]), Replacement(repl_f)))
                 # A - B C in  L B C R + -L A R  (minus pushed all the way to the left, and whole thing negated)
                 repl_f = (lambda lhs: lambda d: normalize_minus(
                     Plus(d["t"].children + [
                         Times([d["ld"]] + d["l"].children + [Minus([lhs])]
                               + d["r"].children)
                     ])))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             normalize_minus(Times([ld, l,
                                                    Minus([ch]), r]))
                             for ch in rhs.children
                         ]), Replacement(repl_f)))
                 # A - B C in  -L B C R + L A R  (minus pushed all the way to the left)
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([d["ld"]] + d["l"].children + [lhs] + d["r"].
                           children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             normalize_minus(Times([ld, l, ch, r]))
                             for ch in rhs.children
                         ]), Replacement(repl_f)))
                 #repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [Times([simplify(to_canonical(Minus(lhs.children)))] + d["r"].children)]))(lhs_sym)
                 #pr.append( RewriteRule( Plus([t] + [
                 #Times([ Minus([ld]), l, ch, r]) if not isinstance(ch, Minus) \
                 #else Times([ l, ch.children[0], r]) \
                 #for ch in rhs.children ]),
                 #Replacement(repl_f) ) )
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([
                         simplify(to_canonical(Minus([Transpose([lhs])])))
                     ] + d["r"].children)
                 ]))(lhs_sym)
                 pr.append( RewriteRule( Plus([t] + [
                             Times([ Minus([ld]), l, simplify(to_canonical(Transpose([ch]))), r]) if not isinstance(ch, Minus) \
                                     else Times([ l, simplify(to_canonical(Transpose([ch]))), r]) \
                                     for ch in rhs.children ]),
                                 Replacement(repl_f) ) )
             elif isinstance(rhs, Times):
                 repl_f = (lambda lhs: lambda d: Times(d[
                     "l"].children + [lhs] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(Times([l] + rhs.children + [r]),
                                 Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Times(d[
                     "l"].children + [Transpose([lhs])] + d["r"].children)
                           )(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             l,
                             simplify(to_canonical(Transpose([rhs]))), r
                         ]), Replacement(repl_f)))
                 # [TODO] Minus is a b*tch. Should go for -1 and remove the operator internally?
                 repl_f = (lambda lhs: lambda d: Times([
                     simplify(to_canonical(Minus([Times([lhs])])))
                 ] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             simplify(
                                 to_canonical(
                                     Minus([Times(rhs.get_children())])))
                         ] + [r]), Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Times([
                     simplify(
                         to_canonical(Minus([Transpose([Times([lhs])])])))
                 ] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             simplify(
                                 to_canonical(
                                     Minus([
                                         Transpose(
                                             [Times(rhs.get_children())])
                                     ])))
                         ] + [r]), Replacement(repl_f)))
             else:
                 pr.append(RewriteRule(rhs, Replacement(lhs_sym)))
                 new_rhs = simplify(to_canonical(Transpose([rhs])))
                 if not isOperand(new_rhs):
                     pr.append(
                         RewriteRule(
                             simplify(to_canonical(Transpose([rhs]))),
                             Replacement(Transpose([lhs_sym]))))
         else:
             pr.append(RewriteRule(rhs, Replacement(lhs)))
         rules.append(pr)
     return rules
Esempio n. 6
0
 # Plus
 # Plus(a) -> a
 RewriteRule(Plus([subexpr]), Replacement(lambda d: d["subexpr"])),
 # Plus( a___, 0, b___ ) -> Plus( a, b )
 RewriteRule((Plus([PS1, subexpr, PS2
                    ]), Constraint(lambda d: isZero(d["subexpr"]))),
             Replacement(lambda d: Plus([d["PS1"], d["PS2"]]))),
 # a - a -> 0
 RewriteRule(
     Plus([PS1, subexpr, PS2, Minus([subexpr]), PS3]),
     Replacement(lambda d: Plus(
         [d["PS1"],
          Zero(d["subexpr"].get_size()), d["PS2"], d["PS3"]]))),
 # Times
 # Times(a) -> a
 RewriteRule(Times([subexpr]), Replacement(lambda d: d["subexpr"])),
 # Times( a___, 0, b___ ) -> 0
 RewriteRule((Times([PS1, subexpr, PS2
                     ]), Constraint(lambda d: isZero(d["subexpr"]))),
             Replacement(lambda d: Zero(
                 Times([d["PS1"], d["subexpr"], d["PS2"]]).get_size()))),
 # Times( a___, A, B, b___ ) /; A == Inv(B) -> Times( a, I, b )
 RewriteRule(
     (
         Times([PS1, PD1, PD2, PS2]),
         Constraint(lambda d: to_canonical(Inverse([d["PD1"]])) == d["PD2"])
         #Constraint(lambda d: simplify(to_canonical(Inverse([d["PD1"]]))) == d["PD2"])
     ),
     Replacement(lambda d: Times(
         [d["PS1"], Identity(d["PD1"].get_size()), d["PS2"]]))),
 # cannot simplify, need this one as well
Esempio n. 7
0
                            # Aestetic. Simply to avoid missing T? values.
                            TOS.push_back_temp( )
                            TOS._TOS.unset_operand( tile.children[0].children[0] )
            if matched_in_this_level: ###
                break


PD1 = PatternDot("PD1")
PD2 = PatternDot("PD2")
PS1 = PatternStar("PS1")
PS2 = PatternStar("PS2")

grouping_rules = [
    # A B + A C D -> A (B + C D)
    RewriteRule(
        Plus([ Times([ PD1, PS1 ]), Times([ PD1, PS2 ]) ]),
        Replacement(lambda d: Times([ d["PD1"], Plus([Times([d["PS1"]]), Times([d["PS2"]])]) ]))
    ),
    # A B - A C D -> A (B - C D)
    RewriteRule(
        Plus([ Times([ PD1, PS1 ]), Times([ Minus([PD1]), PD2, PS2 ]) ]),
        Replacement(lambda d: Times([ d["PD1"], Plus([ Times([d["PS1"]]), Times([Minus([d["PD2"]]), Times([d["PS2"]])])]) ]))
    ),
    # B A + C D A -> (B + C D) A
    RewriteRule(
        Plus([ Times([ PS1, PD1 ]), Times([ PS2, PD1 ]) ]),
        Replacement(lambda d: Times([ Plus([ Times([d["PS1"]]), Times([d["PS2"]]) ]), d["PD1"] ]))
    ),
    # B A - C D A -> (B - C D) A
    RewriteRule(
        Plus([ Times([ PS1, PD1 ]), Times([ Minus([PD2]), PS2, PD1 ]) ]),
Esempio n. 8
0

class triu(Operator):
    def __init__(self, arg):
        Operator.__init__(self, [arg], [], UNARY)
        self.size = arg.get_size()


# Patterns for inv to trsm
A = PatternDot("A")
B = PatternDot("B")
C = PatternDot("C")
# [TODO] Complete the set of patterns
trsm_patterns = [
    #RewriteRule( (Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), Constraint("A.st_info[0] == ST_LOWER")), \
    RewriteRule( Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), \
            Replacement( lambda d: Equal([ d["C"], mrdiv([ Transpose([tril(d["A"])]), d["B"] ]) ]) ) ),
]


#
# Produces Matlab code for:
# - Loop-based code
#
def generate_matlab_code(operation, matlab_dir):
    # Recursive code
    out_path = os.path.join(matlab_dir, operation.name +
                            ".m")  # [FIX] At some this should be opname_rec.m
    with open(out_path, "w") as out:
        generate_matlab_recursive_code(operation, operation.pmes[-1],
                                       out)  # pmes[-1] should be the 2x2 one
Esempio n. 9
0
                 ),
         lambda d: Transpose([ d["A"] ])
     ),
     # Minus
     #( Minus([ A ]), Constraint("isOperand(A, Symbol)") ),
     Instruction(
         ( Minus([ A ]), Constraint("isOperand(A)") ),
         lambda d, t: \
                 RewriteRule(
                     Minus([ d["A"] ]),
                     Replacement( "%s" % t )
                 ),
         lambda d: Minus([ d["A"] ])
     ),
     # Times
     ( Times([ left, A, B, right ]), Constraint("isOperand(A) and isOperand(B)") ),
     # Plus
     ( Plus([ left, A, B, right ]), Constraint("isOperand(A) and isOperand(B)") )
 ],
 #
 # Optimizations
 #
 # A B
 [
     Instruction(
         ( Times([ left, A, B, right ]), Constraint("isOperand(A) and isOperand(B)") ),
         lambda d, t: \
                 RewriteRule(
                     Times([ left, d["A"], d["B"], right ]),
                     Replacement( "Times([ left, %s, right ])" % t )
                 ),
Esempio n. 10
0
pm.DB["rdiv_utn"].overwrite = []
pm.DB["rdiv_utn_ow"] = pm.PredicateMetadata("rdiv_utn_ow", tuple())
pm.DB["rdiv_utn_ow"].overwrite = [(1, 0)]
pm.DB["rdiv_utu"] = pm.PredicateMetadata("rdiv_utu", tuple())
pm.DB["rdiv_utu"].overwrite = []
pm.DB["rdiv_utu_ow"] = pm.PredicateMetadata("rdiv_utu_ow", tuple())
pm.DB["rdiv_utu_ow"].overwrite = [(1, 0)]

A = PatternDot("A")
B = PatternDot("B")
X = PatternDot("X")

trsm2lgen_rules = [
    # X = i(t(A)) B -> ldiv_lni
    RewriteRule((
        Equal([NList([X]), Times([Inverse([A]), B])]),
        Constraint(
            "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name == X.name"
        )),
                Replacement(lambda d: Equal([
                    NList([d["X"]]),
                    Predicate("ldiv_lni", [d["A"], d["B"]],
                              [d["A"].get_size(), d["B"].get_size()])
                ]))),
    # X = i(t(A)) B -> ldiv_lni_ow
    RewriteRule(
        (Equal([NList([X]), Times([Inverse([A]), B])]),
         Constraint(
             "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name != X.name"
         )),
        Replacement(lambda d: Equal([
Esempio n. 11
0
def flatten_blocked_operation(expr):
    if isinstance(expr, BlockedExpression):
        return expr
    if isinstance(expr, Equal):
        flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()]
        return BlockedExpression(map_thread(Equal, flat_ch, 2),
                                 flat_ch[0].size, flat_ch[0].shape)
    if isinstance(expr, Plus):
        flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()]
        return BlockedExpression(map_thread(Plus, flat_ch, 2), flat_ch[0].size,
                                 flat_ch[0].shape)
    # [TODO] will ignore the inner scalar expressions for now
    # Also the plain scalars: I guess it will suffice to check size of blocked == (1,1)
    if isinstance(expr, Times):
        non_scalar_idx = 0
        while expr.children[non_scalar_idx].isScalar():
            non_scalar_idx += 1
        scalars = [
            s.children[0][0] for s in expr.get_children()[:non_scalar_idx]
        ]
        non_scalars = expr.get_children()[non_scalar_idx:]
        if non_scalars:
            prod = flatten_blocked_operation(copy.deepcopy(non_scalars[0]))
            for ch in non_scalars[1:]:
                prod = multiply_blocked_expressions(
                    prod, flatten_blocked_operation(ch))
            prod.children = [[Times(scalars + [cell]) for cell in row]
                             for row in prod.children]
            return prod
        else:
            return [[Times(scalars)]]
        #non_scalar_prod = functools.reduce(
        #multiply_blocked_expressions,
        #[flatten_blocked_operation( ch ) for ch in expr.get_children()[non_scalar_idx:]]
        #)
        #non_scalar_prod.children = [ [ Times(scalars + [cell]) for cell in row ] for row in non_scalar_prod.children ]
        #return non_scalar_prod
    if isinstance(expr, Minus):
        flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()]
        return BlockedExpression(map_thread(Minus, flat_ch, 2),
                                 flat_ch[0].size, flat_ch[0].shape)
    if isinstance(expr, Transpose):
        flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()]
        new = BlockedExpression(map_thread(Transpose, flat_ch, 2),
                                flat_ch[0].size, flat_ch[0].shape)
        new.transpose()
        return new
    if isinstance(expr, Inverse):  # [TODO] Only triangular
        flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()]
        if len(flat_ch[0].get_children()) == 1:  # 1x1 inverse
            return BlockedExpression(map_thread(Inverse, flat_ch, 2),
                                     flat_ch[0].size, flat_ch[0].shape)
        if len(flat_ch[0].get_children()) == 2:  # 2x2 inverse
            children = flat_ch[0].get_children()
            TL = children[0][0]
            TR = children[0][1]
            BL = children[1][0]
            BR = children[1][1]
            return BlockedExpression(
                [[
                    Inverse([TL]),
                    Times([Minus([Inverse([TL])]), TR,
                           Inverse([BR])])
                ],
                 [
                     Times([Minus([Inverse([BR])]), BL,
                            Inverse([TL])]),
                     Inverse([children[1][1]])
                 ]], flat_ch[0].size, flat_ch[0].shape)
Esempio n. 12
0
def multiply_blocked_expressions(a, b):
    b_trans = list(zip(*b.get_children()))
    product = \
        [ [ Plus([Times([copy.deepcopy(za), copy.deepcopy(zb)]) for za, zb in zip(rowa, rowb)]) for rowb in b_trans ] for rowa in a ]
    return BlockedExpression(product, (a.get_size()[0], b.get_size()[1]),
                             (a.shape[0], b.shape[1]))
Esempio n. 13
0
RHS = PatternDot("RHS")
subexpr = PatternDot("subexpr")
PD1 = PatternDot("PD1")
PD2 = PatternDot("PD2")
PS1 = PatternStar("PS1")
PS2 = PatternStar("PS2")
PS3 = PatternStar("PS3")
PSLeft = PatternStar("PSLeft")
PSRight = PatternStar("PSRight")

# TODO: where do we place A*inv(A) -> I?

canonical_rules = [
    # a * ( b + c ) * e -> a*b*e + a*c*e
    RewriteRule(
        Times([PSLeft, Plus([PS1]), PSRight]),
        Replacement(lambda d: Plus(
            [Times([d["PSLeft"], term, d["PSRight"]]) for term in d["PS1"]]))),
    # -( a + b) -> (-a)+(-b)
    RewriteRule(
        Minus([Plus([PS1])]),
        Replacement(lambda d: Plus([Minus([term]) for term in d["PS1"]]))),
    # -( a * b) -> (-a) * b
    RewriteRule(Minus([Times([PD1, PS1])]),
                Replacement(lambda d: Times([Minus([d["PD1"]]), d["PS1"]]))),
    # a * b * (-c) -> (-a) * b * c
    RewriteRule(
        Times([PD1, PS1, Minus([PD2]), PS2]),
        Replacement(lambda d: Times(
            [Minus([d["PD1"]]), d["PS1"], d["PD2"], d["PS2"]]))),
    # Transpose( A + B + C ) -> A^T + B^T + C^T