コード例 #1
0
ファイル: hmm.py プロジェクト: ordabayevy/funsor
    def filter(self, value):
        """
        Compute posterior over final state given a sequence of observations.

        :param ~torch.Tensor value: A sequence of observations.
        :return: A posterior distribution over latent states at the final time
            step, represented as a pair ``(cat, mvn)``, where
            :class:`~pyro.distributions.Categorical` distribution over mixture
            components and ``mvn`` is a
            :class:`~pyro.distributions.MultivariateNormal` with rightmost
            batch dimension ranging over mixture components. This can then be
            used to initialize a sequential Pyro model for prediction.
        :rtype: tuple
        """
        ndims = max(len(self.batch_shape), value.dim() - 2)
        time = Variable("time", Bint[self.event_shape[0]])
        value = tensor_to_funsor(value, ("time", ), 1)

        seq_sum_prod = naive_sequential_sum_product if self.exact else sequential_sum_product
        with interpretation(eager if self.exact else moment_matching):
            logp = self._trans + self._obs(value=value)
            logp = seq_sum_prod(ops.logaddexp, ops.add, logp, time, {
                "class": "class(time=1)",
                "state": "state(time=1)"
            })
            logp += self._init
            logp = logp.reduce(ops.logaddexp, frozenset(["class", "state"]))

        cat, mvn = funsor_to_cat_and_mvn(logp, ndims, ("class(time=1)", ))
        cat = cat.expand(self.batch_shape)
        mvn = mvn.expand(self.batch_shape + cat.logits.shape[-1:])
        return cat, mvn
コード例 #2
0
def test_funsor_to_cat_and_mvn(batch_shape, event_shape, int_size, real_size):
    logits = torch.randn(batch_shape + event_shape + (int_size, ))
    expected_cat = dist.Categorical(logits=logits)
    expected_mvn = random_mvn(batch_shape + event_shape + (int_size, ),
                              real_size)
    event_dims = tuple("abc"[:len(event_shape)]) + ("component", )
    ndims = len(expected_cat.batch_shape)

    funsor_ = (tensor_to_funsor(logits, event_dims) +
               dist_to_funsor(expected_mvn, event_dims)(value="value"))
    assert isinstance(funsor_, Funsor)

    actual_cat, actual_mvn = funsor_to_cat_and_mvn(funsor_, ndims, event_dims)
    assert isinstance(actual_cat, dist.Categorical)
    assert isinstance(actual_mvn, dist.MultivariateNormal)
    assert actual_cat.batch_shape == expected_cat.batch_shape
    assert actual_mvn.batch_shape == expected_mvn.batch_shape
    assert_close(actual_cat.logits, expected_cat.logits, atol=1e-4, rtol=None)
    assert_close(actual_mvn.loc, expected_mvn.loc, atol=1e-4, rtol=None)
    assert_close(actual_mvn.precision_matrix,
                 expected_mvn.precision_matrix,
                 atol=1e-4,
                 rtol=None)
コード例 #3
0
    def filter_and_predict(self, data, smoothing=False):
        trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists(
        )

        log_prob = funsor.Number(0.)

        s_vars = {
            -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components)
        }
        x_vars = {-1: None}

        predictive_x_dists, predictive_y_dists, filtering_dists = [], [], []
        test_LLs = []

        for t, y in enumerate(data):
            s_vars[t] = funsor.Variable(f's_{t}',
                                        funsor.bint(self.num_components))
            x_vars[t] = funsor.Variable(f'x_{t}',
                                        funsor.reals(self.hidden_dim))

            log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]),
                                         value=s_vars[t])

            if t == 0:
                log_prob += self.x_init_mvn(value=x_vars[t])
            else:
                log_prob += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t - 1],
                                         y=x_vars[t])

            if t > 0:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([s_vars[t - 1].name, x_vars[t - 1].name]))

            # do 1-step prediction and compute test LL
            if t > 0:
                predictive_x_dists.append(log_prob)
                _log_prob = log_prob - log_prob.reduce(ops.logaddexp)
                predictive_y_dist = y_dist(s=s_vars[t],
                                           x=x_vars[t]) + _log_prob
                test_LLs.append(
                    predictive_y_dist(y=y).reduce(ops.logaddexp).data.item())
                predictive_y_dist = predictive_y_dist.reduce(
                    ops.logaddexp, frozenset([f"x_{t}", f"s_{t}"]))
                predictive_y_dists.append(
                    funsor_to_mvn(predictive_y_dist, 0, ()))

            log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y)

            # save filtering dists for forward-backward smoothing
            if smoothing:
                filtering_dists.append(log_prob)

        # do the backward recursion using previously computed ingredients
        if smoothing:
            # seed the backward recursion with the filtering distribution at t=T
            smoothing_dists = [filtering_dists[-1]]
            T = data.size(0)

            s_vars = {
                t: funsor.Variable(f's_{t}', funsor.bint(self.num_components))
                for t in range(T)
            }
            x_vars = {
                t: funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim))
                for t in range(T)
            }

            # do the backward recursion.
            # let p[t|t-1] be the predictive distribution at time step t.
            # let p[t|t] be the filtering distribution at time step t.
            # let f[t] denote the prior (transition) density at time step t.
            # then the smoothing distribution p[t|T] at time step t is
            # given by the following recursion.
            # p[t-1|T] = p[t-1|t-1] <p[t|T] f[t] / p[t|t-1]>
            # where <...> denotes integration of the latent variables at time step t.
            for t in reversed(range(T - 1)):
                integral = smoothing_dists[-1] - predictive_x_dists[t]
                integral += dist.Categorical(trans_probs(s=s_vars[t]),
                                             value=s_vars[t + 1])
                integral += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t],
                                         y=x_vars[t + 1])
                integral = integral.reduce(
                    ops.logaddexp,
                    frozenset([s_vars[t + 1].name, x_vars[t + 1].name]))
                smoothing_dists.append(filtering_dists[t] + integral)

        # compute predictive test MSE and predictive variances
        predictive_means = torch.stack([d.mean for d in predictive_y_dists
                                        ])  # T-1 ydim
        predictive_vars = torch.stack([
            d.covariance_matrix.diagonal(dim1=-1, dim2=-2)
            for d in predictive_y_dists
        ])
        predictive_mse = (predictive_means - data[1:, :]).pow(2.0).mean(-1)

        if smoothing:
            # compute smoothed mean function
            smoothing_dists = [
                funsor_to_cat_and_mvn(d, 0, (f"s_{t}", ))
                for t, d in enumerate(reversed(smoothing_dists))
            ]
            means = torch.stack([d[1].mean
                                 for d in smoothing_dists])  # T 2 xdim
            means = torch.matmul(means.unsqueeze(-2),
                                 self.observation_matrix).squeeze(
                                     -2)  # T 2 ydim

            probs = torch.stack([d[0].logits for d in smoothing_dists]).exp()
            probs = probs / probs.sum(-1, keepdim=True)  # T 2

            smoothing_means = (probs.unsqueeze(-1) * means).sum(-2)  # T ydim
            smoothing_probs = probs[:, 1]

            return predictive_mse, torch.tensor(np.array(test_LLs)), predictive_means, predictive_vars, \
                smoothing_means, smoothing_probs
        else:
            return predictive_mse, torch.tensor(np.array(test_LLs))