def get_cse(self, expr, key=None): if key is None: key = self.get_key(expr) try: return self.canonical_subexprs[key] except KeyError: new_expr = prim.wrap_in_cse(getattr(IdentityMapper, expr.mapper_method)(self, expr)) self.canonical_subexprs[key] = new_expr return new_expr
def get_cse(self, expr, key=None): if key is None: key = self.get_key(expr) try: return self.canonical_subexprs[key] except KeyError: new_expr = prim.wrap_in_cse( getattr(IdentityMapper, expr.mapper_method)(self, expr)) self.canonical_subexprs[key] = new_expr return new_expr
def map_common_subexpression(self, expr): # Avoid creating CSE(CSE(...)) if type(expr) is prim.CommonSubexpression: return prim.wrap_in_cse(self.rec(expr.child), expr.prefix) else: # expr is of a derived CSE type result = self.rec(expr.child) if type(result) is prim.CommonSubexpression: result = result.child return type(expr)(result, expr.prefix, **expr.get_extra_properties())
def map_power(self, expr): exp = expr.exponent if isinstance(exp, int): new_base = prim.wrap_in_cse(expr.base) if exp > 1 and exp % 2 == 0: square = prim.wrap_in_cse(new_base * new_base) return self.rec(prim.wrap_in_cse(square**(exp // 2))) elif exp > 1 and exp % 2 == 1: square = prim.wrap_in_cse(new_base * new_base) return self.rec( prim.wrap_in_cse(square**((exp - 1) // 2)) * new_base) elif exp == 1: return new_base elif exp < 0: return self.rec((1 / new_base)**(-exp)) if (isinstance(expr.exponent, prim.Quotient) and isinstance(expr.exponent.numerator, int) and isinstance(expr.exponent.denominator, int)): p, q = expr.exponent.numerator, expr.exponent.denominator if q < 0: q *= -1 p *= -1 if q == 1: return self.rec(new_base**p) if q == 2: assert p != 0 if p > 0: orig_base = prim.wrap_in_cse(expr.base) new_base = prim.wrap_in_cse( prim.Variable("sqrt")(orig_base)) else: new_base = prim.wrap_in_cse( prim.Variable("rsqrt")(expr.base)) p *= -1 return self.rec(new_base**p) return IdentityMapper.map_power(self, expr)
def map_power(self, expr): exp = expr.exponent if isinstance(exp, int): new_base = prim.wrap_in_cse(expr.base) if exp > 1 and exp % 2 == 0: square = prim.wrap_in_cse(new_base*new_base) return self.rec(prim.wrap_in_cse(square**(exp//2))) if exp > 1 and exp % 2 == 1: square = prim.wrap_in_cse(new_base*new_base) return self.rec(prim.wrap_in_cse(square**((exp-1)//2))*new_base) elif exp == 1: return new_base elif exp < 0: return self.rec((1/new_base)**(-exp)) if (isinstance(expr.exponent, prim.Quotient) and isinstance(expr.exponent.numerator, int) and isinstance(expr.exponent.denominator, int)): p, q = expr.exponent.numerator, expr.exponent.denominator if q < 0: q *= -1 p *= -1 if q == 1: return self.rec(new_base**p) if q == 2: assert p != 0 if p > 0: orig_base = prim.wrap_in_cse(expr.base) new_base = prim.wrap_in_cse(prim.Variable("sqrt")(orig_base)) else: new_base = prim.wrap_in_cse(prim.Variable("rsqrt")(expr.base)) p *= -1 return self.rec(new_base**p) return IdentityMapper.map_power(self, expr)
def wrap_in_cse(self, expr, prefix): cse = prim.wrap_in_cse(expr, prefix) return self.cse_cache.setdefault(expr, cse)