def to_canonical_IO(self): self.equation = [ replace_all( eq, canonical_rules + canonicalIO_rules + simplify_rules ) \ for eq in self.equation ] # Minimize number of nodes among: # 1) as is, # 2) applying minus to both sides, and # 3) applying transpose to both sides minimal = [] for eq in self.equation: alternative_forms = [ eq, simplify( to_canonical(Equal([Minus([eq.lhs()]), Minus([eq.rhs()])]))), simplify( to_canonical( Equal([Transpose([eq.lhs()]), Transpose([eq.rhs()])]))) ] _, new = min([(alt.num_nodes(), alt) for alt in alternative_forms]) minimal.append(new) # Set minimal forms self.equation = NList(minimal) # nodes = list( itertools.chain(*[[node for node in eq.iterate_preorder()] for eq in self.equation])) self.operands = [op for op in self.operands if op in nodes]
def inherit_symmetric(operand, part, shape): m, n = shape st = operand.st_info[0] if m == n: for row in range(m): for col in range(row, n): if row == col: # Diagonal part[row][col].set_property(properties.SYMMETRIC) else: # Off-diagonal if st == storage.ST_LOWER: part[row][col] = Transpose([part[col][row]]) else: part[col][row] = Transpose([part[row][col]]) else: raise WrongPartitioningShape
def isSymmetric(node): # isinstance? if node.isSymmetric(): return True # node == trans( node ) alt1 = copy.deepcopy(node) alt1 = to_canonical(alt1)._cleanup() alt2 = copy.deepcopy(node) alt2 = to_canonical(Transpose([alt2]))._cleanup() if alt1 == alt2: return True # more ... if isinstance(node, Plus): return all([isSymmetric(term) for term in node.get_children()]) if isinstance(node, Times): # iif they commute ... return False if isinstance(node, Minus): return isSymmetric(node.get_children()[0]) if isinstance(node, Transpose): return isSymmetric(node.get_children()[0]) if isinstance(node, Inverse): return isSymmetric(node.get_children()[0]) #if isinstance( node, Operand ): #if node.type == "Scalar": #return True return False
def inherit_spd(operand, part, shape): m, n = shape if m == n: for i in range(m): part[i][i].set_property(properties.SPD) if shape == (2, 2): # For 3,3 repart not really needed # Schur complements of A_TL and A_BR are also SPD TL, TR, BL, BR = part.flatten_children() if operand.st_info[0] == storage.ST_LOWER: TOE.set_property( properties.SPD, Plus([ TL, Times([Minus([Transpose([BL])]), Inverse([BR]), BL]) ])) TOE.set_property( properties.SPD, Plus([ BR, Times([Minus([BL]), Inverse([TL]), Transpose([BL])]) ])) else: TOE.set_property( properties.SPD, Plus([ TL, Times([Minus([TR]), Inverse([BR]), Transpose([TR])]) ])) TOE.set_property( properties.SPD, Plus([ BR, Times([Minus([Transpose([TR])]), Inverse([TL]), TR]) ])) else: raise WrongPartitioningShape
def factor(self, ast): if isinstance(ast, str): return self.id2variable(ast) elif isinstance(ast, BaseExpression): return ast elif isinstance(ast, AST): if ast.func == "trans": return Transpose([ast.arg]) elif ast.func == "inv": return Inverse([ast.arg]) else: print(ast.__class__) raise Exception
def expr_to_rule_rhs_lhs(self, predicates): rules = [] t = PatternStar("t") l = PatternStar("l") ld = PatternDot("ld") r = PatternStar("r") for p in predicates: pr = [] lhs, rhs = p.children if len(lhs.children) == 1: #lhs_sym = WrapOutBef( lhs.children[0] ) lhs_sym = lhs.children[0] if isinstance(rhs, Plus): # t___ + rhs -> t + lhs repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [lhs]))(lhs_sym) pr.append( RewriteRule(Plus([t] + rhs.children), Replacement(repl_f))) # t___ + l___ rhs_i r___ + ... -> t + l lhs r repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times(d["l"].children + [lhs] + d["r"].children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([l] + [ch] + [r]) for ch in rhs.children ]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([simplify(to_canonical(Minus([lhs])))] + d["r"]. children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([simplify(to_canonical(Minus([ch])))] + [r]) for ch in rhs.children ]), Replacement(repl_f))) # A - B C in L B C R + -L A R (minus pushed all the way to the left, and whole thing negated) repl_f = (lambda lhs: lambda d: normalize_minus( Plus(d["t"].children + [ Times([d["ld"]] + d["l"].children + [Minus([lhs])] + d["r"].children) ])))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ normalize_minus(Times([ld, l, Minus([ch]), r])) for ch in rhs.children ]), Replacement(repl_f))) # A - B C in -L B C R + L A R (minus pushed all the way to the left) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([d["ld"]] + d["l"].children + [lhs] + d["r"]. children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ normalize_minus(Times([ld, l, ch, r])) for ch in rhs.children ]), Replacement(repl_f))) #repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [Times([simplify(to_canonical(Minus(lhs.children)))] + d["r"].children)]))(lhs_sym) #pr.append( RewriteRule( Plus([t] + [ #Times([ Minus([ld]), l, ch, r]) if not isinstance(ch, Minus) \ #else Times([ l, ch.children[0], r]) \ #for ch in rhs.children ]), #Replacement(repl_f) ) ) repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [ Times([ simplify(to_canonical(Minus([Transpose([lhs])]))) ] + d["r"].children) ]))(lhs_sym) pr.append( RewriteRule( Plus([t] + [ Times([ Minus([ld]), l, simplify(to_canonical(Transpose([ch]))), r]) if not isinstance(ch, Minus) \ else Times([ l, simplify(to_canonical(Transpose([ch]))), r]) \ for ch in rhs.children ]), Replacement(repl_f) ) ) elif isinstance(rhs, Times): repl_f = (lambda lhs: lambda d: Times(d[ "l"].children + [lhs] + d["r"].children))(lhs_sym) pr.append( RewriteRule(Times([l] + rhs.children + [r]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Times(d[ "l"].children + [Transpose([lhs])] + d["r"].children) )(lhs_sym) pr.append( RewriteRule( Times([ l, simplify(to_canonical(Transpose([rhs]))), r ]), Replacement(repl_f))) # [TODO] Minus is a b*tch. Should go for -1 and remove the operator internally? repl_f = (lambda lhs: lambda d: Times([ simplify(to_canonical(Minus([Times([lhs])]))) ] + d["r"].children))(lhs_sym) pr.append( RewriteRule( Times([ simplify( to_canonical( Minus([Times(rhs.get_children())]))) ] + [r]), Replacement(repl_f))) repl_f = (lambda lhs: lambda d: Times([ simplify( to_canonical(Minus([Transpose([Times([lhs])])]))) ] + d["r"].children))(lhs_sym) pr.append( RewriteRule( Times([ simplify( to_canonical( Minus([ Transpose( [Times(rhs.get_children())]) ]))) ] + [r]), Replacement(repl_f))) else: pr.append(RewriteRule(rhs, Replacement(lhs_sym))) new_rhs = simplify(to_canonical(Transpose([rhs]))) if not isOperand(new_rhs): pr.append( RewriteRule( simplify(to_canonical(Transpose([rhs]))), Replacement(Transpose([lhs_sym])))) else: pr.append(RewriteRule(rhs, Replacement(lhs))) rules.append(pr) return rules
Constraint(lambda d: to_canonical(Inverse([d["PD2"]])) == d["PD1"])), Replacement(lambda d: Times( [d["PS1"], Identity(d["PD1"].get_size()), d["PS2"]]))), ## Times( a___, A, Inverse(A), b___ ) -> Times( a, I, b ) #RewriteRule( #Times([ PS1, PD1, Inverse([ PD1 ]), PS2 ]), #Replacement("Times([ PS1, Identity(PD1.get_size()), PS2 ])") #), ## Times( a___, Inverse(A), A, b___ ) -> Times( a, I, b ) #RewriteRule( #Times([ PS1, Inverse([ PD1 ]), PD1, PS2 ]), #Replacement("Times([ PS1, Identity( PD1.get_size() ), PS2 ])") #), # Transpose # T(T(expr)) -> expr RewriteRule(Transpose([Transpose([subexpr])]), Replacement(lambda d: d["subexpr"])), # T(0) -> 0 RewriteRule( (Transpose([subexpr]), Constraint(lambda d: isZero(d["subexpr"]))), Replacement(lambda d: Zero(Transpose([d["subexpr"]]).get_size()))), # Inverse # Inv(Inv(expr)) -> expr RewriteRule(Inverse([Inverse([subexpr])]), Replacement(lambda d: d["subexpr"])), # Identity # A * I (not scalar A) -> A RewriteRule((Times([PS1, PD1, PD2, PS2]), Constraint(lambda d: (d["PD1"].isMatrix() or d["PD1"]. isVector()) and isIdentity(d["PD2"]))), Replacement(lambda d: Times([d["PS1"], d["PD1"], d["PS2"]]))),
class triu(Operator): def __init__(self, arg): Operator.__init__(self, [arg], [], UNARY) self.size = arg.get_size() # Patterns for inv to trsm A = PatternDot("A") B = PatternDot("B") C = PatternDot("C") # [TODO] Complete the set of patterns trsm_patterns = [ #RewriteRule( (Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), Constraint("A.st_info[0] == ST_LOWER")), \ RewriteRule( Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), \ Replacement( lambda d: Equal([ d["C"], mrdiv([ Transpose([tril(d["A"])]), d["B"] ]) ]) ) ), ] # # Produces Matlab code for: # - Loop-based code # def generate_matlab_code(operation, matlab_dir): # Recursive code out_path = os.path.join(matlab_dir, operation.name + ".m") # [FIX] At some this should be opname_rec.m with open(out_path, "w") as out: generate_matlab_recursive_code(operation, operation.pmes[-1], out) # pmes[-1] should be the 2x2 one
# Basic Ops - Completeness of the instruction set # [ # Inverse Instruction( ( Inverse([ A ]), Constraint("isOperand(A)") ), lambda d, t: \ RewriteRule( Inverse([ d["A"] ]), Replacement( "%s" % t ) ), lambda d: Inverse([ d["A"] ]) ), # Transpose Instruction( ( Transpose([ A ]), Constraint("isOperand(A)") ), lambda d, t: \ RewriteRule( Transpose([ d["A"] ]), Replacement( "%s" % t ) ), lambda d: Transpose([ d["A"] ]) ), # Minus #( Minus([ A ]), Constraint("isOperand(A, Symbol)") ), Instruction( ( Minus([ A ]), Constraint("isOperand(A)") ), lambda d, t: \ RewriteRule( Minus([ d["A"] ]), Replacement( "%s" % t )
Predicate("ldiv_lnn_ow", [d["A"], d["B"]], [d["A"].get_size(), d["B"].get_size()]) ]))), RewriteRule( (Equal([NList([X]), Times([Minus([Inverse([A])]), B])]), Constraint("A.isLowerTriangular() and X.st_info[1].name != X.name")), Replacement(lambda d: Equal([ NList([d["X"]]), Minus([ Predicate("ldiv_lnn_ow", [d["A"], d["B"]], [d["A"].get_size(), d["B"].get_size()]) ]) ]))), # X = i(t(A)) B -> ldiv_lti RewriteRule( (Equal([NList([X]), Times([Transpose([Inverse([A])]), B])]), Constraint( "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name == X.name" )), Replacement(lambda d: Equal([ NList([d["X"]]), Predicate("ldiv_lti", [d["A"], d["B"]], [d["A"].get_size(), d["B"].get_size()]) ]))), # X = i(t(A)) B -> ldiv_lti_ow RewriteRule( (Equal([NList([X]), Times([Transpose([Inverse([A])]), B])]), Constraint( "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name != X.name" )), Replacement(lambda d: Equal([
[Times([d["PSLeft"], term, d["PSRight"]]) for term in d["PS1"]]))), # -( a + b) -> (-a)+(-b) RewriteRule( Minus([Plus([PS1])]), Replacement(lambda d: Plus([Minus([term]) for term in d["PS1"]]))), # -( a * b) -> (-a) * b RewriteRule(Minus([Times([PD1, PS1])]), Replacement(lambda d: Times([Minus([d["PD1"]]), d["PS1"]]))), # a * b * (-c) -> (-a) * b * c RewriteRule( Times([PD1, PS1, Minus([PD2]), PS2]), Replacement(lambda d: Times( [Minus([d["PD1"]]), d["PS1"], d["PD2"], d["PS2"]]))), # Transpose( A + B + C ) -> A^T + B^T + C^T RewriteRule( Transpose([Plus([PS1])]), Replacement(lambda d: Plus([Transpose([term]) for term in d["PS1"]])), ), # Transpose( A * B * C ) -> C^T + B^T + A^T RewriteRule( Transpose([Times([PS1])]), Replacement(lambda d: Times( [Transpose([term]) for term in reversed(list(d["PS1"]))])), ), # Transpose( -A ) -> -(A^T) RewriteRule( Transpose([Minus([PD1])]), Replacement(lambda d: Minus([Transpose([d["PD1"]])])), ), # Inverse( -A ) -> -(Inverse(A)) RewriteRule(