Exemple #1
0
class LazyMatrix(LazyTensor):
    """A lazy matrix."""

    _dispatch = Dispatcher(in_class=Self)

    def __init__(self):
        LazyTensor.__init__(self, 2)
        self._left_rules = []
        self._right_rules = []
        self._rules = []

    def add_rule(self, indices, builder):
        """Add a building rule.

        Note:
            For performance reasons, `indices` must already be resolved!

        Args:
            indices (set): Domain of the rule.
            builder (function): Function that takes in an index and gives back the
                corresponding element.
        """
        self._rules.append((frozenset(indices), builder))

    def add_left_rule(self, i_left, indices, builder):
        """Add a building rule for a given left index.

        Note:
            For performance reasons, `indices` must already be resolved!

        Args:
            i_left (int): Fixed left index.
            indices (set): Domain of the rule.
            builder (function): Function that takes in a right index and gives back the
                corresponding element.
        """
        self._left_rules.append((i_left, frozenset(indices), builder))

    def add_right_rule(self, i_right, indices, builder):
        """Add a building rule for a given right index.

        Note:
            For performance reasons, `indices` must already be resolved!

        Args:
            i_right (int): Fixed right index.
            indices (set): Domain of the rule.
            builder (function): Function that takes in a left index and gives back the
                corresponding element.
        """
        self._right_rules.append((i_right, frozenset(indices), builder))

    def _build(self, i):
        i_left, i_right = i

        # Check universal rules.
        for indices, builder in self._rules:
            if i_left in indices and i_right in indices:
                return builder(i_left, i_right)

        # Check left rules.
        for i_left_rule, indices, builder in self._left_rules:
            if i_left == i_left_rule and i_right in indices:
                return builder(i_right)

        # Check right rules.
        for i_right_rule, indices, builder in self._right_rules:
            if i_left in indices and i_right == i_right_rule:
                return builder(i_left)

        raise RuntimeError(f"Could not build value for index {i}.")
Exemple #2
0
def test_sequence():
    # Standard type tests.
    assert hash(Sequence[int]) == hash(Sequence[int])
    assert hash(Sequence[int]) != hash(Sequence[str])
    assert hash(Sequence[int]) != hash(Sequence())
    assert hash(Sequence[Sequence[int]]) == hash(Sequence[Sequence[int]])
    assert hash(Sequence[Sequence[int]]) != hash(Sequence[Sequence[str]])
    assert repr(Sequence()) == "Sequence"
    assert repr(Sequence[int]) == f"Sequence[{Type(int)!r}]"

    # Test instance check.
    assert isinstance([], Sequence())
    assert isinstance([], Sequence[object])
    assert isinstance((1, 2.0), Sequence())
    assert isinstance((1, 2.0), Sequence[Union[int, float]])
    assert isinstance([1, 2], Sequence())
    assert isinstance([1, 2], Sequence[int])
    assert not isinstance((x for x in [1, 2]), Sequence())

    # Test subclass check.
    assert issubclass(ptype(list), Sequence())
    assert issubclass(List[int], Sequence())
    assert issubclass(List[int], Sequence[int])
    assert not issubclass(ptype(list), Sequence[int])
    assert issubclass(Sequence[int], Sequence[object])

    # Check tracking of parametric.
    assert Sequence[int].parametric
    assert ptype(Sequence[Sequence[int]]).parametric
    assert ptype(Union[Sequence[int]]).parametric
    promise = PromisedType()
    promise.deliver(Sequence[int])
    assert promise.resolve().parametric

    # Check tracking of runtime `type_of`.
    assert Sequence[int].runtime_type_of
    assert ptype(Sequence[Sequence[int]]).runtime_type_of
    assert ptype(Union[Sequence[int]]).runtime_type_of
    promise = PromisedType()
    promise.deliver(Sequence[int])
    assert promise.resolve().runtime_type_of

    assert not Sequence().runtime_type_of
    assert ptype(Sequence[Sequence()]).runtime_type_of
    assert not ptype(Union[Sequence()]).runtime_type_of
    promise = PromisedType()
    promise.deliver(Sequence())
    assert not promise.resolve().runtime_type_of

    # Test correctness.
    dispatch = Dispatcher()

    @parametric
    class A:
        def __init__(self, el_type):
            pass

        def __len__(self):
            pass

        def __getitem__(self, item):
            pass

    @parametric
    class B(A):
        @classmethod
        def __getitem_el_type__(cls):
            return cls.type_parameter

    @dispatch
    def f(x):
        return "fallback"

    @dispatch
    def f(x: Sequence()):
        return "seq"

    @dispatch
    def f(x: Sequence[int]):
        return "seq of int"

    @dispatch
    def f(x: Sequence[Sequence[object]]):
        return "seq of seq"

    assert f(1) == "fallback"
    assert f((x for x in [1, 2])) == "fallback"
    # Test various sequences:
    assert f(A(1)) == "seq"
    assert f(A(1.0)) == "seq"
    assert f(B(1)) == "seq of int"
    assert f(B(1.0)) == "seq"
    assert f([1]) == "seq of int"
    assert f([1.0]) == "seq"
    assert f({1: 1}) == "seq of int"
    assert f({1: 1.0}) == "seq"
    assert f((1, 1)) == "seq of int"
    assert f((1.0, 1)) == "seq"
    # Test nested sequences:
    assert f([[1]]) == "seq of seq"
    assert f(([1], [1, 2, "3"])) == "seq of seq"
Exemple #3
0
    class B(A):
        _dispatch = Dispatcher()

        @_dispatch
        def do(self, x: Union[int, "B", str]) -> Union[int, "B"]:
            return x
Exemple #4
0
from .cache import cache, Cache, uprank
from .field import add, mul, broadcast, apply_optional_arg, get_field, \
    Formatter, need_parens
from .input import Input, Unique
from .matrix import Dense, LowRank, UniformlyDiagonal, One, Zero, \
    dense, matrix

__all__ = [
    'Kernel', 'OneKernel', 'ZeroKernel', 'ScaledKernel', 'EQ', 'RQ',
    'Matern12', 'Exp', 'Matern32', 'Matern52', 'Delta', 'Linear',
    'DerivativeKernel', 'DecayingKernel'
]

log = logging.getLogger(__name__)

_dispatch = Dispatcher()


def expand(xs):
    """Expand a sequence to the same element repeated twice if there is only
    one element.

    Args:
        xs (sequence): Sequence to expand.

    Returns:
        object: `xs * 2` or `xs`.
    """
    return xs * 2 if len(xs) == 1 else xs

Exemple #5
0
import sys

from plum import Dispatcher

B = sys.modules[__name__]  # Allow both import styles.
dispatch = Dispatcher()  # This dispatch namespace will be used everywhere.

from .generic import *
from .shaping import *
from .linear_algebra import *
from .random import *

from .numpy import *

from .types import *
from .control_flow import *

# Fix namespace issues with `B.bvn_cdf` simply by setting it explicitly.
B.bvn_cdf = B.generic.bvn_cdf
Exemple #6
0
class Graph(metaclass=Referentiable):
    """A GP model."""
    _dispatch = Dispatcher(in_class=Self)

    def __init__(self):
        self.ps = []
        self.pids = set()
        self.kernels = LazyMatrix()
        self.means = LazyVector()

        # Store named GPs in both ways.
        self.gps_by_name = {}
        self.names_by_gp = {}

    @_dispatch(str)
    def __getitem__(self, name):
        return self.gps_by_name[name]

    @_dispatch(PromisedGP)
    def __getitem__(self, p):
        return self.names_by_gp[id(p)]

    @_dispatch(PromisedGP, str)
    def name(self, p, name):
        """Name a GP.

        Args:
            p (:class:`.graph.GP`): GP to name.
            name (str): Name. Must be unique.
        """
        # Delete any existing names and back-references for the GP.
        if id(p) in self.names_by_gp:
            del self.gps_by_name[self.names_by_gp[id(p)]]
            del self.names_by_gp[id(p)]

        # Check that name is not in use.
        if name in self.gps_by_name:
            raise RuntimeError('Name "{}" for "{}" already taken by "{}".'
                               ''.format(name, p, self[name]))

        # Set the name and the back-reference.
        self.gps_by_name[name] = p
        self.names_by_gp[id(p)] = name

    def _add_p(self, p):
        self.ps.append(p)
        self.pids.add(id(p))

    def _update(self, mean, k_ii_generator, k_ij_generator):
        p = GP(self)
        self.means[p] = mean
        self.kernels.add_rule((p, p), self.pids, k_ii_generator)
        self.kernels.add_rule((p, None), self.pids, k_ij_generator)
        self.kernels.add_rule((None, p), self.pids,
                              lambda pi: reversed(self.kernels[p, pi]))
        self._add_p(p)
        return p

    def add_independent_gp(self, p, kernel, mean):
        """Add an independent GP to the model.

        Args:
            p (:class:`.graph.GP`): GP object to add.
            kernel (:class:`.kernel.Kernel`): Kernel function of GP.
            mean (:class:`.mean.Mean`): Mean function of GP.

        Returns:
            :class:`.graph.GP`: The newly added independent GP.
        """
        # Update means.
        self.means[p] = mean
        # Add rule to kernels.
        self.kernels[p] = kernel
        self.kernels.add_rule((p, None), self.pids, lambda pi: ZeroKernel())
        self.kernels.add_rule((None, p), self.pids, lambda pi: ZeroKernel())
        self._add_p(p)
        return p

    @_dispatch(object, PromisedGP)
    def sum(self, other, p):
        """Sum a GP from the graph with another object.

        Args:
            obj1 (other type or :class:`.graph.GP`): First term in the sum.
            obj2 (other type or :class:`.graph.GP`): Second term in the sum.

        Returns:
            :class:`.graph.GP`: The GP corresponding to the sum.
        """
        return self.sum(p, other)

    @_dispatch(PromisedGP, object)
    def sum(self, p, other):
        return self._update(self.means[p] + other,
                            lambda: self.kernels[p],
                            lambda pi: self.kernels[p, pi])

    @_dispatch(PromisedGP, PromisedGP)
    def sum(self, p1, p2):
        # Check that the GPs are on the same graph.
        if p1.graph != p2.graph:
            raise RuntimeError('Can only add GPs from the same graph.')

        return self._update(self.means[p1] + self.means[p2],
                            (lambda: self.kernels[p1] +
                                     self.kernels[p2] +
                                     self.kernels[p1, p2] +
                                     self.kernels[p2, p1]),
                            lambda pi: self.kernels[p1, pi] +
                                       self.kernels[p2, pi])

    @_dispatch(PromisedGP, B.Numeric)
    def mul(self, p, other):
        """Multiply a GP from the graph with another object.

        Args:
            p (:class:`.graph.GP`): GP in the product.
            other (object): Other object in the product.

        Returns:
            :class:`.graph.GP`: The GP corresponding to the product.
        """
        return self._update(self.means[p] * other,
                            lambda: self.kernels[p] * other ** 2,
                            lambda pi: self.kernels[p, pi] * other)

    @_dispatch(PromisedGP, FunctionType)
    def mul(self, p, f):
        def ones(x):
            return B.ones(B.dtype(x), B.shape(x)[0], 1)

        return self._update(f * self.means[p],
                            lambda: f * self.kernels[p],
                            (lambda pi: TensorProductKernel(f, ones) *
                                        self.kernels[p, pi]))

    def shift(self, p, shift):
        """Shift a GP.

        Args:
            p (:class:`.graph.GP`): GP to shift.
            shift (object): Amount to shift by.

        Returns:
            :class:`.graph.GP`: The shifted GP.
        """
        return self._update(self.means[p].shift(shift),
                            lambda: self.kernels[p].shift(shift),
                            lambda pi: self.kernels[p, pi].shift(shift, 0))

    def stretch(self, p, stretch):
        """Stretch a GP.

        Args:
            p (:class:`.graph.GP`): GP to stretch.
            stretch (object): Extent of stretch.

        Returns:
            :class:`.graph.GP`: The stretched GP.
        """
        return self._update(self.means[p].stretch(stretch),
                            lambda: self.kernels[p].stretch(stretch),
                            lambda pi: self.kernels[p, pi].stretch(stretch, 1))

    def select(self, p, *dims):
        """Select input dimensions.

        Args:
            p (:class:`.graph.GP`): GP to select input
                dimensions from.
            *dims (object): Dimensions to select.

        Returns:
            :class:`.graph.GP`: GP with the specific input dimensions.
        """
        return self._update(self.means[p].select(dims),
                            lambda: self.kernels[p].select(dims),
                            lambda pi: self.kernels[p, pi].select(dims, None))

    def transform(self, p, f):
        """Transform the inputs of a GP.

        Args:
            p (:class:`.graph.GP`): GP to input transform.
            f (function): Input transform.

        Returns:
            :class:`.graph.GP`: Input-transformed GP.
        """
        return self._update(self.means[p].transform(f),
                            lambda: self.kernels[p].transform(f),
                            lambda pi: self.kernels[p, pi].transform(f, None))

    def diff(self, p, dim=0):
        """Differentiate a GP.

        Args:
            p (:class:`.graph.GP`): GP to differentiate.
            dim (int, optional): Dimension of feature which to take the
                derivative with respect to. Defaults to `0`.

        Returns:
            :class:`.graph.GP`: Derivative of GP.
        """
        return self._update(self.means[p].diff(dim),
                            lambda: self.kernels[p].diff(dim),
                            lambda pi: self.kernels[p, pi].diff(dim, None))

    @_dispatch({list, tuple}, AbstractObservations)
    def condition(self, ps, obs):
        """Condition the graph on observations.

        Args:
            ps (list[:class:`.graph.GP`]): Processes to condition.
            obs (:class:`.graph.AbstractObservations`): Observations to
                condition on.

        Returns:
            list[:class:`.graph.GP`]: Posterior processes.
        """

        # A construction like this is necessary to properly close over `p`.
        def build_gens(p):
            def k_ij_generator(pi):
                return obs.posterior_kernel(p, pi)

            def k_ii_generator():
                return obs.posterior_kernel(p, p)

            return k_ii_generator, k_ij_generator

        return [self._update(obs.posterior_mean(p), *build_gens(p)) for p in ps]

    def cross(self, *ps):
        """Construct the Cartesian product of a collection of processes.

        Args:
            *ps (:class:`.graph.GP`): Processes to construct the
                Cartesian product of.

        Returns:
            :class:`.graph.GP`: The Cartesian product of `ps`.
        """
        mok = MOK(*ps)
        return self._update(MOM(*ps),
                            lambda: mok,
                            lambda pi: mok.transform(None, lambda y: At(pi)(y)))

    @_dispatch(int, [At])
    def sample(self, n, *xs):
        """Sample multiple processes simultaneously.

        Args:
            n (int, optional): Number of samples. Defaults to `1`.
            *xs (:class:`.graph.At`): Locations to sample at.

        Returns:
            tuple: Tuple of samples.
        """
        sample = GP(MOK(*self.ps),
                    MOM(*self.ps),
                    graph=Graph())(MultiInput(*xs)).sample(n)

        # To unpack `x`, just keep `.get()`ing.
        def unpack(x):
            while isinstance(x, Input):
                x = x.get()
            return x

        # Unpack sample.
        lengths = [B.shape(uprank(unpack(x)))[0] for x in xs]
        i, samples = 0, []
        for length in lengths:
            samples.append(sample[i:i + length, :])
            i += length
        return samples[0] if len(samples) == 1 else samples

    @_dispatch([At])
    def sample(self, *xs):
        return self.sample(1, *xs)

    @_dispatch([{list, tuple}])
    def logpdf(self, *pairs):
        xs, ys = zip(*pairs)

        # Check that all processes are specified.
        if not all([isinstance(x, At) for x in xs]):
            raise ValueError('Must explicitly specify the processes which to '
                             'compute the log-pdf for.')

        # Uprank all outputs and concatenate.
        y = B.concat(*[uprank(y) for y in ys], axis=0)

        # Return composite log-pdf.
        return GP(MOK(*self.ps),
                  MOM(*self.ps),
                  graph=Graph())(MultiInput(*xs)).logpdf(y)

    @_dispatch(At, B.Numeric)
    def logpdf(self, x, y):
        return x.logpdf(y)

    @_dispatch(Observations)
    def logpdf(self, obs):
        return obs.x.logpdf(obs.y)

    @_dispatch(SparseObservations)
    def logpdf(self, obs):
        return obs.elbo
Exemple #7
0
class Kernel(algebra.Function):
    """Kernel function.

    Kernels can be added and multiplied.
    """
    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(object, object)
    def __call__(self, x, y):
        """Construct the kernel matrix between all `x` and `y`.

        Args:
            x (input): First argument.
            y (input, optional): Second argument. Defaults to first
                argument.

        Returns:
            matrix: Kernel matrix.
        """
        raise RuntimeError('For kernel "{}", could not resolve '
                           'arguments "{}" and "{}".'.format(self, x, y))

    @_dispatch(object)
    def __call__(self, x):
        return self(x, x)

    @_dispatch(Input, Input)
    def __call__(self, x, y):
        # Both input types were not used. Unwrap.
        return self(x.get(), y.get())

    @_dispatch(Input, object)
    def __call__(self, x, y):
        # Left input type was not used. Unwrap.
        return self(x.get(), y)

    @_dispatch(object, Input)
    def __call__(self, x, y):
        # Right input type was not used. Unwrap.
        return self(x, y.get())

    @_dispatch(object, object)
    def elwise(self, x, y):
        """Construct the kernel vector `x` and `y` element-wise.

        Args:
            x (input): First argument.
            y (input, optional): Second argument. Defaults to first
                argument.

        Returns:
            tensor: Kernel vector as a rank 2 column vector.
        """
        # TODO: throw warning
        return B.expand_dims(B.diag(self(x, y)), axis=1)

    @_dispatch(object)
    def elwise(self, x):
        return self.elwise(x, x)

    @_dispatch(Input, Input)
    def elwise(self, x, y):
        # Both input types were not used. Unwrap.
        return self.elwise(x.get(), y.get())

    @_dispatch(Input, object)
    def elwise(self, x, y):
        # Left input type as not used. Unwrap.
        return self.elwise(x.get(), y)

    @_dispatch(object, Input)
    def elwise(self, x, y):
        # Right input type was not used. Unwrap.
        return self.elwise(x, y.get())

    def periodic(self, period=1):
        """Map to a periodic space.

        Args:
            period (tensor, optional): Period. Defaults to `1`.

        Returns:
            :class:`.kernel.Kernel`: Periodic version of the kernel.
        """
        return periodicise(self, period)

    @property
    def stationary(self):
        """Stationarity of the kernel."""
        try:
            return self._stationary_cache
        except AttributeError:
            self._stationary_cache = self._stationary
            return self._stationary_cache

    @property
    def _stationary(self):
        return False
    class C(B):
        _dispatch = Dispatcher(in_class=Self)

        @_dispatch(str)
        def do(self, x):
            return 'str'
Exemple #9
0
class Delta(Kernel, Referentiable):
    """Kronecker delta kernel.

    Args:
        epsilon (float, optional): Tolerance for equality in squared distance.
            Defaults to `1e-10`.
    """

    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, epsilon=1e-10):
        self.epsilon = epsilon

    @_dispatch(B.Numeric, B.Numeric)
    def __call__(self, x, y):
        if x is y:
            return self._eye(uprank(x))
        else:
            return Dense(self._compute(B.pw_dists2(uprank(x), uprank(y))))

    @_dispatch(Unique, Unique)
    def __call__(self, x, y):
        x, y = x.get(), y.get()
        if x is y:
            return self._eye(uprank(x))
        else:
            x, y = uprank(x), uprank(y)
            return Zero(B.dtype(x), B.shape(x)[0], B.shape(y)[0])

    @_dispatch(Unique, object)
    def __call__(self, x, y):
        x, y = uprank(x.get()), uprank(y)
        return Zero(B.dtype(x), B.shape(x)[0], B.shape(y)[0])

    @_dispatch(object, Unique)
    def __call__(self, x, y):
        x, y = uprank(x), uprank(y.get())
        return Zero(B.dtype(x), B.shape(x)[0], B.shape(y)[0])

    @_dispatch(Unique, Unique)
    def elwise(self, x, y):
        x, y = x.get(), y.get()
        if x is y:
            return One(B.dtype(x), B.shape(uprank(x))[0], 1)
        else:
            return Zero(B.dtype(x), B.shape(uprank(x))[0], 1)

    @_dispatch(Unique, object)
    def elwise(self, x, y):
        x = x.get()
        return Zero(B.dtype(x), B.shape(uprank(x))[0], 1)

    @_dispatch(object, Unique)
    def elwise(self, x, y):
        return Zero(B.dtype(x), B.shape(uprank(x))[0], 1)

    @_dispatch(B.Numeric, B.Numeric)
    def elwise(self, x, y):
        if x is y:
            return One(B.dtype(x), B.shape(uprank(x))[0], 1)
        else:
            return self._compute(B.ew_dists2(uprank(x), uprank(y)))

    def _eye(self, x):
        return UniformlyDiagonal(B.cast(B.dtype(x), 1), B.shape(x)[0])

    def _compute(self, dists2):
        dtype = B.dtype(dists2)
        return B.cast(dtype, B.lt(dists2, B.cast(dtype, self.epsilon)))

    @property
    def _stationary(self):
        return True

    @property
    def var(self):
        return 1

    @property
    def length_scale(self):
        return 0

    @property
    def period(self):
        return np.inf

    @_dispatch(Self)
    def __eq__(self, other):
        return self.epsilon == other.epsilon
Exemple #10
0
class DerivativeKernel(Kernel, algebra.DerivativeFunction):
    """Derivative of kernel."""
    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(B.Numeric, B.Numeric)
    @uprank
    def __call__(self, x, y):
        i, j = expand(self.derivs)
        k = self[0]

        # Prevent that `x` equals `y` to stabilise nested gradients.
        y = perturb(y)

        if i is not None and j is not None:
            # Derivative with respect to both `x` and `y`.
            return Dense(dky(dkx_elwise(k.elwise, i), j)(x, y))

        elif i is not None and j is None:
            # Derivative with respect to `x`.
            return Dense(dkx(k.elwise, i)(x, y))

        elif i is None and j is not None:
            # Derivative with respect to `y`.
            return Dense(dky(k.elwise, j)(x, y))

        else:
            raise RuntimeError('No derivative specified.')

    @_dispatch(B.Numeric, B.Numeric)
    @uprank
    def elwise(self, x, y):
        i, j = expand(self.derivs)
        k = self[0]

        # Prevent that `x` equals `y` to stabilise nested gradients.
        y = perturb(y)

        if i is not None and j is not None:
            # Derivative with respect to both `x` and `y`.
            return dky_elwise(dkx_elwise(k.elwise, i), j)(x, y)

        elif i is not None and j is None:
            # Derivative with respect to `x`.
            return dkx_elwise(k.elwise, i)(x, y)

        elif i is None and j is not None:
            # Derivative with respect to `y`.
            return dky_elwise(k.elwise, j)(x, y)

        else:
            raise RuntimeError('No derivative specified.')

    @property
    def _stationary(self):
        # NOTE: In the one-dimensional case, if derivatives with respect to both
        #     arguments are taken, then the result is in fact stationary.
        return False

    @_dispatch(Self)
    def __eq__(self, other):
        return self[0] == other[0] and \
               tuple_equal(expand(self.derivs), expand(other.derivs))
Exemple #11
0
class Kernel(Function, Referentiable):
    """Kernel function.

    Kernels can be added and multiplied.
    """
    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(object, object)
    def __call__(self, x, y):
        """Construct the kernel matrix between all `x` and `y`.

        Args:
            x (input): First argument.
            y (input, optional): Second argument. Defaults to first
                argument.

        Returns:
            :class:`.matrix.Dense:: Kernel matrix.
        """
        raise RuntimeError('For kernel "{}", could not resolve '
                           'arguments "{}" and "{}".'.format(self, x, y))

    @_dispatch(object)
    def __call__(self, x):
        return self(x, x)

    @_dispatch(Input, Input)
    def __call__(self, x, y):
        # Both input types were not used. Unwrap.
        return self(x.get(), y.get())

    @_dispatch(Input, object)
    def __call__(self, x, y):
        # Left input type was not used. Unwrap.
        return self(x.get(), y)

    @_dispatch(object, Input)
    def __call__(self, x, y):
        # Right input type was not used. Unwrap.
        return self(x, y.get())

    @_dispatch(object, object)
    def elwise(self, x, y):
        """Construct the kernel vector `x` and `y` element-wise.

        Args:
            x (input): First argument.
            y (input, optional): Second argument. Defaults to first
                argument.

        Returns:
            tensor: Kernel vector as a rank 2 column vector.
        """
        # TODO: throw warning
        return B.expand_dims(B.diag(self(x, y)), axis=1)

    @_dispatch(object)
    def elwise(self, x):
        return self.elwise(x, x)

    @_dispatch(Input, Input)
    def elwise(self, x, y):
        # Both input types were not used. Unwrap.
        return self.elwise(x.get(), y.get())

    @_dispatch(Input, object)
    def elwise(self, x, y):
        # Left input type as not used. Unwrap.
        return self.elwise(x.get(), y)

    @_dispatch(object, Input)
    def elwise(self, x, y):
        # Right input type was not used. Unwrap.
        return self.elwise(x, y.get())

    def periodic(self, period=1):
        """Map to a periodic space.

        Args:
            period (tensor, optional): Period. Defaults to `1`.

        Returns:
            :class:`.kernel.Kernel`: Periodic version of the kernel.
        """
        return periodicise(self, period)

    def __reversed__(self):
        """Reverse the arguments of the kernel."""
        return reverse(self)

    @_dispatch(int)
    def __pow__(self, power, modulo=None):
        if power < 0:
            raise ValueError('Cannot raise to a negative power.')
        elif power == 0:
            return 1
        else:
            k = self
            for _ in range(power - 1):
                k *= self
        return k

    @property
    def stationary(self):
        """Stationarity of the kernel."""
        try:
            return self._stationary_cache
        except AttributeError:
            self._stationary_cache = self._stationary
            return self._stationary_cache

    @property
    def _stationary(self):
        return False

    @property
    def var(self):
        """Variance of the kernel."""
        raise RuntimeError('The variance of "{}" could not be determined.'
                           ''.format(self.__class__.__name__))

    @property
    def length_scale(self):
        """Approximation of the length scale of the kernel."""
        raise RuntimeError('The length scale of "{}" could not be determined.'
                           ''.format(self.__class__.__name__))

    @property
    def period(self):
        """Period of the kernel."""
        raise RuntimeError('The period of "{}" could not be determined.'
                           ''.format(self.__class__.__name__))
Exemple #12
0
class SelectedKernel(Kernel, SelectedFunction, Referentiable):
    """Kernel with particular input dimensions selected."""

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(B.Numeric, B.Numeric)
    @uprank
    def __call__(self, x, y):
        return self[0](*self._compute(x, y))

    @_dispatch(B.Numeric, B.Numeric)
    @uprank
    def elwise(self, x, y):
        return self[0].elwise(*self._compute(x, y))

    def _compute(self, x, y):
        dims1, dims2 = expand(self.dims)
        x = x if dims1 is None else B.take(x, dims1, axis=1)
        y = y if dims2 is None else B.take(y, dims2, axis=1)
        return x, y

    @property
    def _stationary(self):
        if len(self.dims) == 1:
            return self[0].stationary
        else:
            # NOTE: Can do something more clever here.
            return False

    @property
    def var(self):
        return self[0].var

    @property
    def length_scale(self):
        length_scale = self[0].length_scale
        if B.isscalar(length_scale):
            return length_scale
        else:
            if len(self.dims) == 1:
                return B.take(length_scale, self.dims[0])
            else:
                # NOTE: Can do something more clever here.
                return Kernel.length_scale.fget(self)

    @property
    def period(self):
        period = self[0].period
        if B.isscalar(period):
            return period
        else:
            if len(self.dims) == 1:
                return B.take(period, self.dims[0])
            else:
                # NOTE: Can do something more clever here.
                return Kernel.period.fget(self)

    @_dispatch(Self)
    def __eq__(self, other):
        return self[0] == other[0] and \
               tuple_equal(expand(self.dims), expand(other.dims))
Exemple #13
0
class Kernel(algebra.Function):
    """Kernel function.

    Kernels can be added and multiplied.
    """

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(object, object)
    def __call__(self, x, y):
        """Construct the kernel matrix between all `x` and `y`.

        Args:
            x (input): First argument.
            y (input, optional): Second argument. Defaults to first
                argument.

        Returns:
            matrix: Kernel matrix.
        """
        raise RuntimeError(
            f'For kernel "{self}", could not resolve arguments "{x}" and "{y}".'
        )

    @_dispatch(object)
    def __call__(self, x):
        return self(x, x)

    @_dispatch(Union(Input, FDD), Union(Input, FDD))
    def __call__(self, x, y):
        return self(unwrap(x), unwrap(y))

    @_dispatch(Union(Input, FDD), object)
    def __call__(self, x, y):
        return self(unwrap(x), y)

    @_dispatch(object, Union(Input, FDD))
    def __call__(self, x, y):
        return self(x, unwrap(y))

    @_dispatch(MultiInput, object, precedence=1)
    def __call__(self, x, y):
        return self(x, MultiInput(y))

    @_dispatch(object, MultiInput, precedence=1)
    def __call__(self, x, y):
        return self(MultiInput(x), y)

    @_dispatch(MultiInput, MultiInput)
    def __call__(self, x, y):
        return B.block(*[[self(xi, yi) for yi in y.get()] for xi in x.get()])

    @_dispatch(object, object)
    def elwise(self, x, y):
        """Construct the kernel vector `x` and `y` element-wise.

        Args:
            x (input): First argument.
            y (input, optional): Second argument. Defaults to first
                argument.

        Returns:
            tensor: Kernel vector as a rank 2 column vector.
        """
        # TODO: Throw warning.
        return B.expand_dims(B.diag(self(x, y)), axis=1)

    @_dispatch(object)
    def elwise(self, x):
        return self.elwise(x, x)

    @_dispatch(Union(Input, FDD), Union(Input, FDD))
    def elwise(self, x, y):
        return self.elwise(unwrap(x), unwrap(y))

    @_dispatch(Union(Input, FDD), object)
    def elwise(self, x, y):
        return self.elwise(unwrap(x), y)

    @_dispatch(object, Union(Input, FDD))
    def elwise(self, x, y):
        return self.elwise(x, unwrap(y))

    @_dispatch(MultiInput, object, precedence=1)
    def elwise(self, x, y):
        raise ValueError(
            "Unclear combination of arguments given to Kernel.elwise.")

    @_dispatch(object, MultiInput, precedence=1)
    def elwise(self, x, y):
        raise ValueError(
            "Unclear combination of arguments given to Kernel.elwise.")

    @_dispatch(MultiInput, MultiInput)
    def elwise(self, x, y):
        if len(x.get()) != len(y.get()):
            raise ValueError(
                "Kernel.elwise must be called with similarly sized MultiInputs."
            )
        return B.concat(
            *[self.elwise(xi, yi) for xi, yi in zip(x.get(), y.get())], axis=0)

    def periodic(self, period=1):
        """Map to a periodic space.

        Args:
            period (tensor, optional): Period. Defaults to `1`.

        Returns:
            :class:`.kernel.Kernel`: Periodic version of the kernel.
        """
        return periodicise(self, period)

    @property
    def stationary(self):
        """Stationarity of the kernel."""
        try:
            return self._stationary_cache
        except AttributeError:
            self._stationary_cache = self._stationary
            return self._stationary_cache

    @property
    def _stationary(self):
        return False
Exemple #14
0
class Function(Element, Referentiable):
    """A function.

    Crucially, this is not a field, so that it can be inherited.
    """
    _dispatch = Dispatcher(in_class=Self)

    def stretch(self, *stretches):
        """Stretch the function.

        Args:
            *stretches (tensor): Per input, extent to stretch by.

        Returns:
            :class:`.function_field.Function`: Stretched function.
        """
        return stretch(self, *stretches)

    def __gt__(self, stretch):
        """Shorthand for :meth:`.function_field.Function.stretch`."""
        return self.stretch(stretch)

    def shift(self, *amounts):
        """Shift the inputs of an function by a certain amount.

        Args:
            *amounts (tensor): Per input, amount to shift by.

        Returns:
            :class:`.function_field.Function`: Shifted function.
        """
        return shift(self, *amounts)

    def select(self, *dims):
        """Select particular dimensions of the input features.

        Args:
            *dims (int, sequence, or None): Per input, dimensions to select.
                Set to `None` to select all.

        Returns:
            :class:`.function_field.Function`: Function with dimensions of the
                input features selected.
        """
        return select(self, *dims)

    def transform(self, *fs):
        """Transform the inputs of a function.

        Args:
            *fs (int or tuple): Per input, transformation. Set to `None` to
                not perform a transformation.

        Returns:
            :class:`.function_field.Function`: Function with its inputs
                transformed.
        """
        return transform(self, *fs)

    def diff(self, *derivs):
        """Differentiate a function.

        Args:
            *derivs (int): Per input, dimension of the feature which to take
                the derivatives with respect to. Set to `None` to not take a
                derivative.

        Returns:
            :class:`.function_field.Function`: Derivative of the Function.
        """
        return differentiate(self, *derivs)
Exemple #15
0
class Delta(Kernel):
    """Kronecker delta kernel.

    Args:
        epsilon (float, optional): Tolerance for equality in squared distance.
            Defaults to `1e-10`.
    """

    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, epsilon=1e-10):
        self.epsilon = epsilon

    @_dispatch(B.Numeric, B.Numeric)
    def __call__(self, x, y):
        if x is y:
            return B.fill_diag(B.one(x), B.shape(uprank(x))[0])
        else:
            return Dense(self._compute(B.pw_dists2(uprank(x), uprank(y))))

    @_dispatch(Unique, Unique)
    def __call__(self, x, y):
        x, y = x.get(), y.get()
        if x is y:
            return B.fill_diag(B.one(x), B.shape(uprank(x))[0])
        else:
            x, y = uprank(x), uprank(y)
            return Zero(B.dtype(x), B.shape(x)[0], B.shape(y)[0])

    @_dispatch(WeightedUnique, WeightedUnique)
    def __call__(self, x, y):
        w_x, w_y = x.w, y.w
        x, y = x.get(), y.get()
        if x is y:
            return Diagonal(1 / w_x)
        else:
            x, y = uprank(x), uprank(y)
            return Zero(B.dtype(x), B.shape(x)[0], B.shape(y)[0])

    @_dispatch(Unique, object)
    def __call__(self, x, y):
        x, y = uprank(x.get()), uprank(y)
        return Zero(B.dtype(x), B.shape(x)[0], B.shape(y)[0])

    @_dispatch(object, Unique)
    def __call__(self, x, y):
        x, y = uprank(x), uprank(y.get())
        return Zero(B.dtype(x), B.shape(x)[0], B.shape(y)[0])

    @_dispatch(Unique, Unique)
    def elwise(self, x, y):
        x, y = x.get(), y.get()
        if x is y:
            return B.ones(B.dtype(x), B.shape(uprank(x))[0], 1)
        else:
            return B.zeros(B.dtype(x), B.shape(uprank(x))[0], 1)

    @_dispatch(WeightedUnique, WeightedUnique)
    def elwise(self, x, y):
        w_x, w_y = x.w, y.w
        x, y = x.get(), y.get()
        if x is y:
            return B.uprank(1 / w_x)
        else:
            return B.zeros(B.dtype(x), B.shape(uprank(x))[0], 1)

    @_dispatch(Unique, object)
    def elwise(self, x, y):
        x = x.get()
        return B.zeros(B.dtype(x), B.shape(uprank(x))[0], 1)

    @_dispatch(object, Unique)
    def elwise(self, x, y):
        return B.zeros(B.dtype(x), B.shape(uprank(x))[0], 1)

    @_dispatch(B.Numeric, B.Numeric)
    def elwise(self, x, y):
        if x is y:
            return B.ones(B.dtype(x), B.shape(uprank(x))[0], 1)
        else:
            return self._compute(B.ew_dists2(uprank(x), uprank(y)))

    def _compute(self, dists2):
        dtype = B.dtype(dists2)
        return B.cast(dtype, B.lt(dists2, B.cast(dtype, self.epsilon)))

    @property
    def _stationary(self):
        return True

    @_dispatch(Self)
    def __eq__(self, other):
        return self.epsilon == other.epsilon
    class B(A):
        _dispatch = Dispatcher(in_class=Self)

        @_dispatch(int)
        def do(self, x):
            return 'int'
Exemple #17
0
    class A:
        _dispatch = Dispatcher()

        @_dispatch
        def g(self):
            """docstring of g"""
    class A(metaclass=Referentiable):
        _dispatch = Dispatcher(in_class=Self)

        @_dispatch()
        def g(self):
            """docstring of g"""
Exemple #19
0
class MultiOutputKernel(Kernel, Referentiable):
    """A generic multi-output kernel.

    Args:
        *ps (instance of :class:`.graph.GP`): Processes that make up the
            multi-valued process.
    """
    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, *ps):
        self.kernels = ps[0].graph.kernels
        self.ps = ps

    @_dispatch({B.Numeric, Input}, {B.Numeric, Input}, Cache)
    @cache
    def __call__(self, x, y, B):
        return self(MultiInput(*(p(x) for p in self.ps)),
                    MultiInput(*(p(y) for p in self.ps)), B)

    @_dispatch(At, {B.Numeric, Input}, Cache)
    @cache
    def __call__(self, x, y, B):
        return self(MultiInput(x), MultiInput(*(p(y) for p in self.ps)), B)

    @_dispatch({B.Numeric, Input}, At, Cache)
    @cache
    def __call__(self, x, y, B):
        return self(MultiInput(*(p(x) for p in self.ps)), MultiInput(y), B)

    @_dispatch(At, At, Cache)
    @cache
    def __call__(self, x, y, B):
        return self.kernels[type_parameter(x),
                            type_parameter(y)](x.get(), y.get(), B)

    @_dispatch(MultiInput, At, Cache)
    @cache
    def __call__(self, x, y, B):
        return self(x, MultiInput(y), B)

    @_dispatch(At, MultiInput, Cache)
    @cache
    def __call__(self, x, y, B):
        return self(MultiInput(x), y, B)

    @_dispatch(MultiInput, MultiInput, Cache)
    @cache
    def __call__(self, x, y, B):
        return B.block_matrix(*[[self(xi, yi, B) for yi in y.get()]
                                for xi in x.get()])

    @_dispatch({B.Numeric, Input}, {B.Numeric, Input}, Cache)
    @cache
    def elwise(self, x, y, B):
        return self.elwise(MultiInput(*(p(x) for p in self.ps)),
                           MultiInput(*(p(y) for p in self.ps)), B)

    @_dispatch(At, {B.Numeric, Input}, Cache)
    @cache
    def elwise(self, x, y, B):
        raise ValueError('Unclear combination of arguments given to '
                         'MultiOutputKernel.elwise.')

    @_dispatch({B.Numeric, Input}, At, Cache)
    @cache
    def elwise(self, x, y, B):
        raise ValueError('Unclear combination of arguments given to '
                         'MultiOutputKernel.elwise.')

    @_dispatch(At, At, Cache)
    @cache
    def elwise(self, x, y, B):
        return self.kernels[type_parameter(x),
                            type_parameter(y)].elwise(x.get(), y.get(), B)

    @_dispatch(MultiInput, At, Cache)
    @cache
    def elwise(self, x, y, B):
        raise ValueError('Unclear combination of arguments given to '
                         'MultiOutputKernel.elwise.')

    @_dispatch(At, MultiInput, Cache)
    @cache
    def elwise(self, x, y, B):
        raise ValueError('Unclear combination of arguments given to '
                         'MultiOutputKernel.elwise.')

    @_dispatch(MultiInput, MultiInput, Cache)
    @cache
    def elwise(self, x, y, B):
        if len(x.get()) != len(y.get()):
            raise ValueError('MultiOutputKernel.elwise must be called with '
                             'similarly sized MultiInputs.')
        return B.concat(
            [self.elwise(xi, yi, B) for xi, yi in zip(x.get(), y.get())],
            axis=0)

    def __str__(self):
        ks = [str(self.kernels[p]) for p in self.ps]
        return 'MultiOutputKernel({})'.format(', '.join(ks))
Exemple #20
0
class SparseObservations(AbstractObservations):
    """Observations through inducing points. Takes further arguments
    according to the constructor of :class:`.graph.Observations`.

    Attributes:
        elbo (scalar): ELBO.

    Args:
        z (input): Locations of the inducing points.
        e (:class:`.graph.GP`): Additive, independent noise process.
    """

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch({B.Numeric, Input, tuple, list},
               [Union(tuple, list, PromisedGP)])
    def __init__(self, z, *pairs, **kw_args):
        es, xs, ys = zip(*pairs)
        AbstractObservations.__init__(self, *zip(xs, ys), **kw_args)
        SparseObservations.__init__(self,
                                    z,
                                    self.graph.cross(*es),
                                    self.x,
                                    self.y,
                                    **kw_args)

    @_dispatch({list, tuple},
               PromisedGP,
               {B.Numeric, Input},
               B.Numeric,
               [PromisedGP])
    def __init__(self, zs, e, x, y, ref=None):
        # Ensure `At` everywhere.
        zs = [ensure_at(z, ref=ref) for z in zs]

        # Extract graph.
        graph = type_parameter(zs[0]).graph

        # Create a representative multi-output process.
        p_z = graph.cross(*(type_parameter(z) for z in zs))

        SparseObservations.__init__(self,
                                    p_z(MultiInput(*zs)),
                                    e, x, y, ref=ref)

    @_dispatch({B.Numeric, Input},
               PromisedGP,
               {B.Numeric, Input},
               B.Numeric,
               [PromisedGP])
    def __init__(self, z, e, x, y, ref=None):
        AbstractObservations.__init__(self, x, y, ref=ref)
        self.z = ensure_at(z, self._ref)
        self.e = e

        self._K_z = None
        self._elbo = None
        self._mu = None
        self._A = None

    @property
    def K_z(self):
        """Kernel matrix of the data."""
        if self._K_z is None:  # Cache computation.
            self._compute()
        return self._K_z

    @property
    def elbo(self):
        """ELBO."""
        if self._elbo is None:  # Cache computation.
            self._compute()
        return self._elbo

    @property
    def mu(self):
        """Mean of optimal approximating distribution."""
        if self._mu is None:  # Cache computation.
            self._compute()
        return self._mu

    @property
    def A(self):
        """Parameter of the corrective variance of the kernel of the optimal
        approximating distribution."""
        if self._A is None:  # Cache computation.
            self._compute()
        return self._A

    def _compute(self):
        # Extract processes.
        p_x, x = type_parameter(self.x), self.x.get()
        p_z, z = type_parameter(self.z), self.z.get()

        # Construct the necessary kernel matrices.
        K_zx = self.graph.kernels[p_z, p_x](z, x)
        self._K_z = convert(self.graph.kernels[p_z](z), AbstractMatrix)

        # Evaluating `e.kernel(x)` will yield incorrect results if `x` is a
        # `MultiInput`, because `x` then still designates the particular
        # components of `f`. Fix that by instead designating the elements of
        # `e`.
        if isinstance(x, MultiInput):
            x_n = MultiInput(*(p(xi.get())
                               for p, xi in zip(self.e.kernel.ps, x.get())))
        else:
            x_n = x

        # Construct the noise kernel matrix.
        K_n = self.e.kernel(x_n)

        # The approximation can only handle diagonal noise matrices.
        if not isinstance(K_n, Diagonal):
            raise RuntimeError('Kernel matrix of noise must be diagonal.')

        # And construct the components for the inducing point approximation.
        L_z = B.cholesky(self._K_z)
        self._A = B.add(B.eye(self._K_z),
                        B.iqf(K_n, B.transpose(B.solve(L_z, K_zx))))
        y_bar = uprank(self.y) - self.e.mean(x_n) - self.graph.means[p_x](x)
        prod_y_bar = B.solve(L_z, B.iqf(K_n, B.transpose(K_zx), y_bar))

        # Compute the optimal mean.
        self._mu = B.add(self.graph.means[p_z](z),
                         B.iqf(self._A, B.solve(L_z, self._K_z), prod_y_bar))

        # Compute the ELBO.
        # NOTE: The calculation of `trace_part` asserts that `K_n` is diagonal.
        #       The rest, however, is completely generic.
        trace_part = B.ratio(Diagonal(self.graph.kernels[p_x].elwise(x)[:, 0]) -
                             Diagonal(B.iqf_diag(self._K_z, K_zx)), K_n)
        det_part = B.logdet(2 * B.pi * K_n) + B.logdet(self._A)
        iqf_part = B.iqf(K_n, y_bar)[0, 0] - B.iqf(self._A, prod_y_bar)[0, 0]
        self._elbo = -0.5 * (trace_part + det_part + iqf_part)

    def posterior_kernel(self, p_i, p_j):
        p_z, z = type_parameter(self.z), self.z.get()
        return PosteriorKernel(self.graph.kernels[p_i, p_j],
                               self.graph.kernels[p_z, p_i],
                               self.graph.kernels[p_z, p_j],
                               z, self.K_z) + \
               CorrectiveKernel(self.graph.kernels[p_z, p_i],
                                self.graph.kernels[p_z, p_j],
                                z, self.A, self.K_z)

    def posterior_mean(self, p):
        p_z, z = type_parameter(self.z), self.z.get()
        return PosteriorMean(self.graph.means[p],
                             self.graph.means[p_z],
                             self.graph.kernels[p_z, p],
                             z, self.K_z, self.mu)
Exemple #21
0
class Normal(RandomVector):
    """Normal random variable.

    Args:
        mean (column vector, optional): Mean of the distribution. Defaults to zero.
        var (matrix): Variance of the distribution.
    """

    _dispatch = Dispatcher(in_class=Self)

    @_dispatch({B.Numeric, AbstractMatrix}, {B.Numeric, AbstractMatrix})
    def __init__(self, mean, var):
        self._mean = mean
        self._mean_is_zero = None
        self._var = var

    @_dispatch({B.Numeric, AbstractMatrix})
    def __init__(self, var):
        Normal.__init__(self, 0, var)

    @_dispatch(FunctionType, FunctionType)
    def __init__(self, construct_mean, construct_var):
        self._mean = None
        self._construct_mean = construct_mean
        self._mean_is_zero = None
        self._var = None
        self._construct_var = construct_var

    @_dispatch(FunctionType)
    def __init__(self, construct_var):
        Normal.__init__(self, lambda: 0, construct_var)

    def _resolve_mean(self, construct_zeros):
        if self._mean is None:
            self._mean = self._construct_mean()
        if self._mean_is_zero is None:
            self._mean_is_zero = self._mean is 0 or isinstance(self._mean, Zero)
        if self._mean is 0 and construct_zeros:
            self._mean = B.zeros(self.dtype, self.dim, 1)

    def _resolve_var(self):
        if self._var is None:
            self._var = self._construct_var()
        # Ensure that the variance is a structured matrix for efficient operations.
        self._var = convert(self._var, AbstractMatrix)

    @property
    def mean(self):
        """Mean."""
        self._resolve_mean(construct_zeros=True)
        return self._mean

    @property
    def mean_is_zero(self):
        """The mean is zero."""
        self._resolve_mean(construct_zeros=False)
        return self._mean_is_zero

    @property
    def var(self):
        """Variance."""
        self._resolve_var()
        return self._var

    @property
    def dtype(self):
        """Data type."""
        return B.dtype(self.var)

    @property
    def dim(self):
        """Dimensionality."""
        return B.shape(self.var)[0]

    @property
    def m2(self):
        """Second moment."""
        return self.var + B.outer(B.squeeze(self.mean))

    def marginals(self):
        """Get the marginals.

        Returns:
            tuple: A tuple containing the predictive means and lower and
                upper 95% central credible interval bounds.
        """
        mean = B.squeeze(B.dense(self.mean))
        # It can happen that the variances are slightly negative due to numerical noise.
        # Prevent NaNs from the following square root by taking the maximum with zero.
        error = 1.96 * B.sqrt(B.maximum(B.diag(self.var), B.cast(self.dtype, 0)))
        return mean, mean - error, mean + error

    def logpdf(self, x):
        """Compute the log-pdf.

        Args:
            x (input): Values to compute the log-pdf of.

        Returns:
            list[tensor]: Log-pdf for every input in `x`. If it can be
                determined that the list contains only a single log-pdf,
                then the list is flattened to a scalar.
        """
        logpdfs = (
            -(
                B.logdet(self.var)
                + B.cast(self.dtype, self.dim) * B.cast(self.dtype, B.log_2_pi)
                + B.iqf_diag(self.var, B.subtract(uprank(x), self.mean))
            )
            / 2
        )
        return logpdfs[0] if B.shape(logpdfs) == (1,) else logpdfs

    def entropy(self):
        """Compute the entropy.

        Returns:
            scalar: The entropy.
        """
        return (
            B.logdet(self.var)
            + B.cast(self.dtype, self.dim) * B.cast(self.dtype, B.log_2_pi + 1)
        ) / 2

    @_dispatch(Self)
    def kl(self, other):
        """Compute the KL divergence with respect to another normal
        distribution.

        Args:
            other (:class:`.random.Normal`): Other normal.

        Returns:
            scalar: KL divergence.
        """
        return (
            B.ratio(self.var, other.var)
            + B.iqf_diag(other.var, other.mean - self.mean)[0]
            - B.cast(self.dtype, self.dim)
            + B.logdet(other.var)
            - B.logdet(self.var)
        ) / 2

    @_dispatch(Self)
    def w2(self, other):
        """Compute the 2-Wasserstein distance with respect to another normal
        distribution.

        Args:
            other (:class:`.random.Normal`): Other normal.

        Returns:
            scalar: 2-Wasserstein distance.
        """
        var_root = B.root(self.var)
        root = B.root(B.matmul(var_root, other.var, var_root))
        var_part = B.trace(self.var) + B.trace(other.var) - 2 * B.trace(root)
        mean_part = B.sum((self.mean - other.mean) ** 2)
        # The sum of `mean_part` and `var_par` should be positive, but this
        # may not be the case due to numerical errors.
        return B.sqrt(B.maximum(mean_part + var_part, B.cast(self.dtype, 0)))

    def sample(self, num=1, noise=None):
        """Sample from the distribution.

        Args:
            num (int): Number of samples.
            noise (scalar, optional): Variance of noise to add to the
                samples. Must be positive.

        Returns:
            tensor: Samples as rank 2 column vectors.
        """
        var = self.var

        # Add noise.
        if noise is not None:
            var = B.add(var, B.fill_diag(noise, self.dim))

        # Perform sampling operation.
        sample = B.sample(var, num=num)
        if not self.mean_is_zero:
            sample = B.add(sample, self.mean)

        return B.dense(sample)

    @_dispatch(B.Numeric)
    def __add__(self, other):
        return Normal(self.mean + other, self.var)

    @_dispatch(Self)
    def __add__(self, other):
        return Normal(B.add(self.mean, other.mean), B.add(self.var, other.var))

    @_dispatch(B.Numeric)
    def __mul__(self, other):
        return Normal(B.multiply(self.mean, other), B.multiply(self.var, other ** 2))

    def lmatmul(self, other):
        return Normal(
            B.matmul(other, self.mean),
            B.matmul(B.matmul(other, self.var), other, tr_b=True),
        )

    def rmatmul(self, other):
        return Normal(
            B.matmul(other, self.mean, tr_a=True),
            B.matmul(B.matmul(other, self.var, tr_a=True), other),
        )
Exemple #22
0
class GP(RandomProcess):
    """Gaussian process.

    Args:
        kernel (:class:`.kernel.Kernel`): Kernel of the
            process.
        mean (:class:`.mean.Mean`, optional): Mean function of the
            process. Defaults to zero.
        graph (:class:`.graph.Graph`, optional): Graph to attach to.
    """
    _dispatch = Dispatcher(in_class=Self)

    @_dispatch([object])
    def __init__(self, kernel, mean=None, graph=model, name=None):
        # Resolve kernel.
        if isinstance(kernel, (B.Numeric, FunctionType)):
            kernel = kernel * OneKernel()

        # Resolve mean.
        if mean is None:
            mean = ZeroMean()
        elif isinstance(mean, (B.Numeric, FunctionType)):
            mean = mean * OneMean()

        # Then add a new `GP` to the graph with the resolved kernel and mean.
        self.graph = graph
        self.graph.add_independent_gp(self, kernel, mean)

        # If a name is given, set the name.
        if name:
            self.graph.name(self, name)

    @_dispatch(Graph)
    def __init__(self, graph):
        self.graph = graph

    @property
    def kernel(self):
        """Kernel of the GP."""
        return self.graph.kernels[self]

    @property
    def mean(self):
        """Mean function of the GP."""
        return self.graph.means[self]

    @property
    def name(self):
        """Name of the GP."""
        return self.graph[self]

    @name.setter
    @_dispatch(str)
    def name(self, name):
        self.graph.name(self, name)

    def __call__(self, x):
        """Construct a finite-dimensional distribution at specified locations.

        Args:
            x (input): Points to construct the distribution at.

        Returns:
            :class:`.random.Normal`: Finite-dimensional distribution.
        """
        return Normal(self, x)

    @_dispatch([object])
    def condition(self, *args):
        """Condition the GP. See :meth:`.graph.Graph.condition`."""
        return self.graph.condition((self,), Observations(*args, ref=self))[0]

    @_dispatch(AbstractObservations)
    def condition(self, obs):
        return self.graph.condition((self,), obs)[0]

    @_dispatch(object)
    def __add__(self, other):
        return self.graph.sum(self, other)

    @_dispatch(Random)
    def __add__(self, other):
        raise NotImplementedError('Cannot add a GP and a {}.'
                                  ''.format(type(other).__name__))

    @_dispatch(Self)
    def __add__(self, other):
        return self.graph.sum(self, other)

    @_dispatch(object)
    def __mul__(self, other):
        return self.graph.mul(self, other)

    @_dispatch(Random)
    def __mul__(self, other):
        raise NotImplementedError('Cannot multiply a GP and a {}.'
                                  ''.format(type(other).__name__))

    @_dispatch(Self)
    def __mul__(self, other):
        return (lambda x: self.graph.means[self](x)) * other + \
               self * (lambda x: self.graph.means[other](x)) + \
               GP(kernel=self.graph.kernels[self] *
                         self.graph.kernels[other] +
                         self.graph.kernels[self, other] *
                         self.graph.kernels[other, self],
                  mean=-self.graph.means[self] *
                       self.graph.means[other],
                  graph=self.graph)

    @_dispatch([object])
    def __or__(self, args):
        """Shorthand for conditioning."""
        return self.condition(Observations(*args, ref=self))

    @_dispatch(AbstractObservations)
    def __or__(self, obs):
        return self.condition(obs)

    def shift(self, shift):
        """Shift the GP. See :meth:`.graph.Graph.shift`."""
        return self.graph.shift(self, shift)

    def stretch(self, stretch):
        """Stretch the GP. See :meth:`.graph.Graph.stretch`."""
        return self.graph.stretch(self, stretch)

    def __gt__(self, stretch):
        """Shorthand for :meth:`.graph.GP.stretch`."""
        return self.stretch(stretch)

    def transform(self, f):
        """Input transform the GP. See :meth:`.graph.Graph.transform`."""
        return self.graph.transform(self, f)

    def select(self, *dims):
        """Select dimensions from the input. See :meth:`.graph.Graph.select`."""
        return self.graph.select(self, *dims)

    def __getitem__(self, *dims):
        """Shorthand for :meth:`.graph.GP.select`."""
        return self.select(*dims)

    def diff(self, dim=0):
        """Differentiate the GP. See :meth:`.graph.Graph.diff`."""
        return self.graph.diff(self, dim)

    def diff_approx(self, deriv=1, order=6):
        """Approximate the derivative of the GP by constructing a finite
        difference approximation.

        Args:
            deriv (int): Order of the derivative.
            order (int): Order of the estimate.

        Returns:
            Approximation of the derivative of the GP.
        """
        # Use the FDM library to figure out the coefficients.
        fdm = central_fdm(order, deriv, adapt=0, factor=1e8)
        fdm.estimate()  # Estimate step size.

        # Construct finite difference.
        df = 0
        for g, c in zip(fdm.grid, fdm.coefs):
            df += c * self.shift(-g * fdm.step)
        return df / fdm.step ** deriv

    @property
    def stationary(self):
        """Stationarity of the GP."""
        return self.kernel.stationary

    def __str__(self):
        return self.display()

    def __repr__(self):
        return self.display()

    def display(self, formatter=lambda x: x):
        """Display the GP.

        Args:
            formatter (function, optional): Function to format values.

        Returns:
            str: GP as a string.
        """
        return 'GP({}, {})'.format(self.kernel.display(formatter),
                                   self.mean.display(formatter))
Exemple #23
0
class Element(Referentiable):
    """A field over functions.

    Functions are also referred to as elements of the field. Elements can be
    added and multiplied.
    """

    _dispatch = Dispatcher(in_class=Self)

    def __eq__(self, other):
        return False

    def __mul__(self, other):
        return mul(self, other)

    def __rmul__(self, other):
        return mul(other, self)

    def __add__(self, other):
        return add(self, other)

    def __radd__(self, other):
        return add(other, self)

    def __neg__(self):
        return mul(-1, self)

    def __sub__(self, other):
        return add(self, -other)

    def __rsub__(self, other):
        return add(other, -self)

    @property
    def num_terms(self):
        """Number of terms"""
        return 1

    def term(self, i):
        """Get a specific term.

        Args:
            i (int): Index of term.

        Returns:
            :class:`.field.Element`: The referenced term.
        """
        if i == 0:
            return self
        else:
            raise IndexError('Index out of range.')

    @property
    def num_factors(self):
        """Number of factors"""
        return 1

    def factor(self, i):
        """Get a specific factor.

        Args:
            i (int): Index of factor.

        Returns:
            :class:`.field.Element`: The referenced factor.
        """
        if i == 0:
            return self
        else:
            raise IndexError('Index out of range.')

    @property
    def __name__(self):
        return self.__class__.__name__

    def __repr__(self):
        return self.display()

    def __str__(self):
        return self.display()

    @_dispatch(Formatter)
    def display(self, formatter):
        """Display the element.

        Args:
            formatter (function, optional): Function to format values.

        Returns:
            str: Element as a string.
        """
        # Due to multiple inheritance, we might arrive here before arriving at a
        # method in the appropriate subclass. The only case to consider is if
        # we're not in a leaf.
        if isinstance(self, (JoinElement, WrappedElement)):
            return pretty_print(self, formatter)
        else:
            return self.__class__.__name__ + '()'

    @_dispatch()
    def display(self):
        return self.display(lambda x: x)
Exemple #24
0
    class B(A):
        _dispatch = Dispatcher(in_class=Self)

        @_dispatch(Union(int, Self, str), return_type=Union(int, Self))
        def do(self, x):
            return x
Exemple #25
0
def test_tuple():
    # Standard type tests.
    assert hash(Tuple[int]) == hash(Tuple[int])
    assert hash(Tuple[int]) != hash(Tuple[str])
    assert hash(Tuple[Tuple[int]]) == hash(Tuple[Tuple[int]])
    assert hash(Tuple[Tuple[int]]) != hash(Tuple[Tuple[str]])
    assert repr(Tuple[int]) == f"Tuple[{Type(int)!r}]"
    assert issubclass(Tuple[int].get_types()[0], tuple)
    assert not issubclass(Tuple[int].get_types()[0], int)
    assert not issubclass(Tuple[int].get_types()[0], list)

    # Test instance check.
    assert isinstance((), Tuple())
    assert isinstance((1, 2), Tuple[int, int])

    # Check tracking of parametric.
    assert Tuple[int].parametric
    assert ptype(List[Tuple[int]]).parametric
    assert ptype(Union[Tuple[int]]).parametric
    promise = PromisedType()
    promise.deliver(Tuple[int])
    assert promise.resolve().parametric

    # Check tracking of runtime `type_of`.
    assert Tuple[int].runtime_type_of
    assert ptype(List[Tuple[int]]).runtime_type_of
    assert ptype(Union[Tuple[int]]).runtime_type_of
    promise = PromisedType()
    promise.deliver(Tuple[int])
    assert promise.resolve().runtime_type_of

    # Test correctness.
    dispatch = Dispatcher()

    @dispatch
    def f(x):
        return "fallback"

    @dispatch
    def f(x: tuple):
        return "tup"

    @dispatch
    def f(x: Tuple[int]):
        return "tup of int"

    @dispatch
    def f(x: Tuple[int, int]):
        return "tup of double int"

    @dispatch
    def f(x: Tuple[Tuple[int]]):
        return "tup of tup of int"

    @dispatch
    def f(x: Tuple[Tuple[int], Tuple[int]]):
        return "tup of double tup of int"

    @dispatch
    def f(x: Tuple[int, Tuple[int, int]]):
        return "tup of int and tup of double int"

    assert f((1, )) == "tup of int"
    assert f(1) == "fallback"
    assert f((1, 2)) == "tup of double int"
    assert f((1, 2, "3")) == "tup"
    assert f(((1, ), )) == "tup of tup of int"
    assert f(((1, ), (1, ))) == "tup of double tup of int"
    assert f((1, (1, 2))) == "tup of int and tup of double int"
    assert f(((1, ), (1, 2))) == "tup"
Exemple #26
0
class DerivativeKernel(Kernel, DerivativeFunction, Referentiable):
    """Derivative of kernel."""
    _dispatch = Dispatcher(in_class=Self)

    @_dispatch(B.Numeric, B.Numeric, Cache)
    @cache
    @uprank
    def __call__(self, x, y, B):
        i, j = expand(self.derivs)
        k = self[0]

        # Derivative with respect to both `x` and `y`.
        if i is not None and j is not None:
            z = B.concat([x[:, i], y[:, j]], axis=0)
            n = B.shape(x)[0]
            K = dense(
                k(B.concat([x[:, :i], z[:n, None], x[:, i + 1:]], axis=1),
                  B.concat([y[:, :j], z[n:, None], y[:, j + 1:]], axis=1)))
            return Dense(B.hessians(K, [z])[0][:n, n:])

        # Derivative with respect to `x`.
        elif i is not None and j is None:
            xi = x[:, i:i + 1]
            # Give every `B.identity` a unique cache ID to prevent caching.
            xis = [
                B.identity(xi, cache_id=n) for n in range(B.shape_int(y)[0])
            ]

            def f(z):
                return dense(
                    k(B.concat([x[:, :i], z[0], x[:, i + 1:]], axis=1), z[1]))

            res = B.map_fn(f, (B.stack(xis, axis=0), y[:, None, :]),
                           dtype=B.dtype(x))
            return Dense(B.concat(B.gradients(B.sum(res, axis=0), xis),
                                  axis=1))

        # Derivative with respect to `y`.
        elif i is None and j is not None:
            yj = y[:, j:j + 1]
            # Give every `B.identity` a unique cache ID to prevent caching.
            yjs = [
                B.identity(yj, cache_id=n) for n in range(B.shape_int(x)[0])
            ]

            def f(z):
                return dense(
                    k(z[0], B.concat([y[:, :j], z[1], y[:, j + 1:]], axis=1)))

            res = B.map_fn(f, (x[:, None, :], B.stack(yjs, axis=0)),
                           dtype=B.dtype(x))
            dKt = B.concat(B.gradients(B.sum(res, axis=0), yjs), axis=1)
            return Dense(B.transpose(dKt))

        else:
            raise RuntimeError('No derivative specified.')

    @property
    def _stationary(self):
        # NOTE: In the one-dimensional case, if derivatives with respect to both
        # arguments are taken, then the result is in fact stationary.
        return False

    @_dispatch(Self)
    def __eq__(self, other):
        return self[0] == other[0] and \
               tuple_equal(expand(self.derivs), expand(other.derivs))
Exemple #27
0
class Vars(Provider):
    """Variable storage.

    Args:
        dtype (data type): Data type of the variables.
        source (tensor, optional): Tensor to source variables from. Defaults to
            not being used.
    """
    _dispatch = Dispatcher(in_class=Self)

    def __init__(self, dtype, source=None):
        self.dtype = dtype

        # Source:
        self.source = source
        self.source_index = 0

        # Storage:
        self.vars = []
        self.transforms = []
        self.inverse_transforms = []

        # Lookup:
        self.name_to_index = OrderedDict()
        self._get_vars_cache = {}

        # Packing:
        self.vector_packer = None

    def _get_var(self,
                 transform,
                 inverse_transform,
                 init,
                 generate_init,
                 shape,
                 dtype,
                 name):
        # If the name already exists, return that variable.
        try:
            return self[name]
        except KeyError:
            pass

        # A new variable will be added. Clear lookup cache.
        self._get_vars_cache.clear()

        # Resolve data type.
        dtype = self.dtype if dtype is None else dtype

        # If no source is provided, get the latent from from the provided
        # initialiser.
        if self.source is None:
            # Resolve initialisation and inverse transform.
            if init is None:
                init = generate_init(shape=shape, dtype=dtype)
            else:
                init = B.cast(dtype, init)

            # Construct optimisable variable.
            latent = inverse_transform(init)
            if isinstance(self.dtype, B.TFDType):
                latent = tf.Variable(latent)
            elif isinstance(self.dtype, B.TorchDType):
                pass  # All is good in this case.
            else:
                # Must be a NumPy data type.
                assert isinstance(self.dtype, B.NPDType)
                latent = np.array(latent)
        else:
            # Get the latent variable from the source.
            length = reduce(mul, shape, 1)
            latent_flat = \
                self.source[self.source_index:self.source_index + length]
            self.source_index += length

            # Cast to the right data type.
            latent = B.cast(dtype, B.reshape(latent_flat, *shape))

        # Store transforms.
        self.vars.append(latent)
        self.transforms.append(transform)
        self.inverse_transforms.append(inverse_transform)

        # Get index of the variable.
        index = len(self.vars) - 1

        # Store name if given.
        if name is not None:
            self.name_to_index[name] = index

        # Generate the variable and return.
        return transform(latent)

    def unbounded(self, init=None, shape=(), dtype=None, name=None):
        def generate_init(shape, dtype):
            return B.randn(dtype, *shape)

        return self._get_var(transform=lambda x: x,
                             inverse_transform=lambda x: x,
                             init=init,
                             generate_init=generate_init,
                             shape=shape,
                             dtype=dtype,
                             name=name)

    def positive(self, init=None, shape=(), dtype=None, name=None):
        def generate_init(shape, dtype):
            return B.rand(dtype, *shape)

        return self._get_var(transform=lambda x: B.exp(x),
                             inverse_transform=lambda x: B.log(x),
                             init=init,
                             generate_init=generate_init,
                             shape=shape,
                             dtype=dtype,
                             name=name)

    def bounded(self,
                init=None,
                lower=1e-4,
                upper=1e4,
                shape=(),
                dtype=None,
                name=None):
        def transform(x):
            return lower + (upper - lower) / (1 + B.exp(x))

        def inverse_transform(x):
            return B.log(upper - x) - B.log(x - lower)

        def generate_init(shape, dtype):
            return lower + B.rand(dtype, *shape) * (upper - lower)

        return self._get_var(transform=transform,
                             inverse_transform=inverse_transform,
                             init=init,
                             generate_init=generate_init,
                             shape=shape,
                             dtype=dtype,
                             name=name)

    def __getitem__(self, name):
        index = self.name_to_index[name]
        return self.transforms[index](self.vars[index])

    def assign(self, name, value, differentiable=False):
        """Assign a value to a variable.

        Args:
            name (hashable): Name of variable to assign value to.
            value (tensor): Value to assign.
            differentiable (bool, optional): Do a differentiable assignment.

        Returns:
            tensor: Assignment result.
        """
        index = self.name_to_index[name]
        if differentiable:
            # Do a differentiable assignment.
            self.vars[index] = value
            return value
        else:
            # Overwrite data.
            return _assign(self.vars[index],
                           self.inverse_transforms[index](value))

    def copy(self, detach=False):
        """Create a copy of the variable manager that shares the variables.

        Args:
            detach (bool, optional): Detach the variables in PyTorch. Defaults
                to `False`.

        Returns:
            :class:`.vars.Vars`: Copy.
        """
        vs = Vars(dtype=self.dtype)
        vs.transforms = list(self.transforms)
        vs.inverse_transforms = list(self.inverse_transforms)
        vs.name_to_index = OrderedDict(self.name_to_index)
        vs.vector_packer = self.vector_packer
        if detach:
            for var in self.vars:
                vs.vars.append(var.detach())
        else:
            vs.vars = list(self.vars)
        return vs

    def detach(self):
        """Detach all variables held in PyTorch."""
        self.vars = [v.detach() for v in self.vars]

    def requires_grad(self, value, *names):
        """Set which variables require a gradient in PyTorch.

        Args:
            value (bool): Require a gradient.
            *names (hashable): Specify variables by name.
        """
        for var in self.get_vars(*names):
            var.requires_grad_(value)

    def get_vars(self, *names, **kw_args):
        """Get latent variables.

        If no arguments are supplied, then all latent variables are retrieved.
        Furthermore, the same collection of variables is guaranteed to be
        returned in the same order.

        Args:
            *names (hashable): Get variables by name.
            indices (bool, optional): Get the indices of the variables instead.
                Defaults to `False`.

        Returns:
            list: Matched latent variables or their indices, depending on the
                value of `indices`.
        """
        # If nothing is specified, return all latent variables.
        if len(names) == 0:
            if kw_args.get('indices', False):
                return list(range(len(self.vars)))
            else:
                return self.vars

        # Attempt to use cache.
        cache_key = (names, kw_args.get('indices', False))
        try:
            return self._get_vars_cache[cache_key]
        except KeyError:
            pass

        # Collect indices of matches.
        indices = set()
        for name in names:
            a_match = False
            for k, v in self.name_to_index.items():
                if match(name, k):
                    indices |= {v}
                    a_match = True

            # Check that there was a match.
            if not a_match:
                raise ValueError('No variable matching "{}".'.format(name))

        # Return indices if asked for. Otherwise, return variables.
        if kw_args.get('indices', False):
            res = sorted(indices)
        else:
            res = [self.vars[i] for i in sorted(indices)]

        # Store in cache before returning.
        self._get_vars_cache[cache_key] = res
        return res

    def get_vector(self, *names):
        """Get all the latent variables stacked in a vector.

        If no arguments are supplied, then all latent variables are retrieved.

        Args:
            *names (hashable): Get variables by name.

        Returns:
            tensor: Vector consisting of all latent values
        """
        vars = self.get_vars(*names)
        self.vector_packer = Packer(*vars)
        return self.vector_packer.pack(*vars)

    def set_vector(self, values, *names, **kw_args):
        """Set all the latent variables by values from a vector.

        If no arguments are supplied, then all latent variables are retrieved.

        Args:
            values (tensor): Vector to set the variables to.
            *names (hashable): Set variables by name.
            differentiable (bool, optional): Differentiable assignment. Defaults
                to `False`.

        Returns:
            list: Assignment results.
        """
        values = self.vector_packer.unpack(values)

        if kw_args.get('differentiable', False):
            # Do a differentiable assignment.
            for index, value in zip(self.get_vars(*names, indices=True),
                                    values):
                self.vars[index] = value
            return values
        else:
            # Overwrite data.
            assignments = []
            for var, value in zip(self.get_vars(*names), values):
                assignments.append(_assign(var, value))
            return assignments

    @property
    def names(self):
        """All available names."""
        return list(self.name_to_index.keys())

    def print(self):
        """Print all variables."""
        for name in self.names:
            wbml.out.kv(name, self[name])