Пример #1
0
    def __init__(self, nu=1.5, dt=1.0, obs_dim=2, num_gps=1,
                 length_scale_init=None, kernel_scale_init=None,
                 obs_noise_scale_init=None):
        self.nu = nu
        self.dt = dt
        assert obs_dim > 1, "If obs_dim==1 you should use IndependentMaternGP"
        self.obs_dim = obs_dim
        self.num_gps = num_gps

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim,)

        self.dt = dt
        self.obs_dim = obs_dim
        self.num_gps = num_gps

        super().__init__()

        self.kernel = MaternKernel(nu=nu, num_gps=num_gps,
                                   length_scale_init=length_scale_init,
                                   kernel_scale_init=kernel_scale_init)
        self.full_state_dim = num_gps * self.kernel.state_dim

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.A = nn.Parameter(0.3 * torch.randn(self.num_gps, self.obs_dim))
Пример #2
0
    def __init__(
        self,
        nu=1.5,
        dt=1.0,
        obs_dim=1,
        length_scale_init=None,
        kernel_scale_init=None,
        obs_noise_scale_init=None,
    ):
        self.nu = nu
        self.dt = dt
        self.obs_dim = obs_dim

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim, )

        super().__init__()

        self.kernel = MaternKernel(
            nu=nu,
            num_gps=obs_dim,
            length_scale_init=length_scale_init,
            kernel_scale_init=kernel_scale_init,
        )

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)

        obs_matrix = [1.0] + [0.0] * (self.kernel.state_dim - 1)
        self.register_buffer("obs_matrix",
                             torch.tensor(obs_matrix).unsqueeze(-1))
Пример #3
0
    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)

        self._event_dims = {}
        self._cond_indep_stacks = {}
        self.locs = PyroModule()
        self.scales = PyroModule()

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            # Collect unconstrained event_dims, which may differ from constrained event_dims.
            with helpful_support_errors(site):
                init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach()
            event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim()
            self._event_dims[name] = event_dim

            # Collect independence contexts.
            self._cond_indep_stacks[name] = site["cond_indep_stack"]

            # If subsampling, repeat init_value to full size.
            for frame in site["cond_indep_stack"]:
                full_size = getattr(frame, "full_size", frame.size)
                if full_size != frame.size:
                    dim = frame.dim - event_dim
                    init_loc = periodic_repeat(init_loc, full_size, dim).contiguous()
            init_scale = torch.full_like(init_loc, self._init_scale)

            _deep_setattr(self.locs, name, PyroParam(init_loc, constraints.real, event_dim))
            _deep_setattr(self.scales, name,
                          PyroParam(init_scale, self.scale_constraint, event_dim))
Пример #4
0
    def __init__(
        self,
        n_input: int,
        n_conditions: int,
        lam_scale: float,
        bias_scale: float,
        alpha: float = 1.0,
    ):
        super().__init__()

        self.n_input = n_input
        self.n_conditions = n_conditions

        # weight on monotonic constraint
        self.register_buffer("alpha", torch.as_tensor(alpha))

        # scale of priors on weights and bias
        self.register_buffer("lam_p_scale", torch.as_tensor(lam_scale))
        self.register_buffer("bias_p_scale", torch.as_tensor(bias_scale))

        # parameters for guide
        self.weight_loc = nn.Parameter(
            torch.nn.init.normal_(torch.Tensor(n_input)))
        self.weight_scale = PyroParam(torch.full((n_input, ), lam_scale),
                                      constraint=constraints.positive)
        self.bias_loc = nn.Parameter(
            torch.nn.init.normal_(torch.Tensor(n_conditions)))
        self.bias_scale = PyroParam(torch.full((n_conditions, ), bias_scale),
                                    constraint=constraints.positive)
Пример #5
0
    def __init__(self,
                 nu=1.5,
                 num_gps=1,
                 length_scale_init=None,
                 kernel_scale_init=None):
        if nu not in [0.5, 1.5, 2.5]:
            raise NotImplementedError(
                "The only supported values of nu are 0.5, 1.5 and 2.5")
        self.nu = nu
        self.state_dim = {0.5: 1, 1.5: 2, 2.5: 3}[nu]
        self.num_gps = num_gps

        if length_scale_init is None:
            length_scale_init = torch.ones(num_gps)
        assert length_scale_init.shape == (num_gps, )

        if kernel_scale_init is None:
            kernel_scale_init = torch.ones(num_gps)
        assert kernel_scale_init.shape == (num_gps, )

        super().__init__()

        self.length_scale = PyroParam(length_scale_init,
                                      constraint=constraints.positive)
        self.kernel_scale = PyroParam(kernel_scale_init,
                                      constraint=constraints.positive)

        if self.state_dim > 1:
            for x in range(self.state_dim):
                for y in range(self.state_dim):
                    mask = torch.zeros(self.state_dim, self.state_dim)
                    mask[x, y] = 1.0
                    self.register_buffer("mask{}{}".format(x, y), mask)
Пример #6
0
 def _setup_prototype(self, *args, **kwargs):
     super()._setup_prototype(*args, **kwargs)
     # Initialize guide params
     self.loc = nn.Parameter(self._init_loc())
     self.scale_tril = PyroParam(
         eye_like(self.loc, self.latent_dim) * self._init_scale,
         constraints.lower_cholesky)
Пример #7
0
    def init_mvn_guide(self):
        """ Initialize multivariate normal guide
        """
        init_loc = torch.full((self.n_params, ), 0.0)
        init_scale = eye_like(init_loc, self.n_params) * 0.1

        _deep_setattr(self, "mvn.loc", PyroParam(init_loc, constraints.real))
        _deep_setattr(self, "mvn.scale_tril",
                      PyroParam(init_scale, constraints.lower_cholesky))
Пример #8
0
class PartialMultivariateNormalSamplingGroup(LocatedSamplingGroupWithPrior):
    def __init__(self,
                 sites,
                 name='',
                 diag=_nomatch,
                 init_scale_full: tp.Union[torch.Tensor, float] = 1.,
                 init_scale_diag: tp.Union[torch.Tensor, float] = 1.,
                 *args,
                 **kwargs):
        self.diag_pattern = re.compile(diag)
        self.sites_full, self.sites_diag = ({
            site['name']: site
            for site in _
        } for _ in partition(lambda _: self.diag_pattern.match(_['name']),
                             sites))

        super().__init__(
            dict_union(self.sites_full, self.sites_diag).values(), name, *args,
            **kwargs)

        self.size_full, self.size_diag = (sum(self.sizes[site] for site in _)
                                          for _ in (self.sites_full,
                                                    self.sites_diag))

        jac = self.jacobian(self.loc)

        self.scale_full = PyroParam(self._scale_matrix(init_scale_full,
                                                       jac[:self.size_full]),
                                    event_dim=2,
                                    constraint=constraints.lower_cholesky)
        self.scale_cross = PyroParam(self.loc.new_zeros(
            torch.Size((self.size_diag, self.size_full))),
                                     event_dim=2)
        self.scale_diag = PyroParam(self._scale_diagonal(
            init_scale_diag, jac[self.size_full:]),
                                    event_dim=1,
                                    constraint=constraints.positive)

        self.guide_z_aux = PyroSample(
            dist.Normal(self.loc.new_zeros(()),
                        1.).expand(self.event_shape).to_event(1))

    @property
    def half_log_det(self):
        return self.scale_full.diagonal(
            dim1=-2, dim2=-1).log().sum(-1) + self.scale_diag.log().sum(-1)

    def prior(self):
        z_aux = self.guide_z_aux
        zfull, zdiag = z_aux[..., :self.size_full,
                             None], z_aux[..., self.size_full:]
        return dist.Delta(self.loc + torch.cat(
            ((self.scale_full @ zfull).squeeze(-1),
             (self.scale_cross @ zfull).squeeze(-1) + self.scale_diag * zdiag),
            dim=-1),
                          log_density=-self.half_log_det,
                          event_dim=1)
Пример #9
0
    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            value = PyroParam(site["value"].detach(), constraint=site["fn"].support)
            _deep_setattr(self, name, value)
Пример #10
0
 def _setup_prototype(self, *args, **kwargs):
     super()._setup_prototype(*args, **kwargs)
     # Initialize guide params
     self.loc = nn.Parameter(self._init_loc())
     self.scale = PyroParam(
         self.loc.new_full((self.latent_dim, ), self._init_scale),
         constraints.positive)
Пример #11
0
    def __init__(
        self,
        nu=1.5,
        dt=1.0,
        obs_dim=1,
        linearly_coupled=False,
        length_scale_init=None,
        obs_noise_scale_init=None,
    ):

        if nu != 1.5:
            raise NotImplementedError("The only supported value of nu is 1.5")

        self.dt = dt
        self.obs_dim = obs_dim

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim, )

        super().__init__()

        self.kernel = MaternKernel(nu=nu,
                                   num_gps=obs_dim,
                                   length_scale_init=length_scale_init)
        self.full_state_dim = self.kernel.state_dim * obs_dim

        # we demote self.kernel.kernel_scale from being a nn.Parameter
        # since the relevant scales are now encoded in the wiener noise matrix
        del self.kernel.kernel_scale
        self.kernel.register_buffer("kernel_scale", torch.ones(obs_dim))

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.wiener_noise_tril = PyroParam(
            torch.eye(obs_dim) + 0.03 * torch.randn(obs_dim, obs_dim).tril(-1),
            constraint=constraints.lower_cholesky,
        )

        if linearly_coupled:
            self.obs_matrix = nn.Parameter(
                0.3 * torch.randn(self.obs_dim, self.obs_dim))
        else:
            obs_matrix = torch.zeros(self.full_state_dim, obs_dim)
            for i in range(obs_dim):
                obs_matrix[self.kernel.state_dim * i, i] = 1.0
            self.register_buffer("obs_matrix", obs_matrix)
Пример #12
0
 def _setup_prototype(self, *args, **kwargs):
     super()._setup_prototype(*args, **kwargs)
     # Initialize guide params
     self.loc = nn.Parameter(self._init_loc())
     if self.rank is None:
         self.rank = int(round(self.latent_dim ** 0.5))
     self.scale = PyroParam(
         self.loc.new_full((self.latent_dim,), 0.5 ** 0.5 * self._init_scale),
         constraint=constraints.positive)
     self.cov_factor = nn.Parameter(
         self.loc.new_empty(self.latent_dim, self.rank).normal_(0, 1 / self.rank ** 0.5))
Пример #13
0
class AutoMultivariateNormal(AutoContinuous):
    """
    This implementation of :class:`AutoContinuous` uses a Cholesky
    factorization of a Multivariate Normal distribution to construct a guide
    over the entire latent space. The guide does not depend on the model's
    ``*args, **kwargs``.

    Usage::

        guide = AutoMultivariateNormal(model)
        svi = SVI(model, guide, ...)

    By default the mean vector is initialized by ``init_loc_fn()`` and the
    Cholesky factor is initialized to the identity times a small factor.

    :param callable model: A generative model.
    :param callable init_loc_fn: A per-site initialization function.
        See :ref:`autoguide-initialization` section for available functions.
    :param float init_scale: Initial scale for the standard deviation of each
        (unconstrained transformed) latent variable.
    """

    # TODO consider switching to constraints.softplus_lower_cholesky
    # See https://github.com/pyro-ppl/numpyro/issues/855
    scale_tril_constraint = constraints.lower_cholesky

    def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1):
        if not isinstance(init_scale, float) or not (init_scale > 0):
            raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
        self._init_scale = init_scale
        super().__init__(model, init_loc_fn=init_loc_fn)

    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)
        # Initialize guide params
        self.loc = nn.Parameter(self._init_loc())
        self.scale_tril = PyroParam(eye_like(self.loc, self.latent_dim) * self._init_scale,
                                    self.scale_tril_constraint)

    def get_base_dist(self):
        return dist.Normal(torch.zeros_like(self.loc), torch.zeros_like(self.loc)).to_event(1)

    def get_transform(self, *args, **kwargs):
        return dist.transforms.LowerCholeskyAffine(self.loc, scale_tril=self.scale_tril)

    def get_posterior(self, *args, **kwargs):
        """
        Returns a MultivariateNormal posterior distribution.
        """
        return dist.MultivariateNormal(self.loc, scale_tril=self.scale_tril)

    def _loc_scale(self, *args, **kwargs):
        return self.loc, self.scale_tril.diag()
Пример #14
0
    def __init__(self,
                 sites,
                 name='',
                 init_scale: tp.Union[torch.Tensor, float] = 1.,
                 *args,
                 **kwargs):
        super().__init__(sites, name, *args, **kwargs)

        self.scale = PyroParam(self._scale_diagonal(init_scale,
                                                    self.jacobian(self.loc)),
                               event_dim=1,
                               constraint=constraints.positive)
Пример #15
0
    def __init__(self,
                 sites,
                 name='',
                 diag=_nomatch,
                 init_scale_full: tp.Union[torch.Tensor, float] = 1.,
                 init_scale_diag: tp.Union[torch.Tensor, float] = 1.,
                 *args,
                 **kwargs):
        self.diag_pattern = re.compile(diag)
        self.sites_full, self.sites_diag = ({
            site['name']: site
            for site in _
        } for _ in partition(lambda _: self.diag_pattern.match(_['name']),
                             sites))

        super().__init__(
            dict_union(self.sites_full, self.sites_diag).values(), name, *args,
            **kwargs)

        self.size_full, self.size_diag = (sum(self.sizes[site] for site in _)
                                          for _ in (self.sites_full,
                                                    self.sites_diag))

        jac = self.jacobian(self.loc)

        self.scale_full = PyroParam(self._scale_matrix(init_scale_full,
                                                       jac[:self.size_full]),
                                    event_dim=2,
                                    constraint=constraints.lower_cholesky)
        self.scale_cross = PyroParam(self.loc.new_zeros(
            torch.Size((self.size_diag, self.size_full))),
                                     event_dim=2)
        self.scale_diag = PyroParam(self._scale_diagonal(
            init_scale_diag, jac[self.size_full:]),
                                    event_dim=1,
                                    constraint=constraints.positive)

        self.guide_z_aux = PyroSample(
            dist.Normal(self.loc.new_zeros(()),
                        1.).expand(self.event_shape).to_event(1))
Пример #16
0
    def __init__(self,
                 obs_dim=1,
                 state_dim=2,
                 obs_noise_scale_init=None,
                 learnable_observation_loc=False):
        self.obs_dim = obs_dim
        self.state_dim = state_dim

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim, )

        super().__init__()

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.trans_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                              constraint=constraints.positive)
        self.trans_matrix = nn.Parameter(
            torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim))
        self.obs_matrix = nn.Parameter(0.3 * torch.randn(state_dim, obs_dim))
        self.init_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                             constraint=constraints.positive)

        if learnable_observation_loc:
            self.obs_loc = nn.Parameter(torch.zeros(obs_dim))
        else:
            self.register_buffer('obs_loc', torch.zeros(obs_dim))
Пример #17
0
    def __init__(self,
                 sites,
                 name='',
                 init_scale: tp.Union[torch.Tensor, float] = 1.,
                 *args,
                 **kwargs):
        super().__init__(sites, name, *args, **kwargs)

        self.scale_tril = PyroParam(self._scale_matrix(init_scale,
                                                       self.jacobian(
                                                           self.loc)),
                                    event_dim=2,
                                    constraint=constraints.lower_cholesky)
Пример #18
0
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        model = config_enumerate(self.model)
        self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
            *args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._discrete_sites = []
        self._cond_indep_stacks = {}
        self._plates = {}
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            if site["infer"].get("enumerate") != "parallel":
                raise NotImplementedError(
                    'Expected sample site "{}" to be discrete and '
                    'configured for parallel enumeration'.format(name))

            # collect discrete sample sites
            fn = site["fn"]
            Dist = type(fn)
            if Dist in (dist.Bernoulli, dist.Categorical,
                        dist.OneHotCategorical):
                params = [("probs", fn.probs.detach().clone(),
                           fn.arg_constraints["probs"])]
            else:
                raise NotImplementedError("{} is not supported".format(
                    Dist.__name__))
            self._discrete_sites.append((site, Dist, params))

            # collect independence contexts
            self._cond_indep_stacks[name] = site["cond_indep_stack"]
            for frame in site["cond_indep_stack"]:
                if frame.vectorized:
                    self._plates[frame.name] = frame
                else:
                    raise NotImplementedError(
                        "AutoDiscreteParallel does not support sequential pyro.plate"
                    )
        # Initialize guide params
        for site, Dist, param_spec in self._discrete_sites:
            name = site["name"]
            for param_name, param_init, param_constraint in param_spec:
                _deep_setattr(
                    self, "{}_{}".format(name, param_name),
                    PyroParam(param_init, constraint=param_constraint))
Пример #19
0
    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            value = site["value"].detach()
            event_dim = site["fn"].event_dim

            # If subsampling, repeat init_value to full size.
            for frame in site["cond_indep_stack"]:
                full_size = getattr(frame, "full_size", frame.size)
                if full_size != frame.size:
                    dim = frame.dim - event_dim
                    value = periodic_repeat(value, full_size, dim).contiguous()

            value = PyroParam(value, site["fn"].support, event_dim)
            _deep_setattr(self, name, value)
Пример #20
0
    def __init__(
        self,
        obs_dim=1,
        state_dim=2,
        nu=1.5,
        obs_noise_scale_init=None,
        length_scale_init=None,
        kernel_scale_init=None,
        learnable_observation_loc=False,
    ):
        self.obs_dim = obs_dim
        self.state_dim = state_dim
        self.nu = nu

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim, )

        super().__init__()

        self.kernel = MaternKernel(
            nu=nu,
            num_gps=obs_dim,
            length_scale_init=length_scale_init,
            kernel_scale_init=kernel_scale_init,
        )
        self.dt = 1.0
        self.full_state_dim = self.kernel.state_dim * obs_dim + state_dim
        self.full_gp_state_dim = self.kernel.state_dim * obs_dim

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.trans_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                              constraint=constraints.positive)
        self.z_trans_matrix = nn.Parameter(
            torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim))
        self.z_obs_matrix = nn.Parameter(0.3 * torch.randn(state_dim, obs_dim))
        self.init_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                             constraint=constraints.positive)

        gp_obs_matrix = torch.zeros(self.kernel.state_dim * obs_dim, obs_dim)
        for i in range(obs_dim):
            gp_obs_matrix[self.kernel.state_dim * i, i] = 1.0
        self.register_buffer("gp_obs_matrix", gp_obs_matrix)

        self.obs_selector = torch.tensor(
            [self.kernel.state_dim * d for d in range(obs_dim)],
            dtype=torch.long)

        if learnable_observation_loc:
            self.obs_loc = nn.Parameter(torch.zeros(obs_dim))
        else:
            self.register_buffer("obs_loc", torch.zeros(obs_dim))
Пример #21
0
 def __init__(self, create_plates=None):
     super().__init__()
     # we define parameters here; make sure that the shape is aligned
     # with the shapes of sample sites in model.
     self.ma_weight_loc = PyroParam(torch.zeros(10, 1, 1, 2, 3, 7),
                                    event_dim=3)
     self.ma_weight_scale = PyroParam(torch.ones(10, 1, 1, 2, 3, 7) * 0.1,
                                      dist.constraints.positive,
                                      event_dim=3)
     self.snap_weight_loc = PyroParam(torch.zeros(10, 1, 1, 2, 7),
                                      event_dim=2)
     self.snap_weight_scale = PyroParam(torch.ones(10, 1, 1, 2, 7) * 0.1,
                                        dist.constraints.positive,
                                        event_dim=2)
     self.seasonal_loc = PyroParam(torch.zeros(10, 1, 7, 2, 7), event_dim=2)
     self.seasonal_scale = PyroParam(torch.ones(10, 1, 7, 2, 7) * 0.1,
                                     dist.constraints.positive,
                                     event_dim=2)
     self.create_plates = create_plates
Пример #22
0
class GenericLGSSMWithGPNoiseModel(TimeSeriesModel):
    """
    A generic Linear Gaussian State Space Model parameterized with arbitrary time invariant
    transition and observation dynamics together with separate Gaussian Process noise models
    for each output dimension. In more detail, the generative process is:

        :math:`y_i(t) = \\sum_j A_{ij} z_j(t) + f_i(t) + \\epsilon_i(t)`

    where the latent variables :math:`{\\bf z}(t)` follow generic time invariant Linear Gaussian dynamics
    and the :math:`f_i(t)` are Gaussian Processes with Matern kernels.

    The targets are (implicitly) assumed to be evenly spaced in time. In particular a timestep of
    :math:`dt=1.0` for the continuous-time GP dynamics corresponds to a single discrete step of
    the :math:`{\\bf z}`-space dynamics. Training and inference are logarithmic in the length of
    the time series T.

    :param int obs_dim: The dimension of the targets at each time step.
    :param int state_dim: The dimension of the :math:`{\\bf z}` latent state at each time step.
    :param float nu: The order of the Matern kernel; one of 0.5, 1.5 or 2.5.
    :param torch.Tensor length_scale_init: optional initial values for the kernel length scale
        given as a ``obs_dim``-dimensional tensor
    :param torch.Tensor kernel_scale_init: optional initial values for the kernel scale
        given as a ``obs_dim``-dimensional tensor
    :param torch.Tensor obs_noise_scale_init: optional initial values for the observation noise scale
        given as a ``obs_dim``-dimensional tensor
    :param bool learnable_observation_loc: whether the mean of the observation model should be learned or not;
            defaults to False.
    """
    def __init__(self,
                 obs_dim=1,
                 state_dim=2,
                 nu=1.5,
                 obs_noise_scale_init=None,
                 length_scale_init=None,
                 kernel_scale_init=None,
                 learnable_observation_loc=False):
        self.obs_dim = obs_dim
        self.state_dim = state_dim
        self.nu = nu

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim, )

        super().__init__()

        self.kernel = MaternKernel(nu=nu,
                                   num_gps=obs_dim,
                                   length_scale_init=length_scale_init,
                                   kernel_scale_init=kernel_scale_init)
        self.dt = 1.0
        self.full_state_dim = self.kernel.state_dim * obs_dim + state_dim
        self.full_gp_state_dim = self.kernel.state_dim * obs_dim

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.trans_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                              constraint=constraints.positive)
        self.z_trans_matrix = nn.Parameter(
            torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim))
        self.z_obs_matrix = nn.Parameter(0.3 * torch.randn(state_dim, obs_dim))
        self.init_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                             constraint=constraints.positive)

        gp_obs_matrix = torch.zeros(self.kernel.state_dim * obs_dim, obs_dim)
        for i in range(obs_dim):
            gp_obs_matrix[self.kernel.state_dim * i, i] = 1.0
        self.register_buffer("gp_obs_matrix", gp_obs_matrix)

        self.obs_selector = torch.tensor(
            [self.kernel.state_dim * d for d in range(obs_dim)],
            dtype=torch.long)

        if learnable_observation_loc:
            self.obs_loc = nn.Parameter(torch.zeros(obs_dim))
        else:
            self.register_buffer('obs_loc', torch.zeros(obs_dim))

    def _get_obs_matrix(self):
        # (obs_dim + state_dim, obs_dim) => (gp_state_dim * obs_dim + state_dim, obs_dim)
        return torch.cat([self.gp_obs_matrix, self.z_obs_matrix], dim=0)

    def _get_init_dist(self):
        loc = self.z_trans_matrix.new_zeros(self.full_state_dim)
        covar = self.z_trans_matrix.new_zeros(self.full_state_dim,
                                              self.full_state_dim)
        covar[:self.full_gp_state_dim, :self.
              full_gp_state_dim] = block_diag_embed(
                  self.kernel.stationary_covariance())
        covar[self.full_gp_state_dim:,
              self.full_gp_state_dim:] = self.init_noise_scale_sq.diag_embed()
        return MultivariateNormal(loc, covar)

    def _get_obs_dist(self):
        return dist.Normal(self.obs_loc, self.obs_noise_scale).to_event(1)

    def _get_dist(self):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
        to :class:`GenericLGSSMWithGPNoiseModel`.
        """
        gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance(
            dt=self.dt)

        trans_covar = self.z_trans_matrix.new_zeros(self.full_state_dim,
                                                    self.full_state_dim)
        trans_covar[:self.full_gp_state_dim, :self.
                    full_gp_state_dim] = block_diag_embed(gp_process_covar)
        trans_covar[
            self.full_gp_state_dim:,
            self.full_gp_state_dim:] = self.trans_noise_scale_sq.diag_embed()
        trans_dist = MultivariateNormal(
            trans_covar.new_zeros(self.full_state_dim), trans_covar)

        full_trans_mat = trans_covar.new_zeros(self.full_state_dim,
                                               self.full_state_dim)
        full_trans_mat[:self.full_gp_state_dim, :self.
                       full_gp_state_dim] = block_diag_embed(gp_trans_matrix)
        full_trans_mat[self.full_gp_state_dim:,
                       self.full_gp_state_dim:] = self.z_trans_matrix

        return dist.GaussianHMM(self._get_init_dist(), full_trans_mat,
                                trans_dist, self._get_obs_matrix(),
                                self._get_obs_dist())

    @pyro_method
    def log_prob(self, targets):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued ``targets`` at each time step
        :returns torch.Tensor: A (scalar) log probability.
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self._get_dist().log_prob(targets)

    @torch.no_grad()
    def _filter(self, targets):
        """
        Return the filtering state for the associated state space model.
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self._get_dist().filter(targets)

    @torch.no_grad()
    def _forecast(self,
                  N_timesteps,
                  filtering_state,
                  include_observation_noise=True):
        """
        Internal helper for forecasting.
        """
        dts = torch.arange(N_timesteps,
                           dtype=self.z_trans_matrix.dtype,
                           device=self.z_trans_matrix.device) + 1.0
        dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

        gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance(
            dt=dts)
        gp_trans_matrix = block_diag_embed(gp_trans_matrix)
        gp_process_covar = block_diag_embed(gp_process_covar[..., 0:1, 0:1])

        N_trans_matrix = repeated_matmul(self.z_trans_matrix, N_timesteps)
        N_trans_obs = torch.matmul(N_trans_matrix, self.z_obs_matrix)

        # z-state contribution + gp contribution
        predicted_mean1 = torch.matmul(
            filtering_state.loc[-self.state_dim:].unsqueeze(-2),
            N_trans_obs).squeeze(-2)
        predicted_mean2 = torch.matmul(
            filtering_state.loc[:self.full_gp_state_dim].unsqueeze(-2),
            gp_trans_matrix[..., self.obs_selector]).squeeze(-2)
        predicted_mean = predicted_mean1 + predicted_mean2

        # first compute the contributions from filtering_state.covariance_matrix: z-space and gp
        fs_cov = filtering_state.covariance_matrix
        predicted_covar1z = torch.matmul(N_trans_obs.transpose(-1, -2),
                                         torch.matmul(
                                             fs_cov[self.full_gp_state_dim:,
                                                    self.full_gp_state_dim:],
                                             N_trans_obs))  # N O O
        gp_trans = gp_trans_matrix[..., self.obs_selector]
        predicted_covar1gp = torch.matmul(
            gp_trans.transpose(-1, -2),
            torch.matmul(
                fs_cov[:self.full_gp_state_dim:, :self.full_gp_state_dim],
                gp_trans))

        # next compute the contribution from process noise that is injected at each timestep.
        # (we need to do a cumulative sum to integrate across time for the z-state contribution)
        z_process_covar = self.trans_noise_scale_sq.diag_embed()
        N_trans_obs_shift = torch.cat(
            [self.z_obs_matrix.unsqueeze(0), N_trans_obs[0:-1]])
        predicted_covar2z = torch.matmul(N_trans_obs_shift.transpose(
            -1, -2), torch.matmul(z_process_covar, N_trans_obs_shift))  # N O O

        predicted_covar = predicted_covar1z + predicted_covar1gp + gp_process_covar + \
            torch.cumsum(predicted_covar2z, dim=0)

        if include_observation_noise:
            predicted_covar = predicted_covar + self.obs_noise_scale.pow(
                2.0).diag_embed()

        return predicted_mean, predicted_covar

    @pyro_method
    def forecast(self, targets, N_timesteps):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued targets at each time step. These
            represent the training data that are conditioned on for the purpose of making
            forecasts.
        :param int N_timesteps: The number of timesteps to forecast into the future from
            the final target ``targets[-1]``.
        :returns torch.distributions.MultivariateNormal: Returns a predictive MultivariateNormal distribution
            with batch shape ``(N_timesteps,)`` and event shape ``(obs_dim,)``
        """
        filtering_state = self._filter(targets)
        predicted_mean, predicted_covar = self._forecast(
            N_timesteps, filtering_state)
        return MultivariateNormal(predicted_mean, predicted_covar)
Пример #23
0
    def _setup_prototype(self, *args, **kwargs):

        super()._setup_prototype(*args, **kwargs)

        self._event_dims = {}
        self._cond_indep_stacks = {}
        self.hidden2locs = PyroModule()
        self.hidden2scales = PyroModule()

        if "multiple" in self.encoder_mode:
            # create module for collecting multiple encoder NN
            self.multiple_encoders = PyroModule()

        # Initialize guide params
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            # Collect unconstrained event_dims, which may differ from constrained event_dims.
            with helpful_support_errors(site):
                init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach()
            event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim()
            self._event_dims[name] = event_dim

            # Collect independence contexts.
            self._cond_indep_stacks[name] = site["cond_indep_stack"]

            # determine the number of hidden layers
            if "multiple" in self.encoder_mode:
                if "multiple" in self.n_hidden.keys():
                    n_hidden = self.n_hidden["multiple"]
                else:
                    n_hidden = self.n_hidden[name]
            elif "single" in self.encoder_mode:
                n_hidden = self.n_hidden["single"]
            # add linear layer for locs and scales
            param_dim = (n_hidden, self.amortised_plate_sites["sites"][name])
            init_param = np.random.normal(
                np.zeros(param_dim),
                (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden),
            ).astype("float32")
            _deep_setattr(
                self.hidden2locs,
                name,
                PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)),
            )

            init_param = np.random.normal(
                np.zeros(param_dim),
                (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden),
            ).astype("float32")
            _deep_setattr(
                self.hidden2scales,
                name,
                PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)),
            )

            if "multiple" in self.encoder_mode:
                # create multiple encoders
                if self.encoder_instance is not None:
                    # copy instances
                    encoder_ = deepcopy(self.encoder_instance).to(site["value"].device)
                    # convert to pyro module
                    to_pyro_module_(encoder_)
                    _deep_setattr(
                        self.multiple_encoders,
                        name,
                        encoder_,
                    )
                else:
                    # create instances
                    _deep_setattr(
                        self.multiple_encoders,
                        name,
                        self.encoder_class(n_in=self.multiple_n_in, n_out=n_hidden, **self.multi_encoder_kwargs).to(
                            site["value"].device
                        ),
                    )
Пример #24
0
class GenericLGSSM(TimeSeriesModel):
    """
    A generic Linear Gaussian State Space Model parameterized with arbitrary time invariant
    transition and observation dynamics. The targets are (implicitly) assumed to be evenly
    spaced in time. Training and inference are logarithmic in the length of the time series T.

    :param int obs_dim: The dimension of the targets at each time step.
    :param int state_dim: The dimension of latent state at each time step.
    :param bool learnable_observation_loc: whether the mean of the observation model should be learned or not;
        defaults to False.
    """
    def __init__(self,
                 obs_dim=1,
                 state_dim=2,
                 obs_noise_scale_init=None,
                 learnable_observation_loc=False):
        self.obs_dim = obs_dim
        self.state_dim = state_dim

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim, )

        super().__init__()

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.trans_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                              constraint=constraints.positive)
        self.trans_matrix = nn.Parameter(
            torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim))
        self.obs_matrix = nn.Parameter(0.3 * torch.randn(state_dim, obs_dim))
        self.init_noise_scale_sq = PyroParam(torch.ones(state_dim),
                                             constraint=constraints.positive)

        if learnable_observation_loc:
            self.obs_loc = nn.Parameter(torch.zeros(obs_dim))
        else:
            self.register_buffer('obs_loc', torch.zeros(obs_dim))

    def _get_init_dist(self):
        loc = self.obs_matrix.new_zeros(self.state_dim)
        return MultivariateNormal(loc, self.init_noise_scale_sq.diag_embed())

    def _get_obs_dist(self):
        return dist.Normal(self.obs_loc, self.obs_noise_scale).to_event(1)

    def _get_trans_dist(self):
        loc = self.obs_matrix.new_zeros(self.state_dim)
        return MultivariateNormal(loc, self.trans_noise_scale_sq.diag_embed())

    def _get_dist(self):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to :class:`GenericLGSSM`.
        """
        return dist.GaussianHMM(self._get_init_dist(), self.trans_matrix,
                                self._get_trans_dist(), self.obs_matrix,
                                self._get_obs_dist())

    @pyro_method
    def log_prob(self, targets):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued ``targets`` at each time step
        :returns torch.Tensor: A (scalar) log probability.
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self._get_dist().log_prob(targets)

    @torch.no_grad()
    def _filter(self, targets):
        """
        Return the filtering state for the associated state space model.
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self._get_dist().filter(targets)

    @torch.no_grad()
    def _forecast(self,
                  N_timesteps,
                  filtering_state,
                  include_observation_noise=True):
        """
        Internal helper for forecasting.
        """
        N_trans_matrix = repeated_matmul(self.trans_matrix, N_timesteps)
        N_trans_obs = torch.matmul(N_trans_matrix, self.obs_matrix)
        predicted_mean = torch.matmul(filtering_state.loc, N_trans_obs)

        # first compute the contribution from filtering_state.covariance_matrix
        predicted_covar1 = torch.matmul(N_trans_obs.transpose(-1, -2),
                                        torch.matmul(
                                            filtering_state.covariance_matrix,
                                            N_trans_obs))  # N O O

        # next compute the contribution from process noise that is injected at each timestep.
        # (we need to do a cumulative sum to integrate across time)
        process_covar = self._get_trans_dist().covariance_matrix
        N_trans_obs_shift = torch.cat(
            [self.obs_matrix.unsqueeze(0), N_trans_obs[:-1]])
        predicted_covar2 = torch.matmul(N_trans_obs_shift.transpose(
            -1, -2), torch.matmul(process_covar, N_trans_obs_shift))  # N O O

        predicted_covar = predicted_covar1 + torch.cumsum(predicted_covar2,
                                                          dim=0)

        if include_observation_noise:
            predicted_covar = predicted_covar + self.obs_noise_scale.pow(
                2.0).diag_embed()

        return predicted_mean, predicted_covar

    @pyro_method
    def forecast(self, targets, N_timesteps):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued targets at each time step. These
            represent the training data that are conditioned on for the purpose of making
            forecasts.
        :param int N_timesteps: The number of timesteps to forecast into the future from
            the final target ``targets[-1]``.
        :returns torch.distributions.MultivariateNormal: Returns a predictive MultivariateNormal distribution
            with batch shape ``(N_timesteps,)`` and event shape ``(obs_dim,)``
        """
        filtering_state = self._filter(targets)
        predicted_mean, predicted_covar = self._forecast(
            N_timesteps, filtering_state)
        return torch.distributions.MultivariateNormal(predicted_mean,
                                                      predicted_covar)
Пример #25
0
class LinearlyCoupledMaternGP(TimeSeriesModel):
    """
    A time series model in which each output dimension is modeled as a linear combination
    of shared univariate Gaussian Processes with Matern kernels.

    In more detail, the generative process is:

        :math:`y_i(t) = \\sum_j A_{ij} f_j(t) + \\epsilon_i(t)`

    The targets :math:`y_i` are assumed to be evenly spaced in time. Training and inference
    are logarithmic in the length of the time series T.

    :param float nu: The order of the Matern kernel; one of 0.5, 1.5 or 2.5.
    :param float dt: The time spacing between neighboring observations of the time series.
    :param int obs_dim: The dimension of the targets at each time step.
    :param int num_gps: The number of independent GPs that are mixed to model the time series.
        Typical values might be :math:`\\N_{\\rm gp} \\in [\\D_{\\rm obs} / 2, \\D_{\\rm obs}]`
    :param torch.Tensor length_scale_init: optional initial values for the kernel length scale
        given as a ``num_gps``-dimensional tensor
    :param torch.Tensor kernel_scale_init: optional initial values for the kernel scale
        given as a ``num_gps``-dimensional tensor
    :param torch.Tensor obs_noise_scale_init: optional initial values for the observation noise scale
        given as a ``obs_dim``-dimensional tensor
    """
    def __init__(self, nu=1.5, dt=1.0, obs_dim=2, num_gps=1,
                 length_scale_init=None, kernel_scale_init=None,
                 obs_noise_scale_init=None):
        self.nu = nu
        self.dt = dt
        assert obs_dim > 1, "If obs_dim==1 you should use IndependentMaternGP"
        self.obs_dim = obs_dim
        self.num_gps = num_gps

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim,)

        self.dt = dt
        self.obs_dim = obs_dim
        self.num_gps = num_gps

        super().__init__()

        self.kernel = MaternKernel(nu=nu, num_gps=num_gps,
                                   length_scale_init=length_scale_init,
                                   kernel_scale_init=kernel_scale_init)
        self.full_state_dim = num_gps * self.kernel.state_dim

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.A = nn.Parameter(0.3 * torch.randn(self.num_gps, self.obs_dim))

    def _get_obs_matrix(self):
        # (num_gps, obs_dim) => (state_dim * num_gps, obs_dim)
        return self.A.repeat_interleave(self.kernel.state_dim, dim=0) * \
            self.A.new_tensor([1.0] + [0.0] * (self.kernel.state_dim - 1)).repeat(self.num_gps).unsqueeze(-1)

    def _stationary_covariance(self):
        return block_diag_embed(self.kernel.stationary_covariance())

    def _get_init_dist(self):
        loc = self.A.new_zeros(self.full_state_dim)
        return MultivariateNormal(loc, self._stationary_covariance())

    def _get_obs_dist(self):
        loc = self.A.new_zeros(self.obs_dim)
        return dist.Normal(loc, self.obs_noise_scale).to_event(1)

    def get_dist(self, duration=None):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
        to a :class:`LinearlyCoupledMaternGP`.

        :param int duration: Optional size of the time axis ``event_shape[0]``.
            This is required when sampling from homogeneous HMMs whose parameters
            are not expanded along the time axis.
        """
        trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt)
        trans_matrix = block_diag_embed(trans_matrix)
        process_covar = block_diag_embed(process_covar)
        loc = self.A.new_zeros(self.full_state_dim)
        trans_dist = MultivariateNormal(loc, process_covar)
        return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist,
                                self._get_obs_matrix(), self._get_obs_dist(), duration=duration)

    @pyro_method
    def log_prob(self, targets):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued ``targets`` at each time step
        :returns torch.Tensor: a (scalar) log probability
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self.get_dist().log_prob(targets)

    @torch.no_grad()
    def _filter(self, targets):
        """
        Return the filtering state for the associated state space model.
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self.get_dist().filter(targets)

    @torch.no_grad()
    def _forecast(self, dts, filtering_state, include_observation_noise=True, full_covar=True):
        """
        Internal helper for forecasting.
        """
        assert dts.dim() == 1
        dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        trans_mat, process_covar = self.kernel.transition_matrix_and_covariance(dt=dts)
        trans_mat = block_diag_embed(trans_mat)  # S x full_state_dim x full_state_dim
        process_covar = block_diag_embed(process_covar)  # S x full_state_dim x full_state_dim
        obs_matrix = self._get_obs_matrix()  # full_state_dim x obs_dim
        trans_obs = torch.matmul(trans_mat, obs_matrix)  # S x full_state_dim x obs_dim
        predicted_mean = torch.matmul(filtering_state.loc.unsqueeze(-2), trans_obs).squeeze(-2)
        predicted_function_covar = torch.matmul(trans_obs.transpose(-1, -2),
                                                torch.matmul(filtering_state.covariance_matrix,
                                                trans_obs))
        predicted_function_covar = predicted_function_covar + \
            torch.matmul(obs_matrix.transpose(-1, -2), torch.matmul(process_covar, obs_matrix))

        if include_observation_noise:
            obs_noise = self.obs_noise_scale.pow(2.0).diag_embed()
            predicted_function_covar = predicted_function_covar + obs_noise
        if not full_covar:
            predicted_function_covar = predicted_function_covar.diagonal(dim1=-1, dim2=-2)

        return predicted_mean, predicted_function_covar

    @pyro_method
    def forecast(self, targets, dts):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued targets at each time step. These
            represent the training data that are conditioned on for the purpose of making
            forecasts.
        :param torch.Tensor dts: A 1-dimensional tensor of times to forecast into the future,
            with zero corresponding to the time of the final target ``targets[-1]``.
        :returns torch.distributions.MultivariateNormal: Returns a predictive MultivariateNormal
            distribution with batch shape ``(S,)`` and event shape ``(obs_dim,)``,
            where ``S`` is the size of ``dts``.
        """
        filtering_state = self._filter(targets)
        predicted_mean, predicted_covar = self._forecast(dts, filtering_state)
        return MultivariateNormal(predicted_mean, predicted_covar)
Пример #26
0
class MaternKernel(PyroModule):
    """
    Provides the building blocks for representing univariate Gaussian Processes (GPs)
    with Matern kernels as state space models.

    :param float nu: The order of the Matern kernel (one of 0.5, 1.5 or 2.5)
    :param int num_gps: the number of GPs
    :param torch.Tensor length_scale_init: optional `num_gps`-dimensional vector of initializers
        for the length scale
    :param torch.Tensor kernel_scale_init: optional `num_gps`-dimensional vector of initializers
        for the kernel scale

    **References**

    [1] `Kalman Filtering and Smoothing Solutions to Temporal Gaussian Process Regression Models`,
        Jouni Hartikainen and Simo Sarkka.
    [2] `Stochastic Differential Equation Methods for Spatio-Temporal Gaussian Process Regression`,
        Arno Solin.
    """
    def __init__(self,
                 nu=1.5,
                 num_gps=1,
                 length_scale_init=None,
                 kernel_scale_init=None):
        if nu not in [0.5, 1.5, 2.5]:
            raise NotImplementedError(
                "The only supported values of nu are 0.5, 1.5 and 2.5")
        self.nu = nu
        self.state_dim = {0.5: 1, 1.5: 2, 2.5: 3}[nu]
        self.num_gps = num_gps

        if length_scale_init is None:
            length_scale_init = torch.ones(num_gps)
        assert length_scale_init.shape == (num_gps, )

        if kernel_scale_init is None:
            kernel_scale_init = torch.ones(num_gps)
        assert kernel_scale_init.shape == (num_gps, )

        super().__init__()

        self.length_scale = PyroParam(length_scale_init,
                                      constraint=constraints.positive)
        self.kernel_scale = PyroParam(kernel_scale_init,
                                      constraint=constraints.positive)

        if self.state_dim > 1:
            for x in range(self.state_dim):
                for y in range(self.state_dim):
                    mask = torch.zeros(self.state_dim, self.state_dim)
                    mask[x, y] = 1.0
                    self.register_buffer("mask{}{}".format(x, y), mask)

    @pyro_method
    def transition_matrix(self, dt):
        """
        Compute the (exponentiated) transition matrix of the GP latent space.
        The resulting matrix has layout (num_gps, old_state, new_state), i.e. this
        matrix multiplies states from the right.

        See section 5 in reference [1] for details.

        :param float dt: the time interval over which the GP latent space evolves.
        :returns torch.Tensor: a 3-dimensional tensor of transition matrices of shape
            (num_gps, state_dim, state_dim).
        """
        if self.nu == 0.5:
            rho = self.length_scale.unsqueeze(-1).unsqueeze(-1)
            return torch.exp(-dt / rho)
        elif self.nu == 1.5:
            rho = self.length_scale.unsqueeze(-1).unsqueeze(-1)
            dt_rho = dt / rho
            trans = (1.0 + root_three * dt_rho) * self.mask00 + \
                (-3.0 * dt_rho / rho) * self.mask01 + \
                dt * self.mask10 + \
                (1.0 - root_three * dt_rho) * self.mask11
            return torch.exp(-root_three * dt_rho) * trans
        elif self.nu == 2.5:
            rho = self.length_scale.unsqueeze(-1).unsqueeze(-1)
            dt_rho = root_five * dt / rho
            dt_rho_sq = dt_rho.pow(2.0)
            dt_rho_cu = dt_rho.pow(3.0)
            dt_rho_qu = dt_rho.pow(4.0)
            dt_sq = dt**2.0
            trans = (1.0 + dt_rho + 0.5 * dt_rho_sq) * self.mask00 + \
                (-0.5 * dt_rho_cu / dt) * self.mask01 + \
                ((0.5 * dt_rho_qu - dt_rho_cu) / dt_sq) * self.mask02 + \
                ((dt_rho + 1.0) * dt) * self.mask10 + \
                (1.0 + dt_rho - dt_rho_sq) * self.mask11 + \
                ((dt_rho_cu - 3.0 * dt_rho_sq) / dt) * self.mask12 + \
                (0.5 * dt_sq) * self.mask20 + \
                ((1.0 - 0.5 * dt_rho) * dt) * self.mask21 + \
                (1.0 - 2.0 * dt_rho + 0.5 * dt_rho_sq) * self.mask22
            return torch.exp(-dt_rho) * trans

    @pyro_method
    def stationary_covariance(self):
        """
        Compute the stationary state covariance. See Eqn. 3.26 in reference [2].

        :returns torch.Tensor: a 3-dimensional tensor of covariance matrices of shape
            (num_gps, state_dim, state_dim).
        """
        if self.nu == 0.5:
            sigmasq = self.kernel_scale.pow(2).unsqueeze(-1).unsqueeze(-1)
            return sigmasq
        elif self.nu == 1.5:
            sigmasq = self.kernel_scale.pow(2).unsqueeze(-1).unsqueeze(-1)
            rhosq = self.length_scale.pow(2).unsqueeze(-1).unsqueeze(-1)
            p_infinity = self.mask00 + (3.0 / rhosq) * self.mask11
            return sigmasq * p_infinity
        elif self.nu == 2.5:
            sigmasq = self.kernel_scale.pow(2).unsqueeze(-1).unsqueeze(-1)
            rhosq = self.length_scale.pow(2).unsqueeze(-1).unsqueeze(-1)
            p_infinity = 0.0
            p_infinity = self.mask00 + \
                (five_thirds / rhosq) * (self.mask11 - self.mask02 - self.mask20) + \
                (25.0 / rhosq.pow(2.0)) * self.mask22
            return sigmasq * p_infinity

    @pyro_method
    def process_covariance(self, A):
        """
        Given a transition matrix `A` computed with `transition_matrix` compute the
        the process covariance as described in Eqn. 3.11 in reference [2].

        :returns torch.Tensor: a batched covariance matrix of shape (num_gps, state_dim, state_dim)
        """
        assert A.shape[-3:] == (self.num_gps, self.state_dim, self.state_dim)
        p = self.stationary_covariance()
        q = p - torch.matmul(A.transpose(-1, -2), torch.matmul(p, A))
        return q

    @pyro_method
    def transition_matrix_and_covariance(self, dt):
        """
        Get the transition matrix and process covariance corresponding to a time interval `dt`.

        :param float dt: the time interval over which the GP latent space evolves.
        :returns tuple: (`transition_matrix`, `process_covariance`) both 3-dimensional tensors of
            shape (num_gps, state_dim, state_dim)
        """
        trans_matrix = self.transition_matrix(dt)
        process_covar = self.process_covariance(trans_matrix)
        return trans_matrix, process_covar
Пример #27
0
 def __init__(self, sites, name='', *args, **kwargs):
     super().__init__(sites, name, *args, **kwargs)
     self.loc = PyroParam(self.init[self.mask], event_dim=1)
Пример #28
0
class IndependentMaternGP(TimeSeriesModel):
    """
    A time series model in which each output dimension is modeled independently
    with a univariate Gaussian Process with a Matern kernel. The targets are assumed
    to be evenly spaced in time. Training and inference are logarithmic in the length
    of the time series T.

    :param float nu: The order of the Matern kernel; one of 0.5, 1.5 or 2.5.
    :param float dt: The time spacing between neighboring observations of the time series.
    :param int obs_dim: The dimension of the targets at each time step.
    :param torch.Tensor length_scale_init: optional initial values for the kernel length scale
        given as a ``obs_dim``-dimensional tensor
    :param torch.Tensor kernel_scale_init: optional initial values for the kernel scale
        given as a ``obs_dim``-dimensional tensor
    :param torch.Tensor obs_noise_scale_init: optional initial values for the observation noise scale
        given as a ``obs_dim``-dimensional tensor
    """
    def __init__(self, nu=1.5, dt=1.0, obs_dim=1,
                 length_scale_init=None, kernel_scale_init=None,
                 obs_noise_scale_init=None):
        self.nu = nu
        self.dt = dt
        self.obs_dim = obs_dim

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim,)

        super().__init__()

        self.kernel = MaternKernel(nu=nu, num_gps=obs_dim,
                                   length_scale_init=length_scale_init,
                                   kernel_scale_init=kernel_scale_init)

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)

        obs_matrix = [1.0] + [0.0] * (self.kernel.state_dim - 1)
        self.register_buffer("obs_matrix", torch.tensor(obs_matrix).unsqueeze(-1))

    def _get_init_dist(self):
        return torch.distributions.MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, self.kernel.state_dim),
                                                      self.kernel.stationary_covariance().squeeze(-3))

    def _get_obs_dist(self):
        return dist.Normal(self.obs_matrix.new_zeros(self.obs_dim, 1, 1),
                           self.obs_noise_scale.unsqueeze(-1).unsqueeze(-1)).to_event(1)

    def get_dist(self, duration=None):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
        to ``obs_dim``-many independent Matern GPs.

        :param int duration: Optional size of the time axis ``event_shape[0]``.
            This is required when sampling from homogeneous HMMs whose parameters
            are not expanded along the time axis.
        """
        trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt)
        trans_dist = MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, 1, self.kernel.state_dim),
                                        process_covar.unsqueeze(-3))
        trans_matrix = trans_matrix.unsqueeze(-3)
        return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist,
                                self.obs_matrix, self._get_obs_dist(), duration=duration)

    @pyro_method
    def log_prob(self, targets):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued ``targets`` at each time step
        :returns torch.Tensor: A 1-dimensional tensor of log probabilities of shape ``(obs_dim,)``
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self.get_dist().log_prob(targets.t().unsqueeze(-1))

    @torch.no_grad()
    def _filter(self, targets):
        """
        Return the filtering state for the associated state space model.
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self.get_dist().filter(targets.t().unsqueeze(-1))

    @torch.no_grad()
    def _forecast(self, dts, filtering_state, include_observation_noise=True):
        """
        Internal helper for forecasting.
        """
        assert dts.dim() == 1
        dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=dts)
        trans_matrix = trans_matrix[..., 0:1]
        predicted_mean = torch.matmul(filtering_state.loc.unsqueeze(-2), trans_matrix).squeeze(-2)[..., 0]
        predicted_function_covar = torch.matmul(trans_matrix.transpose(-1, -2), torch.matmul(
                                                filtering_state.covariance_matrix, trans_matrix))[..., 0, 0] + \
            process_covar[..., 0, 0]

        if include_observation_noise:
            predicted_function_covar = predicted_function_covar + self.obs_noise_scale.pow(2.0)
        return predicted_mean, predicted_function_covar

    @pyro_method
    def forecast(self, targets, dts):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued targets at each time step. These
            represent the training data that are conditioned on for the purpose of making
            forecasts.
        :param torch.Tensor dts: A 1-dimensional tensor of times to forecast into the future,
            with zero corresponding to the time of the final target ``targets[-1]``.
        :returns torch.distributions.Normal: Returns a predictive Normal distribution with batch shape ``(S,)`` and
            event shape ``(obs_dim,)``, where ``S`` is the size of ``dts``.
        """
        filtering_state = self._filter(targets)
        predicted_mean, predicted_covar = self._forecast(dts, filtering_state)
        return torch.distributions.Normal(predicted_mean, predicted_covar.sqrt())
Пример #29
0
class DependentMaternGP(TimeSeriesModel):
    """
    A time series model in which each output dimension is modeled as a univariate Gaussian Process
    with a Matern kernel. The different output dimensions become correlated because the Gaussian
    Processes are driven by a correlated Wiener process; see reference [1] for details.
    If, in addition, `linearly_coupled` is True, additional correlation is achieved through
    linear mixing as in :class:`LinearlyCoupledMaternGP`. The targets are assumed to be evenly
    spaced in time. Training and inference are logarithmic in the length of the time series T.

    :param float nu: The order of the Matern kernel; must be 1.5.
    :param float dt: The time spacing between neighboring observations of the time series.
    :param int obs_dim: The dimension of the targets at each time step.
    :param bool linearly_coupled: Whether to linearly mix the various gaussian processes in the likelihood.
        Defaults to False.
    :param torch.Tensor length_scale_init: optional initial values for the kernel length scale
        given as a ``obs_dim``-dimensional tensor
    :param torch.Tensor obs_noise_scale_init: optional initial values for the observation noise scale
        given as a ``obs_dim``-dimensional tensor

    References
    [1] "Dependent Matern Processes for Multivariate Time Series," Alexander Vandenberg-Rodes, Babak Shahbaba.
    """
    def __init__(self, nu=1.5, dt=1.0, obs_dim=1, linearly_coupled=False,
                 length_scale_init=None, obs_noise_scale_init=None):

        if nu != 1.5:
            raise NotImplementedError("The only supported value of nu is 1.5")

        self.dt = dt
        self.obs_dim = obs_dim

        if obs_noise_scale_init is None:
            obs_noise_scale_init = 0.2 * torch.ones(obs_dim)
        assert obs_noise_scale_init.shape == (obs_dim,)

        super().__init__()

        self.kernel = MaternKernel(nu=nu, num_gps=obs_dim,
                                   length_scale_init=length_scale_init)
        self.full_state_dim = self.kernel.state_dim * obs_dim

        # we demote self.kernel.kernel_scale from being a nn.Parameter
        # since the relevant scales are now encoded in the wiener noise matrix
        del self.kernel.kernel_scale
        self.kernel.register_buffer("kernel_scale", torch.ones(obs_dim))

        self.obs_noise_scale = PyroParam(obs_noise_scale_init,
                                         constraint=constraints.positive)
        self.wiener_noise_tril = PyroParam(torch.eye(obs_dim) +
                                           0.03 * torch.randn(obs_dim, obs_dim).tril(-1),
                                           constraint=constraints.lower_cholesky)

        if linearly_coupled:
            self.obs_matrix = nn.Parameter(0.3 * torch.randn(self.obs_dim, self.obs_dim))
        else:
            obs_matrix = torch.zeros(self.full_state_dim, obs_dim)
            for i in range(obs_dim):
                obs_matrix[self.kernel.state_dim * i, i] = 1.0
            self.register_buffer("obs_matrix", obs_matrix)

    def _get_obs_matrix(self):
        if self.obs_matrix.size(0) == self.obs_dim:
            # (num_gps, obs_dim) => (state_dim * num_gps, obs_dim)
            selector = [1.0] + [0.0] * (self.kernel.state_dim - 1)
            return self.obs_matrix.repeat_interleave(self.kernel.state_dim, dim=0) * \
                self.obs_matrix.new_tensor(selector).repeat(self.obs_dim).unsqueeze(-1)
        else:
            return self.obs_matrix

    def _get_init_dist(self, stationary_covariance):
        return torch.distributions.MultivariateNormal(self.obs_matrix.new_zeros(self.full_state_dim),
                                                      stationary_covariance)

    def _get_obs_dist(self):
        return dist.Normal(self.obs_matrix.new_zeros(self.obs_dim),
                           self.obs_noise_scale).to_event(1)

    def _get_wiener_cov(self):
        chol = self.wiener_noise_tril
        wiener_cov = torch.mm(chol, chol.t()).reshape(self.obs_dim, 1, self.obs_dim, 1)
        wiener_cov = wiener_cov * wiener_cov.new_ones(self.kernel.state_dim, 1, self.kernel.state_dim)
        return wiener_cov.reshape(self.full_state_dim, self.full_state_dim)

    def _stationary_covariance(self):
        rho_j = math.sqrt(3.0) / self.kernel.length_scale.unsqueeze(-1).unsqueeze(-1)
        rho_i = rho_j.unsqueeze(-1)
        block = 2.0 * self.kernel.mask00 + \
            (rho_i - rho_j) * (self.kernel.mask01 - self.kernel.mask10) + \
            (2.0 * rho_i * rho_j) * self.kernel.mask11
        block = block / (rho_i + rho_j).pow(3.0)
        block = block.transpose(-2, -3).reshape(self.full_state_dim, self.full_state_dim)
        return self._get_wiener_cov() * block

    def _get_trans_dist(self, trans_matrix, stationary_covariance):
        covar = stationary_covariance - torch.matmul(trans_matrix.transpose(-1, -2),
                                                     torch.matmul(stationary_covariance, trans_matrix))
        return MultivariateNormal(covar.new_zeros(self.full_state_dim), covar)

    def _trans_matrix_distribution_stat_covar(self, dts):
        stationary_covariance = self._stationary_covariance()
        trans_matrix = self.kernel.transition_matrix(dt=dts)
        trans_matrix = block_diag_embed(trans_matrix)
        trans_dist = self._get_trans_dist(trans_matrix, stationary_covariance)
        return trans_matrix, trans_dist, stationary_covariance

    def get_dist(self, duration=None):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to a :class:`DependentMaternGP`

        :param int duration: Optional size of the time axis ``event_shape[0]``.
            This is required when sampling from homogeneous HMMs whose parameters
            are not expanded along the time axis.
        """
        trans_matrix, trans_dist, stat_covar = self._trans_matrix_distribution_stat_covar(self.dt)
        return dist.GaussianHMM(self._get_init_dist(stat_covar), trans_matrix,
                                trans_dist, self._get_obs_matrix(), self._get_obs_dist(), duration=duration)

    @pyro_method
    def log_prob(self, targets):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued ``targets`` at each time step
        :returns torch.Tensor: A (scalar) log probability
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self.get_dist().log_prob(targets)

    @torch.no_grad()
    def _filter(self, targets):
        """
        Return the filtering state for the associated state space model.
        """
        assert targets.dim() == 2 and targets.size(-1) == self.obs_dim
        return self.get_dist().filter(targets)

    @torch.no_grad()
    def _forecast(self, dts, filtering_state, include_observation_noise=True):
        """
        Internal helper for forecasting.
        """
        assert dts.dim() == 1
        dts = dts.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        trans_matrix, trans_dist, _ = self._trans_matrix_distribution_stat_covar(dts)
        obs_matrix = self._get_obs_matrix()
        trans_obs = torch.matmul(trans_matrix, obs_matrix)

        predicted_mean = torch.matmul(filtering_state.loc.unsqueeze(-2), trans_obs).squeeze(-2)
        predicted_function_covar = torch.matmul(trans_obs.transpose(-1, -2),
                                                torch.matmul(filtering_state.covariance_matrix, trans_obs)) + \
            torch.matmul(obs_matrix.t(), torch.matmul(trans_dist.covariance_matrix, obs_matrix))

        if include_observation_noise:
            predicted_function_covar = predicted_function_covar + self.obs_noise_scale.pow(2.0)

        return predicted_mean, predicted_function_covar

    @pyro_method
    def forecast(self, targets, dts):
        """
        :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets
            of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim``
            is the dimension of the real-valued targets at each time step. These
            represent the training data that are conditioned on for the purpose of making
            forecasts.
        :param torch.Tensor dts: A 1-dimensional tensor of times to forecast into the future,
            with zero corresponding to the time of the final target ``targets[-1]``.
        :returns torch.distributions.MultivariateNormal: Returns a predictive MultivariateNormal
            distribution with batch shape ``(S,)`` and event shape ``(obs_dim,)``, where ``S`` is the size of ``dts``.
        """
        filtering_state = self._filter(targets)
        predicted_mean, predicted_covar = self._forecast(dts, filtering_state)
        return MultivariateNormal(predicted_mean, predicted_covar)