Ejemplo n.º 1
0
 def to_canonical_IO(self):
     self.equation = [
         replace_all( eq, canonical_rules + canonicalIO_rules + simplify_rules ) \
             for eq in self.equation
     ]
     # Minimize number of nodes among:
     #   1) as is,
     #   2) applying minus to both sides, and
     #   3) applying transpose to both sides
     minimal = []
     for eq in self.equation:
         alternative_forms = [
             eq,
             simplify(
                 to_canonical(Equal([Minus([eq.lhs()]),
                                     Minus([eq.rhs()])]))),
             simplify(
                 to_canonical(
                     Equal([Transpose([eq.lhs()]),
                            Transpose([eq.rhs()])])))
         ]
         _, new = min([(alt.num_nodes(), alt) for alt in alternative_forms])
         minimal.append(new)
     # Set minimal forms
     self.equation = NList(minimal)
     #
     nodes = list(
         itertools.chain(*[[node for node in eq.iterate_preorder()]
                           for eq in self.equation]))
     self.operands = [op for op in self.operands if op in nodes]
Ejemplo n.º 2
0
def tile_expr( _expr ):
    tiles = []
    expr = copy.deepcopy( _expr )
    lhs, rhs = expr.get_children()
    if isinstance( rhs, Predicate ):
        tiles_per_arg = []
        for i, arg in enumerate( rhs.get_children() ):
            if not isOperand( arg ):
                tiles_per_arg.append( list(all_tilings( arg )) )
            else:
                tiles_per_arg.append( [[arg]] )
        cross = itertools.product( *tiles_per_arg )
        for comb in cross:
            new_pred = copy.deepcopy( rhs )
            new_ch = [ t[-1] for t in comb ]
            for i,ch in enumerate( new_ch ):
                new_pred.set_children( i, ch )
            updates = list(itertools.chain.from_iterable( [t[:-1] for t in comb] ))
            tiles.append( updates + [Equal([lhs, new_pred])] )
    else:
        if isOperand(rhs):
            tiles.append( [_expr] )
        else:
            for tiling in all_tilings( rhs ):
                last = tiling.pop() # This is just a (temporary) symbol (output of one-to-last)
                one_to_last = tiling.pop()
                lhs, rhs = expr.get_children()
                one_to_last.set_children( 0, lhs ) # assign one-to-last to actual lhs of eq
                tiling.append( one_to_last )
                tiles.append( tiling )
    return tiles
Ejemplo n.º 3
0
 def algorithm_initialization(self, init_state):
     init = []
     for expr in init_state:
         lhs, rhs = expr.get_children()
         lhs_ch = lhs.get_children()
         #init.extend([ Equal([ NList([lch]), rhs ]) for lch in lhs_ch if not isZero(lch) and not isZero(rhs) ])
         init.extend([
             Equal([NList([lch]), rhs]) for lch in lhs_ch if not isZero(lch)
         ])
     return init
Ejemplo n.º 4
0
 def tile( self, tree, node, match_dict ):
     _T = TOS.new_temp()
     temp = Symbol("T" + str(_T))
     temp.set_property( properties.TEMPORARY )
     #
     if isinstance( node, Predicate ):
         if node == tree:
             print( "[Warning] If predicate has multiple outputs, may break if not careful from caller" )
         return Equal([ NList([ temp ]), copy.deepcopy( node ) ]), \
                 replace( tree, [RewriteRule( copy.deepcopy( node ), Replacement(temp) )] )
     else:
         tile_expr = self.create_tile( match_dict )
         propagate_properties( tile_expr, temp )
         propagate_st_info( tile_expr, temp )
         ## [FIXME] Quick and dirty test
         #if isUpperTriangular( tile_expr ):
             #print( temp, "is upper triangular" )
             #temp.set_property( properties.UPPER_TRIANGULAR )
         return Equal([ NList([ temp ]), self.create_tile( match_dict ) ]), \
                 replace( tree, [self.create_rewrite_rule( match_dict, temp )] )
Ejemplo n.º 5
0
def _checkPredicateOverwrite(statement):
    lhs = statement.lhs()
    rhs = statement.rhs()
    if not isinstance(rhs, Predicate):
        rhs_ops = [
            node for node in rhs.iterate_preorder()
            if isinstance(node, Symbol)
        ]
        tmp_ops = [op for op in rhs_ops if op.isTemporary()]
        if len(tmp_ops
               ) == 1:  # [FIXME] Quick and dirty to play with temporaries
            tmp = tmp_ops[0]
            if lhs.children[0].size == tmp.size:
                if not lhs.children[0].isTemporary():
                    overwrites = False
                    for op in rhs_ops:
                        try:
                            overwrites = lhs.children[0].st_info[
                                1] == op.st_info[1]
                        except AttributeError:
                            pass
                        if overwrites: break
                    if not overwrites:
                        statements = []
                        statements.append(Equal([NList(lhs.children), tmp]))
                        statement.children[1] = replace(
                            copy.deepcopy(rhs),
                            [RewriteRule(tmp, Replacement(lhs.children[0]))])
                        statements.append(statement)
                        return statements
        else:
            # TRSM 2x2 ...
            pass
        return [statement]
    if not pm.DB[rhs.name].overwrite:  # []
        return [statement]

    statements = []
    # [FIXME] Assumes one single operands get overwritten. Will break in the future
    already_copied = []
    for inp, out in pm.DB[rhs.name].overwrite:
        if inp in already_copied:
            continue
        already_copied.append(inp)
        #
        if rhs.children[inp] != lhs.children[
                out]:  # [FIXME] All should have st_into
            try:
                overwrites = lhs.children[out].st_info[1] == rhs.children[
                    inp].st_info[1]
            except AttributeError:
                overwrites = False
            if overwrites:
                statements.append(statement)  #[FIXME] Gosh...
                continue
            inpop = rhs.children[inp]
            outop = lhs.children[out]
            if inpop.isTemporary() or inpop.isInput():
                # if multiple outputs overwrite input (e.g., LU)
                if len([o for i, o in pm.DB[rhs.name].overwrite if i == inp
                        ]) > 1:
                    try:
                        outop = TOS._TOS[outop.st_info[1].name][
                            0]  # LU  (ABR = T3; [LBR,UBR] = LU(ABR))
                    except:
                        pass
                    outop.st_info = (None, outop)
                #
                statements.append(Equal([NList([outop]), inpop]))
                rhs.children[inp] = outop
                statements.append(statement)
            else:
                lhs.children[out] = rhs.children[inp]
                statements.append(statement)
                statements.append(Equal([NList([inpop]), outop]))
        else:
            statements.append(statement)
    return statements
Ejemplo n.º 6
0
    def learn_pattern(self):
        inops = [op for op in self.operands if op.isInput()]
        outops = [op for op in self.operands if op.isOutput()]
        #
        single_assignment = len( self.equation.children ) == 1 and \
                            isinstance(self.equation.children[0].children[0], Symbol) # eq.lhs.single_entry_in_NL
        #
        op_to_pattern = [
            RewriteRule(op, Replacement("PatternDot(%s.name)" % op.name))
            for op in self.operands
        ]
        pattern = NList([
            replace_all(copy.deepcopy(eq), op_to_pattern)
            for eq in self.equation
        ])
        if single_assignment:
            props_str = [
                symbol_props_to_constraints_no_io(op) for op in self.operands
            ]
            constraint = Constraint(" and ".join(
                [prop for prop in props_str if prop]))
        else:
            constraint = Constraint(" and ".join(
                [symbol_props_to_constraints(op) for op in self.operands]))
        # [TODO] Tuple for get_size
        replacement = Replacement(
            "Equal([ NList([%s]), Predicate( \"%s\", [%s], [%s] ) ])" %
            (", ".join([op.name for op in outops]), self.name, ", ".join([
                op.name for op in inops
            ]), ", ".join(["%s.get_size()" % op.get_name() for op in outops])))
        # [TODO] This should be part of the verbose option
        print("* Learnt pattern")
        print("*   ", pattern, end="")
        if constraint.to_eval:
            print("with        ", constraint.to_eval)
        print(" --> ")
        print("*          ", replacement.to_eval)
        print("**********************************")
        # [TODO] Maybe sort known ops by specificity (a la mathematica)
        #known_ops.insert( 0, RewriteRule( (pattern, constraint), replacement ) )

        if single_assignment:
            expr = pattern.children[0]
            expr.children[0] = NList([expr.children[0]])
            known_ops_single.insert(
                0, RewriteRule((expr, constraint), replacement))
            # With minus
            replacement = Replacement(
                "Equal([ NList([%s]), Minus([ Predicate( \"%s\", [%s], [%s] ) ]) ])"
                % (", ".join([op.name for op in outops]), self.name, ", ".join(
                    [op.name for op in inops]), ", ".join(
                        ["%s.get_size()" % op.get_name() for op in outops])))
            expr = copy.deepcopy(expr)
            expr.children[1] = Minus([expr.children[1]])
            expr.children[1] = normalize_minus(copy.deepcopy(expr.children[1]))
            known_ops_single.insert(
                0, RewriteRule((expr, constraint), replacement))
            #with open(os.path.join("OUTPUT", self.name+"_patterns"), "wb") as patt_f:
            #pickle.dump( known_ops_single[1], patt_f )
            #pickle.dump( known_ops_single[0], patt_f )
        else:
            known_ops.insert(0, RewriteRule((pattern, constraint),
                                            replacement))
            with open(os.path.join("OUTPUT", self.name + "_patterns"),
                      "wb") as patt_f:
                pickle.dump(known_ops[0], patt_f)

        pattern = Equal([
            NList([PatternDot(op.get_name()) for op in outops]),
            Predicate(self.name, [PatternDot(op.get_name()) for op in inops],
                      [op.get_size() for op in outops])
        ])
        replacement = Replacement(equation2replacement(self.equation))
        op_to_implicit.append(RewriteRule(pattern, replacement))
Ejemplo n.º 7
0
from BindDimensions import bindDimensions
from graph import build_dep_graph, subgraphs, zero_out_lower
from Partitioning import partition_shape, repartition, repartition_shape, repart_group
from Tiling import tile_expr
from utils import sort
from RewritingExtension import *

from BackEnd.MatlabCode import generate_matlab_code, click2matlab
#from BackEnd.LatexCode import generate_latex_code
from BackEnd.LGenCode import generate_lgen_files
#from BackEnd.CCode import generate_c_code

known_ops = [
    # Straight assignment
    RewriteRule(
        (NList([Equal([PatternDot("LHS"), PatternDot("RHS")])]),
         Constraint(
             "isinstance(LHS, Symbol) and isOutput(LHS) and isInput(RHS)")),
        Replacement("Equal([ NList([LHS]), RHS ])"))
]
known_ops_single = []
known_pmes = []
op_to_implicit = []


# Add a verbose option
# For now it would help to illustrate the process
class Operation(object):
    def __init__(self, name, operands, equation, overwrite):
        self.name = name
        self.operands = operands
Ejemplo n.º 8
0
canonicalIO_rules = [
    ## Equal[Minus, Minus] -> remove minuses
    ## add support for head_[_]
    #RewriteRule(
    #(
    #Equal([ Minus([LHS]), Minus([RHS]) ]),
    #Constraint("True")
    #),
    #Replacement("Equal([ LHS, RHS ])")
    #),
    # Input to the right
    # Plus
    RewriteRule(
        (
            Equal([ Plus([PS1, subexpr, PS2]), RHS ]),
            Constraint(lambda d: isInput(d["subexpr"]))
        ),
        Replacement(lambda d: Equal([ Plus([ d["PS1"], d["PS2"] ]), \
                                      Plus([ d["RHS"], Minus([d["subexpr"]])]) ]))
    ),
    # Minus
    RewriteRule(
        (
            Equal([ Minus([subexpr]), RHS ]),
            Constraint(lambda d: isInput(d["subexpr"]))
        ),
        Replacement(lambda d: Equal([ Zero(d["subexpr"].get_size()), \
                                      Plus([d["RHS"], d["subexpr"]]) ]))
    ),
    # Transpose
Ejemplo n.º 9
0
 def equation(self, ast):
     lhs = ast['lhs']
     rhs = ast['rhs']
     if lhs.get_size() != rhs.get_size():
         raise SizeError('Equation\'s lhs and rhs have different sizes' % (lhs.get_size(), rhs.get_size()))
     return Equal([lhs, rhs])
Ejemplo n.º 10
0

class triu(Operator):
    def __init__(self, arg):
        Operator.__init__(self, [arg], [], UNARY)
        self.size = arg.get_size()


# Patterns for inv to trsm
A = PatternDot("A")
B = PatternDot("B")
C = PatternDot("C")
# [TODO] Complete the set of patterns
trsm_patterns = [
    #RewriteRule( (Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), Constraint("A.st_info[0] == ST_LOWER")), \
    RewriteRule( Equal([ C, Times([ B, Transpose([Inverse([A])]) ]) ]), \
            Replacement( lambda d: Equal([ d["C"], mrdiv([ Transpose([tril(d["A"])]), d["B"] ]) ]) ) ),
]


#
# Produces Matlab code for:
# - Loop-based code
#
def generate_matlab_code(operation, matlab_dir):
    # Recursive code
    out_path = os.path.join(matlab_dir, operation.name +
                            ".m")  # [FIX] At some this should be opname_rec.m
    with open(out_path, "w") as out:
        generate_matlab_recursive_code(operation, operation.pmes[-1],
                                       out)  # pmes[-1] should be the 2x2 one
Ejemplo n.º 11
0
pm.DB["rdiv_utn"].overwrite = []
pm.DB["rdiv_utn_ow"] = pm.PredicateMetadata("rdiv_utn_ow", tuple())
pm.DB["rdiv_utn_ow"].overwrite = [(1, 0)]
pm.DB["rdiv_utu"] = pm.PredicateMetadata("rdiv_utu", tuple())
pm.DB["rdiv_utu"].overwrite = []
pm.DB["rdiv_utu_ow"] = pm.PredicateMetadata("rdiv_utu_ow", tuple())
pm.DB["rdiv_utu_ow"].overwrite = [(1, 0)]

A = PatternDot("A")
B = PatternDot("B")
X = PatternDot("X")

trsm2lgen_rules = [
    # X = i(t(A)) B -> ldiv_lni
    RewriteRule((
        Equal([NList([X]), Times([Inverse([A]), B])]),
        Constraint(
            "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name == X.name"
        )),
                Replacement(lambda d: Equal([
                    NList([d["X"]]),
                    Predicate("ldiv_lni", [d["A"], d["B"]],
                              [d["A"].get_size(), d["B"].get_size()])
                ]))),
    # X = i(t(A)) B -> ldiv_lni_ow
    RewriteRule(
        (Equal([NList([X]), Times([Inverse([A]), B])]),
         Constraint(
             "A.isLowerTriangular() and A.isImplicitUnitDiagonal() and X.st_info[1].name != X.name"
         )),
        Replacement(lambda d: Equal([