def test_constants(): x = sympy.symbols('x') y = 2.0 * x + sympy.UnevaluatedExpr(1.0) mod = sympytorch.SymPyModule(expressions=[y]) assert mod.sympy() == [y] assert set(p.item() for p in mod.parameters()) == {2.0} assert set(b.item() for b in mod.buffers()) == {1.0}
def _aggregate_sympy_constants(expr): """ Aggregate constants and symbolic components within a sympy expression to separate sub-expressions. Parameters ----------- expr : :class:`sympy.core.expr.Expr` Expression to aggregate. For matricies, use :func:`~sympy.Matrix.applyfunc`. Returns ------- :class:`sympy.core.expr.Expr` """ const = expr.func(*[term for term in expr.args if not term.free_symbols]) vars = expr.func(*[term for term in expr.args if term.free_symbols]) if const: return sympy.UnevaluatedExpr(const) * sympy.UnevaluatedExpr(vars) else: return sympy.UnevaluatedExpr(vars)
def handle_gcd_lcm(f, args): """ Return the result of gcd() or lcm(), as UnevaluatedExpr f: str - name of function ("gcd" or "lcm") args: List[Expr] - list of function arguments """ args = tuple(map(sympy.nsimplify, args)) # gcd() and lcm() don't support evaluate=False return sympy.UnevaluatedExpr(getattr(sympy, f)(args))
def __to_njit_function_k(sf, hsymbols, kx, ky, dtype=np.complex128): kset = {kx, ky} # Check wheter k is contained in the free symbols contains_k = bool(sf.free_symbols.intersection(kset)) if contains_k: # All free Hamiltonian symbols get function parameters if dtype == np.complex256: return lambdify(list(hsymbols), sf, "numpy") return njit(lambdify(list(hsymbols), sf, "numpy")) # Here we have non k variables in sf. Expand sf by 0*kx*ky sf = sf + kx * ky * sp.UnevaluatedExpr(0) if dtype == np.complex256: return lambdify(list(hsymbols), sf, "numpy") return njit(lambdify(list(hsymbols), sf, "numpy"))
def __to_njit_function_kp(sf, hsymbols, kx, ky, kxp, kyp, dtype=np.complex128): kset = {kx, ky, kxp, kyp} hsymbols = hsymbols.union({kxp, kyp}) # Check wheter k is contained in the free symbols contains_k = bool(sf.free_symbols.intersection(kset)) if contains_k: # All free Hamiltonian symbols get function parameters if dtype == np.complex256: return lambdify(list(hsymbols), sf, "numpy") return njit(lambdify(list(hsymbols), sf, "numpy")) sf = sf + kx * ky * kxp * kyp * sp.UnevaluatedExpr(0) if dtype == np.complex256: return lambdify(list(hsymbols), sf, "numpy") return njit(lambdify(list(hsymbols), sf, "numpy"))
def _pretty(self, printer): z = sympy.UnevaluatedExpr(sympy.Symbol('z')) k = getattr(self.parameters, 'k', 0) num = None for i, c in enumerate(self.parameters.b): if c: term = printer._print( np.ldexp(float(c), k) * z**-i) num = num + term if num else term if num is None: num = 0 den = printer._print(float(self.parameters.a[0])) for i, c in enumerate(self.parameters.a[1:], start=1): if c: den += printer._print(np.ldexp(float(c), k) * z**-i) return printer._print(num) / printer._print(den)
def _hide_floats(expression, _memodict): try: return _memodict[expression] except KeyError: pass if issubclass(expression.func, sympy.Float): new_expression = sympy.UnevaluatedExpr(expression) elif issubclass(expression.func, sympy.Integer): new_expression = expression elif issubclass(expression.func, sympy.Symbol): new_expression = expression else: new_expression = expression.func( *[_hide_floats(arg, _memodict) for arg in expression.args]) _memodict[expression] = new_expression return new_expression
def pull_out(self, expr): # NB: This ignores the subclass and just returns a NamedExpression: return NamedExpression( self.name, sp.UnevaluatedExpr(expr) * sp.UnevaluatedExpr(sp.simplify(self.expr / expr)))
# # Generate macro # def pow_to_mul(base, exp): assert exp >= 1 if exp == 1: return base return sympy.Mul(base, pow_to_mul(base, exp - 1), evaluate=False) pow_to_mul_optim = sympy_rewriting.ReplaceOptim( lambda p: p.is_Pow and p.exp.is_Integer, lambda p: sympy.UnevaluatedExpr(pow_to_mul(p.base, int(p.exp)))) def preprocess(expr): expr = expr.replace(sympy.pi, sympy.symbols('SH_PI')) expr = expr.replace(sympy.sin, lambda _: sympy.symbols('SH_SIN_THETA')) expr = expr.replace(sympy.cos, lambda _: sympy.symbols('SH_COS_THETA')) expr = pow_to_mul_optim(expr) return expr def generate_list(name, exprs): result = f"#define {name}(_)" for expr in exprs: result += f" \\\n _({sympy.printing.ccode(expr)})" return result