def map_power(self, expr): from pymbolic.primitives import Expression, Sum if isinstance(expr.base, Product): return self.rec(pymbolic.flattened_product( child**expr.exponent for child in newbase)) if isinstance(expr.exponent, int): newbase = self.rec(expr.base) if isinstance(newbase, Sum): return self.map_product(pymbolic.flattened_product(expr.exponent*(newbase,))) else: return IdentityMapper.map_power(self, expr) else: return IdentityMapper.map_power(self, expr)
def map_power(self, expr): from pymbolic.primitives import Expression, Sum if isinstance(expr.base, Product): return self.rec(pymbolic.flattened_product( child**expr.exponent for child in newbase)) if isinstance(expr.exponent, int): newbase = self.rec(expr.base) if isinstance(newbase, Sum): return self.map_product(pymbolic.flattened_product(expr.exponent*(newbase,))) else: return IdentityMapper.map_power(self, expr) else: return IdentityMapper.map_power(self, expr)
def map_power(self, expr): exp = expr.exponent if isinstance(exp, int): new_base = prim.wrap_in_cse(expr.base) if exp > 1 and exp % 2 == 0: square = prim.wrap_in_cse(new_base * new_base) return self.rec(prim.wrap_in_cse(square**(exp // 2))) elif exp > 1 and exp % 2 == 1: square = prim.wrap_in_cse(new_base * new_base) return self.rec( prim.wrap_in_cse(square**((exp - 1) // 2)) * new_base) elif exp == 1: return new_base elif exp < 0: return self.rec((1 / new_base)**(-exp)) if (isinstance(expr.exponent, prim.Quotient) and isinstance(expr.exponent.numerator, int) and isinstance(expr.exponent.denominator, int)): p, q = expr.exponent.numerator, expr.exponent.denominator if q < 0: q *= -1 p *= -1 if q == 1: return self.rec(new_base**p) if q == 2: assert p != 0 if p > 0: orig_base = prim.wrap_in_cse(expr.base) new_base = prim.wrap_in_cse( prim.Variable("sqrt")(orig_base)) else: new_base = prim.wrap_in_cse( prim.Variable("rsqrt")(expr.base)) p *= -1 return self.rec(new_base**p) return IdentityMapper.map_power(self, expr)
def map_power(self, expr): exp = expr.exponent if isinstance(exp, int): new_base = prim.wrap_in_cse(expr.base) if exp > 1 and exp % 2 == 0: square = prim.wrap_in_cse(new_base*new_base) return self.rec(prim.wrap_in_cse(square**(exp//2))) if exp > 1 and exp % 2 == 1: square = prim.wrap_in_cse(new_base*new_base) return self.rec(prim.wrap_in_cse(square**((exp-1)//2))*new_base) elif exp == 1: return new_base elif exp < 0: return self.rec((1/new_base)**(-exp)) if (isinstance(expr.exponent, prim.Quotient) and isinstance(expr.exponent.numerator, int) and isinstance(expr.exponent.denominator, int)): p, q = expr.exponent.numerator, expr.exponent.denominator if q < 0: q *= -1 p *= -1 if q == 1: return self.rec(new_base**p) if q == 2: assert p != 0 if p > 0: orig_base = prim.wrap_in_cse(expr.base) new_base = prim.wrap_in_cse(prim.Variable("sqrt")(orig_base)) else: new_base = prim.wrap_in_cse(prim.Variable("rsqrt")(expr.base)) p *= -1 return self.rec(new_base**p) return IdentityMapper.map_power(self, expr)