def before_to_rule(self, before): lhs, rhs = before.get_children() if isinstance(rhs, Plus): rules = [ RewriteRule( Plus([PatternStar("rest")] + rhs.get_children()), Replacement(lambda d: Plus(d["rest"].get_children() + lhs. get_children()))), #Replacement("Plus(rest+lhs.get_children())") ) RewriteRule( Plus([ Times([PatternStar("l")] + [ch] + [PatternStar("r")]) for ch in rhs.get_children() ]), Replacement(lambda d: Times(d["l"].get_children( ) + lhs.get_children() + d["r"].get_children()))), RewriteRule( to_canonical( Plus([ Times([PatternStar("l")] + [ch] + [PatternStar("r")]) for ch in rhs.get_children() ])), Replacement(lambda d: Times(d["l"].get_children( ) + lhs.get_children() + d["r"].get_children()))) ] elif isinstance(rhs, Times): rules = [ RewriteRule( Times([PatternStar("l")] + rhs.get_children() + [PatternStar("r")]), Replacement(lambda d: Times(d["l"].get_children( ) + lhs.get_children() + d["r"].get_children()))), # [TODO] For chol loop invariants 2 and 3, should fix elsewhere (maybe -1 instead of minus operator) RewriteRule( Times([PatternStar("l")] + [to_canonical(Minus([Times(rhs.get_children())]))] + [PatternStar("r")]), Replacement(lambda d: Times(d["l"].get_children() + [ to_canonical(Minus([Times(lhs.get_children())])) ] + d["r"].get_children()))) ] else: # [FIX] what if multiple outputs? #rules = [ RewriteRule( rhs, Replacement( lhs ) ) ] rules = [RewriteRule(rhs, Replacement(lhs.children[0]))] return rules
def term(self, ast): if isinstance(ast, BaseExpression): return ast elif isinstance(ast, AST): if ast.ops[0] == "-": return Minus(ast.args) elif ast.ops[0] == "*": return Times(ast.args) else: print(ast.__class__)
def inherit_spd(operand, part, shape): m, n = shape if m == n: for i in range(m): part[i][i].set_property(properties.SPD) if shape == (2, 2): # For 3,3 repart not really needed # Schur complements of A_TL and A_BR are also SPD TL, TR, BL, BR = part.flatten_children() if operand.st_info[0] == storage.ST_LOWER: TOE.set_property( properties.SPD, Plus([ TL, Times([Minus([Transpose([BL])]), Inverse([BR]), BL]) ])) TOE.set_property( properties.SPD, Plus([ BR, Times([Minus([BL]), Inverse([TL]), Transpose([BL])]) ])) else: TOE.set_property( properties.SPD, Plus([ TL, Times([Minus([TR]), Inverse([BR]), Transpose([TR])]) ])) TOE.set_property( properties.SPD, Plus([ BR, Times([Minus([Transpose([TR])]), Inverse([TL]), TR]) ])) else: raise WrongPartitioningShape
def normalize_minus( expr ): ld = PatternDot("ld") md = PatternDot("md") l = PatternStar("l") r = PatternStar("r") return replace_all( expr, [RewriteRule( Times([ ld, l, Minus([md]), r ]), Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"], d["md"], d["r"]]) )), RewriteRule( Times([ Minus([Minus([ld])]), l ]), Replacement( lambda d: Times([ d["ld"], d["l"]]) )), RewriteRule( Minus([ Times([ ld, l ]) ]), Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"]]) )) ] )
def expr_to_rule_rhs_lhs(self, predicates): rules = [] t = PatternStar("t") l = PatternStar("l") ld = PatternDot("ld") r = PatternStar("r") for p in predicates: pr = [] lhs, rhs = p.children if len(lhs.children) == 1: #lhs_sym = WrapOutBef( lhs.children[0] ) lhs_sym = lhs.children[0] if isinstance(rhs, Plus): # t___ + rhs -> t + lhs repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [lhs]))(lhs_sym) pr.append( RewriteRule(Plus([t] + rhs.children), Replacement(repl_f))) # t___ + l___ rhs_i r___ + ... -> t + l lhs r repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times(d["l"].children + [lhs] + d["r"].children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([l] + [ch] + [r]) for ch in rhs.children ]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([simplify(to_canonical(Minus([lhs])))] + d["r"]. children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([simplify(to_canonical(Minus([ch])))] + [r]) for ch in rhs.children ]), Replacement(repl_f))) # A - B C in L B C R + -L A R (minus pushed all the way to the left, and whole thing negated) repl_f = (lambda lhs: lambda d: normalize_minus( Plus(d["t"].children + [ Times([d["ld"]] + d["l"].children + [Minus([lhs])] + d["r"].children) ])))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ normalize_minus(Times([ld, l, Minus([ch]), r])) for ch in rhs.children ]), Replacement(repl_f))) # A - B C in -L B C R + L A R (minus pushed all the way to the left) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([d["ld"]] + d["l"].children + [lhs] + d["r"]. children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ normalize_minus(Times([ld, l, ch, r])) for ch in rhs.children ]), Replacement(repl_f))) #repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [Times([simplify(to_canonical(Minus(lhs.children)))] + d["r"].children)]))(lhs_sym) #pr.append( RewriteRule( Plus([t] + [ #Times([ Minus([ld]), l, ch, r]) if not isinstance(ch, Minus) \ #else Times([ l, ch.children[0], r]) \ #for ch in rhs.children ]), #Replacement(repl_f) ) ) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([ simplify(to_canonical(Minus([Transpose([lhs])]))) ] + d["r"].children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([ Minus([ld]), l, simplify(to_canonical(Transpose([ch]))), r]) if not isinstance(ch, Minus) \ else Times([ l, simplify(to_canonical(Transpose([ch]))), r]) \ for ch in rhs.children ]), Replacement(repl_f) ) ) elif isinstance(rhs, Times): repl_f = (lambda lhs: lambda d: Times(d[ "l"].children + [lhs] + d["r"].children))(lhs_sym) pr.append( RewriteRule(Times([l] + rhs.children + [r]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Times(d[ "l"].children + [Transpose([lhs])] + d["r"].children) )(lhs_sym) pr.append( RewriteRule( Times([ l, simplify(to_canonical(Transpose([rhs]))), r ]), Replacement(repl_f))) # [TODO] Minus is a b*tch. Should go for -1 and remove the operator internally? repl_f = (lambda lhs: lambda d: Times([ simplify(to_canonical(Minus([Times([lhs])]))) ] + d["r"].children))(lhs_sym) pr.append( RewriteRule( Times([ simplify( to_canonical( Minus([Times(rhs.get_children())]))) ] + [r]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Times([ simplify( to_canonical(Minus([Transpose([Times([lhs])])]))) ] + d["r"].children))(lhs_sym) pr.append( RewriteRule( Times([ simplify( to_canonical( Minus([ Transpose( [Times(rhs.get_children())]) ]))) ] + [r]), Replacement(repl_f))) else: pr.append(RewriteRule(rhs, Replacement(lhs_sym))) new_rhs = simplify(to_canonical(Transpose([rhs]))) if not isOperand(new_rhs): pr.append( RewriteRule( simplify(to_canonical(Transpose([rhs]))), Replacement(Transpose([lhs_sym])))) else: pr.append(RewriteRule(rhs, Replacement(lhs))) rules.append(pr) return rules
# Plus # Plus(a) -> a RewriteRule(Plus([subexpr]), Replacement(lambda d: d["subexpr"])), # Plus( a___, 0, b___ ) -> Plus( a, b ) RewriteRule((Plus([PS1, subexpr, PS2 ]), Constraint(lambda d: isZero(d["subexpr"]))), Replacement(lambda d: Plus([d["PS1"], d["PS2"]]))), # a - a -> 0 RewriteRule( Plus([PS1, subexpr, PS2, Minus([subexpr]), PS3]), Replacement(lambda d: Plus( [d["PS1"], Zero(d["subexpr"].get_size()), d["PS2"], d["PS3"]]))), # Times # Times(a) -> a RewriteRule(Times([subexpr]), Replacement(lambda d: d["subexpr"])), # Times( a___, 0, b___ ) -> 0 RewriteRule((Times([PS1, subexpr, PS2 ]), Constraint(lambda d: isZero(d["subexpr"]))), Replacement(lambda d: Zero( Times([d["PS1"], d["subexpr"], d["PS2"]]).get_size()))), # Times( a___, A, B, b___ ) /; A == Inv(B) -> Times( a, I, b ) RewriteRule( ( Times([PS1, PD1, PD2, PS2]), Constraint(lambda d: to_canonical(Inverse([d["PD1"]])) == d["PD2"]) #Constraint(lambda d: simplify(to_canonical(Inverse([d["PD1"]]))) == d["PD2"]) ), Replacement(lambda d: Times( [d["PS1"], Identity(d["PD1"].get_size()), d["PS2"]]))), # cannot simplify, need this one as well
# Aestetic. Simply to avoid missing T? values. TOS.push_back_temp( ) TOS._TOS.unset_operand( tile.children[0].children[0] ) if matched_in_this_level: ### break PD1 = PatternDot("PD1") PD2 = PatternDot("PD2") PS1 = PatternStar("PS1") PS2 = PatternStar("PS2") grouping_rules = [ # A B + A C D -> A (B + C D) RewriteRule( Plus([ Times([ PD1, PS1 ]), Times([ PD1, PS2 ]) ]), Replacement(lambda d: Times([ d["PD1"], Plus([Times([d["PS1"]]), Times([d["PS2"]])]) ])) ), # A B - A C D -> A (B - C D) RewriteRule( Plus([ Times([ PD1, PS1 ]), Times([ Minus([PD1]), PD2, PS2 ]) ]), Replacement(lambda d: Times([ d["PD1"], Plus([ Times([d["PS1"]]), Times([Minus([d["PD2"]]), Times([d["PS2"]])])]) ])) ), # B A + C D A -> (B + C D) A RewriteRule( Plus([ Times([ PS1, PD1 ]), Times([ PS2, PD1 ]) ]), Replacement(lambda d: Times([ Plus([ Times([d["PS1"]]), Times([d["PS2"]]) ]), d["PD1"] ])) ), # B A - C D A -> (B - C D) A RewriteRule( Plus([ Times([ PS1, PD1 ]), Times([ Minus([PD2]), PS2, PD1 ]) ]),
class triu(Operator): def __init__(self, arg): Operator.__init__(self, [arg], [], UNARY) self.size = arg.get_size() # Patterns for inv to trsm A = PatternDot("A") B = PatternDot("B") C = PatternDot("C") # [TODO] Complete the set of patterns trsm_patterns = [ #RewriteRule( (Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), Constraint("A.st_info[0] == ST_LOWER")), \ RewriteRule( Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), \ Replacement( lambda d: Equal([ d["C"], mrdiv([ Transpose([tril(d["A"])]), d["B"] ]) ]) ) ), ] # # Produces Matlab code for: # - Loop-based code # def generate_matlab_code(operation, matlab_dir): # Recursive code out_path = os.path.join(matlab_dir, operation.name + ".m") # [FIX] At some this should be opname_rec.m with open(out_path, "w") as out: generate_matlab_recursive_code(operation, operation.pmes[-1], out) # pmes[-1] should be the 2x2 one
), lambda d: Transpose([ d["A"] ]) ), # Minus #( Minus([ A ]), Constraint("isOperand(A, Symbol)") ), Instruction( ( Minus([ A ]), Constraint("isOperand(A)") ), lambda d, t: \ RewriteRule( Minus([ d["A"] ]), Replacement( "%s" % t ) ), lambda d: Minus([ d["A"] ]) ), # Times ( Times([ left, A, B, right ]), Constraint("isOperand(A) and isOperand(B)") ), # Plus ( Plus([ left, A, B, right ]), Constraint("isOperand(A) and isOperand(B)") ) ], # # Optimizations # # A B [ Instruction( ( Times([ left, A, B, right ]), Constraint("isOperand(A) and isOperand(B)") ), lambda d, t: \ RewriteRule( Times([ left, d["A"], d["B"], right ]), Replacement( "Times([ left, %s, right ])" % t ) ),
pm.DB["rdiv_utn"].overwrite = [] pm.DB["rdiv_utn_ow"] = pm.PredicateMetadata("rdiv_utn_ow", tuple()) pm.DB["rdiv_utn_ow"].overwrite = [(1, 0)] pm.DB["rdiv_utu"] = pm.PredicateMetadata("rdiv_utu", tuple()) pm.DB["rdiv_utu"].overwrite = [] pm.DB["rdiv_utu_ow"] = pm.PredicateMetadata("rdiv_utu_ow", tuple()) pm.DB["rdiv_utu_ow"].overwrite = [(1, 0)] A = PatternDot("A") B = PatternDot("B") X = PatternDot("X") trsm2lgen_rules = [ # X = i(t(A)) B -> ldiv_lni RewriteRule(( Equal([NList([X]), Times([Inverse([A]), B])]), Constraint( "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name == X.name" )), Replacement(lambda d: Equal([ NList([d["X"]]), Predicate("ldiv_lni", [d["A"], d["B"]], [d["A"].get_size(), d["B"].get_size()]) ]))), # X = i(t(A)) B -> ldiv_lni_ow RewriteRule( (Equal([NList([X]), Times([Inverse([A]), B])]), Constraint( "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name != X.name" )), Replacement(lambda d: Equal([
def flatten_blocked_operation(expr): if isinstance(expr, BlockedExpression): return expr if isinstance(expr, Equal): flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()] return BlockedExpression(map_thread(Equal, flat_ch, 2), flat_ch[0].size, flat_ch[0].shape) if isinstance(expr, Plus): flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()] return BlockedExpression(map_thread(Plus, flat_ch, 2), flat_ch[0].size, flat_ch[0].shape) # [TODO] will ignore the inner scalar expressions for now # Also the plain scalars: I guess it will suffice to check size of blocked == (1,1) if isinstance(expr, Times): non_scalar_idx = 0 while expr.children[non_scalar_idx].isScalar(): non_scalar_idx += 1 scalars = [ s.children[0][0] for s in expr.get_children()[:non_scalar_idx] ] non_scalars = expr.get_children()[non_scalar_idx:] if non_scalars: prod = flatten_blocked_operation(copy.deepcopy(non_scalars[0])) for ch in non_scalars[1:]: prod = multiply_blocked_expressions( prod, flatten_blocked_operation(ch)) prod.children = [[Times(scalars + [cell]) for cell in row] for row in prod.children] return prod else: return [[Times(scalars)]] #non_scalar_prod = functools.reduce( #multiply_blocked_expressions, #[flatten_blocked_operation( ch ) for ch in expr.get_children()[non_scalar_idx:]] #) #non_scalar_prod.children = [ [ Times(scalars + [cell]) for cell in row ] for row in non_scalar_prod.children ] #return non_scalar_prod if isinstance(expr, Minus): flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()] return BlockedExpression(map_thread(Minus, flat_ch, 2), flat_ch[0].size, flat_ch[0].shape) if isinstance(expr, Transpose): flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()] new = BlockedExpression(map_thread(Transpose, flat_ch, 2), flat_ch[0].size, flat_ch[0].shape) new.transpose() return new if isinstance(expr, Inverse): # [TODO] Only triangular flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()] if len(flat_ch[0].get_children()) == 1: # 1x1 inverse return BlockedExpression(map_thread(Inverse, flat_ch, 2), flat_ch[0].size, flat_ch[0].shape) if len(flat_ch[0].get_children()) == 2: # 2x2 inverse children = flat_ch[0].get_children() TL = children[0][0] TR = children[0][1] BL = children[1][0] BR = children[1][1] return BlockedExpression( [[ Inverse([TL]), Times([Minus([Inverse([TL])]), TR, Inverse([BR])]) ], [ Times([Minus([Inverse([BR])]), BL, Inverse([TL])]), Inverse([children[1][1]]) ]], flat_ch[0].size, flat_ch[0].shape)
def multiply_blocked_expressions(a, b): b_trans = list(zip(*b.get_children())) product = \ [ [ Plus([Times([copy.deepcopy(za), copy.deepcopy(zb)]) for za, zb in zip(rowa, rowb)]) for rowb in b_trans ] for rowa in a ] return BlockedExpression(product, (a.get_size()[0], b.get_size()[1]), (a.shape[0], b.shape[1]))
RHS = PatternDot("RHS") subexpr = PatternDot("subexpr") PD1 = PatternDot("PD1") PD2 = PatternDot("PD2") PS1 = PatternStar("PS1") PS2 = PatternStar("PS2") PS3 = PatternStar("PS3") PSLeft = PatternStar("PSLeft") PSRight = PatternStar("PSRight") # TODO: where do we place A*inv(A) -> I? canonical_rules = [ # a * ( b + c ) * e -> a*b*e + a*c*e RewriteRule( Times([PSLeft, Plus([PS1]), PSRight]), Replacement(lambda d: Plus( [Times([d["PSLeft"], term, d["PSRight"]]) for term in d["PS1"]]))), # -( a + b) -> (-a)+(-b) RewriteRule( Minus([Plus([PS1])]), Replacement(lambda d: Plus([Minus([term]) for term in d["PS1"]]))), # -( a * b) -> (-a) * b RewriteRule(Minus([Times([PD1, PS1])]), Replacement(lambda d: Times([Minus([d["PD1"]]), d["PS1"]]))), # a * b * (-c) -> (-a) * b * c RewriteRule( Times([PD1, PS1, Minus([PD2]), PS2]), Replacement(lambda d: Times( [Minus([d["PD1"]]), d["PS1"], d["PD2"], d["PS2"]]))), # Transpose( A + B + C ) -> A^T + B^T + C^T