def test_union(self): dom_a = Domain([-2,0,2]) dom_b = Domain([-2,-1,1,2]) self.assertNotEqual(dom_a.union(dom_b), dom_a) self.assertNotEqual(dom_a.union(dom_b), dom_b) self.assertEqual(dom_a.union(dom_b), Domain([-2,-1,0,1,2])) self.assertEqual(dom_b.union(dom_a), Domain([-2,-1,0,1,2]))
def test__eq__close(self): tol = .8*HTOL d4 = Domain([-2,0,1,3,5]) d5 = Domain([-2*(1+tol),0-tol,1+tol,3*(1+tol),5*(1-tol)]) d6 = Domain([-2*(1+2*tol),0-2*tol,1+2*tol,3*(1+2*tol),5*(1-2*tol)]) self.assertEqual(d4,d5) self.assertNotEqual(d4,d6)
def test_support(self): dom_a = Domain([-2, 1]) dom_b = Domain([-2, 0, 1]) dom_c = Domain(np.linspace(-10, 10, 51)) self.assertTrue(np.all(dom_a.support.view(np.ndarray) == [-2, 1])) self.assertTrue(np.all(dom_b.support.view(np.ndarray) == [-2, 1])) self.assertTrue(np.all(dom_c.support.view(np.ndarray) == [-10, 10]))
def test_size(self): dom_a = Domain([-2, 1]) dom_b = Domain([-2, 0, 1]) dom_c = Domain(np.linspace(-10, 10, 51)) self.assertEqual(dom_a.size, 2) self.assertEqual(dom_b.size, 3) self.assertEqual(dom_c.size, 51)
def test__contains__close(self): tol = .8 * HTOL d1 = Domain([-1, 2]) d2 = Domain([-1 - tol, 2 + 2 * tol]) d3 = Domain([-1 - 2 * tol, 2 + 4 * tol]) self.assertTrue(d1 in d2) self.assertTrue(d2 in d1) self.assertFalse(d3 in d1)
def test_breakpoints_in_close(self): tol = .8*HTOL d1 = Domain([-1, 0, 1]) d2 = Domain([-2, 0-tol, 1+tol, 3]) result = d1.breakpoints_in(d2) self.assertFalse(result[0]) self.assertTrue(result[1]) self.assertTrue(result[2])
def test_breakpoints_in_close(self): tol = .8 * HTOL d1 = Domain([-1, 0, 1]) d2 = Domain([-2, 0 - tol, 1 + tol, 3]) result = d1.breakpoints_in(d2) self.assertFalse(result[0]) self.assertTrue(result[1]) self.assertTrue(result[2])
def test_domain(self): d1 = Domain([-1, 1]) d2 = Domain([-1, 0, 1, 2]) self.assertIsInstance(self.f0.domain, np.ndarray) self.assertIsInstance(self.f1.domain, Domain) self.assertIsInstance(self.f2.domain, Domain) self.assertEqual(self.f0.domain.size, 0) self.assertEqual(self.f1.domain, d1) self.assertEqual(self.f2.domain, d2)
def test__iter__(self): dom_a = Domain([-2, 1]) dom_b = Domain([-2, 0, 1]) dom_c = Domain([-1, 0, 1, 2]) res_a = (-2, 1) res_b = (-2, 0, 1) res_c = (-1, 0, 1, 2) self.assertTrue(all([x == y for x, y in zip(dom_a, res_a)])) self.assertTrue(all([x == y for x, y in zip(dom_b, res_b)])) self.assertTrue(all([x == y for x, y in zip(dom_c, res_c)]))
def test_union(self): dom_a = Domain([-2, 0, 2]) dom_b = Domain([-2, -1, 1, 2]) self.assertNotEqual(dom_a.union(dom_b), dom_a) self.assertNotEqual(dom_a.union(dom_b), dom_b) self.assertEqual(dom_a.union(dom_b), Domain([-2, -1, 0, 1, 2])) self.assertEqual(dom_b.union(dom_a), Domain([-2, -1, 0, 1, 2]))
def test_restrict(self): dom1 = Domain(np.linspace(-2, 1.5, 13)) dom2 = Domain(np.linspace(-1.7, 0.93, 17)) dom3 = dom1.merge(dom2).restrict(dom2) f = chebfun(cos, dom1).restrict(dom2) g = chebfun(cos, dom3) self.assertEquals(f.domain, g.domain) for n, fun in enumerate(f): # we allow two degrees of freedom difference either way # TODO: once standard chop is fixed, may be able to reduce 4 to 0 self.assertLessEqual(fun.size - g.funs[n].size, 4)
def test_restrict(self): dom1 = Domain(np.linspace(-2,1.5,13)) dom2 = Domain(np.linspace(-1.7,0.93,17)) dom3 = dom1.merge(dom2).restrict(dom2) f = chebfun(cos, dom1).restrict(dom2) g = chebfun(cos, dom3) self.assertEquals(f.domain, g.domain) for n, fun in enumerate(f): # we allow two degrees of freedom difference either way # TODO: once standard chop is fixed, may be able to reduce 4 to 0 self.assertLessEqual(fun.size-g.funs[n].size, 4)
def test_intervals(self): dom_a = Domain([-2,1]) dom_b = Domain([-2,0,1]) dom_c = Domain([-1,0,1,2]) res_a = [(-2,1)] res_b = [(-2,0), (0,1)] res_c = [(-1,0), (0,1), (1,2)] self.assertTrue(all([itvl==Interval(a,b) for itvl, (a,b) in zip(dom_a.intervals, res_a)])) self.assertTrue(all([itvl==Interval(a,b) for itvl, (a,b) in zip(dom_b.intervals, res_b)])) self.assertTrue(all([itvl==Interval(a,b) for itvl, (a,b) in zip(dom_c.intervals, res_c)]))
def test__contains__(self): d1 = Domain([-2, 0, 1, 3, 5]) d2 = Domain([-1, 2]) d3 = Domain(np.linspace(-10, 10, 1000)) d4 = Domain([-1, 0, 1, 2]) self.assertTrue(d2 in d1) self.assertTrue(d1 in d3) self.assertTrue(d2 in d3) self.assertTrue(d2 in d3) self.assertTrue(d2 in d4) self.assertTrue(d4 in d2) self.assertFalse(d1 in d2) self.assertFalse(d3 in d1) self.assertFalse(d3 in d2)
def piecewise_constant(domain=[-1, 0, 1], values=[0, 1]): """Initlialise a piecewise constant Chebfun""" funs = [] intervals = [x for x in Domain(domain).intervals] for interval, value in zip(intervals, values): funs.append(Bndfun.initconst(value, interval)) return Chebfun(funs)
def test__break_2(self): altdom = Domain([-2, 3]) newdom = self.f1.domain.union(altdom) f1_new = self.f1._break(newdom) self.assertEqual(f1_new.domain, newdom) self.assertNotEqual(f1_new.domain, altdom) xx = np.linspace(-2, 3, 1000) error = infnorm(self.f1(xx) - f1_new(xx)) self.assertLessEqual(error, 3 * eps)
def test__break_3(self): altdom = Domain(np.linspace(-2, 3, 1000)) newdom = self.f2.domain.union(altdom) f2_new = self.f2._break(newdom) self.assertEqual(f2_new.domain, newdom) self.assertNotEqual(f2_new.domain, altdom) self.assertNotEqual(f2_new.domain, self.f2.domain) xx = np.linspace(-2, 3, 1000) error = infnorm(self.f2(xx) - f2_new(xx)) self.assertLessEqual(error, 3 * eps)
def test_restrict(self): dom_a = Domain([-2,-1,0,1]) dom_b = Domain([-1.5,-.5,0.5]) dom_c = Domain(np.linspace(-2,1,16)) self.assertEqual(dom_a.restrict(dom_b), Domain([-1.5,-1,-.5,0,.5])) self.assertEqual(dom_a.restrict(dom_c), dom_c) self.assertEqual(dom_a.restrict(dom_a), dom_a) self.assertEqual(dom_b.restrict(dom_b), dom_b) self.assertEqual(dom_c.restrict(dom_c), dom_c) # tests to check if catch breakpoints that are different by eps # (linspace introduces these effects) dom_d = Domain(np.linspace(-.4,.4,2)) self.assertEqual(dom_c.restrict(dom_d), Domain([-.4,-.2,0,.2,.4]))
def initfun_fixedlen(cls, f, n, domain=None): nn = np.array(n) if nn.size < 2: funs = generate_funs(domain, Bndfun.initfun_fixedlen, { 'f': f, 'n': n }) else: domain = Domain(domain if domain is not None else prefs.domain) if not nn.size == domain.size - 1: raise BadFunLengthArgument funs = [] for interval, length in zip(domain.intervals, nn): funs.append(Bndfun.initfun_fixedlen(f, interval, length)) return cls(funs)
def test_breakpoints_in(self): d1 = Domain([-1,0,1]) d2 = Domain([-2,0.5,1,3]) result1 = d1.breakpoints_in(d2) self.assertIsInstance(result1, np.ndarray) self.assertTrue(result1.size, 3) self.assertFalse(result1[0]) self.assertFalse(result1[1]) self.assertTrue(result1[2]) result2 = d2.breakpoints_in(d1) self.assertIsInstance(result2, np.ndarray) self.assertTrue(result2.size, 4) self.assertFalse(result2[0]) self.assertFalse(result2[1]) self.assertTrue(result2[2]) self.assertFalse(result2[3]) self.assertTrue(d1.breakpoints_in(d1).all()) self.assertTrue(d2.breakpoints_in(d2).all()) self.assertFalse(d1.breakpoints_in(Domain([-5,5])).any()) self.assertFalse(d2.breakpoints_in(Domain([-5,5])).any())
def test_restrict_raises(self): dom_a = Domain([-2, -1, 0, 1]) dom_b = Domain([-1.5, -.5, 0.5]) dom_c = Domain(np.linspace(-2, 1, 16)) self.assertRaises(NotSubdomain, dom_b.restrict, dom_a) self.assertRaises(NotSubdomain, dom_b.restrict, dom_c)
def test_restrict(self): dom_a = Domain([-2, -1, 0, 1]) dom_b = Domain([-1.5, -.5, 0.5]) dom_c = Domain(np.linspace(-2, 1, 16)) self.assertEqual(dom_a.restrict(dom_b), Domain([-1.5, -1, -.5, 0, .5])) self.assertEqual(dom_a.restrict(dom_c), dom_c) self.assertEqual(dom_a.restrict(dom_a), dom_a) self.assertEqual(dom_b.restrict(dom_b), dom_b) self.assertEqual(dom_c.restrict(dom_c), dom_c) # tests to check if catch breakpoints that are different by eps # (linspace introduces these effects) dom_d = Domain(np.linspace(-.4, .4, 2)) self.assertEqual(dom_c.restrict(dom_d), Domain([-.4, -.2, 0, .2, .4]))
def domain(self): '''Construct and return a Domain object corresponding to self''' return Domain.from_chebfun(self)
def test_from_chebfun(self): ff = chebfun(lambda x: np.cos(x), np.linspace(-10,10,11)) Domain.from_chebfun(ff)
def _restrict(self, subinterval): '''Restrict a chebfun to a subinterval, without simplifying''' newdom = self.domain.restrict(Domain(subinterval)) return self._break(newdom)
def test_breakpoints_in(self): d1 = Domain([-1, 0, 1]) d2 = Domain([-2, 0.5, 1, 3]) result1 = d1.breakpoints_in(d2) self.assertIsInstance(result1, np.ndarray) self.assertTrue(result1.size, 3) self.assertFalse(result1[0]) self.assertFalse(result1[1]) self.assertTrue(result1[2]) result2 = d2.breakpoints_in(d1) self.assertIsInstance(result2, np.ndarray) self.assertTrue(result2.size, 4) self.assertFalse(result2[0]) self.assertFalse(result2[1]) self.assertTrue(result2[2]) self.assertFalse(result2[3]) self.assertTrue(d1.breakpoints_in(d1).all()) self.assertTrue(d2.breakpoints_in(d2).all()) self.assertFalse(d1.breakpoints_in(Domain([-5, 5])).any()) self.assertFalse(d2.breakpoints_in(Domain([-5, 5])).any())
class Chebfun(object): @classmethod def initempty(cls): return cls(np.array([])) @classmethod def initconst(cls, c, domain=DefaultPrefs.domain): funs = generate_funs(domain, Bndfun.initconst, [c]) return cls(funs) @classmethod def initidentity(cls, domain=DefaultPrefs.domain): funs = generate_funs(domain, Bndfun.initidentity) return cls(funs) @classmethod def initfun(cls, f, domain=DefaultPrefs.domain, n=None): if n: return Chebfun.initfun_fixedlen(f, n, domain) else: return Chebfun.initfun_adaptive(f, domain) @classmethod def initfun_adaptive(cls, f, domain=DefaultPrefs.domain): funs = generate_funs(domain, Bndfun.initfun_adaptive, [f]) return cls(funs) @classmethod def initfun_fixedlen(cls, f, n, domain=DefaultPrefs.domain): domain = np.array(domain) nn = np.array(n) if nn.size == 1: nn = nn * np.ones(domain.size - 1) elif nn.size > 1: if nn.size != domain.size - 1: raise BadFunLengthArgument if domain.size < 2: raise BadDomainArgument funs = np.array([]) intervals = zip(domain[:-1], domain[1:]) for interval, length in zip(intervals, nn): interval = Interval(*interval) fun = Bndfun.initfun_fixedlen(f, interval, length) funs = np.append(funs, fun) return cls(funs) # -------------------- # operator overloads # -------------------- def __add__(self, f): return self._apply_binop(f, operator.add) @self_empty(np.array([])) @float_argument def __call__(self, x): # initialise output out = np.full(x.size, np.nan) # evaluate a fun when x is an interior point for fun in self: idx = fun.interval.isinterior(x) out[idx] = fun(x[idx]) # evaluate the breakpoint data for x at a breakpoint breakpoints = self.breakpoints for breakpoint in breakpoints: out[x == breakpoint] = self.breakdata[breakpoint] # first and last funs used to evaluate outside of the chebfun domain lpts, rpts = x < breakpoints[0], x > breakpoints[-1] out[lpts] = self.funs[0](x[lpts]) out[rpts] = self.funs[-1](x[rpts]) return out def __init__(self, funs): self.funs = check_funs(funs) self.breakdata = compute_breakdata(self.funs) self.transposed = False def __iter__(self): return self.funs.__iter__() def __mul__(self, f): return self._apply_binop(f, operator.mul) def __neg__(self): return self.__class__(-self.funs) def __pos__(self): return self def __pow__(self, f): return self._apply_binop(f, operator.pow) def __rtruediv__(self, c): # Executed when truediv(f, self) fails, which is to say whenever c # is not a Chebfun. We proceeed on the assumption f is a scalar. constfun = lambda x: .0 * x + c newfuns = [] for fun in self: quotnt = lambda x: constfun(x) / fun(x) newfun = fun.initfun_adaptive(quotnt, fun.interval) newfuns.append(newfun) return self.__class__(newfuns) @self_empty('chebfun<empty>') def __repr__(self): rowcol = 'row' if self.transposed else 'column' numpcs = self.funs.size plural = '' if numpcs == 1 else 's' header = 'chebfun {} ({} smooth piece{})\n'\ .format(rowcol, numpcs, plural) toprow = ' interval length endpoint values\n' tmplat = '[{:8.2g},{:8.2g}] {:6} {:8.2g} {:8.2g}\n' rowdta = '' for fun in self: endpts = fun.support xl, xr = endpts fl, fr = fun(endpts) row = tmplat.format(xl, xr, fun.size, fl, fr) rowdta += row btmrow = 'vertical scale = {:3.2g}'.format(self.vscale) btmxtr = '' if numpcs == 1 else \ ' total length = {}'.format(sum([f.size for f in self])) return header + toprow + rowdta + btmrow + btmxtr def __rsub__(self, f): return -(self - f) @cast_arg_to_chebfun def __rpow__(self, f): return f**self def __truediv__(self, f): return self._apply_binop(f, operator.truediv) __rmul__ = __mul__ __div__ = __truediv__ __rdiv__ = __rtruediv__ __radd__ = __add__ def __str__(self): rowcol = 'row' if self.transposed else 'col' out = '<chebfun-{},{},{}>\n'.format(rowcol, self.funs.size, sum([f.size for f in self])) return out def __sub__(self, f): return self._apply_binop(f, operator.sub) # ------------------ # internal helpers # ------------------ @self_empty() def _apply_binop(self, f, op): '''Funnel method used in the implementation of Chebfun binary operators. The high-level idea is to first break each chebfun into a series of pieces corresponding to the union of the domains of each before applying the supplied binary operator and simplifying. In the case of the second argument being a scalar we don't need to do the simplify step, since at the Tech-level these operations are are defined such that there is no change in the number of coefficients. ''' try: if f.isempty: return f except: pass if np.isscalar(f): chbfn1 = self chbfn2 = f * np.ones(self.funs.size) simplify = False else: newdom = self.domain.union(f.domain) chbfn1 = self._break(newdom) chbfn2 = f._break(newdom) simplify = True newfuns = [] for fun1, fun2 in zip(chbfn1, chbfn2): newfun = op(fun1, fun2) if simplify: newfun = newfun.simplify() newfuns.append(newfun) return self.__class__(newfuns) def _break(self, targetdomain): '''Resamples self to the supplied Domain object, targetdomain. This method is intended as private since one will typically need to have called either Domain.union(f), or Domain.merge(f) prior to call.''' newfuns = [] subintervals = targetdomain.intervals interval = next(subintervals) # next(..) for Python2/3 compatibility for fun in self: while interval in fun.interval: newfun = fun.restrict(interval) newfuns.append(newfun) try: interval = next(subintervals) except StopIteration: break return self.__class__(newfuns) # ------------ # properties # ------------ @property def breakpoints(self): return np.array([x for x in self.breakdata.keys()]) @property @self_empty(np.array([])) def domain(self): '''Construct and return a Domain object corresponding to self''' return Domain.from_chebfun(self) @property @self_empty(Domain([])) def support(self): '''Return an array containing the first and last breakpoints''' return Domain(self.breakpoints[[0, -1]]) @property @self_empty(0.) def hscale(self): return np.float(np.abs(self.support).max()) @property @self_empty(False) def isconst(self): # TODO: find an abstract way of referencing funs[0].coeffs[0] c = self.funs[0].coeffs[0] return all(fun.isconst and fun.coeffs[0] == c for fun in self) @property def isempty(self): return self.funs.size == 0 @property @self_empty(0.) def vscale(self): return np.max([fun.vscale for fun in self]) @property @self_empty() def x(self): '''Return a Chebfun representing the identity the support of self''' return self.__class__.initidentity(self.support) # ----------- # utilities # ----------- def copy(self): return self.__class__([fun.copy() for fun in self]) @self_empty() def _restrict(self, subinterval): '''Restrict a chebfun to a subinterval, without simplifying''' newdom = self.domain.restrict(Domain(subinterval)) return self._break(newdom) @self_empty() def simplify(self): '''Simplify each fun in the chebfun''' return self.__class__([fun.simplify() for fun in self]) def restrict(self, subinterval): '''Restrict a chebfun to a subinterval''' return self._restrict(subinterval).simplify() @cache @self_empty(np.array([])) def roots(self): '''Compute the roots of a Chebfun, i.e., the set of values x for which f(x) = 0.''' allrts = [] prvrts = np.array([]) htol = 1e2 * self.hscale * DefaultPrefs.eps for fun in self: rts = fun.roots() # ignore first root if equal to the last root of previous fun # TODO: there could be multiple roots at breakpoints if prvrts.size > 0 and rts.size > 0: if abs(prvrts[-1] - rts[0]) <= htol: rts = rts[1:] allrts.append(rts) prvrts = rts return np.concatenate([x for x in allrts]) # ---------- # calculus # ---------- def cumsum(self): newfuns = [] prevfun = None for fun in self: integral = fun.cumsum() if prevfun: # enforce continuity by adding the function value # at the right endpoint of the previous fun _, fb = prevfun.endvalues integral = integral + fb newfuns.append(integral) prevfun = integral return self.__class__(newfuns) def diff(self): dfuns = np.array([fun.diff() for fun in self]) return self.__class__(dfuns) def sum(self): return np.sum([fun.sum() for fun in self]) # ---------- # plotting # ---------- def plot(self, ax=None, *args, **kwargs): ax = ax or plt.gca() a, b = self.support xx = np.linspace(a, b, 2001) ax.plot(xx, self(xx), *args, **kwargs) return ax def plotcoeffs(self, ax=None, *args, **kwargs): ax = ax or plt.gca() for fun in self: fun.plotcoeffs(ax=ax) return ax # ---------- # utilities # ---------- @self_empty() def absolute(self): '''Absolute value of a Chebfun''' newdom = self.domain.merge(self.roots()) funs = [x.absolute() for x in self._break(newdom)] return self.__class__(funs) abs = absolute @self_empty() @cast_arg_to_chebfun def maximum(self, other): '''Pointwise maximum of self and another chebfun''' return self._maximum_minimum(other, operator.ge) @self_empty() @cast_arg_to_chebfun def minimum(self, other): '''Pointwise mimimum of self and another chebfun''' return self._maximum_minimum(other, operator.lt) def _maximum_minimum(self, other, comparator): '''Method for computing the pointwise maximum/minimum of two Chebfuns''' roots = (self - other).roots() newdom = self.domain.union(other.domain).merge(roots) switch = newdom.support.merge(roots) keys = .5 * ((-1)**np.arange(switch.size - 1) + 1) if comparator(other(switch[0]), self(switch[0])): keys = 1 - keys funs = np.array([]) for interval, use_self in zip(switch.intervals, keys): subdom = newdom.restrict(interval) if use_self: subfun = self.restrict(subdom) else: subfun = other.restrict(subdom) funs = np.append(funs, subfun.funs) return self.__class__(funs)
def test__ne___result_type(self): d1 = Domain([-2, 0, 1, 3, 5]) d2 = Domain([-2, 0, 1, 3, 5]) d3 = Domain([-1, 1]) self.assertIsInstance(d1 != d2, bool) self.assertIsInstance(d1 != d3, bool)
def domain(self): '''Construct and return a Domain object corresponding to self''' return Domain.from_chebfun(self)
def test_union_close(self): tol = .8 * HTOL dom_a = Domain([-2, 0, 2]) dom_c = Domain([-2 - 2 * tol, -1 + tol, 1 + tol, 2 + 2 * tol]) self.assertEqual(dom_a.union(dom_c), Domain([-2, -1, 0, 1, 2])) self.assertEqual(dom_c.union(dom_a), Domain([-2, -1, 0, 1, 2]))
def support(self): '''Return an array containing the first and last breakpoints''' return Domain(self.breakpoints[[0, -1]])
def test_merge(self): dom_a = Domain([-2,-1,0,1]) dom_b = Domain([-1.5,-.5,0.5]) self.assertEqual(dom_b.merge(dom_a), Domain([-2,-1.5,-1,-.5,0,.5,1]))
def test__init__(self): Domain([-2, 1]) Domain([-2, 0, 1]) Domain(np.array([-2, 1])) Domain(np.array([-2, 0, 1])) Domain(np.linspace(-10, 10, 51))
def test_merge(self): dom_a = Domain([-2, -1, 0, 1]) dom_b = Domain([-1.5, -.5, 0.5]) self.assertEqual(dom_b.merge(dom_a), Domain([-2, -1.5, -1, -.5, 0, .5, 1]))
def test_from_chebfun(self): ff = chebfun(lambda x: np.cos(x), np.linspace(-10, 10, 11)) Domain.from_chebfun(ff)
def test__eq__(self): d1 = Domain([-2, 0, 1, 3, 5]) d2 = Domain([-2, 0, 1, 3, 5]) d3 = Domain([-1, 1]) self.assertEqual(d1, d2) self.assertNotEqual(d1, d3)
def test_union_close(self): tol = .8*HTOL dom_a = Domain([-2,0,2]) dom_c = Domain([-2-2*tol,-1+tol,1+tol,2+2*tol]) self.assertEqual(dom_a.union(dom_c), Domain([-2,-1,0,1,2])) self.assertEqual(dom_c.union(dom_a), Domain([-2,-1,0,1,2]))
def test_union_raises(self): dom_a = Domain([-2, 0]) dom_b = Domain([-2, 3]) self.assertRaises(SupportMismatch, dom_a.union, dom_b) self.assertRaises(SupportMismatch, dom_b.union, dom_a)
def test__ne__(self): d1 = Domain([-2, 0, 1, 3, 5]) d2 = Domain([-2, 0, 1, 3, 5]) d3 = Domain([-1, 1]) self.assertFalse(d1 != d2) self.assertTrue(d1 != d3)