Пример #1
0
    def get_cse(self, expr, key=None):
        if key is None:
            key = self.get_key(expr)

        try:
            return self.canonical_subexprs[key]
        except KeyError:
            new_expr = prim.wrap_in_cse(getattr(IdentityMapper, expr.mapper_method)(self, expr))
            self.canonical_subexprs[key] = new_expr
            return new_expr
Пример #2
0
    def get_cse(self, expr, key=None):
        if key is None:
            key = self.get_key(expr)

        try:
            return self.canonical_subexprs[key]
        except KeyError:
            new_expr = prim.wrap_in_cse(
                getattr(IdentityMapper, expr.mapper_method)(self, expr))
            self.canonical_subexprs[key] = new_expr
            return new_expr
Пример #3
0
    def map_common_subexpression(self, expr):
        # Avoid creating CSE(CSE(...))
        if type(expr) is prim.CommonSubexpression:
            return prim.wrap_in_cse(self.rec(expr.child), expr.prefix)
        else:
            # expr is of a derived CSE type
            result = self.rec(expr.child)
            if type(result) is prim.CommonSubexpression:
                result = result.child

            return type(expr)(result, expr.prefix, **expr.get_extra_properties())
Пример #4
0
    def map_common_subexpression(self, expr):
        # Avoid creating CSE(CSE(...))
        if type(expr) is prim.CommonSubexpression:
            return prim.wrap_in_cse(self.rec(expr.child), expr.prefix)
        else:
            # expr is of a derived CSE type
            result = self.rec(expr.child)
            if type(result) is prim.CommonSubexpression:
                result = result.child

            return type(expr)(result, expr.prefix,
                              **expr.get_extra_properties())
Пример #5
0
    def map_power(self, expr):
        exp = expr.exponent
        if isinstance(exp, int):
            new_base = prim.wrap_in_cse(expr.base)

            if exp > 1 and exp % 2 == 0:
                square = prim.wrap_in_cse(new_base * new_base)
                return self.rec(prim.wrap_in_cse(square**(exp // 2)))
            elif exp > 1 and exp % 2 == 1:
                square = prim.wrap_in_cse(new_base * new_base)
                return self.rec(
                    prim.wrap_in_cse(square**((exp - 1) // 2)) * new_base)
            elif exp == 1:
                return new_base
            elif exp < 0:
                return self.rec((1 / new_base)**(-exp))

        if (isinstance(expr.exponent, prim.Quotient)
                and isinstance(expr.exponent.numerator, int)
                and isinstance(expr.exponent.denominator, int)):

            p, q = expr.exponent.numerator, expr.exponent.denominator
            if q < 0:
                q *= -1
                p *= -1

            if q == 1:
                return self.rec(new_base**p)

            if q == 2:
                assert p != 0

                if p > 0:
                    orig_base = prim.wrap_in_cse(expr.base)
                    new_base = prim.wrap_in_cse(
                        prim.Variable("sqrt")(orig_base))
                else:
                    new_base = prim.wrap_in_cse(
                        prim.Variable("rsqrt")(expr.base))
                    p *= -1

                return self.rec(new_base**p)

        return IdentityMapper.map_power(self, expr)
Пример #6
0
    def map_power(self, expr):
        exp = expr.exponent
        if isinstance(exp, int):
            new_base = prim.wrap_in_cse(expr.base)

            if exp > 1 and exp % 2 == 0:
                square = prim.wrap_in_cse(new_base*new_base)
                return self.rec(prim.wrap_in_cse(square**(exp//2)))
            if exp > 1 and exp % 2 == 1:
                square = prim.wrap_in_cse(new_base*new_base)
                return self.rec(prim.wrap_in_cse(square**((exp-1)//2))*new_base)
            elif exp == 1:
                return new_base
            elif exp < 0:
                return self.rec((1/new_base)**(-exp))

        if (isinstance(expr.exponent, prim.Quotient)
                and isinstance(expr.exponent.numerator, int)
                and isinstance(expr.exponent.denominator, int)):

            p, q = expr.exponent.numerator, expr.exponent.denominator
            if q < 0:
                q *= -1
                p *= -1

            if q == 1:
                return self.rec(new_base**p)

            if q == 2:
                assert p != 0

                if p > 0:
                    orig_base = prim.wrap_in_cse(expr.base)
                    new_base = prim.wrap_in_cse(prim.Variable("sqrt")(orig_base))
                else:
                    new_base = prim.wrap_in_cse(prim.Variable("rsqrt")(expr.base))
                    p *= -1

                return self.rec(new_base**p)

        return IdentityMapper.map_power(self, expr)
Пример #7
0
 def wrap_in_cse(self, expr, prefix):
     cse = prim.wrap_in_cse(expr, prefix)
     return self.cse_cache.setdefault(expr, cse)
Пример #8
0
 def wrap_in_cse(self, expr, prefix):
     cse = prim.wrap_in_cse(expr, prefix)
     return self.cse_cache.setdefault(expr, cse)