def _solve_relational(r): if sym not in r.free_symbols: nonsymfail(r) rv = _solve_inequality(r, sym) if isinstance(rv, Relational): free = rv.args[1].free_symbols if rv.args[0] != sym or sym in free: raise NotImplementedError(filldedent(''' Unable to solve relational %s for %s.''' % (r, sym))) if rv.rel_op == '==': # this equality has been affirmed to have the form # Eq(sym, rhs) where rhs is sym-free; it represents # a zero-width interval which will be ignored # whether it is an isolated condition or contained # within an And or an Or rv = S.false elif rv.rel_op == '!=': try: rv = Or(sym < rv.rhs, sym > rv.rhs) except TypeError: # e.g. x != I ==> all real x satisfy rv = S.true elif rv == (S.NegativeInfinity < sym) & (sym < S.Infinity): rv = S.true return rv
def _solve_relational(r): if sym not in r.free_symbols: nonsymfail(r) rv = _solve_inequality(r, sym) if isinstance(rv, Relational): free = rv.args[1].free_symbols if rv.args[0] != sym or sym in free: raise NotImplementedError( filldedent(''' Unable to solve relational %s for %s.''' % (r, sym))) if rv.rel_op == '==': # this equality has been affirmed to have the form # Eq(sym, rhs) where rhs is sym-free; it represents # a zero-width interval which will be ignored # whether it is an isolated condition or contained # within an And or an Or rv = S.false elif rv.rel_op == '!=': try: rv = Or(sym < rv.rhs, sym > rv.rhs) except TypeError: # e.g. x != I ==> all real x satisfy rv = S.true elif rv == (S.NegativeInfinity < sym) & (sym < S.Infinity): rv = S.true return rv
def process_conds(cond): """ Turn ``cond`` into a strip (a, b), and auxiliary conditions. """ a = -oo b = oo aux = True conds = conjuncts(to_cnf(cond)) t = Dummy('t', real=True) for c in conds: a_ = oo b_ = -oo aux_ = [] for d in disjuncts(c): d_ = d.replace(re, lambda x: x.as_real_imag()[0]).subs(re(s), t) if not d.is_Relational or (d.rel_op != '<' and d.rel_op != '<=') \ or d_.has(s) or not d_.has(t): aux_ += [d] continue soln = _solve_inequality(d_, t) if not soln.is_Relational or \ (soln.rel_op != '<' and soln.rel_op != '<='): aux_ += [d] continue if soln.lhs == t: b_ = Max(soln.rhs, b_) else: a_ = Min(soln.lhs, a_) if a_ != oo and a_ != b: a = Max(a_, a) elif b_ != -oo and b_ != a: b = Min(b_, b) else: aux = And(aux, Or(*aux_)) return a, b, aux
def range_from_relationals( and_expr: typing.Union[sympy.And, sympy.Rel], gen: sympy.Symbol ) -> (typing.Optional[sympy.Basic], typing.Optional[sympy.Basic]): """ :return whether there is a solution, optional start range, optional end range (including; assume integer; assume simplified) """ if isinstance(and_expr, sympy.Rel): args = [and_expr] else: assert isinstance(and_expr, sympy.And) args = and_expr.args assert all([ isinstance(part, sympy.Rel) and gen in part.free_symbols for part in args ]) rel_ops = [">=", "<=", "=="] rhs_by_c = {} for part in args: assert isinstance(part, sympy.Rel) part = _solve_inequality(part, gen) assert isinstance(part, sympy.Rel) assert part.lhs == gen rel_op, rhs = part.rel_op, part.rhs assert rel_op in rel_ops assert rel_op not in rhs_by_c # not simplified? rhs_by_c[rel_op] = rhs if "==" in rhs_by_c: assert set(rhs_by_c.keys()) == {"=="} # only op. not simplified? return rhs_by_c["=="], rhs_by_c["=="] return rhs_by_c.get(">=", None), rhs_by_c.get("<=", None)
def _laplace_transform(f, t, s, simplify=True): """ The backend function for laplace transforms. """ from sympy import (re, Max, exp, pi, Abs, Min, periodic_argument as arg, cos, Wild, symbols) F = integrate(exp(-s * t) * f, (t, 0, oo)) if not F.has(Integral): return _simplify(F, simplify), -oo, True if not F.is_Piecewise: raise IntegralTransformError('Laplace', f, 'could not compute integral') F, cond = F.args[0] if F.has(Integral): raise IntegralTransformError('Laplace', f, 'integral in unexpected form') a = -oo aux = True conds = conjuncts(to_cnf(cond)) u = Dummy('u', real=True) p, q, w1, w2, w3 = symbols('p q w1 w2 w3', cls=Wild, exclude=[s]) for c in conds: a_ = oo aux_ = [] for d in disjuncts(c): m = d.match(abs(arg((s + w3)**p * q, w1)) < w2) if m: if m[q] > 0 and m[w2] / m[p] == pi / 2: d = re(s + m[w3]) > 0 m = d.match(0 < cos(abs(arg(s, q))) * abs(s) - p) if m: d = re(s) > m[p] d_ = d.replace(re, lambda x: x.expand().as_real_imag()[0]).subs( re(s), t) if not d.is_Relational or (d.rel_op != '<' and d.rel_op != '<=') \ or d_.has(s) or not d_.has(t): aux_ += [d] continue soln = _solve_inequality(d_, t) if not soln.is_Relational or \ (soln.rel_op != '<' and soln.rel_op != '<='): aux_ += [d] continue if soln.lhs == t: raise IntegralTransformError('Laplace', f, 'convergence not in half-plane?') else: a_ = Min(soln.lhs, a_) if a_ != oo: a = Max(a_, a) else: aux = And(aux, Or(*aux_)) return _simplify(F, simplify), a, aux
def _mellin_transform(f, x, s_, integrator=_default_integrator, simplify=True): """ Backend function to compute mellin transforms. """ from sympy import re, Max, Min # We use a fresh dummy, because assumptions on s might drop conditions on # convergence of the integral. s = _dummy('s', 'mellin-transform', f) F = integrator(x**(s - 1) * f, x) if not F.has(Integral): return _simplify(F.subs(s, s_), simplify), (-oo, oo), True if not F.is_Piecewise: raise IntegralTransformError('Mellin', f, 'could not compute integral') F, cond = F.args[0] if F.has(Integral): raise IntegralTransformError('Mellin', f, 'integral in unexpected form') a = -oo b = oo aux = True conds = conjuncts(to_cnf(cond)) t = Dummy('t', real=True) for c in conds: a_ = oo b_ = -oo aux_ = [] for d in disjuncts(c): d_ = d.replace(re, lambda x: x.as_real_imag()[0]).subs(re(s), t) if not d.is_Relational or (d.rel_op != '<' and d.rel_op != '<=') \ or d_.has(s) or not d_.has(t): aux_ += [d] continue soln = _solve_inequality(d_, t) if not soln.is_Relational or \ (soln.rel_op != '<' and soln.rel_op != '<='): aux_ += [d] continue if soln.lhs == t: b_ = Max(soln.rhs, b_) else: a_ = Min(soln.lhs, a_) if a_ != oo and a_ != b: a = Max(a_, a) elif b_ != -oo and b_ != a: b = Min(b_, b) else: aux = And(aux, Or(*aux_)) if aux is False: raise IntegralTransformError('Mellin', f, 'no convergence found') return _simplify(F.subs(s, s_), simplify), (a, b), aux
def simp_heaviside(arg): a = arg.subs(exp(-t), u) if a.has(t): return Heaviside(arg) rel = _solve_inequality(a > 0, u) if rel.lhs == u: k = log(rel.rhs) return Heaviside(t + k) else: k = log(rel.lhs) return Heaviside(-(t + k))
def process_conds(conds): """ Turn ``conds`` into a strip and auxiliary conditions. """ a = -oo aux = True conds = conjuncts(to_cnf(conds)) u = Dummy('u', real=True) p, q, w1, w2, w3, w4, w5 = symbols('p q w1 w2 w3 w4 w5', cls=Wild, exclude=[s]) for c in conds: a_ = oo aux_ = [] for d in disjuncts(c): m = d.match(abs(arg((s + w3)**p * q, w1)) < w2) if not m: m = d.match(abs(arg((s + w3)**p * q, w1)) <= w2) if not m: m = d.match(abs(arg((polar_lift(s + w3))**p * q, w1)) < w2) if not m: m = d.match( abs(arg((polar_lift(s + w3))**p * q, w1)) <= w2) if m: if m[q] > 0 and m[w2] / m[p] == pi / 2: d = re(s + m[w3]) > 0 m = d.match( 0 < cos(abs(arg(s**w1 * w5, q)) * w2) * abs(s**w3)**w4 - p) if not m: m = d.match( 0 < cos(abs(arg(polar_lift(s)**w1 * w5, q)) * w2) * abs(s**w3)**w4 - p) if m and all(m[wild] > 0 for wild in [w1, w2, w3, w4, w5]): d = re(s) > m[p] d_ = d.replace(re, lambda x: x.expand().as_real_imag()[0]).subs( re(s), t) if not d.is_Relational or (d.rel_op != '<' and d.rel_op != '<=') \ or d_.has(s) or not d_.has(t): aux_ += [d] continue soln = _solve_inequality(d_, t) if not soln.is_Relational or \ (soln.rel_op != '<' and soln.rel_op != '<='): aux_ += [d] continue if soln.lhs == t: raise IntegralTransformError( 'Laplace', f, 'convergence not in half-plane?') else: a_ = Min(soln.lhs, a_) if a_ != oo: a = Max(a_, a) else: aux = And(aux, Or(*aux_)) return a, aux
def _laplace_transform(f, t, s, simplify=True): """ The backend function for laplace transforms. """ from sympy import (re, Max, exp, pi, Abs, Min, periodic_argument as arg, cos, Wild, symbols) F = integrate(exp(-s*t) * f, (t, 0, oo)) if not F.has(Integral): return _simplify(F, simplify), -oo, True if not F.is_Piecewise: raise IntegralTransformError('Laplace', f, 'could not compute integral') F, cond = F.args[0] if F.has(Integral): raise IntegralTransformError('Laplace', f, 'integral in unexpected form') a = -oo aux = True conds = conjuncts(to_cnf(cond)) u = Dummy('u', real=True) p, q, w1, w2, w3 = symbols('p q w1 w2 w3', cls=Wild, exclude=[s]) for c in conds: a_ = oo aux_ = [] for d in disjuncts(c): m = d.match(abs(arg((s + w3)**p*q, w1)) < w2) if m: if m[q] > 0 and m[w2]/m[p] == pi/2: d = re(s + m[w3]) > 0 m = d.match(0 < cos(abs(arg(s, q)))*abs(s) - p) if m: d = re(s) > m[p] d_ = d.replace(re, lambda x: x.expand().as_real_imag()[0]).subs(re(s), t) if not d.is_Relational or (d.rel_op != '<' and d.rel_op != '<=') \ or d_.has(s) or not d_.has(t): aux_ += [d] continue soln = _solve_inequality(d_, t) if not soln.is_Relational or \ (soln.rel_op != '<' and soln.rel_op != '<='): aux_ += [d] continue if soln.lhs == t: raise IntegralTransformError('Laplace', f, 'convergence not in half-plane?') else: a_ = Min(soln.lhs, a_) if a_ != oo: a = Max(a_, a) else: aux = And(aux, Or(*aux_)) return _simplify(F, simplify), a, aux
def _mellin_transform(f, x, s_, integrator=_default_integrator, simplify=True): """ Backend function to compute mellin transforms. """ from sympy import re, Max, Min # We use a fresh dummy, because assumptions on s might drop conditions on # convergence of the integral. s = _dummy('s', 'mellin-transform', f) F = integrator(x**(s-1) * f, x) if not F.has(Integral): return _simplify(F.subs(s, s_), simplify), (-oo, oo), True if not F.is_Piecewise: raise IntegralTransformError('Mellin', f, 'could not compute integral') F, cond = F.args[0] if F.has(Integral): raise IntegralTransformError('Mellin', f, 'integral in unexpected form') a = -oo b = oo aux = True conds = conjuncts(to_cnf(cond)) t = Dummy('t', real=True) for c in conds: a_ = oo b_ = -oo aux_ = [] for d in disjuncts(c): d_ = d.replace(re, lambda x: x.as_real_imag()[0]).subs(re(s), t) if not d.is_Relational or (d.rel_op != '<' and d.rel_op != '<=') \ or d_.has(s) or not d_.has(t): aux_ += [d] continue soln = _solve_inequality(d_, t) if not soln.is_Relational or \ (soln.rel_op != '<' and soln.rel_op != '<='): aux_ += [d] continue if soln.lhs == t: b_ = Max(soln.rhs, b_) else: a_ = Min(soln.lhs, a_) if a_ != oo and a_ != b: a = Max(a_, a) elif b_ != -oo and b_ != a: b = Min(b_, b) else: aux = And(aux, Or(*aux_)) if aux is False: raise IntegralTransformError('Mellin', f, 'no convergence found') return _simplify(F.subs(s, s_), simplify), (a, b), aux
def process_conds(conds): """ Turn ``conds`` into a strip and auxiliary conditions. """ a = -oo aux = True conds = conjuncts(to_cnf(conds)) u = Dummy('u', real=True) p, q, w1, w2, w3, w4, w5 = symbols('p q w1 w2 w3 w4 w5', cls=Wild, exclude=[s]) for c in conds: a_ = oo aux_ = [] for d in disjuncts(c): m = d.match(abs(arg((s + w3)**p*q, w1)) < w2) if not m: m = d.match(abs(arg((s + w3)**p*q, w1)) <= w2) if not m: m = d.match(abs(arg((polar_lift(s + w3))**p*q, w1)) < w2) if not m: m = d.match(abs(arg((polar_lift(s + w3))**p*q, w1)) <= w2) if m: if m[q] > 0 and m[w2]/m[p] == pi/2: d = re(s + m[w3]) > 0 m = d.match(0 < cos(abs(arg(s**w1*w5, q))*w2)*abs(s**w3)**w4 - p) if not m: m = d.match(0 < cos(abs(arg(polar_lift(s)**w1*w5, q))*w2)*abs(s**w3)**w4 - p) if m and all(m[wild] > 0 for wild in [w1, w2, w3, w4, w5]): d = re(s) > m[p] d_ = d.replace(re, lambda x: x.expand().as_real_imag()[0]).subs(re(s), t) if not d.is_Relational or \ d.rel_op not in ('>', '>=', '<', '<=') \ or d_.has(s) or not d_.has(t): aux_ += [d] continue soln = _solve_inequality(d_, t) if not soln.is_Relational or \ soln.rel_op not in ('>', '>=', '<', '<='): aux_ += [d] continue if soln.lts == t: raise IntegralTransformError('Laplace', f, 'convergence not in half-plane?') else: a_ = Min(soln.lts, a_) if a_ != oo: a = Max(a_, a) else: aux = And(aux, Or(*aux_)) return a, aux
def _solve_relational(r): rv = _solve_inequality(r, sym) if isinstance(rv, Relational) and \ sym in rv.free_symbols: if rv.args[0] != sym: raise NotImplementedError(filldedent(''' Unable to solve relational %s for %s.''' % (r, sym))) if rv.rel_op == '!=': try: rv = Or(sym < rv.rhs, sym > rv.rhs) except TypeError: # e.g. x != I ==> all real x satisfy rv = S.true if rv == (S.NegativeInfinity < sym) & (sym < S.Infinity): rv = S.true return rv
def _sort_expr_cond(self, sym, a, b, targetcond=None): """Determine what intervals the expr, cond pairs affect. 1) If cond is True, then log it as default 1.1) Currently if cond can't be evaluated, throw NotImplementedError. 2) For each inequality, if previous cond defines part of the interval update the new conds interval. - eg x < 1, x < 3 -> [oo,1],[1,3] instead of [oo,1],[oo,3] 3) Sort the intervals to make it easier to find correct exprs Under normal use, we return the expr,cond pairs in increasing order along the real axis corresponding to the symbol sym. If targetcond is given, we return a list of (lowerbound, upperbound) pairs for this condition.""" from sympy.solvers.inequalities import _solve_inequality default = None int_expr = [] expr_cond = [] or_cond = False or_intervals = [] independent_expr_cond = [] for expr, cond in self.args: if isinstance(cond, Or): for cond2 in sorted(cond.args, key=default_sort_key): expr_cond.append((expr, cond2)) else: expr_cond.append((expr, cond)) if cond == True: break for expr, cond in expr_cond: if cond == True: independent_expr_cond.append((expr, cond)) default = self.func(*independent_expr_cond) break orig_cond = cond if sym not in cond.free_symbols: independent_expr_cond.append((expr, cond)) continue elif isinstance(cond, Equality): continue elif isinstance(cond, And): lower = S.NegativeInfinity upper = S.Infinity for cond2 in cond.args: if sym not in [cond2.lts, cond2.gts]: cond2 = _solve_inequality(cond2, sym) if cond2.lts == sym: upper = Min(cond2.gts, upper) elif cond2.gts == sym: lower = Max(cond2.lts, lower) else: raise NotImplementedError( "Unable to handle interval evaluation of expression.") else: if sym not in [cond.lts, cond.gts]: cond = _solve_inequality(cond, sym) lower, upper = cond.lts, cond.gts # part 1: initialize with givens if cond.lts == sym: # part 1a: expand the side ... lower = S.NegativeInfinity # e.g. x <= 0 ---> -oo <= 0 elif cond.gts == sym: # part 1a: ... that can be expanded upper = S.Infinity # e.g. x >= 0 ---> oo >= 0 else: raise NotImplementedError( "Unable to handle interval evaluation of expression.") # part 1b: Reduce (-)infinity to what was passed in. lower, upper = Max(a, lower), Min(b, upper) for n in xrange(len(int_expr)): # Part 2: remove any interval overlap. For any conflicts, the # iterval already there wins, and the incoming interval updates # its bounds accordingly. if self.__eval_cond(lower < int_expr[n][1]) and \ self.__eval_cond(lower >= int_expr[n][0]): lower = int_expr[n][1] elif len(int_expr[n][1].free_symbols) and \ self.__eval_cond(lower >= int_expr[n][0]): if self.__eval_cond(lower == int_expr[n][0]): lower = int_expr[n][1] else: int_expr[n][1] = Min(lower, int_expr[n][1]) elif len(int_expr[n][0].free_symbols) and \ self.__eval_cond(upper == int_expr[n][1]): upper = Min(upper, int_expr[n][0]) elif len(int_expr[n][1].free_symbols) and \ (lower >= int_expr[n][0]) != True and \ (int_expr[n][1] == Min(lower, upper)) != True: upper = Min(upper, int_expr[n][0]) elif self.__eval_cond(upper > int_expr[n][0]) and \ self.__eval_cond(upper <= int_expr[n][1]): upper = int_expr[n][0] elif len(int_expr[n][0].free_symbols) and \ self.__eval_cond(upper < int_expr[n][1]): int_expr[n][0] = Max(upper, int_expr[n][0]) if self.__eval_cond(lower >= upper) != True: # Is it still an interval? int_expr.append([lower, upper, expr]) if orig_cond == targetcond: return [(lower, upper, None)] elif isinstance(targetcond, Or) and cond in targetcond.args: or_cond = Or(or_cond, cond) or_intervals.append((lower, upper, None)) if or_cond == targetcond: or_intervals.sort(key=lambda x: x[0]) return or_intervals int_expr.sort(key=lambda x: x[1].sort_key( ) if x[1].is_number else S.NegativeInfinity.sort_key()) int_expr.sort(key=lambda x: x[0].sort_key( ) if x[0].is_number else S.Infinity.sort_key()) for n in xrange(len(int_expr)): if len(int_expr[n][0].free_symbols) or len(int_expr[n][1].free_symbols): if isinstance(int_expr[n][1], Min) or int_expr[n][1] == b: newval = Min(*int_expr[n][:-1]) if n > 0 and int_expr[n][0] == int_expr[n - 1][1]: int_expr[n - 1][1] = newval int_expr[n][0] = newval else: newval = Max(*int_expr[n][:-1]) if n < len(int_expr) - 1 and int_expr[n][1] == int_expr[n + 1][0]: int_expr[n + 1][0] = newval int_expr[n][1] = newval # Add holes to list of intervals if there is a default value, # otherwise raise a ValueError. holes = [] curr_low = a for int_a, int_b, expr in int_expr: if (curr_low < int_a) == True: holes.append([curr_low, Min(b, int_a), default]) elif (curr_low >= int_a) != True: holes.append([curr_low, Min(b, int_a), default]) curr_low = Min(b, int_b) if (curr_low < b) == True: holes.append([Min(b, curr_low), b, default]) elif (curr_low >= b) != True: holes.append([Min(b, curr_low), b, default]) if holes and default is not None: int_expr.extend(holes) if targetcond == True: return [(h[0], h[1], None) for h in holes] elif holes and default is None: raise ValueError("Called interval evaluation over piecewise " "function on undefined intervals %s" % ", ".join([str((h[0], h[1])) for h in holes])) return int_expr
def _handle_irel(self, x, handler): """Return either None (if the conditions of self depend only on x) else a Piecewise expression whose expressions (handled by the handler that was passed) are paired with the governing x-independent relationals, e.g. Piecewise((A, a(x) & b(y)), (B, c(x) | c(y)) -> Piecewise( (handler(Piecewise((A, a(x) & True), (B, c(x) | True)), b(y) & c(y)), (handler(Piecewise((A, a(x) & True), (B, c(x) | False)), b(y)), (handler(Piecewise((A, a(x) & False), (B, c(x) | True)), c(y)), (handler(Piecewise((A, a(x) & False), (B, c(x) | False)), True)) """ # identify governing relationals rel = self.atoms(Relational) irel = list( ordered([ r for r in rel if x not in r.free_symbols and r not in (S.true, S.false) ])) if irel: args = {} exprinorder = [] for truth in product((1, 0), repeat=len(irel)): reps = dict(zip(irel, truth)) # only store the true conditions since the false are implied # when they appear lower in the Piecewise args if 1 not in truth: cond = None # flag this one so it doesn't get combined else: andargs = Tuple(*[i for i in reps if reps[i]]) free = list(andargs.free_symbols) if len(free) == 1: from sympy.solvers.inequalities import ( reduce_inequalities, _solve_inequality) try: t = reduce_inequalities(andargs, free[0]) # ValueError when there are potentially # nonvanishing imaginary parts except (ValueError, NotImplementedError): # at least isolate free symbol on left t = And(*[ _solve_inequality(a, free[0], linear=True) for a in andargs ]) else: t = And(*andargs) if t is S.false: continue # an impossible combination cond = t expr = handler(self.xreplace(reps)) if isinstance(expr, self.func) and len(expr.args) == 1: expr, econd = expr.args[0] cond = And(econd, True if cond is None else cond) # the ec pairs are being collected since all possibilities # are being enumerated, but don't put the last one in since # its expr might match a previous expression and it # must appear last in the args if cond is not None: args.setdefault(expr, []).append(cond) # but since we only store the true conditions we must maintain # the order so that the expression with the most true values # comes first exprinorder.append(expr) # convert collected conditions as args of Or for k in args: args[k] = Or(*args[k]) # take them in the order obtained args = [(e, args[e]) for e in uniq(exprinorder)] # add in the last arg args.append((expr, True)) # if any condition reduced to True, it needs to go last # and there should only be one of them or else the exprs # should agree trues = [i for i in range(len(args)) if args[i][1] is S.true] if not trues: # make the last one True since all cases were enumerated e, c = args[-1] args[-1] = (e, S.true) else: assert len(set([e for e, c in [args[i] for i in trues]])) == 1 args.append(args.pop(trues.pop())) while trues: args.pop(trues.pop()) return Piecewise(*args)
def test__solve_inequality(): for op in (Gt, Lt, Le, Ge, Eq, Ne): assert _solve_inequality(op(x, 1), x).lhs == x assert _solve_inequality(op(S.One, x), x).lhs == x # don't get tricked by symbol on right: solve it assert _solve_inequality(Eq(2 * x - 1, x), x) == Eq(x, 1) ie = Eq(S.One, y) assert _solve_inequality(ie, x) == ie for fx in (x**2, exp(x), sin(x) + cos(x), x * (1 + x)): for c in (0, 1): e = 2 * fx - c > 0 assert _solve_inequality(e, x, linear=True) == (fx > c / S(2)) assert _solve_inequality(2 * x**2 + 2 * x - 1 < 0, x, linear=True) == (x * (x + 1) < S.Half) assert _solve_inequality(Eq(x * y, 1), x) == Eq(x * y, 1) nz = Symbol('nz', nonzero=True) assert _solve_inequality(Eq(x * nz, 1), x) == Eq(x, 1 / nz) assert _solve_inequality(x * nz < 1, x) == (x * nz < 1) a = Symbol('a', positive=True) assert _solve_inequality(a / x > 1, x) == (S.Zero < x) & (x < a) assert _solve_inequality(a / x > 1, x, linear=True) == (1 / x > 1 / a) # make sure to include conditions under which solution is valid e = Eq(1 - x, x * (1 / x - 1)) assert _solve_inequality(e, x) == Ne(x, 0) assert _solve_inequality(x < x * (1 / x - 1), x) == (x < S.Half) & Ne(x, 0)
def test__solve_inequality(): for op in (Gt, Lt, Le, Ge, Eq, Ne): assert _solve_inequality(op(x, 1), x).lhs == x assert _solve_inequality(op(S.One, x), x).lhs == x # don't get tricked by symbol on right: solve it assert _solve_inequality(Eq(2*x - 1, x), x) == Eq(x, 1) ie = Eq(S.One, y) assert _solve_inequality(ie, x) == ie for fx in (x**2, exp(x), sin(x) + cos(x), x*(1 + x)): for c in (0, 1): e = 2*fx - c > 0 assert _solve_inequality(e, x, linear=True) == ( fx > c/2) assert _solve_inequality(2*x**2 + 2*x - 1 < 0, x, linear=True) == ( x*(x + 1) < S.Half) assert _solve_inequality(Eq(x*y, 1), x) == Eq(x*y, 1) nz = Symbol('nz', nonzero=True) assert _solve_inequality(Eq(x*nz, 1), x) == Eq(x, 1/nz) assert _solve_inequality(x*nz < 1, x) == (x*nz < 1) a = Symbol('a', positive=True) assert _solve_inequality(a/x > 1, x, linear=True) == (1/x > 1/a) # make sure to include conditions under which solution is valid e = Eq(1 - x, x*(1/x - 1)) assert _solve_inequality(e, x) == Ne(x, 0) assert _solve_inequality(x < x*(1/x - 1), x) == (x < S.Half) & Ne(x, 0)
def simplify_and( x: sympy.Basic, gen: typing.Optional[sympy.Symbol] = None, extra_conditions: typing.Optional[sympy.Basic] = True) -> sympy.Basic: """ Some rules, because SymPy currently does not automatically simplify them... """ assert isinstance(x, sympy.Basic), "type x: %r" % type(x) from sympy.solvers.inequalities import reduce_rational_inequalities from sympy.core.relational import Relational syms = [] if gen is not None: syms.append(gen) w1 = sympy.Wild("w1") w2 = sympy.Wild("w2") for sub_expr in x.find(sympy.Eq(w1, w2)): m = sub_expr.match(sympy.Eq(w1, w2)) ws_ = m[w1], m[w2] for w_ in ws_: if isinstance(w_, sympy.Symbol) and w_ not in syms: syms.append(w_) for w_ in x.free_symbols: if w_ not in syms: syms.append(w_) if len(syms) >= 1: _c = syms[0] if len(syms) >= 2: n = syms[1] else: n = sympy.Wild("n") else: return x x = x.replace(((_c - 2 * n >= -1) & (_c - 2 * n <= -1)), sympy.Eq(_c, 2 * n - 1)) # probably not needed anymore... apply_rules = True while apply_rules: apply_rules = False for and_expr in x.find(sympy.And): assert isinstance(and_expr, sympy.And) and_expr_ = reduce_rational_inequalities([and_expr.args], _c) # print(and_expr, "->", and_expr_) if and_expr_ != and_expr: x = x.replace(and_expr, and_expr_) and_expr = and_expr_ if and_expr == sympy.sympify(False): continue if isinstance(and_expr, sympy.Rel): continue assert isinstance(and_expr, sympy.And) and_expr_args = list(and_expr.args) # for i, part in enumerate(and_expr_args): # and_expr_args[i] = part.simplify() if all([ isinstance(part, Relational) and _c in part.free_symbols for part in and_expr_args ]): # No equality, as that should have been resolved above. rel_ops = ["==", ">=", "<="] if not (_c.is_Integer or _c.assumptions0["integer"]): rel_ops.extend(["<", ">"]) rhs_by_c = {op: [] for op in rel_ops} for part in and_expr_args: assert isinstance(part, Relational) part = _solve_inequality(part, _c) assert isinstance(part, Relational) assert part.lhs == _c rel_op, rhs = part.rel_op, part.rhs if _c.is_Integer or _c.assumptions0["integer"]: if rel_op == "<": rhs = rhs - 1 rel_op = "<=" elif rel_op == ">": rhs = rhs + 1 rel_op = ">=" assert rel_op in rhs_by_c, "x: %r, _c: %r, and expr: %r, part %r" % ( x, _c, and_expr, part) other_rhs = rhs_by_c[rel_op] assert isinstance(other_rhs, list) need_to_add = True for rhs_ in other_rhs: cmp = Relational.ValidRelationOperator[rel_op](rhs, rhs_) if simplify_and( sympy.And(sympy.Not(cmp), extra_conditions)) == sympy.sympify( False): # checks True... other_rhs.remove(rhs_) break elif simplify_and(sympy.And( cmp, extra_conditions)) == sympy.sympify(False): need_to_add = False break # else: # raise NotImplementedError("cannot compare %r in %r; extra cond %r" % (cmp, and_expr, extra_conditions)) if need_to_add: other_rhs.append(rhs) if rhs_by_c[">="] and rhs_by_c["<="]: all_false = False for lhs in rhs_by_c[">="]: for rhs in rhs_by_c["<="]: if sympy.Lt(lhs, rhs) == sympy.sympify(False): all_false = True if sympy.Eq(lhs, rhs) == sympy.sympify(True): rhs_by_c["=="].append(lhs) if all_false: x = x.replace(and_expr, False) continue if rhs_by_c["=="]: all_false = False while len(rhs_by_c["=="]) >= 2: lhs, rhs = rhs_by_c["=="][:2] if sympy.Eq(lhs, rhs) == sympy.sympify(False): all_false = True break elif sympy.Eq(lhs, rhs) == sympy.sympify(True): rhs_by_c["=="].pop(1) else: raise NotImplementedError( "cannot cmp %r == %r. rhs_by_c %r" % (lhs, rhs, rhs_by_c)) if all_false: x = x.replace(and_expr, False) continue new_parts = [sympy.Eq(_c, rhs_by_c["=="][0])] for op in rel_ops: for part in rhs_by_c[op]: new_parts.append( Relational.ValidRelationOperator[op]( rhs_by_c["=="][0], part).simplify()) else: # no "==" new_parts = [] for op in rel_ops: for part in rhs_by_c[op]: new_parts.append( Relational.ValidRelationOperator[op](_c, part)) assert new_parts and_expr_ = sympy.And(*new_parts) # print(and_expr, "--->", and_expr_) x = x.replace(and_expr, and_expr_) and_expr = and_expr_ # Probably all the remaining hard-coded rules are not needed anymore with the more generic code above... if sympy.Eq(_c, 2 * n) in and_expr.args: if (_c - 2 * n <= -1) in and_expr.args: x = x.replace(and_expr, False) continue if sympy.Eq(_c - 2 * n, -1) in and_expr.args: x = x.replace(and_expr, False) continue if (_c - n <= -1) in and_expr.args: x = x.replace(and_expr, False) continue if (_c >= n) in and_expr.args and (_c - n <= -1) in and_expr.args: x = x.replace(and_expr, False) continue if sympy.Eq(_c - 2 * n, -1) in and_expr.args: # assume n>=1 if (_c >= n) in and_expr.args: x = x.replace( and_expr, sympy.And( * [arg for arg in and_expr.args if arg != (_c >= n)])) apply_rules = True break if (_c - n >= -1) in and_expr.args: x = x.replace( and_expr, sympy.And(*[ arg for arg in and_expr.args if arg != (_c - n >= -1) ])) apply_rules = True break if (_c >= n) in and_expr.args: if (_c - n >= -1) in and_expr.args: x = x.replace( and_expr, sympy.And(*[ arg for arg in and_expr.args if arg != (_c - n >= -1) ])) apply_rules = True break if (_c - n >= -1) in and_expr.args and (_c - n <= -1) in and_expr.args: args = list(and_expr.args) args.remove((_c - n >= -1)) args.remove((_c - n <= -1)) args.append(sympy.Eq(_c - n, -1)) if (_c - 2 * n <= -1) in args: args.remove((_c - 2 * n <= -1)) x = x.replace(and_expr, sympy.And(*args)) apply_rules = True break return x
def _sort_expr_cond(self, sym, a, b, targetcond=None): """Determine what intervals the expr, cond pairs affect. 1) If cond is True, then log it as default 1.1) Currently if cond can't be evaluated, throw NotImplementedError. 2) For each inequality, if previous cond defines part of the interval update the new conds interval. - eg x < 1, x < 3 -> [oo,1],[1,3] instead of [oo,1],[oo,3] 3) Sort the intervals to make it easier to find correct exprs Under normal use, we return the expr,cond pairs in increasing order along the real axis corresponding to the symbol sym. If targetcond is given, we return a list of (lowerbound, upperbound) pairs for this condition.""" from sympy.solvers.inequalities import _solve_inequality default = None int_expr = [] expr_cond = [] or_cond = False or_intervals = [] independent_expr_cond = [] for expr, cond in self.args: if isinstance(cond, Or): for cond2 in sorted(cond.args, key=default_sort_key): expr_cond.append((expr, cond2)) else: expr_cond.append((expr, cond)) if cond == True: break for expr, cond in expr_cond: if cond == True: independent_expr_cond.append((expr, cond)) default = self.func(*independent_expr_cond) break orig_cond = cond if sym not in cond.free_symbols: independent_expr_cond.append((expr, cond)) continue elif isinstance(cond, Equality): continue elif isinstance(cond, And): lower = S.NegativeInfinity upper = S.Infinity for cond2 in cond.args: if sym not in [cond2.lts, cond2.gts]: cond2 = _solve_inequality(cond2, sym) if cond2.lts == sym: upper = Min(cond2.gts, upper) elif cond2.gts == sym: lower = Max(cond2.lts, lower) else: raise NotImplementedError( "Unable to handle interval evaluation of expression." ) else: if sym not in [cond.lts, cond.gts]: cond = _solve_inequality(cond, sym) lower, upper = cond.lts, cond.gts # part 1: initialize with givens if cond.lts == sym: # part 1a: expand the side ... lower = S.NegativeInfinity # e.g. x <= 0 ---> -oo <= 0 elif cond.gts == sym: # part 1a: ... that can be expanded upper = S.Infinity # e.g. x >= 0 ---> oo >= 0 else: raise NotImplementedError( "Unable to handle interval evaluation of expression.") # part 1b: Reduce (-)infinity to what was passed in. lower, upper = Max(a, lower), Min(b, upper) for n in range(len(int_expr)): # Part 2: remove any interval overlap. For any conflicts, the # iterval already there wins, and the incoming interval updates # its bounds accordingly. if self.__eval_cond(lower < int_expr[n][1]) and \ self.__eval_cond(lower >= int_expr[n][0]): lower = int_expr[n][1] elif len(int_expr[n][1].free_symbols) and \ self.__eval_cond(lower >= int_expr[n][0]): if self.__eval_cond(lower == int_expr[n][0]): lower = int_expr[n][1] else: int_expr[n][1] = Min(lower, int_expr[n][1]) elif len(int_expr[n][0].free_symbols) and \ self.__eval_cond(upper == int_expr[n][1]): upper = Min(upper, int_expr[n][0]) elif len(int_expr[n][1].free_symbols) and \ (lower >= int_expr[n][0]) != True and \ (int_expr[n][1] == Min(lower, upper)) != True: upper = Min(upper, int_expr[n][0]) elif self.__eval_cond(upper > int_expr[n][0]) and \ self.__eval_cond(upper <= int_expr[n][1]): upper = int_expr[n][0] elif len(int_expr[n][0].free_symbols) and \ self.__eval_cond(upper < int_expr[n][1]): int_expr[n][0] = Max(upper, int_expr[n][0]) if self.__eval_cond( lower >= upper) != True: # Is it still an interval? int_expr.append([lower, upper, expr]) if orig_cond == targetcond: return [(lower, upper, None)] elif isinstance(targetcond, Or) and cond in targetcond.args: or_cond = Or(or_cond, cond) or_intervals.append((lower, upper, None)) if or_cond == targetcond: or_intervals.sort(key=lambda x: x[0]) return or_intervals int_expr.sort(key=lambda x: x[1].sort_key() if x[1].is_number else S.NegativeInfinity.sort_key()) int_expr.sort(key=lambda x: x[0].sort_key() if x[0].is_number else S.Infinity.sort_key()) for n in range(len(int_expr)): if len(int_expr[n][0].free_symbols) or len( int_expr[n][1].free_symbols): if isinstance(int_expr[n][1], Min) or int_expr[n][1] == b: newval = Min(*int_expr[n][:-1]) if n > 0 and int_expr[n][0] == int_expr[n - 1][1]: int_expr[n - 1][1] = newval int_expr[n][0] = newval else: newval = Max(*int_expr[n][:-1]) if n < len(int_expr) - 1 and int_expr[n][1] == int_expr[ n + 1][0]: int_expr[n + 1][0] = newval int_expr[n][1] = newval # Add holes to list of intervals if there is a default value, # otherwise raise a ValueError. holes = [] curr_low = a for int_a, int_b, expr in int_expr: if (curr_low < int_a) == True: holes.append([curr_low, Min(b, int_a), default]) elif (curr_low >= int_a) != True: holes.append([curr_low, Min(b, int_a), default]) curr_low = Min(b, int_b) if (curr_low < b) == True: holes.append([Min(b, curr_low), b, default]) elif (curr_low >= b) != True: holes.append([Min(b, curr_low), b, default]) if holes and default is not None: int_expr.extend(holes) if targetcond == True: return [(h[0], h[1], None) for h in holes] elif holes and default is None: raise ValueError("Called interval evaluation over piecewise " "function on undefined intervals %s" % ", ".join([str((h[0], h[1])) for h in holes])) return int_expr
def _handle_irel(self, x, handler): """Return either None (if the conditions of self depend only on x) else a Piecewise expression whose expressions (handled by the handler that was passed) are paired with the governing x-independent relationals, e.g. Piecewise((A, a(x) & b(y)), (B, c(x) | c(y)) -> Piecewise( (handler(Piecewise((A, a(x) & True), (B, c(x) | True)), b(y) & c(y)), (handler(Piecewise((A, a(x) & True), (B, c(x) | False)), b(y)), (handler(Piecewise((A, a(x) & False), (B, c(x) | True)), c(y)), (handler(Piecewise((A, a(x) & False), (B, c(x) | False)), True)) """ # identify governing relationals rel = self.atoms(Relational) irel = list(ordered([r for r in rel if x not in r.free_symbols and r not in (S.true, S.false)])) if irel: args = {} exprinorder = [] for truth in product((1, 0), repeat=len(irel)): reps = dict(zip(irel, truth)) # only store the true conditions since the false are implied # when they appear lower in the Piecewise args if 1 not in truth: cond = None # flag this one so it doesn't get combined else: andargs = Tuple(*[i for i in reps if reps[i]]) free = list(andargs.free_symbols) if len(free) == 1: from sympy.solvers.inequalities import ( reduce_inequalities, _solve_inequality) try: t = reduce_inequalities(andargs, free[0]) # ValueError when there are potentially # nonvanishing imaginary parts except (ValueError, NotImplementedError): # at least isolate free symbol on left t = And(*[_solve_inequality( a, free[0], linear=True) for a in andargs]) else: t = And(*andargs) if t is S.false: continue # an impossible combination cond = t expr = handler(self.xreplace(reps)) if isinstance(expr, self.func) and len(expr.args) == 1: expr, econd = expr.args[0] cond = And(econd, True if cond is None else cond) # the ec pairs are being collected since all possibilities # are being enumerated, but don't put the last one in since # its expr might match a previous expression and it # must appear last in the args if cond is not None: args.setdefault(expr, []).append(cond) # but since we only store the true conditions we must maintain # the order so that the expression with the most true values # comes first exprinorder.append(expr) # convert collected conditions as args of Or for k in args: args[k] = Or(*args[k]) # take them in the order obtained args = [(e, args[e]) for e in uniq(exprinorder)] # add in the last arg args.append((expr, True)) # if any condition reduced to True, it needs to go last # and there should only be one of them or else the exprs # should agree trues = [i for i in range(len(args)) if args[i][1] is S.true] if not trues: # make the last one True since all cases were enumerated e, c = args[-1] args[-1] = (e, S.true) else: assert len(set([e for e, c in [args[i] for i in trues]])) == 1 args.append(args.pop(trues.pop())) while trues: args.pop(trues.pop()) return Piecewise(*args)