def test_cache_clearing():
    dispatch = Dispatcher()

    @dispatch(object)
    def f(x):
        return 1

    @dispatch(List(int))
    def f(x):
        return 1

    f(1)

    # Check that cache is used.
    assert len(f.methods) == 2
    assert len(f.precedences) == 2
    assert f._parametric

    dispatch.clear_cache()

    # Check that cache is cleared.
    assert len(f.methods) == 0
    assert len(f.precedences) == 0
    assert not f._parametric

    f(1)
    clear_all_cache()

    # Again check that cache is cleared.
    assert len(f.methods) == 0
    assert len(f.precedences) == 0
    assert len(subclasscheck_cache) == 0
    assert not f._parametric
Exemple #2
0
def test_cache_clearing():
    dispatch = Dispatcher()

    @dispatch
    def f(x: object):
        return 1

    @dispatch
    def f(x: List[int]):
        return 1

    f(1)

    # Check that cache is used.
    assert len(f._methods) == 2
    assert len(f._precedences) == 2
    assert f._runtime_type_of

    dispatch.clear_cache()

    # Check that cache is cleared.
    assert len(f._methods) == 0
    assert len(f._precedences) == 0
    assert not f._runtime_type_of

    f(1)
    clear_all_cache()

    # Again check that cache is cleared.
    assert len(f._methods) == 0
    assert len(f._precedences) == 0
    assert len(subclasscheck_cache) == 0
    assert not f._runtime_type_of
def test_invoke_in_class():
    dispatch = Dispatcher()

    class A:
        def do(self, x):
            return "fallback"

    class B(A):
        @dispatch
        def do(self, x: int):
            return "int"

    class C(B):
        @dispatch
        def do(self, x: str):
            return "str"

    c = C()

    # Test bound calls.
    assert c.do.invoke(str)("1") == "str"
    assert c.do.invoke(int)(1) == "int"
    assert c.do.invoke(float)(1.0) == "fallback"

    # Test unbound calls.
    assert C.do.invoke(C, str)(c, "1") == "str"
    assert C.do.invoke(C, int)(c, 1) == "int"
    assert C.do.invoke(C, float)(c, 1.0) == "fallback"
Exemple #4
0
class DerivativeFunction(WrappedFunction):
    """Compute the derivative of a function.

    Args:
        e (:class:`.field_function.Element`): Function to compute the
            derivative of.
        *derivs (tensor): Per input, the index of the dimension which to
            take the derivative of. Set to `None` to not take a derivative.
    """
    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, e, *derivs):
        WrappedFunction.__init__(self, e)
        self.derivs = derivs

    @_dispatch(object, Formatter)
    def display(self, e, formatter):
        if len(self.derivs) == 1:
            derivs = '({})'.format(self.derivs[0])
        else:
            derivs = self.derivs
        return 'd{} {}'.format(derivs, e)

    @_dispatch(Self)
    def __eq__(self, other):
        return self[0] == other[0] and \
               tuple_equal(self.derivs, other.derivs)
def test_invoke():
    dispatch = Dispatcher()

    @dispatch()
    def f():
        return "fallback"

    @dispatch
    def f(x: int):
        return "int"

    @dispatch
    def f(x: str):
        return "str"

    @dispatch
    def f(x: Union[int, str, float]):
        return "int, str, or float"

    assert f() == "fallback"
    assert f(1) == "int"
    assert f("1") == "str"
    assert f(1.0) == "int, str, or float"
    assert f.invoke()() == "fallback"
    assert f.invoke(int)("1") == "int"
    assert f.invoke(str)(1) == "str"
    assert f.invoke(float)(1) == "int, str, or float"
    assert f.invoke(Union[int, str])(1) == "int, str, or float"
    assert f.invoke(Union[int, str, float])(1) == "int, str, or float"
Exemple #6
0
class Constant(LowRank, Referentiable):
    """Constant symmetric positive-definite matrix.

    Args:
        constant (scalar): Constant of the matrix.
        rows (scalar): Number of rows.
        cols (scalar, optional): Number of columns. Defaults to the number of
            rows.
    """
    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, constant, rows, cols=None):
        self.constant = constant
        self.rows = rows
        self.cols = rows if cols is None else cols

        # Construct and initialise the low-rank representation.
        left = B.ones(B.dtype(self.constant), self.rows, 1)
        if self.rows is self.cols:
            right = left
        else:
            right = B.ones(B.dtype(self.constant), self.cols, 1)
        middle = B.expand_dims(B.expand_dims(self.constant, axis=0), axis=0)
        LowRank.__init__(self, left=left, right=right, middle=middle)

    @classmethod
    def from_(cls, constant, ref):
        return cls(B.cast(B.dtype(ref), constant), *B.shape(ref))

    def __eq__(self, other):
        return B.shape(self) == B.shape(other) \
               and self.constant == other.constant
Exemple #7
0
class TensorProductMean(Mean, TensorProductFunction):
    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(B.Numeric)
    @uprank
    def __call__(self, x):
        return uprank(self.fs[0](x))
Exemple #8
0
class CorrectiveKernel(Kernel, Referentiable):
    """Kernel that adds the corrective variance in sparse conditioning.

    Args:
        k_zi (:class:`.kernel.Kernel`): Kernel between the processes
            corresponding to the left input and the inducing points
            respectively.
        k_zj (:class:`.kernel.Kernel`): Kernel between the processes
            corresponding to the right input and the inducing points
            respectively.
        z (input): Locations of the inducing points.
        A (tensor): Corrective matrix.
        L (tensor): Kernel matrix of the inducing points.
    """
    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, k_zi, k_zj, z, A, K_z):
        self.k_zi = k_zi
        self.k_zj = k_zj
        self.z = z
        self.A = A
        self.L = B.cholesky(matrix(K_z))

    @_dispatch(object, object, Cache)
    @cache
    def __call__(self, x, y, B):
        return B.qf(self.A, B.trisolve(self.L, self.k_zi(self.z, x)),
                    B.trisolve(self.L, self.k_zj(self.z, y)))

    @_dispatch(object, object, Cache)
    @cache
    def elwise(self, x, y, B):
        return B.qf_diag(self.A, B.trisolve(self.L, self.k_zi(self.z, x)),
                         B.trisolve(self.L, self.k_zj(self.z, y)))[:, None]
Exemple #9
0
class ZeroKernel(Kernel, ZeroFunction, Referentiable):
    """Constant kernel of `0`."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(B.Numeric, B.Numeric, Cache)
    @cache
    @uprank
    def __call__(self, x, y, B):
        if x is y:
            return Zero(B.dtype(x), B.shape(x)[0])
        else:
            return Zero(B.dtype(x), B.shape(x)[0], B.shape(y)[0])

    @_dispatch(B.Numeric, B.Numeric, Cache)
    @cache
    @uprank
    def elwise(self, x, y, B):
        return B.zeros([B.shape(x)[0], 1], dtype=B.dtype(x))

    @property
    def _stationary(self):
        return True

    @property
    def var(self):
        return 0

    @property
    def length_scale(self):
        return 0

    @property
    def period(self):
        return 0
Exemple #10
0
class Exp(Kernel):
    """Exponential kernel."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(B.Numeric, B.Numeric)
    @uprank
    def __call__(self, x, y):
        return Dense(B.exp(-B.pw_dists(x, y)))

    @_dispatch(B.Numeric, B.Numeric)
    @uprank
    def elwise(self, x, y):
        return B.exp(-B.ew_dists(x, y))

    @property
    def _stationary(self):
        return True

    @property
    def var(self):
        return 1

    @property
    def length_scale(self):
        return 1

    @property
    def period(self):
        return np.inf

    @_dispatch(Self)
    def __eq__(self, other):
        return True
Exemple #11
0
class Linear(Kernel):
    """Linear kernel."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(B.Numeric, B.Numeric)
    def __call__(self, x, y):
        if x is y:
            return LowRank(uprank(x))
        else:
            return LowRank(left=uprank(x), right=uprank(y))

    @_dispatch(B.Numeric, B.Numeric)
    @uprank
    def elwise(self, x, y):
        return B.expand_dims(B.sum(B.multiply(x, y), axis=1), axis=1)

    @property
    def _stationary(self):
        return False

    @property
    def period(self):
        return np.inf

    @_dispatch(Self)
    def __eq__(self, other):
        return True
Exemple #12
0
class ScaledKernel(Kernel, ScaledFunction):
    """Scaled kernel."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(object, object)
    def __call__(self, x, y):
        return self._compute(self[0](x, y))

    @_dispatch(object, object)
    def elwise(self, x, y):
        return self._compute(self[0].elwise(x, y))

    def _compute(self, K):
        return B.multiply(self.scale, K)

    @property
    def _stationary(self):
        return self[0].stationary

    @property
    def var(self):
        return self.scale * self[0].var

    @property
    def length_scale(self):
        return self[0].length_scale

    @property
    def period(self):
        return self[0].period
Exemple #13
0
class InputTransformedKernel(Kernel, InputTransformedFunction):
    """Input-transformed kernel."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(object, object)
    @uprank
    def __call__(self, x, y):
        return self[0](*self._compute(x, y))

    @_dispatch(object, object)
    @uprank
    def elwise(self, x, y):
        return self[0].elwise(*self._compute(x, y))

    def _compute(self, x, y):
        f1, f2 = expand(self.fs)
        x = x if f1 is None else uprank(f1(x))
        y = y if f2 is None else uprank(f2(y))
        return x, y

    @_dispatch(Self)
    def __eq__(self, other):
        return self[0] == other[0] and \
               tuple_equal(expand(self.fs), expand(other.fs))
Exemple #14
0
class ZeroKernel(Kernel, ZeroFunction):
    """Constant kernel of `0`."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(B.Numeric, B.Numeric)
    def __call__(self, x, y):
        if x is y:
            return Zero(B.dtype(x), B.shape(uprank(x))[0])
        else:
            return Zero(B.dtype(x),
                        B.shape(uprank(x))[0],
                        B.shape(uprank(y))[0])

    @_dispatch(B.Numeric, B.Numeric)
    @uprank
    def elwise(self, x, y):
        return B.zeros(B.dtype(x), B.shape(x)[0], 1)

    @property
    def _stationary(self):
        return True

    @property
    def var(self):
        return 0

    @property
    def length_scale(self):
        return 0

    @property
    def period(self):
        return 0
Exemple #15
0
class ScaledElement(WrappedElement, Referentiable):
    """Scaled element.

    Args:
        e (:class:`.field.Element`): Element to scale.
        scale (tensor): Scale.
    """
    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, e, scale):
        WrappedElement.__init__(self, e)
        self.scale = scale

    @property
    def num_factors(self):
        return self[0].num_factors + 1

    @_dispatch(object, Formatter)
    def display(self, e, formatter):
        return '{} * {}'.format(formatter(self.scale), e)

    def factor(self, i):
        if i >= self.num_factors:
            raise IndexError('Index out of range.')
        else:
            return self.scale if i == 0 else self[0].factor(i - 1)

    @_dispatch(Self)
    def __eq__(self, other):
        return B.all(self.scale == other.scale) and self[0] == other[0]
Exemple #16
0
class MultiOutputMean(Mean, Referentiable):
    """A generic multi-output mean.

    Args:
        *ps (instance of :class:`.graph.GP`): Processes that make up the
            multi-valued process.
    """
    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, *ps):
        self.means = ps[0].graph.means
        self.ps = ps

    @_dispatch(B.Numeric)
    def __call__(self, x):
        return self(MultiInput(*(p(x) for p in self.ps)))

    @_dispatch(At)
    def __call__(self, x):
        return self.means[type_parameter(x)](x.get())

    @_dispatch(MultiInput)
    def __call__(self, x):
        return B.concat(*[self(xi) for xi in x.get()], axis=0)

    def __str__(self):
        ks = [str(self.means[p]) for p in self.ps]
        return 'MultiOutputMean({})'.format(', '.join(ks))
def test_multi_in_class():
    dispatch = Dispatcher()

    class A:
        @dispatch
        def f(self, x):
            return "fallback"

        @dispatch.multi(
            (
                object,
                int,
            ),
            (
                object,
                str,
            ),
        )
        def f(self, x: Union[int, str]):
            return "int or str"

    a = A()
    assert a.f(1) == "int or str"
    assert a.f("1") == "int or str"
    assert a.f(1.0) == "fallback"
Exemple #18
0
class AbstractObservations(metaclass=Referentiable):
    """Abstract base class for observations."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch({B.Numeric, Input}, B.Numeric, [PromisedGP])
    def __init__(self, x, y, ref=None):
        self._ref = ref
        self.x = ensure_at(x, self._ref)
        self.y = y
        self.graph = type_parameter(self.x).graph

    @_dispatch([Union(tuple, list, PromisedGP)])
    def __init__(self, *pairs, **kw_args):
        # Check whether there's a reference.
        self._ref = kw_args['ref'] if 'ref' in kw_args else None

        # Ensure `At` for all pairs.
        pairs = [(ensure_at(x, self._ref), y) for x, y in pairs]

        # Get the graph from the first pair.
        self.graph = type_parameter(pairs[0][0]).graph

        # Extend the graph by the Cartesian product `p` of all processes.
        p = self.graph.cross(*self.graph.ps)

        # Condition on the newly created vector-valued GP.
        xs, ys = zip(*pairs)
        self.x = p(MultiInput(*xs))
        self.y = B.concat(*[uprank(y) for y in ys], axis=0)

    @_dispatch({tuple, list})
    def __ror__(self, ps):
        return self.graph.condition(ps, self)

    def posterior_kernel(self, p_i, p_j):  # pragma: no cover
        """Get the posterior kernel between two processes.

        Args:
            p_i (:class:`.graph.GP`): First process.
            p_j (:class:`.graph.GP`): Second process.

        Returns:
            :class:`.kernel.Kernel`: Posterior kernel between the first and
                second process.
        """
        raise NotImplementedError('Posterior kernel construction not '
                                  'implemented.')

    def posterior_mean(self, p):  # pragma: no cover
        """Get the posterior kernel of a process.

        Args:
            p (:class:`.graph.GP`): Process.

        Returns:
            :class:`.mean.Mean`: Posterior mean of `p`.
        """
        raise NotImplementedError('Posterior mean construction not '
                                  'implemented.')
Exemple #19
0
class PosteriorKernel(Kernel, Referentiable):
    """Posterior kernel.

    Args:
        k_ij (:class:`.kernel.Kernel`): Kernel between processes
            corresponding to the left input and the right input respectively.
        k_zi (:class:`.kernel.Kernel`): Kernel between processes
            corresponding to the data and the left input respectively.
        k_zj (:class:`.kernel.Kernel`): Kernel between processes
            corresponding to the data and the right input respectively.
        z (input): Locations of data.
        K_z (:class:`.matrix.Dense`): Kernel matrix of data.
    """

    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, k_ij, k_zi, k_zj, z, K_z):
        self.k_ij = k_ij
        self.k_zi = k_zi
        self.k_zj = k_zj
        self.z = z
        self.K_z = matrix(K_z)

    @_dispatch(object, object, Cache)
    @cache
    def __call__(self, x, y, B):
        return B.schur(self.k_ij(x, y, B), self.k_zi(self.z, x, B), self.K_z,
                       self.k_zj(self.z, y, B))

    @_dispatch(object, object, Cache)
    @cache
    def elwise(self, x, y, B):
        qf_diag = B.qf_diag(self.K_z, self.k_zi(self.z, x, B),
                            self.k_zj(self.z, y, B))
        return B.subtract(self.k_ij.elwise(x, y, B), B.expand_dims(qf_diag, 1))
Exemple #20
0
class SumKernel(Kernel, SumFunction, Referentiable):
    """Sum of kernels."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(object, object, Cache)
    @cache
    def __call__(self, x, y, B):
        return B.add(self[0](x, y, B), self[1](x, y, B))

    @_dispatch(object, object, Cache)
    @cache
    def elwise(self, x, y, B):
        return B.add(self[0].elwise(x, y, B), self[1].elwise(x, y, B))

    @property
    def _stationary(self):
        return self[0].stationary and self[1].stationary

    @property
    def var(self):
        return self[0].var + self[1].var

    @property
    def length_scale(self):
        return (self[0].var * self[0].length_scale +
                self[1].var * self[1].length_scale) / self.var

    @property
    def period(self):
        return np.inf
Exemple #21
0
class TensorProductKernel(Kernel, TensorProductFunction, Referentiable):
    """Tensor product kernel."""
    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(B.Numeric, B.Numeric, Cache)
    @cache
    @uprank
    def __call__(self, x, y, B):
        f1, f2 = expand(self.fs)
        if x is y and f1 is f2:
            return LowRank(apply_optional_arg(f1, x, B))
        else:
            return LowRank(left=apply_optional_arg(f1, x, B),
                           right=apply_optional_arg(f2, y, B))

    @_dispatch(B.Numeric, B.Numeric, Cache)
    @cache
    @uprank
    def elwise(self, x, y, B):
        f1, f2 = expand(self.fs)
        return B.multiply(apply_optional_arg(f1, x, B),
                          apply_optional_arg(f2, y, B))

    @_dispatch(Self)
    def __eq__(self, other):
        return tuple_equal(expand(self.fs), expand(other.fs))
Exemple #22
0
class InputTransformedKernel(Kernel, InputTransformedFunction, Referentiable):
    """Input-transformed kernel."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(object, object, Cache)
    @cache
    @uprank
    def __call__(self, x, y, B):
        return self[0](*self._compute(x, y, B))

    @_dispatch(object, object, Cache)
    @cache
    @uprank
    def elwise(self, x, y, B):
        return self[0].elwise(*self._compute(x, y, B))

    def _compute(self, x, y, B):
        f1, f2 = expand(self.fs)
        x = x if f1 is None else apply_optional_arg(f1, x, B)
        y = y if f2 is None else apply_optional_arg(f2, y, B)
        return x, y, B

    @_dispatch(Self)
    def __eq__(self, other):
        return self[0] == other[0] and \
               tuple_equal(expand(self.fs), expand(other.fs))
Exemple #23
0
class ScaledKernel(Kernel, ScaledFunction, Referentiable):
    """Scaled kernel."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(object, object, Cache)
    @cache
    def __call__(self, x, y, B):
        return self._compute(self[0](x, y, B), B)

    @_dispatch(object, object, Cache)
    @cache
    def elwise(self, x, y, B):
        return self._compute(self[0].elwise(x, y, B), B)

    def _compute(self, K, B):
        return B.multiply(B.cast(self.scale, dtype=B.dtype(K)), K)

    @property
    def _stationary(self):
        return self[0].stationary

    @property
    def var(self):
        return self.scale * self[0].var

    @property
    def length_scale(self):
        return self[0].length_scale

    @property
    def period(self):
        return self[0].period
Exemple #24
0
class LowRank(Dense, Referentiable):
    """Low-rank symmetric positive-definite matrix.

    The low-rank matrix is constructed via `left diag(scales) transpose(right)`.

    Args:
        left (tensor): Left part of the matrix.
        right (tensor, optional): Right part of the matrix. Defaults to `left`.
        scales (tensor, optional): Scaling of the outer products. Defaults to
            ones.
    """
    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, left, right=None, middle=None):
        Dense.__init__(self, None)
        self.left = left
        self.right = left if right is None else right
        if middle is None:
            self.middle = B.eye(B.dtype(self.left), B.shape(self.left)[1])
        else:
            self.middle = middle

        # Shorthands:
        self.l = self.left
        self.r = self.right
        self.m = self.middle

    @_dispatch(Self)
    def __eq__(self, other):
        return (self.left == other.left, self.middle == other.middle,
                self.right == other.right)
Exemple #25
0
class ProductKernel(Kernel, ProductFunction, Referentiable):
    """Product of two kernels."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(object, object, Cache)
    @cache
    def __call__(self, x, y, B):
        return B.multiply(self[0](x, y, B), self[1](x, y, B))

    @_dispatch(object, object, Cache)
    @cache
    def elwise(self, x, y, B):
        return B.multiply(self[0].elwise(x, y, B), self[1].elwise(x, y, B))

    @property
    def _stationary(self):
        return self[0].stationary and self[1].stationary

    @property
    def var(self):
        return self[0].var * self[1].var

    @property
    def length_scale(self):
        return B.minimum(self[0].length_scale, self[1].length_scale)

    @property
    def period(self):
        return np.inf
    class A(metaclass=Referentiable):
        _dispatch = Dispatcher(in_class=Self)

        @_dispatch(int)
        def __call__(self, x):
            pass

        @_dispatch(str)
        def __call__(self, x):
            pass

        @_dispatch(int)
        def go(self, x):
            pass

        @_dispatch(str)
        def go(self, x):
            pass

        @_dispatch(int)
        def go_again(self, x):
            pass

        @_dispatch(str)
        def go_again(self, x):
            pass
Exemple #27
0
class Linear(Kernel, Referentiable):
    """Linear kernel."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(B.Numeric, B.Numeric, Cache)
    @cache
    @uprank
    def __call__(self, x, y, B):
        return LowRank(x) if x is y else LowRank(left=x, right=y)

    @_dispatch(B.Numeric, B.Numeric, Cache)
    @cache
    @uprank
    def elwise(self, x, y, B):
        return B.expand_dims(B.sum(B.multiply(x, y), axis=1), 1)

    @property
    def _stationary(self):
        return False

    @property
    def period(self):
        return np.inf

    @_dispatch(Self)
    def __eq__(self, other):
        return True
Exemple #28
0
class NoisyKernel(Kernel, Referentiable):
    """Noisy observations of a latent process.

    Uses :class:`.input.Latent` and :class:`.input.Observed`.

    Args:
        k_f (:class:`.kernel.Kernel`): Kernel of the latent process.
        k_n (:class:`.kernel.Kernel`): Kernel of the noise.
    """
    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, k_f, k_n):
        self.k_f = k_f
        self.k_n = k_n
        self.k_y = k_f + k_n

    @_dispatch({Latent, Observed}, {Latent, Observed}, Cache)
    @cache
    def __call__(self, x, y, cache):
        return self.k_f(x.get(), y.get())

    @_dispatch(Observed, Observed, Cache)
    @cache
    def __call__(self, x, y, cache):
        return self.k_y(x.get(), y.get())

    def __str__(self):
        return 'NoisyKernel({}, {})'.format(self.k_f, self.k_n)
Exemple #29
0
class PosteriorMean(Mean):
    """Posterior mean.

    Args:
        m_i (:class:`.mean.Mean`): Mean of process corresponding to
            the input.
        m_z (:class:`.mean.Mean`): Mean of process corresponding to
            the data.
        k_zi (:class:`.kernel.Kernel`): Kernel between processes
            corresponding to the data and the input respectively.
        z (input): Locations of data.
        K_z (:class:`.matrix.Dense`): Kernel matrix of data.
        y (tensor): Observations to condition on.
    """

    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, m_i, m_z, k_zi, z, K_z, y):
        self.m_i = m_i
        self.m_z = m_z
        self.k_zi = k_zi
        self.z = z
        self.K_z = K_z
        self.y = uprank(y)

    @_dispatch(object)
    def __call__(self, x):
        diff = B.subtract(self.y, self.m_z(self.z))
        return B.add(self.m_i(x), B.qf(self.K_z, self.k_zi(self.z, x), diff))
Exemple #30
0
def test_basic_arithmetic():
    dispatch = Dispatcher()

    @dispatch(Number)
    def f1(x):
        return np.array([[x**2]])

    @dispatch(object)
    def f1(x):
        return np.sum(x**2, axis=1)[:, None]

    @dispatch(Number)
    def f2(x):
        return np.array([[x**3]])

    @dispatch(object)
    def f2(x):
        return np.sum(x**3, axis=1)[:, None]

    m1 = TensorProductMean(f1)
    m2 = TensorProductMean(f2)
    m3 = ZeroMean()
    x1 = np.random.randn(10, 2)
    x2 = np.random.randn()

    yield ok, np.allclose((m1 * m2)(x1), m1(x1) * m2(x1)), 'prod'
    yield ok, np.allclose((m1 * m2)(x2), m1(x2) * m2(x2)), 'prod 2'
    yield ok, np.allclose((m1 + m3)(x1), m1(x1) + m3(x1)), 'sum'
    yield ok, np.allclose((m1 + m3)(x2), m1(x2) + m3(x2)), 'sum 2'
    yield ok, np.allclose((5. * m1)(x1), 5. * m1(x1)), 'prod 3'
    yield ok, np.allclose((5. * m1)(x2), 5. * m1(x2)), 'prod 4'
    yield ok, np.allclose((5. + m1)(x1), 5. + m1(x1)), 'sum 3'
    yield ok, np.allclose((5. + m1)(x2), 5. + m1(x2)), 'sum 4'