def __init__(self, warn_on_digit_loss=True, int_type=np.int64, float_type=np.float64): IdentityMapper.__init__(self) self.warn = warn_on_digit_loss self.float_type = float_type self.iinfo = np.iinfo(int_type)
def map_constant(self, expr): """Convert complex values not within complex64 to a product for loopy """ if not isinstance(expr, complex): return IdentityMapper.map_constant(self, expr) if complex(self.float_type(expr.imag)) == expr.imag: return IdentityMapper.map_constant(self, expr) return expr.real + prim.Product((expr.imag, 1j))
def map_variable(self, expr): match_obj = INDEXED_VAR_RE.match(expr.name) if match_obj is not None: name = match_obj.group(1) subscript = int(match_obj.group(2)) if name in self.name_whitelist: return prim.Variable(name).index(subscript) else: return IdentityMapper.map_variable(self, expr) else: return IdentityMapper.map_variable(self, expr)
def map_power(self, expr): from pymbolic.primitives import Expression, Sum if isinstance(expr.base, Product): return self.rec(pymbolic.flattened_product( child**expr.exponent for child in newbase)) if isinstance(expr.exponent, int): newbase = self.rec(expr.base) if isinstance(newbase, Sum): return self.map_product(pymbolic.flattened_product(expr.exponent*(newbase,))) else: return IdentityMapper.map_power(self, expr) else: return IdentityMapper.map_power(self, expr)
def rec(self, expr): if _is_atomic(expr) or not self.is_constant[expr]: return IdentityMapper.rec(self, expr) else: new_var = self.new_var_func() self.assignments[new_var] = expr return new_var
def map_product(self, expr): from pymbolic.primitives import Sum, Product def expand(prod): if not isinstance(prod, Product): return prod leading = [] for i in prod.children: if isinstance(i, Sum): break else: leading.append(i) if len(leading) == len(prod.children): # no more sums found result = pymbolic.flattened_product(prod.children) return result else: sum = prod.children[len(leading)] assert isinstance(sum, Sum) rest = prod.children[len(leading)+1:] if rest: rest = expand(Product(rest)) else: rest = 1 result = self.collector(pymbolic.flattened_sum( pymbolic.flattened_product(leading) * expand(sumchild*rest) for sumchild in sum.children )) return result return expand(IdentityMapper.map_product(self, expr))
def map_sum(self, expr): from pymbolic.primitives import Sum res = IdentityMapper.map_sum(self, expr) if isinstance(res, Sum): return self.collector(res) else: return res
def map_substitution(self, expr): assert isinstance(expr.child, prim.Derivative) call = expr.child.child if (isinstance(call.function, prim.Variable) and call.function.name in ["hankel_1", "bessel_j"]): function = call.function order, _ = call.parameters arg, = expr.values n_derivs = len(expr.child.variables) import sympy as sym # AS (9.1.31) # http://dlmf.nist.gov/10.6.7 if order >= 0: order_str = str(order) else: order_str = "m"+str(-order) k = n_derivs return prim.CommonSubexpression( 2**(-k)*sum( (-1)**idx*int(sym.binomial(k, idx)) * function(i, arg) for idx, i in enumerate(range(order-k, order+k+1, 2))), "d%d_%s_%s" % (n_derivs, function.name, order_str)) else: return IdentityMapper.map_substitution(self, expr)
def map_substitution(self, expr): assert isinstance(expr.child, prim.Derivative) call = expr.child.child if (isinstance(call.function, prim.Variable) and call.function.name in ["hankel_1", "bessel_j"]): function = call.function order, _ = call.parameters arg, = expr.values n_derivs = len(expr.child.variables) import sympy as sp # AS (9.1.31) # http://dlmf.nist.gov/10.6.7 if order >= 0: order_str = str(order) else: order_str = "m"+str(-order) k = n_derivs return prim.CommonSubexpression( 2**(-k)*sum( (-1)**idx*int(sp.binomial(k, idx)) * function(i, arg) for idx, i in enumerate(range(order-k, order+k+1, 2))), "d%d_%s_%s" % (n_derivs, function.name, order_str)) else: return IdentityMapper.map_substitution(self, expr)
def map_product(self, expr): def dist(prod): if not isinstance(prod, Product): return prod leading = [] for i in prod.children: if isinstance(i, Sum): break else: leading.append(i) if len(leading) == len(prod.children): # no more sums found result = pymbolic.flattened_product(prod.children) return result else: sum = prod.children[len(leading)] assert isinstance(sum, Sum) rest = prod.children[len(leading)+1:] if rest: rest = dist(Product(rest)) else: rest = 1 result = self.collect(pymbolic.flattened_sum( pymbolic.flattened_product(leading) * dist(sumchild*rest) for sumchild in sum.children )) return result return dist(IdentityMapper.map_product(self, expr))
def map_call(self, expr): if isinstance(expr.function, prim.Variable): name = expr.function.name if name in ["hankel_1", "bessel_j"]: order, arg = expr.parameters return getattr(self.bessel_getter, name)(order, self.rec(arg)) return IdentityMapper.map_call(self, expr)
def map_subscript(self, expr): from pymbolic.primitives import CommonSubexpression if expr.aggregate.name == self.vec_name \ and isinstance(expr.index, int): return CommonSubexpression( expr.aggregate.index((expr.index, ) + self.additional_indices)) else: return IdentityMapper.map_subscript(self, expr)
def map_subscript(self, expr): from pymbolic.primitives import CommonSubexpression if expr.aggregate.name == self.vec_name \ and isinstance(expr.index, int): return CommonSubexpression(expr.aggregate.index( (expr.index,) + self.additional_indices)) else: return IdentityMapper.map_subscript(self, expr)
def map_quotient(self, expr): num = expr.numerator denom = expr.denominator if isinstance(num, int) and isinstance(denom, int): if num % denom == 0: return num // denom return int(expr.numerator) / int(expr.denominator) return IdentityMapper.map_quotient(self, expr)
def test_mappers(): from pymbolic import variables f, x, y, z = variables("f x y z") for expr in [f(x, (y, z), name=z**2)]: from pymbolic.mapper import WalkMapper from pymbolic.mapper.dependency import DependencyMapper str(expr) IdentityMapper()(expr) WalkMapper()(expr) DependencyMapper()(expr)
def map_power(self, expr): exp = expr.exponent if isinstance(exp, int): new_base = prim.wrap_in_cse(expr.base) if exp > 1 and exp % 2 == 0: square = prim.wrap_in_cse(new_base * new_base) return self.rec(prim.wrap_in_cse(square**(exp // 2))) elif exp > 1 and exp % 2 == 1: square = prim.wrap_in_cse(new_base * new_base) return self.rec( prim.wrap_in_cse(square**((exp - 1) // 2)) * new_base) elif exp == 1: return new_base elif exp < 0: return self.rec((1 / new_base)**(-exp)) if (isinstance(expr.exponent, prim.Quotient) and isinstance(expr.exponent.numerator, int) and isinstance(expr.exponent.denominator, int)): p, q = expr.exponent.numerator, expr.exponent.denominator if q < 0: q *= -1 p *= -1 if q == 1: return self.rec(new_base**p) if q == 2: assert p != 0 if p > 0: orig_base = prim.wrap_in_cse(expr.base) new_base = prim.wrap_in_cse( prim.Variable("sqrt")(orig_base)) else: new_base = prim.wrap_in_cse( prim.Variable("rsqrt")(expr.base)) p *= -1 return self.rec(new_base**p) return IdentityMapper.map_power(self, expr)
def map_power(self, expr): exp = expr.exponent if isinstance(exp, int): new_base = prim.wrap_in_cse(expr.base) if exp > 1 and exp % 2 == 0: square = prim.wrap_in_cse(new_base*new_base) return self.rec(prim.wrap_in_cse(square**(exp//2))) if exp > 1 and exp % 2 == 1: square = prim.wrap_in_cse(new_base*new_base) return self.rec(prim.wrap_in_cse(square**((exp-1)//2))*new_base) elif exp == 1: return new_base elif exp < 0: return self.rec((1/new_base)**(-exp)) if (isinstance(expr.exponent, prim.Quotient) and isinstance(expr.exponent.numerator, int) and isinstance(expr.exponent.denominator, int)): p, q = expr.exponent.numerator, expr.exponent.denominator if q < 0: q *= -1 p *= -1 if q == 1: return self.rec(new_base**p) if q == 2: assert p != 0 if p > 0: orig_base = prim.wrap_in_cse(expr.base) new_base = prim.wrap_in_cse(prim.Variable("sqrt")(orig_base)) else: new_base = prim.wrap_in_cse(prim.Variable("rsqrt")(expr.base)) p *= -1 return self.rec(new_base**p) return IdentityMapper.map_power(self, expr)
def map_constant(self, expr): """Convert integer values not within the range of `self.int_type` to float. """ if not is_integer(expr): return IdentityMapper.map_constant(self, expr) if self.iinfo.min <= expr <= self.iinfo.max: return expr if self.warn: expr_as_float = self.float_type(expr) if int(expr_as_float) != int(expr): from warnings import warn warn("Converting '%d' to '%s' loses digits" % (expr, self.float_type.__name__)) # Suppress further warnings. self.warn = False return expr_as_float return self.float_type(expr)
def map_variable(self, expr): if expr.name == "pi": return prim.Variable("M_PI") else: return IdentityMapper.map_variable(self, expr)
def map_sum(self, expr): res = IdentityMapper.map_sum(self, expr) if isinstance(res, Sum): return self.collect(res) else: return res
def __call__(self, expr, new_var_func): self.new_var_func = new_var_func self.is_constant = self.constant_finding_mapper(expr) self.assignments = {} result = IdentityMapper.__call__(self, expr) return result, self.assignments
def test_structure_preservation(): x = prim.Sum((5, 7)) from pymbolic.mapper import IdentityMapper x2 = IdentityMapper()(x) assert x == x2
def __init__(self, float_type=np.float32): IdentityMapper.__init__(self) self.float_type = float_type