Exemplo n.º 1
0
def inherit_spd(operand, part, shape):
    m, n = shape
    if m == n:
        for i in range(m):
            part[i][i].set_property(properties.SPD)
        if shape == (2, 2):  # For 3,3 repart not really needed
            # Schur complements of A_TL and A_BR are also SPD
            TL, TR, BL, BR = part.flatten_children()
            if operand.st_info[0] == storage.ST_LOWER:
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        TL,
                        Times([Minus([Transpose([BL])]),
                               Inverse([BR]), BL])
                    ]))
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        BR,
                        Times([Minus([BL]),
                               Inverse([TL]),
                               Transpose([BL])])
                    ]))
            else:
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        TL,
                        Times([Minus([TR]),
                               Inverse([BR]),
                               Transpose([TR])])
                    ]))
                TOE.set_property(
                    properties.SPD,
                    Plus([
                        BR,
                        Times([Minus([Transpose([TR])]),
                               Inverse([TL]), TR])
                    ]))
    else:
        raise WrongPartitioningShape
Exemplo n.º 2
0
 def factor(self, ast):
     if isinstance(ast, str):
         return self.id2variable(ast)
     elif isinstance(ast, BaseExpression):
         return ast
     elif isinstance(ast, AST):
         if ast.func == "trans":
             return Transpose([ast.arg])
         elif ast.func == "inv":
             return Inverse([ast.arg])
     else:
         print(ast.__class__)
         raise Exception
Exemplo n.º 3
0

class triu(Operator):
    def __init__(self, arg):
        Operator.__init__(self, [arg], [], UNARY)
        self.size = arg.get_size()


# Patterns for inv to trsm
A = PatternDot("A")
B = PatternDot("B")
C = PatternDot("C")
# [TODO] Complete the set of patterns
trsm_patterns = [
    #RewriteRule( (Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), Constraint("A.st_info[0] == ST_LOWER")), \
    RewriteRule( Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), \
            Replacement( lambda d: Equal([ d["C"], mrdiv([ Transpose([tril(d["A"])]), d["B"] ]) ]) ) ),
]


#
# Produces Matlab code for:
# - Loop-based code
#
def generate_matlab_code(operation, matlab_dir):
    # Recursive code
    out_path = os.path.join(matlab_dir, operation.name +
                            ".m")  # [FIX] At some this should be opname_rec.m
    with open(out_path, "w") as out:
        generate_matlab_recursive_code(operation, operation.pmes[-1],
                                       out)  # pmes[-1] should be the 2x2 one
Exemplo n.º 4
0
        if symm_ops:
            print( "Propagating st info from:", symm_ops )
            out.st_info = (symm_ops[0].st_info[0], out )
        else:
            pass
            # [FIXME] Default?


# [TODO] Complete
instruction_set = [
    # Basic Ops - Completeness of the instruction set
    #
    [
        # Inverse
        Instruction(
            ( Inverse([ A ]), Constraint("isOperand(A)") ),
            lambda d, t: \
                    RewriteRule(
                        Inverse([ d["A"] ]),
                        Replacement( "%s" % t )
                    ),
            lambda d: Inverse([ d["A"] ])
        ),
        # Transpose
        Instruction(
            ( Transpose([ A ]), Constraint("isOperand(A)") ),
            lambda d, t: \
                    RewriteRule(
                        Transpose([ d["A"] ]),
                        Replacement( "%s" % t )
                    ),
Exemplo n.º 5
0
pm.DB["rdiv_utn"].overwrite = []
pm.DB["rdiv_utn_ow"] = pm.PredicateMetadata("rdiv_utn_ow", tuple())
pm.DB["rdiv_utn_ow"].overwrite = [(1, 0)]
pm.DB["rdiv_utu"] = pm.PredicateMetadata("rdiv_utu", tuple())
pm.DB["rdiv_utu"].overwrite = []
pm.DB["rdiv_utu_ow"] = pm.PredicateMetadata("rdiv_utu_ow", tuple())
pm.DB["rdiv_utu_ow"].overwrite = [(1, 0)]

A = PatternDot("A")
B = PatternDot("B")
X = PatternDot("X")

trsm2lgen_rules = [
    # X = i(t(A)) B -> ldiv_lni
    RewriteRule((
        Equal([NList([X]), Times([Inverse([A]), B])]),
        Constraint(
            "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name == X.name"
        )),
                Replacement(lambda d: Equal([
                    NList([d["X"]]),
                    Predicate("ldiv_lni", [d["A"], d["B"]],
                              [d["A"].get_size(), d["B"].get_size()])
                ]))),
    # X = i(t(A)) B -> ldiv_lni_ow
    RewriteRule(
        (Equal([NList([X]), Times([Inverse([A]), B])]),
         Constraint(
             "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name != X.name"
         )),
        Replacement(lambda d: Equal([
Exemplo n.º 6
0
def flatten_blocked_operation(expr):
    if isinstance(expr, BlockedExpression):
        return expr
    if isinstance(expr, Equal):
        flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()]
        return BlockedExpression(map_thread(Equal, flat_ch, 2),
                                 flat_ch[0].size, flat_ch[0].shape)
    if isinstance(expr, Plus):
        flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()]
        return BlockedExpression(map_thread(Plus, flat_ch, 2), flat_ch[0].size,
                                 flat_ch[0].shape)
    # [TODO] will ignore the inner scalar expressions for now
    # Also the plain scalars: I guess it will suffice to check size of blocked == (1,1)
    if isinstance(expr, Times):
        non_scalar_idx = 0
        while expr.children[non_scalar_idx].isScalar():
            non_scalar_idx += 1
        scalars = [
            s.children[0][0] for s in expr.get_children()[:non_scalar_idx]
        ]
        non_scalars = expr.get_children()[non_scalar_idx:]
        if non_scalars:
            prod = flatten_blocked_operation(copy.deepcopy(non_scalars[0]))
            for ch in non_scalars[1:]:
                prod = multiply_blocked_expressions(
                    prod, flatten_blocked_operation(ch))
            prod.children = [[Times(scalars + [cell]) for cell in row]
                             for row in prod.children]
            return prod
        else:
            return [[Times(scalars)]]
        #non_scalar_prod = functools.reduce(
        #multiply_blocked_expressions,
        #[flatten_blocked_operation( ch ) for ch in expr.get_children()[non_scalar_idx:]]
        #)
        #non_scalar_prod.children = [ [ Times(scalars + [cell]) for cell in row ] for row in non_scalar_prod.children ]
        #return non_scalar_prod
    if isinstance(expr, Minus):
        flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()]
        return BlockedExpression(map_thread(Minus, flat_ch, 2),
                                 flat_ch[0].size, flat_ch[0].shape)
    if isinstance(expr, Transpose):
        flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()]
        new = BlockedExpression(map_thread(Transpose, flat_ch, 2),
                                flat_ch[0].size, flat_ch[0].shape)
        new.transpose()
        return new
    if isinstance(expr, Inverse):  # [TODO] Only triangular
        flat_ch = [flatten_blocked_operation(ch) for ch in expr.get_children()]
        if len(flat_ch[0].get_children()) == 1:  # 1x1 inverse
            return BlockedExpression(map_thread(Inverse, flat_ch, 2),
                                     flat_ch[0].size, flat_ch[0].shape)
        if len(flat_ch[0].get_children()) == 2:  # 2x2 inverse
            children = flat_ch[0].get_children()
            TL = children[0][0]
            TR = children[0][1]
            BL = children[1][0]
            BR = children[1][1]
            return BlockedExpression(
                [[
                    Inverse([TL]),
                    Times([Minus([Inverse([TL])]), TR,
                           Inverse([BR])])
                ],
                 [
                     Times([Minus([Inverse([BR])]), BL,
                            Inverse([TL])]),
                     Inverse([children[1][1]])
                 ]], flat_ch[0].size, flat_ch[0].shape)
Exemplo n.º 7
0
        Replacement(lambda d: Plus([Transpose([term]) for term in d["PS1"]])),
    ),
    # Transpose( A * B * C ) -> C^T + B^T + A^T
    RewriteRule(
        Transpose([Times([PS1])]),
        Replacement(lambda d: Times(
            [Transpose([term]) for term in reversed(list(d["PS1"]))])),
    ),
    # Transpose( -A ) -> -(A^T)
    RewriteRule(
        Transpose([Minus([PD1])]),
        Replacement(lambda d: Minus([Transpose([d["PD1"]])])),
    ),
    # Inverse( -A ) -> -(Inverse(A))
    RewriteRule(
        Inverse([Minus([PD1])]),
        Replacement(lambda d: Minus([Inverse([d["PD1"]])])),
    ),
    # Inverse( A^T ) -> (Inverse(A))^T
    RewriteRule(
        Inverse([Transpose([PD1])]),
        Replacement(lambda d: Transpose([Inverse([d["PD1"]])])),
    ),
    # Inverse( A * B * C ) -> inv(C) + inv(B) + inv(A)
    RewriteRule(
        Inverse([Times([PS1])]),
        Replacement(lambda d: Times(
            [Inverse([term]) for term in reversed(list(d["PS1"]))])),
    )
]