예제 #1
0
 def to_canonical_IO(self):
     self.equation = [
         replace_all( eq, canonical_rules + canonicalIO_rules + simplify_rules ) \
             for eq in self.equation
     ]
     # Minimize number of nodes among:
     #   1) as is,
     #   2) applying minus to both sides, and
     #   3) applying transpose to both sides
     minimal = []
     for eq in self.equation:
         alternative_forms = [
             eq,
             simplify(
                 to_canonical(Equal([Minus([eq.lhs()]),
                                     Minus([eq.rhs()])]))),
             simplify(
                 to_canonical(
                     Equal([Transpose([eq.lhs()]),
                            Transpose([eq.rhs()])])))
         ]
         _, new = min([(alt.num_nodes(), alt) for alt in alternative_forms])
         minimal.append(new)
     # Set minimal forms
     self.equation = NList(minimal)
     #
     nodes = list(
         itertools.chain(*[[node for node in eq.iterate_preorder()]
                           for eq in self.equation]))
     self.operands = [op for op in self.operands if op in nodes]
예제 #2
0
def normalize_minus( expr ):
    ld = PatternDot("ld")
    md = PatternDot("md")
    l = PatternStar("l")
    r = PatternStar("r")
    return replace_all( expr, [RewriteRule( Times([ ld, l, Minus([md]), r ]),
                                            Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"], d["md"], d["r"]]) )),
                               RewriteRule( Times([ Minus([Minus([ld])]), l ]),
                                            Replacement( lambda d: Times([ d["ld"], d["l"]]) )),
                               RewriteRule( Minus([ Times([ ld, l ]) ]),
                                            Replacement( lambda d: Times([ Minus([d["ld"]]), d["l"]]) )) ] )
예제 #3
0
 def express_in_terms_of_input(self, assignments):
     # In LU, there may be (overwritable) inputs as lhs
     for a in assignments:
         if a.children[0].children[0].isInput():
             raise Exception
     assignments = [
         a for a in assignments if not a.children[0].children[0].isInput()
     ]
     expand_rules = list(
         itertools.chain(*self.expr_to_rule_lhs_rhs(assignments)))
     new_ass = []
     for expr in assignments:
         new = copy.deepcopy(expr)
         new.children[1] = simplify(
             to_canonical(replace_all(new.children[1], expand_rules)))
         new_ass.append(new)
     return new_ass
예제 #4
0
파일: Tiling.py 프로젝트: shreyas42/slingen
def all_tilings( expr ):
    ongoing = [ [expr] ]
    while len( ongoing ) > 0:
        alg = ongoing.pop()
        to_tile = alg.pop()
        if isOperand( to_tile ):
            alg.append( to_tile )
            yield alg
            continue
        to_tile = replace_all( copy.deepcopy( to_tile), grouping_rules )
        for collection in reversed(instruction_set):
            matched_in_this_level = False ### These are just control vars
                                          ### to avoid redundancies
            for instr in collection:
                for node in to_tile.iterate_preorder():
                    # To avoid redundancies
                    matched_this_node = [] ###
                    if isinstance( instr, tuple ): # A way to deactivate some for quick development/debugging?
                        continue
                    for _m in instr.match( node ):
                        matched_in_this_level = True ###
                        #_T = TOS.new_temp()
                        new = copy.deepcopy( to_tile )
                        #tile, new = instr.tile( new, _m, _T )
                        tile, new = instr.tile( new, node, _m )
                        lhs, rhs = tile.get_children()
                        if rhs not in matched_this_node: ###
                            matched_this_node.append( rhs ) ###
                            ongoing.append( alg[:] + [tile, new] )
                            # Set size of new temporary
                            lhs.children[0].size = rhs.get_size()
                        else:
                            # Aestetic. Simply to avoid missing T? values.
                            TOS.push_back_temp( )
                            TOS._TOS.unset_operand( tile.children[0].children[0] )
            if matched_in_this_level: ###
                break
예제 #5
0
    def find_updates_v2(self, before, after):
        # If a part is (partially) computed in the before and
        #   does not appear in the after or
        # going from before to after requires undoing some computation
        # it is potentially unstable, and more expensive: ignore
        dict_bef = dict([(str(u.get_children()[0]), u) for u in before])
        dict_aft = dict([(str(u.get_children()[0]), u) for u in after])
        ignore = False
        quadrant = None
        for k, v in dict_bef.items():
            if k not in dict_aft:
                ignore = True
                break
            else:
                rules = self.expr_to_rule_rhs_lhs([v])
                rules = list(itertools.chain(*rules))
                expr_copy = copy.deepcopy(dict_aft[k])
                t = replace(expr_copy, rules)
                #if v == replace( expr_copy, rules ):
                if dict_aft[k] == t:
                    ignore = True
                    break
        if ignore:
            print("[INFO] Skipping invariant: %s" % reason)
            return None
        #
        # Wrap outputs for before and after
        WrapBefOut = WrapOutBef
        for u in before:
            u.children[0] = NList([WrapBefOut(l) for l in u.children[0]])
        #
        wrap_rules_after = []
        for u in after:
            u.children[0] = NList([WrapOutAft(l) for l in u.children[0]])
        # replace before in after
        wrap_rules_before = []
        for u in before:
            print(u)
            lhs, rhs = u.get_children()
            if len(lhs.children) > 1:
                continue
            rules = self.expr_to_rule_rhs_lhs([u])
            wrap_rules_before.append(list(itertools.chain(*rules)))
        #
        for i, rule in enumerate(reversed(wrap_rules_before)):
            idx = len(wrap_rules_before) - i - 1
            for j in range(idx - 1, -1, -1):
                for _rule in rule:
                    _rule.pattern = replace_all(_rule.pattern,
                                                wrap_rules_before[j])
        wrap_rules_before = list(itertools.chain(*wrap_rules_before))
        #
        for u in after:
            _, rhs = u.get_children()
            u.children[1] = simplify(
                to_canonical(replace_all(copy.deepcopy(rhs),
                                         wrap_rules_before)))
        # replace after in after
        done = False
        while not done:
            # replace after in after
            wrap_rules_after = []
            for u in after:
                lhs, rhs = u.get_children()
                if len(lhs.children) > 1:
                    wrap_rules_after.append([])
                    continue
                rules = self.expr_to_rule_rhs_lhs([u])
                wrap_rules_after.append(list(itertools.chain(*rules)))
            #
            after_top = [copy.deepcopy(u) for u in after]
            for i, u in enumerate(after):
                _, rhs = u.get_children()
                rules = list(
                    itertools.chain.from_iterable(wrap_rules_after[:i] +
                                                  wrap_rules_after[i + 1:]))
                u.children[1] = simplify(
                    to_canonical(replace_all(copy.deepcopy(rhs), rules)))
            done = True
            for top, bot in zip(after_top, after):
                if top != bot:
                    done = False
                    break
        # [TODO] Multiple lhss, won't work
        updates = []
        for u in after:
            lhs, rhs = u.get_children()
            lhs = lhs.children[0]  # NList[op] -> op
            if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \
                    matchq(lhs.children[0], rhs.children[0]):
                continue
            updates.append(u)
        #
        tiled_updates = []
        for u in updates:
            print("*   ", u)
            tilings = list(tile_expr(u))
            if len(tilings) > 1:
                print("[WARNING] Multiple (%d) tilings for expression %s" %
                      (len(tilings), u))
                print("          Discarding all but one")
            tiled_updates.extend(tilings[0])
        tiled_updates = sort(tiled_updates)
        print("* Tiled update")
        for t in tiled_updates:
            print("*   ", t)

        # Drop WrapOutBef's
        # Drop WrapOutAft's
        s = PatternDot("s")
        updates = []
        for u in tiled_updates:
            u = replace_all(
                u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))])
            u = replace_all(
                u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))])
            updates.append(u)

        return updates
예제 #6
0
    def find_updates(self, before, after):
        # If a part is (partially) computed in the before and
        #   does not appear in the after or
        # going from before to after requires undoing some computation
        # it is potentially unstable, and more expensive: ignore
        try:
            before_finputs = self.express_in_terms_of_input(before)
            after_finputs = self.express_in_terms_of_input(after)
        except:
            # [TODO] In LU's variant 5, parts of A appear as lhs's
            return None
        #
        dict_bef = dict([(str(u.get_children()[0]), u)
                         for u in before_finputs])
        dict_aft = dict([(str(u.get_children()[0]), u) for u in after_finputs])
        same = []
        ignore = False
        for k, v in dict_bef.items():
            if k in dict_aft and matchq(v, dict_aft[k]):
                same.extend(v.children[0].children)
            if k not in dict_aft:
                ignore = True
                reason = "%s not in %s" % (k, dict_aft.keys())
                break
            else:
                rules = self.expr_to_rule_rhs_lhs([v])
                rules = list(itertools.chain(*rules))
                expr_copy = copy.deepcopy(dict_aft[k])
                t = replace(expr_copy, rules)
                #if v == replace( expr_copy, rules ):
                if dict_aft[k] == t:
                    ignore = True
                    reason = "%s would require undoing job" % k
                    break
        if ignore:
            print("[INFO] Skipping invariant: %s" % reason)
            return None
        #
        # Wrap outputs for before and after
        WrapBefOut = WrapOutBef
        lhss = []
        for u in before:
            lhss.extend(u.children[0])
            u.children[0] = NList([WrapBefOut(l) for l in u.children[0]])
        for u in before:
            u.children[1] = replace(
                u.children[1],
                [RewriteRule(l, Replacement(WrapBefOut(l))) for l in lhss])
        #
        lhss = []
        for u in after:
            lhss.extend(u.children[0])
            u.children[0] = NList([WrapOutAft(l) for l in u.children[0]])
        wrap_rules_after = \
                [
                    RewriteRule(l, Replacement(WrapBefOut(l))) if l in same else
                    RewriteRule(l, Replacement(WrapOutAft(l))) for l in lhss
                ]
        for u in after:
            u.children[1] = replace(u.children[1], wrap_rules_after)
        # replace before in before
        wrap_rules_before = []
        for u in before:
            lhs, rhs = u.get_children()
            #if len(lhs.children) > 1:
            #wrap_rules_before.append([])
            #continue
            rules = self.expr_to_rule_rhs_lhs([u])
            wrap_rules_before.append(list(itertools.chain(*rules)))
        #
        new_rules = []
        for i, rules in enumerate(wrap_rules_before):
            new_rules.append([])
            for rule in rules:
                new_r = copy.deepcopy(rule)
                new_r.pattern = replace_all(
                    new_r.pattern,
                    list(
                        itertools.chain.from_iterable(wrap_rules_before[:i] +
                                                      wrap_rules_before[i +
                                                                        1:])))
                if new_r.pattern != rule.pattern:
                    new_rules[-1].append(new_r)
        for r1, r2 in zip(new_rules, wrap_rules_before):
            r2.extend(r1)
        #
        wrap_rules_before = list(itertools.chain(*wrap_rules_before))
        done = False
        while not done:
            after_top = [copy.deepcopy(u) for u in after]
            for i, u in enumerate(after):
                _, rhs = u.get_children()
                u.children[1] = simplify(
                    to_canonical(
                        replace_all(copy.deepcopy(rhs), wrap_rules_before)))
            done = True
            for top, bot in zip(after_top, after):
                if top != bot:
                    done = False
                    break
        # replace after in after
        done = False
        while not done:
            # replace after in after
            wrap_rules_after = []
            for u in after:
                lhs, rhs = u.get_children()
                #if len(lhs.children) > 1:
                #wrap_rules_after.append([])
                #continue
                rules = self.expr_to_rule_rhs_lhs([u])
                wrap_rules_after.append(list(itertools.chain(*rules)))
            #
            after_top = [copy.deepcopy(u) for u in after]
            for i, u in enumerate(after):
                _, rhs = u.get_children()
                rules = list(
                    itertools.chain.from_iterable(wrap_rules_after[:i] +
                                                  wrap_rules_after[i + 1:]))
                u.children[1] = simplify(
                    to_canonical(replace_all(copy.deepcopy(rhs), rules)))
            done = True
            for top, bot in zip(after_top, after):
                if top != bot:
                    done = False
                    break
        # [TODO] Multiple lhss, won't work
        updates = []
        for u in after:
            lhs, rhs = u.get_children()
            if len(lhs.children) == 1:
                lhs = lhs.children[0]  # NList[op] -> op
                if isinstance(rhs, WrapBefOut) and isinstance(lhs, WrapOutAft) and \
                        matchq(lhs.children[0], rhs.children[0]):
                    continue
            elif not isinstance(rhs,
                                NList):  # multiple outputs/predicate in rhs,
                # but not complete (otherwise it would be NList)
                pass
            else:
                to_skip = True
                for l, r in zip(lhs.children, rhs.children):
                    if not( isinstance(r, WrapBefOut) and isinstance(l, WrapOutAft) and \
                            matchq(l.children[0], r.children[0]) ):
                        to_skip = False
                        break
                if to_skip:
                    continue
            updates.append(u)
        #
        tiled_updates = []
        for u in updates:
            print("*   ", u)
            tilings = list(tile_expr(u))
            if len(tilings) > 1:
                print("[WARNING] Multiple (%d) tilings for expression %s" %
                      (len(tilings), u))
                print("          Discarding all but one")
            tiled_updates.extend(tilings[0])
        tiled_updates = sort(tiled_updates)
        print("* Tiled update")
        for t in tiled_updates:
            print("*   ", t)

        # Drop WrapOutBef's
        # Drop WrapOutAft's
        s = PatternDot("s")
        updates = []
        for u in tiled_updates:
            u = replace_all(
                u, [RewriteRule(WrapOutAft(s), Replacement(lambda d: d["s"]))])
            u = replace_all(
                u, [RewriteRule(WrapOutBef(s), Replacement(lambda d: d["s"]))])
            updates.append(u)

        return updates
예제 #7
0
    def generate_loop_based_algorithms(self):
        print("* Generating Loop-based algorithms...")
        self.algs = []
        variant = 1
        for pme, linvs in zip(self.pmes, self.linvs):
            algs = []
            for linv in linvs:
                print("* ")
                print("* Loop invariant", variant)
                for expr in linv.expressions:
                    print("*     ", expr)
                print("* ")
                trav, init_state, _ = linv.traversals[
                    0]  # this would be another for loop
                init = self.algorithm_initialization(init_state)
                print("* Init")
                #print( init_state )
                print("*   ", init)
                s = PatternDot("s")
                init = [
                    replace_all(i, [
                        RewriteRule(WrapOutBef(s),
                                    Replacement(lambda d: d["s"]))
                    ]) for i in init
                ]
                print("* Before")
                repart, before = self.generate_predicate_before(
                    pme, trav, linv.expressions, linv)
                print("* After")
                reversed_trav = dict([(k, (r * -1, c * -1))
                                      for k, (r, c) in trav.items()])
                cont_with, after = self.generate_predicate_before(
                    pme, reversed_trav, linv.expressions, linv)
                # find updates
                print("* Updates")
                updates = self.find_updates(before, after)
                if updates is None:
                    #variant += 1
                    continue
                # Tile updates
                for u in updates:
                    infer_properties(u)
                final_updates = []
                # [DIEGO] Fixing some output pieces being labeled as input
                outputs = []
                for u in updates:
                    lhs, rhs = u.children
                    for l in lhs:
                        if not l.isTemporary():
                            outputs.append(l)
                            l.set_property(OUTPUT)
                #
                for u in updates:
                    #[DIEGO] u.children[0].children[0].set_property( OUTPUT )
                    for node in u.children[1].iterate_preorder():
                        #if isinstance(node, Symbol):
                        if isinstance(node, Symbol) and node not in outputs:
                            node.set_property(INPUT)
                    #
                    #copy_u = replace( copy.deepcopy(u), known_ops_single )
                    copy_u = copy.deepcopy(u)
                    copy_u = tile_expr(copy.deepcopy(copy_u))[0]  # One tiling
                    final_updates.extend(copy_u)
                if len(updates) == 0:
                    print("No updates!! Should only happen in copy")
                    continue

                algs.append(
                    Algorithm(linv, variant, init, repart, cont_with, before,
                              after, final_updates))
                algs[-1].prepare_for_code_generation()
                variant += 1
            self.algs.append(algs)
예제 #8
0
    def learn_pattern(self):
        inops = [op for op in self.operands if op.isInput()]
        outops = [op for op in self.operands if op.isOutput()]
        #
        single_assignment = len( self.equation.children ) == 1 and \
                            isinstance(self.equation.children[0].children[0], Symbol) # eq.lhs.single_entry_in_NL
        #
        op_to_pattern = [
            RewriteRule(op, Replacement("PatternDot(%s.name)" % op.name))
            for op in self.operands
        ]
        pattern = NList([
            replace_all(copy.deepcopy(eq), op_to_pattern)
            for eq in self.equation
        ])
        if single_assignment:
            props_str = [
                symbol_props_to_constraints_no_io(op) for op in self.operands
            ]
            constraint = Constraint(" and ".join(
                [prop for prop in props_str if prop]))
        else:
            constraint = Constraint(" and ".join(
                [symbol_props_to_constraints(op) for op in self.operands]))
        # [TODO] Tuple for get_size
        replacement = Replacement(
            "Equal([ NList([%s]), Predicate( \"%s\", [%s], [%s] ) ])" %
            (", ".join([op.name for op in outops]), self.name, ", ".join([
                op.name for op in inops
            ]), ", ".join(["%s.get_size()" % op.get_name() for op in outops])))
        # [TODO] This should be part of the verbose option
        print("* Learnt pattern")
        print("*   ", pattern, end="")
        if constraint.to_eval:
            print("with        ", constraint.to_eval)
        print(" --> ")
        print("*          ", replacement.to_eval)
        print("**********************************")
        # [TODO] Maybe sort known ops by specificity (a la mathematica)
        #known_ops.insert( 0, RewriteRule( (pattern, constraint), replacement ) )

        if single_assignment:
            expr = pattern.children[0]
            expr.children[0] = NList([expr.children[0]])
            known_ops_single.insert(
                0, RewriteRule((expr, constraint), replacement))
            # With minus
            replacement = Replacement(
                "Equal([ NList([%s]), Minus([ Predicate( \"%s\", [%s], [%s] ) ]) ])"
                % (", ".join([op.name for op in outops]), self.name, ", ".join(
                    [op.name for op in inops]), ", ".join(
                        ["%s.get_size()" % op.get_name() for op in outops])))
            expr = copy.deepcopy(expr)
            expr.children[1] = Minus([expr.children[1]])
            expr.children[1] = normalize_minus(copy.deepcopy(expr.children[1]))
            known_ops_single.insert(
                0, RewriteRule((expr, constraint), replacement))
            #with open(os.path.join("OUTPUT", self.name+"_patterns"), "wb") as patt_f:
            #pickle.dump( known_ops_single[1], patt_f )
            #pickle.dump( known_ops_single[0], patt_f )
        else:
            known_ops.insert(0, RewriteRule((pattern, constraint),
                                            replacement))
            with open(os.path.join("OUTPUT", self.name + "_patterns"),
                      "wb") as patt_f:
                pickle.dump(known_ops[0], patt_f)

        pattern = Equal([
            NList([PatternDot(op.get_name()) for op in outops]),
            Predicate(self.name, [PatternDot(op.get_name()) for op in inops],
                      [op.get_size() for op in outops])
        ])
        replacement = Replacement(equation2replacement(self.equation))
        op_to_implicit.append(RewriteRule(pattern, replacement))
예제 #9
0
def simplify(expr):
    return replace_all(expr, simplify_rules)
예제 #10
0
def to_canonicalIO(expr):
    return replace_all(expr, canonical_rules + canonicalIO_rules)
예제 #11
0
def to_canonical(expr):
    return replace_all(expr, canonical_rules + simplify_rules_base)