Ejemplo n.º 1
0
 def to_canonical_IO(self):
     self.equation = [
         replace_all( eq, canonical_rules + canonicalIO_rules + simplify_rules ) \
             for eq in self.equation
     ]
     # Minimize number of nodes among:
     #   1) as is,
     #   2) applying minus to both sides, and
     #   3) applying transpose to both sides
     minimal = []
     for eq in self.equation:
         alternative_forms = [
             eq,
             simplify(
                 to_canonical(Equal([Minus([eq.lhs()]),
                                     Minus([eq.rhs()])]))),
             simplify(
                 to_canonical(
                     Equal([Transpose([eq.lhs()]),
                            Transpose([eq.rhs()])])))
         ]
         _, new = min([(alt.num_nodes(), alt) for alt in alternative_forms])
         minimal.append(new)
     # Set minimal forms
     self.equation = NList(minimal)
     #
     nodes = list(
         itertools.chain(*[[node for node in eq.iterate_preorder()]
                           for eq in self.equation]))
     self.operands = [op for op in self.operands if op in nodes]
Ejemplo n.º 2
0
def inherit_symmetric(operand, part, shape):
    m, n = shape
    st = operand.st_info[0]
    if m == n:
        for row in range(m):
            for col in range(row, n):
                if row == col:  # Diagonal
                    part[row][col].set_property(properties.SYMMETRIC)
                else:  # Off-diagonal
                    if st == storage.ST_LOWER:
                        part[row][col] = Transpose([part[col][row]])
                    else:
                        part[col][row] = Transpose([part[row][col]])
    else:
        raise WrongPartitioningShape
Ejemplo n.º 3
0
def isSymmetric(node):
    # isinstance?
    if node.isSymmetric():
        return True
    # node == trans( node )
    alt1 = copy.deepcopy(node)
    alt1 = to_canonical(alt1)._cleanup()
    alt2 = copy.deepcopy(node)
    alt2 = to_canonical(Transpose([alt2]))._cleanup()
    if alt1 == alt2:
        return True
    # more ...
    if isinstance(node, Plus):
        return all([isSymmetric(term) for term in node.get_children()])
    if isinstance(node, Times):  # iif they commute ...
        return False
    if isinstance(node, Minus):
        return isSymmetric(node.get_children()[0])
    if isinstance(node, Transpose):
        return isSymmetric(node.get_children()[0])
    if isinstance(node, Inverse):
        return isSymmetric(node.get_children()[0])
    #if isinstance( node, Operand ):
    #if node.type == "Scalar":
    #return True
    return False
Ejemplo n.º 4
0
def inherit_spd(operand, part, shape):
    m, n = shape
    if m == n:
        for i in range(m):
            part[i][i].set_property(properties.SPD)
        if shape == (2, 2):  # For 3,3 repart not really needed
            # Schur complements of A_TL and A_BR are also SPD
            TL, TR, BL, BR = part.flatten_children()
            if operand.st_info[0] == storage.ST_LOWER:
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        TL,
                        Times([Minus([Transpose([BL])]),
                               Inverse([BR]), BL])
                    ]))
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        BR,
                        Times([Minus([BL]),
                               Inverse([TL]),
                               Transpose([BL])])
                    ]))
            else:
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        TL,
                        Times([Minus([TR]),
                               Inverse([BR]),
                               Transpose([TR])])
                    ]))
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        BR,
                        Times([Minus([Transpose([TR])]),
                               Inverse([TL]), TR])
                    ]))
    else:
        raise WrongPartitioningShape
Ejemplo n.º 5
0
 def factor(self, ast):
     if isinstance(ast, str):
         return self.id2variable(ast)
     elif isinstance(ast, BaseExpression):
         return ast
     elif isinstance(ast, AST):
         if ast.func == "trans":
             return Transpose([ast.arg])
         elif ast.func == "inv":
             return Inverse([ast.arg])
     else:
         print(ast.__class__)
         raise Exception
Ejemplo n.º 6
0
 def expr_to_rule_rhs_lhs(self, predicates):
     rules = []
     t = PatternStar("t")
     l = PatternStar("l")
     ld = PatternDot("ld")
     r = PatternStar("r")
     for p in predicates:
         pr = []
         lhs, rhs = p.children
         if len(lhs.children) == 1:
             #lhs_sym = WrapOutBef( lhs.children[0] )
             lhs_sym = lhs.children[0]
             if isinstance(rhs, Plus):
                 # t___ + rhs -> t + lhs
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children +
                                                      [lhs]))(lhs_sym)
                 pr.append(
                     RewriteRule(Plus([t] + rhs.children),
                                 Replacement(repl_f)))
                 # t___ + l___ rhs_i r___ + ... -> t + l lhs r
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times(d["l"].children + [lhs] + d["r"].children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             Times([l] + [ch] + [r]) for ch in rhs.children
                         ]), Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([simplify(to_canonical(Minus([lhs])))] + d["r"].
                           children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             Times([simplify(to_canonical(Minus([ch])))] +
                                   [r]) for ch in rhs.children
                         ]), Replacement(repl_f)))
                 # A - B C in  L B C R + -L A R  (minus pushed all the way to the left, and whole thing negated)
                 repl_f = (lambda lhs: lambda d: normalize_minus(
                     Plus(d["t"].children + [
                         Times([d["ld"]] + d["l"].children + [Minus([lhs])]
                               + d["r"].children)
                     ])))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             normalize_minus(Times([ld, l,
                                                    Minus([ch]), r]))
                             for ch in rhs.children
                         ]), Replacement(repl_f)))
                 # A - B C in  -L B C R + L A R  (minus pushed all the way to the left)
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([d["ld"]] + d["l"].children + [lhs] + d["r"].
                           children)
                 ]))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Plus([t] + [
                             normalize_minus(Times([ld, l, ch, r]))
                             for ch in rhs.children
                         ]), Replacement(repl_f)))
                 #repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [Times([simplify(to_canonical(Minus(lhs.children)))] + d["r"].children)]))(lhs_sym)
                 #pr.append( RewriteRule( Plus([t] + [
                 #Times([ Minus([ld]), l, ch, r]) if not isinstance(ch, Minus) \
                 #else Times([ l, ch.children[0], r]) \
                 #for ch in rhs.children ]),
                 #Replacement(repl_f) ) )
                 repl_f = (lambda lhs: lambda d: Plus(d["t"].children + [
                     Times([
                         simplify(to_canonical(Minus([Transpose([lhs])])))
                     ] + d["r"].children)
                 ]))(lhs_sym)
                 pr.append( RewriteRule( Plus([t] + [
                             Times([ Minus([ld]), l, simplify(to_canonical(Transpose([ch]))), r]) if not isinstance(ch, Minus) \
                                     else Times([ l, simplify(to_canonical(Transpose([ch]))), r]) \
                                     for ch in rhs.children ]),
                                 Replacement(repl_f) ) )
             elif isinstance(rhs, Times):
                 repl_f = (lambda lhs: lambda d: Times(d[
                     "l"].children + [lhs] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(Times([l] + rhs.children + [r]),
                                 Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Times(d[
                     "l"].children + [Transpose([lhs])] + d["r"].children)
                           )(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             l,
                             simplify(to_canonical(Transpose([rhs]))), r
                         ]), Replacement(repl_f)))
                 # [TODO] Minus is a b*tch. Should go for -1 and remove the operator internally?
                 repl_f = (lambda lhs: lambda d: Times([
                     simplify(to_canonical(Minus([Times([lhs])])))
                 ] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             simplify(
                                 to_canonical(
                                     Minus([Times(rhs.get_children())])))
                         ] + [r]), Replacement(repl_f)))
                 repl_f = (lambda lhs: lambda d: Times([
                     simplify(
                         to_canonical(Minus([Transpose([Times([lhs])])])))
                 ] + d["r"].children))(lhs_sym)
                 pr.append(
                     RewriteRule(
                         Times([
                             simplify(
                                 to_canonical(
                                     Minus([
                                         Transpose(
                                             [Times(rhs.get_children())])
                                     ])))
                         ] + [r]), Replacement(repl_f)))
             else:
                 pr.append(RewriteRule(rhs, Replacement(lhs_sym)))
                 new_rhs = simplify(to_canonical(Transpose([rhs])))
                 if not isOperand(new_rhs):
                     pr.append(
                         RewriteRule(
                             simplify(to_canonical(Transpose([rhs]))),
                             Replacement(Transpose([lhs_sym]))))
         else:
             pr.append(RewriteRule(rhs, Replacement(lhs)))
         rules.append(pr)
     return rules
Ejemplo n.º 7
0
      Constraint(lambda d: to_canonical(Inverse([d["PD2"]])) == d["PD1"])),
     Replacement(lambda d: Times(
         [d["PS1"], Identity(d["PD1"].get_size()), d["PS2"]]))),
 ## Times( a___, A, Inverse(A), b___ ) -> Times( a, I, b )
 #RewriteRule(
 #Times([ PS1, PD1, Inverse([ PD1 ]), PS2 ]),
 #Replacement("Times([ PS1, Identity(PD1.get_size()), PS2 ])")
 #),
 ## Times( a___, Inverse(A), A, b___ ) -> Times( a, I, b )
 #RewriteRule(
 #Times([ PS1, Inverse([ PD1 ]), PD1, PS2 ]),
 #Replacement("Times([ PS1, Identity( PD1.get_size() ), PS2 ])")
 #),
 # Transpose
 # T(T(expr)) -> expr
 RewriteRule(Transpose([Transpose([subexpr])]),
             Replacement(lambda d: d["subexpr"])),
 # T(0) -> 0
 RewriteRule(
     (Transpose([subexpr]), Constraint(lambda d: isZero(d["subexpr"]))),
     Replacement(lambda d: Zero(Transpose([d["subexpr"]]).get_size()))),
 # Inverse
 # Inv(Inv(expr)) -> expr
 RewriteRule(Inverse([Inverse([subexpr])]),
             Replacement(lambda d: d["subexpr"])),
 # Identity
 # A * I (not scalar A) -> A
 RewriteRule((Times([PS1, PD1, PD2, PS2]),
              Constraint(lambda d: (d["PD1"].isMatrix() or d["PD1"].
                                    isVector()) and isIdentity(d["PD2"]))),
             Replacement(lambda d: Times([d["PS1"], d["PD1"], d["PS2"]]))),
Ejemplo n.º 8
0

class triu(Operator):
    def __init__(self, arg):
        Operator.__init__(self, [arg], [], UNARY)
        self.size = arg.get_size()


# Patterns for inv to trsm
A = PatternDot("A")
B = PatternDot("B")
C = PatternDot("C")
# [TODO] Complete the set of patterns
trsm_patterns = [
    #RewriteRule( (Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), Constraint("A.st_info[0] == ST_LOWER")), \
    RewriteRule( Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), \
            Replacement( lambda d: Equal([ d["C"], mrdiv([ Transpose([tril(d["A"])]), d["B"] ]) ]) ) ),
]


#
# Produces Matlab code for:
# - Loop-based code
#
def generate_matlab_code(operation, matlab_dir):
    # Recursive code
    out_path = os.path.join(matlab_dir, operation.name +
                            ".m")  # [FIX] At some this should be opname_rec.m
    with open(out_path, "w") as out:
        generate_matlab_recursive_code(operation, operation.pmes[-1],
                                       out)  # pmes[-1] should be the 2x2 one
Ejemplo n.º 9
0
 # Basic Ops - Completeness of the instruction set
 #
 [
     # Inverse
     Instruction(
         ( Inverse([ A ]), Constraint("isOperand(A)") ),
         lambda d, t: \
                 RewriteRule(
                     Inverse([ d["A"] ]),
                     Replacement( "%s" % t )
                 ),
         lambda d: Inverse([ d["A"] ])
     ),
     # Transpose
     Instruction(
         ( Transpose([ A ]), Constraint("isOperand(A)") ),
         lambda d, t: \
                 RewriteRule(
                     Transpose([ d["A"] ]),
                     Replacement( "%s" % t )
                 ),
         lambda d: Transpose([ d["A"] ])
     ),
     # Minus
     #( Minus([ A ]), Constraint("isOperand(A, Symbol)") ),
     Instruction(
         ( Minus([ A ]), Constraint("isOperand(A)") ),
         lambda d, t: \
                 RewriteRule(
                     Minus([ d["A"] ]),
                     Replacement( "%s" % t )
Ejemplo n.º 10
0
         Predicate("ldiv_lnn_ow", [d["A"], d["B"]],
                   [d["A"].get_size(), d["B"].get_size()])
     ]))),
 RewriteRule(
     (Equal([NList([X]), Times([Minus([Inverse([A])]), B])]),
      Constraint("A.isLowerTriangular() and X.st_info[1].name != X.name")),
     Replacement(lambda d: Equal([
         NList([d["X"]]),
         Minus([
             Predicate("ldiv_lnn_ow", [d["A"], d["B"]],
                       [d["A"].get_size(), d["B"].get_size()])
         ])
     ]))),
 # X = i(t(A)) B -> ldiv_lti
 RewriteRule(
     (Equal([NList([X]), Times([Transpose([Inverse([A])]), B])]),
      Constraint(
          "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name == X.name"
      )),
     Replacement(lambda d: Equal([
         NList([d["X"]]),
         Predicate("ldiv_lti", [d["A"], d["B"]],
                   [d["A"].get_size(), d["B"].get_size()])
     ]))),
 # X = i(t(A)) B -> ldiv_lti_ow
 RewriteRule(
     (Equal([NList([X]), Times([Transpose([Inverse([A])]), B])]),
      Constraint(
          "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name != X.name"
      )),
     Replacement(lambda d: Equal([
Ejemplo n.º 11
0
         [Times([d["PSLeft"], term, d["PSRight"]]) for term in d["PS1"]]))),
 # -( a + b) -> (-a)+(-b)
 RewriteRule(
     Minus([Plus([PS1])]),
     Replacement(lambda d: Plus([Minus([term]) for term in d["PS1"]]))),
 # -( a * b) -> (-a) * b
 RewriteRule(Minus([Times([PD1, PS1])]),
             Replacement(lambda d: Times([Minus([d["PD1"]]), d["PS1"]]))),
 # a * b * (-c) -> (-a) * b * c
 RewriteRule(
     Times([PD1, PS1, Minus([PD2]), PS2]),
     Replacement(lambda d: Times(
         [Minus([d["PD1"]]), d["PS1"], d["PD2"], d["PS2"]]))),
 # Transpose( A + B + C ) -> A^T + B^T + C^T
 RewriteRule(
     Transpose([Plus([PS1])]),
     Replacement(lambda d: Plus([Transpose([term]) for term in d["PS1"]])),
 ),
 # Transpose( A * B * C ) -> C^T + B^T + A^T
 RewriteRule(
     Transpose([Times([PS1])]),
     Replacement(lambda d: Times(
         [Transpose([term]) for term in reversed(list(d["PS1"]))])),
 ),
 # Transpose( -A ) -> -(A^T)
 RewriteRule(
     Transpose([Minus([PD1])]),
     Replacement(lambda d: Minus([Transpose([d["PD1"]])])),
 ),
 # Inverse( -A ) -> -(Inverse(A))
 RewriteRule(