def test_cache_clearing(): dispatch = Dispatcher() @dispatch(object) def f(x): return 1 @dispatch(List(int)) def f(x): return 1 f(1) # Check that cache is used. assert len(f.methods) == 2 assert len(f.precedences) == 2 assert f._parametric dispatch.clear_cache() # Check that cache is cleared. assert len(f.methods) == 0 assert len(f.precedences) == 0 assert not f._parametric f(1) clear_all_cache() # Again check that cache is cleared. assert len(f.methods) == 0 assert len(f.precedences) == 0 assert len(subclasscheck_cache) == 0 assert not f._parametric
def test_cache_clearing(): dispatch = Dispatcher() @dispatch def f(x: object): return 1 @dispatch def f(x: List[int]): return 1 f(1) # Check that cache is used. assert len(f._methods) == 2 assert len(f._precedences) == 2 assert f._runtime_type_of dispatch.clear_cache() # Check that cache is cleared. assert len(f._methods) == 0 assert len(f._precedences) == 0 assert not f._runtime_type_of f(1) clear_all_cache() # Again check that cache is cleared. assert len(f._methods) == 0 assert len(f._precedences) == 0 assert len(subclasscheck_cache) == 0 assert not f._runtime_type_of
def test_invoke_in_class(): dispatch = Dispatcher() class A: def do(self, x): return "fallback" class B(A): @dispatch def do(self, x: int): return "int" class C(B): @dispatch def do(self, x: str): return "str" c = C() # Test bound calls. assert c.do.invoke(str)("1") == "str" assert c.do.invoke(int)(1) == "int" assert c.do.invoke(float)(1.0) == "fallback" # Test unbound calls. assert C.do.invoke(C, str)(c, "1") == "str" assert C.do.invoke(C, int)(c, 1) == "int" assert C.do.invoke(C, float)(c, 1.0) == "fallback"
class DerivativeFunction(WrappedFunction): """Compute the derivative of a function. Args: e (:class:`.field_function.Element`): Function to compute the derivative of. *derivs (tensor): Per input, the index of the dimension which to take the derivative of. Set to `None` to not take a derivative. """ _dispatch = Dispatcher(in_class=Self) def __init__(self, e, *derivs): WrappedFunction.__init__(self, e) self.derivs = derivs @_dispatch(object, Formatter) def display(self, e, formatter): if len(self.derivs) == 1: derivs = '({})'.format(self.derivs[0]) else: derivs = self.derivs return 'd{} {}'.format(derivs, e) @_dispatch(Self) def __eq__(self, other): return self[0] == other[0] and \ tuple_equal(self.derivs, other.derivs)
def test_invoke(): dispatch = Dispatcher() @dispatch() def f(): return "fallback" @dispatch def f(x: int): return "int" @dispatch def f(x: str): return "str" @dispatch def f(x: Union[int, str, float]): return "int, str, or float" assert f() == "fallback" assert f(1) == "int" assert f("1") == "str" assert f(1.0) == "int, str, or float" assert f.invoke()() == "fallback" assert f.invoke(int)("1") == "int" assert f.invoke(str)(1) == "str" assert f.invoke(float)(1) == "int, str, or float" assert f.invoke(Union[int, str])(1) == "int, str, or float" assert f.invoke(Union[int, str, float])(1) == "int, str, or float"
class Constant(LowRank, Referentiable): """Constant symmetric positive-definite matrix. Args: constant (scalar): Constant of the matrix. rows (scalar): Number of rows. cols (scalar, optional): Number of columns. Defaults to the number of rows. """ _dispatch = Dispatcher(in_class=Self) def __init__(self, constant, rows, cols=None): self.constant = constant self.rows = rows self.cols = rows if cols is None else cols # Construct and initialise the low-rank representation. left = B.ones(B.dtype(self.constant), self.rows, 1) if self.rows is self.cols: right = left else: right = B.ones(B.dtype(self.constant), self.cols, 1) middle = B.expand_dims(B.expand_dims(self.constant, axis=0), axis=0) LowRank.__init__(self, left=left, right=right, middle=middle) @classmethod def from_(cls, constant, ref): return cls(B.cast(B.dtype(ref), constant), *B.shape(ref)) def __eq__(self, other): return B.shape(self) == B.shape(other) \ and self.constant == other.constant
class TensorProductMean(Mean, TensorProductFunction): _dispatch = Dispatcher(in_class=Self) @_dispatch(B.Numeric) @uprank def __call__(self, x): return uprank(self.fs[0](x))
class CorrectiveKernel(Kernel, Referentiable): """Kernel that adds the corrective variance in sparse conditioning. Args: k_zi (:class:`.kernel.Kernel`): Kernel between the processes corresponding to the left input and the inducing points respectively. k_zj (:class:`.kernel.Kernel`): Kernel between the processes corresponding to the right input and the inducing points respectively. z (input): Locations of the inducing points. A (tensor): Corrective matrix. L (tensor): Kernel matrix of the inducing points. """ _dispatch = Dispatcher(in_class=Self) def __init__(self, k_zi, k_zj, z, A, K_z): self.k_zi = k_zi self.k_zj = k_zj self.z = z self.A = A self.L = B.cholesky(matrix(K_z)) @_dispatch(object, object, Cache) @cache def __call__(self, x, y, B): return B.qf(self.A, B.trisolve(self.L, self.k_zi(self.z, x)), B.trisolve(self.L, self.k_zj(self.z, y))) @_dispatch(object, object, Cache) @cache def elwise(self, x, y, B): return B.qf_diag(self.A, B.trisolve(self.L, self.k_zi(self.z, x)), B.trisolve(self.L, self.k_zj(self.z, y)))[:, None]
class ZeroKernel(Kernel, ZeroFunction, Referentiable): """Constant kernel of `0`.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(B.Numeric, B.Numeric, Cache) @cache @uprank def __call__(self, x, y, B): if x is y: return Zero(B.dtype(x), B.shape(x)[0]) else: return Zero(B.dtype(x), B.shape(x)[0], B.shape(y)[0]) @_dispatch(B.Numeric, B.Numeric, Cache) @cache @uprank def elwise(self, x, y, B): return B.zeros([B.shape(x)[0], 1], dtype=B.dtype(x)) @property def _stationary(self): return True @property def var(self): return 0 @property def length_scale(self): return 0 @property def period(self): return 0
class Exp(Kernel): """Exponential kernel.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(B.Numeric, B.Numeric) @uprank def __call__(self, x, y): return Dense(B.exp(-B.pw_dists(x, y))) @_dispatch(B.Numeric, B.Numeric) @uprank def elwise(self, x, y): return B.exp(-B.ew_dists(x, y)) @property def _stationary(self): return True @property def var(self): return 1 @property def length_scale(self): return 1 @property def period(self): return np.inf @_dispatch(Self) def __eq__(self, other): return True
class Linear(Kernel): """Linear kernel.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(B.Numeric, B.Numeric) def __call__(self, x, y): if x is y: return LowRank(uprank(x)) else: return LowRank(left=uprank(x), right=uprank(y)) @_dispatch(B.Numeric, B.Numeric) @uprank def elwise(self, x, y): return B.expand_dims(B.sum(B.multiply(x, y), axis=1), axis=1) @property def _stationary(self): return False @property def period(self): return np.inf @_dispatch(Self) def __eq__(self, other): return True
class ScaledKernel(Kernel, ScaledFunction): """Scaled kernel.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(object, object) def __call__(self, x, y): return self._compute(self[0](x, y)) @_dispatch(object, object) def elwise(self, x, y): return self._compute(self[0].elwise(x, y)) def _compute(self, K): return B.multiply(self.scale, K) @property def _stationary(self): return self[0].stationary @property def var(self): return self.scale * self[0].var @property def length_scale(self): return self[0].length_scale @property def period(self): return self[0].period
class InputTransformedKernel(Kernel, InputTransformedFunction): """Input-transformed kernel.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(object, object) @uprank def __call__(self, x, y): return self[0](*self._compute(x, y)) @_dispatch(object, object) @uprank def elwise(self, x, y): return self[0].elwise(*self._compute(x, y)) def _compute(self, x, y): f1, f2 = expand(self.fs) x = x if f1 is None else uprank(f1(x)) y = y if f2 is None else uprank(f2(y)) return x, y @_dispatch(Self) def __eq__(self, other): return self[0] == other[0] and \ tuple_equal(expand(self.fs), expand(other.fs))
class ZeroKernel(Kernel, ZeroFunction): """Constant kernel of `0`.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(B.Numeric, B.Numeric) def __call__(self, x, y): if x is y: return Zero(B.dtype(x), B.shape(uprank(x))[0]) else: return Zero(B.dtype(x), B.shape(uprank(x))[0], B.shape(uprank(y))[0]) @_dispatch(B.Numeric, B.Numeric) @uprank def elwise(self, x, y): return B.zeros(B.dtype(x), B.shape(x)[0], 1) @property def _stationary(self): return True @property def var(self): return 0 @property def length_scale(self): return 0 @property def period(self): return 0
class ScaledElement(WrappedElement, Referentiable): """Scaled element. Args: e (:class:`.field.Element`): Element to scale. scale (tensor): Scale. """ _dispatch = Dispatcher(in_class=Self) def __init__(self, e, scale): WrappedElement.__init__(self, e) self.scale = scale @property def num_factors(self): return self[0].num_factors + 1 @_dispatch(object, Formatter) def display(self, e, formatter): return '{} * {}'.format(formatter(self.scale), e) def factor(self, i): if i >= self.num_factors: raise IndexError('Index out of range.') else: return self.scale if i == 0 else self[0].factor(i - 1) @_dispatch(Self) def __eq__(self, other): return B.all(self.scale == other.scale) and self[0] == other[0]
class MultiOutputMean(Mean, Referentiable): """A generic multi-output mean. Args: *ps (instance of :class:`.graph.GP`): Processes that make up the multi-valued process. """ _dispatch = Dispatcher(in_class=Self) def __init__(self, *ps): self.means = ps[0].graph.means self.ps = ps @_dispatch(B.Numeric) def __call__(self, x): return self(MultiInput(*(p(x) for p in self.ps))) @_dispatch(At) def __call__(self, x): return self.means[type_parameter(x)](x.get()) @_dispatch(MultiInput) def __call__(self, x): return B.concat(*[self(xi) for xi in x.get()], axis=0) def __str__(self): ks = [str(self.means[p]) for p in self.ps] return 'MultiOutputMean({})'.format(', '.join(ks))
def test_multi_in_class(): dispatch = Dispatcher() class A: @dispatch def f(self, x): return "fallback" @dispatch.multi( ( object, int, ), ( object, str, ), ) def f(self, x: Union[int, str]): return "int or str" a = A() assert a.f(1) == "int or str" assert a.f("1") == "int or str" assert a.f(1.0) == "fallback"
class AbstractObservations(metaclass=Referentiable): """Abstract base class for observations.""" _dispatch = Dispatcher(in_class=Self) @_dispatch({B.Numeric, Input}, B.Numeric, [PromisedGP]) def __init__(self, x, y, ref=None): self._ref = ref self.x = ensure_at(x, self._ref) self.y = y self.graph = type_parameter(self.x).graph @_dispatch([Union(tuple, list, PromisedGP)]) def __init__(self, *pairs, **kw_args): # Check whether there's a reference. self._ref = kw_args['ref'] if 'ref' in kw_args else None # Ensure `At` for all pairs. pairs = [(ensure_at(x, self._ref), y) for x, y in pairs] # Get the graph from the first pair. self.graph = type_parameter(pairs[0][0]).graph # Extend the graph by the Cartesian product `p` of all processes. p = self.graph.cross(*self.graph.ps) # Condition on the newly created vector-valued GP. xs, ys = zip(*pairs) self.x = p(MultiInput(*xs)) self.y = B.concat(*[uprank(y) for y in ys], axis=0) @_dispatch({tuple, list}) def __ror__(self, ps): return self.graph.condition(ps, self) def posterior_kernel(self, p_i, p_j): # pragma: no cover """Get the posterior kernel between two processes. Args: p_i (:class:`.graph.GP`): First process. p_j (:class:`.graph.GP`): Second process. Returns: :class:`.kernel.Kernel`: Posterior kernel between the first and second process. """ raise NotImplementedError('Posterior kernel construction not ' 'implemented.') def posterior_mean(self, p): # pragma: no cover """Get the posterior kernel of a process. Args: p (:class:`.graph.GP`): Process. Returns: :class:`.mean.Mean`: Posterior mean of `p`. """ raise NotImplementedError('Posterior mean construction not ' 'implemented.')
class PosteriorKernel(Kernel, Referentiable): """Posterior kernel. Args: k_ij (:class:`.kernel.Kernel`): Kernel between processes corresponding to the left input and the right input respectively. k_zi (:class:`.kernel.Kernel`): Kernel between processes corresponding to the data and the left input respectively. k_zj (:class:`.kernel.Kernel`): Kernel between processes corresponding to the data and the right input respectively. z (input): Locations of data. K_z (:class:`.matrix.Dense`): Kernel matrix of data. """ _dispatch = Dispatcher(in_class=Self) def __init__(self, k_ij, k_zi, k_zj, z, K_z): self.k_ij = k_ij self.k_zi = k_zi self.k_zj = k_zj self.z = z self.K_z = matrix(K_z) @_dispatch(object, object, Cache) @cache def __call__(self, x, y, B): return B.schur(self.k_ij(x, y, B), self.k_zi(self.z, x, B), self.K_z, self.k_zj(self.z, y, B)) @_dispatch(object, object, Cache) @cache def elwise(self, x, y, B): qf_diag = B.qf_diag(self.K_z, self.k_zi(self.z, x, B), self.k_zj(self.z, y, B)) return B.subtract(self.k_ij.elwise(x, y, B), B.expand_dims(qf_diag, 1))
class SumKernel(Kernel, SumFunction, Referentiable): """Sum of kernels.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(object, object, Cache) @cache def __call__(self, x, y, B): return B.add(self[0](x, y, B), self[1](x, y, B)) @_dispatch(object, object, Cache) @cache def elwise(self, x, y, B): return B.add(self[0].elwise(x, y, B), self[1].elwise(x, y, B)) @property def _stationary(self): return self[0].stationary and self[1].stationary @property def var(self): return self[0].var + self[1].var @property def length_scale(self): return (self[0].var * self[0].length_scale + self[1].var * self[1].length_scale) / self.var @property def period(self): return np.inf
class TensorProductKernel(Kernel, TensorProductFunction, Referentiable): """Tensor product kernel.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(B.Numeric, B.Numeric, Cache) @cache @uprank def __call__(self, x, y, B): f1, f2 = expand(self.fs) if x is y and f1 is f2: return LowRank(apply_optional_arg(f1, x, B)) else: return LowRank(left=apply_optional_arg(f1, x, B), right=apply_optional_arg(f2, y, B)) @_dispatch(B.Numeric, B.Numeric, Cache) @cache @uprank def elwise(self, x, y, B): f1, f2 = expand(self.fs) return B.multiply(apply_optional_arg(f1, x, B), apply_optional_arg(f2, y, B)) @_dispatch(Self) def __eq__(self, other): return tuple_equal(expand(self.fs), expand(other.fs))
class InputTransformedKernel(Kernel, InputTransformedFunction, Referentiable): """Input-transformed kernel.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(object, object, Cache) @cache @uprank def __call__(self, x, y, B): return self[0](*self._compute(x, y, B)) @_dispatch(object, object, Cache) @cache @uprank def elwise(self, x, y, B): return self[0].elwise(*self._compute(x, y, B)) def _compute(self, x, y, B): f1, f2 = expand(self.fs) x = x if f1 is None else apply_optional_arg(f1, x, B) y = y if f2 is None else apply_optional_arg(f2, y, B) return x, y, B @_dispatch(Self) def __eq__(self, other): return self[0] == other[0] and \ tuple_equal(expand(self.fs), expand(other.fs))
class ScaledKernel(Kernel, ScaledFunction, Referentiable): """Scaled kernel.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(object, object, Cache) @cache def __call__(self, x, y, B): return self._compute(self[0](x, y, B), B) @_dispatch(object, object, Cache) @cache def elwise(self, x, y, B): return self._compute(self[0].elwise(x, y, B), B) def _compute(self, K, B): return B.multiply(B.cast(self.scale, dtype=B.dtype(K)), K) @property def _stationary(self): return self[0].stationary @property def var(self): return self.scale * self[0].var @property def length_scale(self): return self[0].length_scale @property def period(self): return self[0].period
class LowRank(Dense, Referentiable): """Low-rank symmetric positive-definite matrix. The low-rank matrix is constructed via `left diag(scales) transpose(right)`. Args: left (tensor): Left part of the matrix. right (tensor, optional): Right part of the matrix. Defaults to `left`. scales (tensor, optional): Scaling of the outer products. Defaults to ones. """ _dispatch = Dispatcher(in_class=Self) def __init__(self, left, right=None, middle=None): Dense.__init__(self, None) self.left = left self.right = left if right is None else right if middle is None: self.middle = B.eye(B.dtype(self.left), B.shape(self.left)[1]) else: self.middle = middle # Shorthands: self.l = self.left self.r = self.right self.m = self.middle @_dispatch(Self) def __eq__(self, other): return (self.left == other.left, self.middle == other.middle, self.right == other.right)
class ProductKernel(Kernel, ProductFunction, Referentiable): """Product of two kernels.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(object, object, Cache) @cache def __call__(self, x, y, B): return B.multiply(self[0](x, y, B), self[1](x, y, B)) @_dispatch(object, object, Cache) @cache def elwise(self, x, y, B): return B.multiply(self[0].elwise(x, y, B), self[1].elwise(x, y, B)) @property def _stationary(self): return self[0].stationary and self[1].stationary @property def var(self): return self[0].var * self[1].var @property def length_scale(self): return B.minimum(self[0].length_scale, self[1].length_scale) @property def period(self): return np.inf
class A(metaclass=Referentiable): _dispatch = Dispatcher(in_class=Self) @_dispatch(int) def __call__(self, x): pass @_dispatch(str) def __call__(self, x): pass @_dispatch(int) def go(self, x): pass @_dispatch(str) def go(self, x): pass @_dispatch(int) def go_again(self, x): pass @_dispatch(str) def go_again(self, x): pass
class Linear(Kernel, Referentiable): """Linear kernel.""" _dispatch = Dispatcher(in_class=Self) @_dispatch(B.Numeric, B.Numeric, Cache) @cache @uprank def __call__(self, x, y, B): return LowRank(x) if x is y else LowRank(left=x, right=y) @_dispatch(B.Numeric, B.Numeric, Cache) @cache @uprank def elwise(self, x, y, B): return B.expand_dims(B.sum(B.multiply(x, y), axis=1), 1) @property def _stationary(self): return False @property def period(self): return np.inf @_dispatch(Self) def __eq__(self, other): return True
class NoisyKernel(Kernel, Referentiable): """Noisy observations of a latent process. Uses :class:`.input.Latent` and :class:`.input.Observed`. Args: k_f (:class:`.kernel.Kernel`): Kernel of the latent process. k_n (:class:`.kernel.Kernel`): Kernel of the noise. """ _dispatch = Dispatcher(in_class=Self) def __init__(self, k_f, k_n): self.k_f = k_f self.k_n = k_n self.k_y = k_f + k_n @_dispatch({Latent, Observed}, {Latent, Observed}, Cache) @cache def __call__(self, x, y, cache): return self.k_f(x.get(), y.get()) @_dispatch(Observed, Observed, Cache) @cache def __call__(self, x, y, cache): return self.k_y(x.get(), y.get()) def __str__(self): return 'NoisyKernel({}, {})'.format(self.k_f, self.k_n)
class PosteriorMean(Mean): """Posterior mean. Args: m_i (:class:`.mean.Mean`): Mean of process corresponding to the input. m_z (:class:`.mean.Mean`): Mean of process corresponding to the data. k_zi (:class:`.kernel.Kernel`): Kernel between processes corresponding to the data and the input respectively. z (input): Locations of data. K_z (:class:`.matrix.Dense`): Kernel matrix of data. y (tensor): Observations to condition on. """ _dispatch = Dispatcher(in_class=Self) def __init__(self, m_i, m_z, k_zi, z, K_z, y): self.m_i = m_i self.m_z = m_z self.k_zi = k_zi self.z = z self.K_z = K_z self.y = uprank(y) @_dispatch(object) def __call__(self, x): diff = B.subtract(self.y, self.m_z(self.z)) return B.add(self.m_i(x), B.qf(self.K_z, self.k_zi(self.z, x), diff))
def test_basic_arithmetic(): dispatch = Dispatcher() @dispatch(Number) def f1(x): return np.array([[x**2]]) @dispatch(object) def f1(x): return np.sum(x**2, axis=1)[:, None] @dispatch(Number) def f2(x): return np.array([[x**3]]) @dispatch(object) def f2(x): return np.sum(x**3, axis=1)[:, None] m1 = TensorProductMean(f1) m2 = TensorProductMean(f2) m3 = ZeroMean() x1 = np.random.randn(10, 2) x2 = np.random.randn() yield ok, np.allclose((m1 * m2)(x1), m1(x1) * m2(x1)), 'prod' yield ok, np.allclose((m1 * m2)(x2), m1(x2) * m2(x2)), 'prod 2' yield ok, np.allclose((m1 + m3)(x1), m1(x1) + m3(x1)), 'sum' yield ok, np.allclose((m1 + m3)(x2), m1(x2) + m3(x2)), 'sum 2' yield ok, np.allclose((5. * m1)(x1), 5. * m1(x1)), 'prod 3' yield ok, np.allclose((5. * m1)(x2), 5. * m1(x2)), 'prod 4' yield ok, np.allclose((5. + m1)(x1), 5. + m1(x1)), 'sum 3' yield ok, np.allclose((5. + m1)(x2), 5. + m1(x2)), 'sum 4'