コード例 #1
0
ファイル: test_dump.py プロジェクト: sthagen/hope
 def test_dump_complex(self):
     wrapper = DummyWrapper()
     module = Module(wrapper.dummy.__name__)
     args = [wrapper, np.arange(10)]
     ASTTransformer(module).module_visit(wrapper.dummy, args)
     fkt_str = Dumper().visit(module)
     assert fkt_str is not None
コード例 #2
0
 def __init__(self):
     self.dumper = Dumper()
     self.checkVisitor, self.createExpr, self.sympyToAst, self.sympyPow, self.next = CheckOptimizeVisitor(
     ), CreateExprVisitor(
         self.dumper), SympyToAstVisitor(), SympyPowVisitor(), 0
コード例 #3
0
class Optimizer(NodeVisitor):
    def __init__(self):
        self.dumper = Dumper()
        self.checkVisitor, self.createExpr, self.sympyToAst, self.sympyPow, self.next = CheckOptimizeVisitor(
        ), CreateExprVisitor(
            self.dumper), SympyToAstVisitor(), SympyPowVisitor(), 0

    def visit_Number(self, node):
        pass

    def visit_NewVariable(self, node):
        pass

    def visit_Variable(self, node):
        pass

    def visit_Object(self, node):
        pass

    def visit_ObjectAttr(self, node):
        pass

    def visit_Dimension(self, node):
        pass

    def visit_View(self, node):
        pass

    def visit_Expr(self, node):
        pass

    def visit_Assign(self, node):
        pass

    def visit_AugAssign(self, node):
        pass

    def visit_UnaryOp(self, node):
        pass

    def visit_BinOp(self, node):
        pass

    def visit_Compare(self, node):
        pass

    def visit_If(self, node):
        # TODO: optimize condition
        # if condition is compile time -> remove!
        self.visit(node.body)
        if not node.orelse is None:
            self.visit(node.orelse)

    def visit_For(self, node):
        self.symbols[self.dumper.visit(node.iter)] = node.iter
        self.visit(node.body)
        del self.symbols[self.dumper.visit(node.iter)]

    def visit_While(self, node):
        self.visit(node.body)

    def visit_Call(self, node):
        pass

    def visit_GlobalFunction(self, node):
        pass

    def visit_HopeAttr(self, node):
        pass

    def visit_NumpyAttr(self, node):
        pass

    def visit_NumpyContraction(self, node):
        pass

    def visit_Allocate(self, node):
        self.symbols[self.dumper.visit(node.variable)] = node.variable

    def visit_Return(self, node):
        pass

    def visit_Block(self, node):
        body, knownexprs, powexprs = [], {}, {}
        for astexpr in node.body:
            self.visit(astexpr)
            if isinstance(astexpr, Assign):
                if isinstance(astexpr.target, View):
                    self.symbols[self.dumper.visit(
                        astexpr.target.variable)] = astexpr.target
                elif isinstance(astexpr.target, Variable):
                    self.symbols[self.dumper.visit(
                        astexpr.target)] = astexpr.target
                else:
                    raise Exception("Unknown token".format(
                        self.dumper.visit(astexpr.target)))
            # TODO: implement for expr
            # TODO: replace subexpressions over several lines
            if isinstance(astexpr,
                          (Assign, AugAssign)) and self.checkVisitor.visit(
                              astexpr.value):
                symexpr = sp.simplify(self.createExpr.visit(astexpr.value))
                subexprs, newexprs = sp.cse(symexpr, optimizations='basic')
                if len(newexprs) != 1:
                    raise Exception(
                        "Error running Common Subexpression Detection for {1!s}"
                        .format(symexpr))
                newexpr = newexprs[0]
                for symbol, subexpr in subexprs:
                    for subsymbol, newsymbol in list(knownexprs.items()):
                        subexpr = subexpr.subs(subsymbol, newsymbol)
                    for powexpr in self.sympyPow.visit(subexpr):
                        subexpr, _ = self.replace_pow(body, subexpr, powexprs,
                                                      powexpr,
                                                      np.abs(powexpr.exp.p))
                    value = self.sympyToAst.visit(sp.simplify(subexpr))
                    name, self.next = "__sp{0}".format(
                        self.next), self.next + 1
                    self.symbols[name] = Variable(name,
                                                  copy.deepcopy(value.shape),
                                                  value.dtype)
                    body.append(Assign(self.symbols[name], value))
                    knownexprs[symbol] = sp.Symbol(name)
                    newexpr = newexpr.subs(symbol, knownexprs[symbol])
                for powexpr in sorted(self.sympyPow.visit(newexpr),
                                      key=lambda x: -np.abs(x.exp.p)):
                    newexpr, _ = self.replace_pow(body, newexpr,
                                                  powexprs, powexpr,
                                                  np.abs(powexpr.exp.p))
                newvalue = self.sympyToAst.visit(sp.simplify(newexpr))
                if astexpr.value.dtype != newvalue.dtype:
                    if isinstance(newvalue, Number):
                        newvalue = Number(astexpr.value.dtype(newvalue.value))
                    else:
                        raise Exception(
                            "dtype does not match {0} != {1}".format(
                                self.dumper.visit(astexpr.value),
                                self.dumper.visit(newvalue)))
                if not (len(astexpr.target.shape) > 0
                        and len(newvalue.shape) == 0):
                    if len(astexpr.value.shape) != len(newvalue.shape):
                        raise Exception(
                            "length of shape does not match {0} != {1}".format(
                                self.dumper.visit(astexpr.value),
                                self.dumper.visit(newvalue)))
                    for extent1, extent2 in zip(astexpr.value.shape,
                                                newvalue.shape):
                        (lower1, upper1), (lower2, upper2) = extent1, extent2
                        if not ((lower1 is None and lower2 is None)
                                or lower1 == lower2) or upper1 != upper2:
                            raise Exception(
                                "shape does not match {0} != {1}".format(
                                    self.dumper.visit(astexpr.value),
                                    self.dumper.visit(newvalue)))
                astexpr.value = newvalue
                body.append(astexpr)
            else:
                body.append(astexpr)
        node.body = body

    def visit_Body(self, node):
        for block in node.blocks:
            self.visit(block)

    def visit_FunctionDef(self, node):
        if not node.optimized:
            self.symbols = {}
            for var in node.signature:
                self.add_symbol(var)
            self.createExpr.symbols, self.sympyToAst.symbols = self.symbols, self.symbols
            node.optimized = True
            self.visit(node.body)

    def visit_Module(self, node):
        for fktcls in list(node.functions.values()):
            for fkt in fktcls:
                self.visit(fkt)

    def add_symbol(self, symbol):
        if isinstance(symbol, Object):
            for attr in list(symbol.attrs.values()):
                self.add_symbol(attr)
        else:
            self.symbols[self.dumper.visit(symbol)] = symbol

    def replace_pow(self, body, symexpr, powexprs, expr, exp):
        if exp == 1:
            return (symexpr, None)
        elif not (expr.base, exp) in powexprs:
            if exp == 2:
                operand = sp.simplify(expr.base)
                value = BinOp("Mult", self.sympyToAst.visit(operand),
                              self.sympyToAst.visit(operand))
            elif exp % 2 == 1:
                _, operand = self.replace_pow(body, symexpr, powexprs, expr,
                                              exp - 1)
                value = BinOp("Mult", self.symbols[operand],
                              self.sympyToAst.visit(sp.simplify(expr.base)))
            else:
                _, operand = self.replace_pow(body, symexpr, powexprs, expr,
                                              exp / 2)
                value = BinOp("Mult", self.symbols[operand],
                              self.symbols[operand])
            name, self.next = "__sp{0}".format(self.next), self.next + 1
            self.symbols[name] = Variable(name, copy.deepcopy(value.shape),
                                          value.dtype)
            body.append(Assign(self.symbols[name], value))
            powexprs[(expr.base, exp)] = name
        if np.abs(expr.exp.p) == exp:
            symbol = sp.Symbol(powexprs[(expr.base, exp)])
            symexpr = symexpr.subs(
                expr, self.symbols[powexprs[(expr.base, exp)]].dtype(1) /
                symbol if expr.exp.is_negative else symbol)
        return (symexpr, powexprs[(expr.base, exp)])
コード例 #4
0
ファイル: _optimizer.py プロジェクト: BrainGrylls/hope
 def __init__(self):
     self.dumper = Dumper()
     self.checkVisitor, self.createExpr, self.sympyToAst, self.sympyPow, self.next = CheckOptimizeVisitor(), CreateExprVisitor(self.dumper), SympyToAstVisitor(), SympyPowVisitor(), 0
コード例 #5
0
ファイル: _optimizer.py プロジェクト: BrainGrylls/hope
class Optimizer(NodeVisitor):
    def __init__(self):
        self.dumper = Dumper()
        self.checkVisitor, self.createExpr, self.sympyToAst, self.sympyPow, self.next = CheckOptimizeVisitor(), CreateExprVisitor(self.dumper), SympyToAstVisitor(), SympyPowVisitor(), 0

    def visit_Number(self, node): pass
    def visit_NewVariable(self, node): pass
    def visit_Variable(self, node): pass
    def visit_Object(self, node): pass
    def visit_ObjectAttr(self, node): pass
    def visit_Dimension(self, node): pass
    def visit_View(self, node): pass
    def visit_Expr(self, node): pass
    def visit_Assign(self, node): pass
    def visit_AugAssign(self, node): pass
    def visit_UnaryOp(self, node): pass
    def visit_BinOp(self, node): pass
    def visit_Compare(self, node): pass

    def visit_If(self, node):
        # TODO: optimize condition
        # if condition is compile time -> remove!
        self.visit(node.body)
        if not node.orelse is None:
            self.visit(node.orelse)

    def visit_For(self, node):
        self.symbols[self.dumper.visit(node.iter)] = node.iter
        self.visit(node.body)
        del self.symbols[self.dumper.visit(node.iter)]
    def visit_While(self, node):
        self.visit(node.body)

    def visit_Call(self, node): pass
    def visit_GlobalFunction(self, node): pass
    def visit_HopeAttr(self, node): pass
    def visit_NumpyAttr(self, node): pass
    def visit_NumpyContraction(self, node): pass

    def visit_Allocate(self, node):
        self.symbols[self.dumper.visit(node.variable)] = node.variable

    def visit_Return(self, node): pass

    def visit_Block(self, node):
        body, knownexprs, powexprs = [], {}, {}
        for astexpr in node.body:
            self.visit(astexpr)
            if isinstance(astexpr, Assign):
                if isinstance(astexpr.target, View):
                    self.symbols[self.dumper.visit(astexpr.target.variable)] = astexpr.target
                elif isinstance(astexpr.target, Variable):
                    self.symbols[self.dumper.visit(astexpr.target)] = astexpr.target
                else:
                    raise Exception("Unknown token".format(self.dumper.visit(astexpr.target)))
            # TODO: implement for expr
            # TODO: replace subexpressions over several lines
            if isinstance(astexpr, (Assign, AugAssign)) and self.checkVisitor.visit(astexpr.value):
                symexpr = sp.simplify(self.createExpr.visit(astexpr.value))
                subexprs, newexprs = sp.cse(symexpr, optimizations='basic')
                if len(newexprs) != 1:
                    raise Exception("Error running Common Subexpression Detection for {1!s}".format(symexpr))
                newexpr = newexprs[0]
                for symbol, subexpr in subexprs:
                    for subsymbol, newsymbol in list(knownexprs.items()):
                        subexpr = subexpr.subs(subsymbol, newsymbol)
                    for powexpr in self.sympyPow.visit(subexpr):
                        subexpr, _ = self.replace_pow(body, subexpr, powexprs, powexpr, np.abs(powexpr.exp.p))
                    value = self.sympyToAst.visit(sp.simplify(subexpr))
                    name, self.next = "__sp{0}".format(self.next), self.next + 1
                    self.symbols[name] = Variable(name, copy.deepcopy(value.shape), value.dtype)
                    body.append(Assign(self.symbols[name], value))
                    knownexprs[symbol] = sp.Symbol(name)
                    newexpr = newexpr.subs(symbol, knownexprs[symbol])
                for powexpr in sorted(self.sympyPow.visit(newexpr), key=lambda x: -np.abs(x.exp.p)):
                    newexpr, _ = self.replace_pow(body, newexpr, powexprs, powexpr, np.abs(powexpr.exp.p))
                newvalue = self.sympyToAst.visit(sp.simplify(newexpr))
                if astexpr.value.dtype != newvalue.dtype:
                    if isinstance(newvalue, Number):
                        newvalue = Number(astexpr.value.dtype(newvalue.value))
                    else:
                        raise Exception("dtype does not match {0} != {1}".format(self.dumper.visit(astexpr.value), self.dumper.visit(newvalue)))
                if not(len(astexpr.target.shape) > 0 and len(newvalue.shape) == 0):
                    if len(astexpr.value.shape) != len(newvalue.shape):
                        raise Exception("length of shape does not match {0} != {1}".format(self.dumper.visit(astexpr.value), self.dumper.visit(newvalue)))
                    for extent1, extent2 in zip(astexpr.value.shape, newvalue.shape):
                        (lower1, upper1), (lower2, upper2) = extent1, extent2
                        if not ((lower1 is None and lower2 is None) or lower1 == lower2) or upper1 != upper2:
                            raise Exception("shape does not match {0} != {1}".format(self.dumper.visit(astexpr.value), self.dumper.visit(newvalue)))
                astexpr.value = newvalue
                body.append(astexpr)
            else:
                body.append(astexpr)
        node.body = body

    def visit_Body(self, node):
        for block in node.blocks:
            self.visit(block)

    def visit_FunctionDef(self, node):
        if not node.optimized:
            self.symbols = {}
            for var in node.signature:
                self.add_symbol(var)
            self.createExpr.symbols, self.sympyToAst.symbols = self.symbols, self.symbols
            node.optimized = True
            self.visit(node.body)

    def visit_Module(self, node):
        for fktcls in list(node.functions.values()):
            for fkt in fktcls:
                self.visit(fkt)

    def add_symbol(self, symbol):
        if isinstance(symbol, Object):
            for attr in list(symbol.attrs.values()):
                self.add_symbol(attr)
        else:
            self.symbols[self.dumper.visit(symbol)] = symbol

    def replace_pow(self, body, symexpr, powexprs, expr, exp):
        if exp == 1:
            return (symexpr, None)
        elif not (expr.base, exp) in powexprs:
            if exp == 2:
                operand = sp.simplify(expr.base)
                value = BinOp("Mult", self.sympyToAst.visit(operand), self.sympyToAst.visit(operand))
            elif exp % 2 == 1:
                _, operand = self.replace_pow(body, symexpr, powexprs, expr, exp - 1)
                value = BinOp("Mult", self.symbols[operand], self.sympyToAst.visit(sp.simplify(expr.base)))
            else:
                _, operand = self.replace_pow(body, symexpr, powexprs, expr, exp / 2)
                value = BinOp("Mult", self.symbols[operand], self.symbols[operand])
            name, self.next = "__sp{0}".format(self.next), self.next + 1
            self.symbols[name] = Variable(name, copy.deepcopy(value.shape), value.dtype)
            body.append(Assign(self.symbols[name], value))
            powexprs[(expr.base, exp)] = name
        if np.abs(expr.exp.p) == exp:
            symbol = sp.Symbol(powexprs[(expr.base, exp)])
            symexpr = symexpr.subs(expr, self.symbols[powexprs[(expr.base, exp)]].dtype(1) / symbol if expr.exp.is_negative else symbol)
        return (symexpr, powexprs[(expr.base, exp)])
コード例 #6
0
 def __init__(self):
     self.next_loopid, self.merged, self.slicemap, self.library, self.dumper  = 0, None, {}, {}, Dumper()
コード例 #7
0
 def __str__(self):
     from hope._dump import Dumper
     return Dumper().visit(self)