Beispiel #1
0
 def _get_dist(self):
     """
     Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to a :class:`DependentMaternGP`
     """
     trans_matrix, trans_dist, stat_covar = self._trans_matrix_distribution_stat_covar(self.dt)
     return dist.GaussianHMM(self._get_init_dist(stat_covar), trans_matrix,
                             trans_dist, self._get_obs_matrix(), self._get_obs_dist())
Beispiel #2
0
def test_gaussian_hmm_log_prob(init_shape, trans_mat_shape, trans_mvn_shape,
                               obs_mat_shape, obs_mvn_shape, hidden_dim,
                               obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)

    actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                              obs_dist)
    expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                                     obs_dist)
    assert actual_dist.batch_shape == expected_dist.batch_shape
    assert actual_dist.event_shape == expected_dist.event_shape

    shape = broadcast_shape(init_shape + (1, ), trans_mat_shape,
                            trans_mvn_shape, obs_mat_shape, obs_mvn_shape)
    data = obs_dist.expand(shape).sample()
    assert data.shape == actual_dist.shape()

    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
    check_expand(actual_dist, data)
Beispiel #3
0
    def __call__(self, name, fn, obs):
        # Reparameterize the initial distribution as conditionally Gaussian.
        init_dist = fn.initial_dist
        if self.init is not None:
            init_dist, _ = self.init("{}_init".format(name), init_dist, None)

        # Reparameterize the transition distribution as conditionally Gaussian.
        trans_dist = fn.transition_dist
        if self.trans is not None:
            trans_dist, _ = self.trans("{}_trans".format(name),
                                       trans_dist.to_event(1), None)
            trans_dist = trans_dist.to_event(-1)

        # Reparameterize the observation distribution as conditionally Gaussian.
        obs_dist = fn.observation_dist
        if self.obs is not None:
            obs_dist, obs = self.obs("{}_obs".format(name),
                                     obs_dist.to_event(1), obs)
            obs_dist = obs_dist.to_event(-1)

        # Reparameterize the entire HMM as conditionally Gaussian.
        hmm = dist.GaussianHMM(init_dist, fn.transition_matrix, trans_dist,
                               fn.observation_matrix, obs_dist)

        # Apply any observation transforms.
        if fn.transforms:
            hmm = dist.TransformedDistribution(hmm, fn.transforms)

        return hmm, obs
Beispiel #4
0
    def __call__(self, name, fn, obs):
        fn, event_dim = self._unwrap(fn)
        assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM))
        if fn.duration is None:
            raise ValueError(
                "LinearHMMReparam requires duration to be specified "
                "on targeted LinearHMM distributions")

        # Unwrap IndependentHMM.
        if isinstance(fn, dist.IndependentHMM):
            if obs is not None:
                obs = obs.transpose(-1, -2).unsqueeze(-1)
            hmm, obs = self(name, fn.base_dist.to_event(1), obs)
            hmm = dist.IndependentHMM(hmm.to_event(-1))
            if obs is not None:
                obs = obs.squeeze(-1).transpose(-1, -2)
            return hmm, obs

        # Reparameterize the initial distribution as conditionally Gaussian.
        init_dist = fn.initial_dist
        if self.init is not None:
            init_dist, _ = self.init("{}_init".format(name),
                                     self._wrap(init_dist, event_dim - 1),
                                     None)
            init_dist = init_dist.to_event(1 - init_dist.event_dim)

        # Reparameterize the transition distribution as conditionally Gaussian.
        trans_dist = fn.transition_dist
        if self.trans is not None:
            if trans_dist.batch_shape[-1] != fn.duration:
                trans_dist = trans_dist.expand(trans_dist.batch_shape[:-1] +
                                               (fn.duration, ))
            trans_dist, _ = self.trans("{}_trans".format(name),
                                       self._wrap(trans_dist, event_dim), None)
            trans_dist = trans_dist.to_event(1 - trans_dist.event_dim)

        # Reparameterize the observation distribution as conditionally Gaussian.
        obs_dist = fn.observation_dist
        if self.obs is not None:
            if obs_dist.batch_shape[-1] != fn.duration:
                obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] +
                                           (fn.duration, ))
            obs_dist, obs = self.obs("{}_obs".format(name),
                                     self._wrap(obs_dist, event_dim), obs)
            obs_dist = obs_dist.to_event(1 - obs_dist.event_dim)

        # Reparameterize the entire HMM as conditionally Gaussian.
        hmm = dist.GaussianHMM(init_dist,
                               fn.transition_matrix,
                               trans_dist,
                               fn.observation_matrix,
                               obs_dist,
                               duration=fn.duration)
        hmm = self._wrap(hmm, event_dim)

        # Apply any observation transforms.
        if fn.transforms:
            hmm = dist.TransformedDistribution(hmm, fn.transforms)

        return hmm, obs
Beispiel #5
0
def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape,
                            obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)
    if diag:
        scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1)
        obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    shape = broadcast_shape(init_shape + (1,),
                            trans_mat_shape,
                            trans_mvn_shape,
                            obs_mat_shape,
                            obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim,)
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape

    data = obs_dist.expand(shape).sample()
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)

    final = d.filter(data)
    assert isinstance(final, dist.MultivariateNormal)
    assert final.batch_shape == d.batch_shape
    assert final.event_shape == (hidden_dim,)
Beispiel #6
0
    def _get_dist(self):
        """
        Get the `GaussianHMM` distribution that corresponds to `GenericLGSSMWithGPNoiseModel`.
        """
        gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance(
            dt=self.dt)

        trans_covar = self.z_trans_matrix.new_zeros(self.full_state_dim,
                                                    self.full_state_dim)
        trans_covar[:self.full_gp_state_dim, :self.
                    full_gp_state_dim] = block_diag_embed(gp_process_covar)
        eye = torch.eye(self.state_dim,
                        device=trans_covar.device,
                        dtype=trans_covar.dtype)
        trans_covar[
            self.full_gp_state_dim:, self.
            full_gp_state_dim:] = self.log_trans_noise_scale_sq.exp() * eye
        trans_dist = MultivariateNormal(
            trans_covar.new_zeros(self.full_state_dim), trans_covar)

        full_trans_mat = trans_covar.new_zeros(self.full_state_dim,
                                               self.full_state_dim)
        full_trans_mat[:self.full_gp_state_dim, :self.
                       full_gp_state_dim] = block_diag_embed(gp_trans_matrix)
        full_trans_mat[self.full_gp_state_dim:,
                       self.full_gp_state_dim:] = self.z_trans_matrix

        return dist.GaussianHMM(self._get_init_dist(), full_trans_mat,
                                trans_dist, self._get_obs_matrix(),
                                self._get_obs_dist())
Beispiel #7
0
    def get_dist(self, duration=None):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
        to :class:`GenericLGSSMWithGPNoiseModel`.

        :param int duration: Optional size of the time axis ``event_shape[0]``.
            This is required when sampling from homogeneous HMMs whose parameters
            are not expanded along the time axis.
        """
        gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance(
            dt=self.dt)

        trans_covar = self.z_trans_matrix.new_zeros(self.full_state_dim,
                                                    self.full_state_dim)
        trans_covar[:self.full_gp_state_dim, :self.
                    full_gp_state_dim] = block_diag_embed(gp_process_covar)
        trans_covar[
            self.full_gp_state_dim:,
            self.full_gp_state_dim:] = self.trans_noise_scale_sq.diag_embed()
        trans_dist = MultivariateNormal(
            trans_covar.new_zeros(self.full_state_dim), trans_covar)

        full_trans_mat = trans_covar.new_zeros(self.full_state_dim,
                                               self.full_state_dim)
        full_trans_mat[:self.full_gp_state_dim, :self.
                       full_gp_state_dim] = block_diag_embed(gp_trans_matrix)
        full_trans_mat[self.full_gp_state_dim:,
                       self.full_gp_state_dim:] = self.z_trans_matrix

        return dist.GaussianHMM(self._get_init_dist(),
                                full_trans_mat,
                                trans_dist,
                                self._get_obs_matrix(),
                                self._get_obs_dist(),
                                duration=duration)
Beispiel #8
0
    def _get_dist(self):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
        to :class:`GenericLGSSMWithGPNoiseModel`.
        """
        gp_trans_matrix, gp_process_covar = self.kernel.transition_matrix_and_covariance(
            dt=self.dt)

        trans_covar = self.z_trans_matrix.new_zeros(self.full_state_dim,
                                                    self.full_state_dim)
        trans_covar[:self.full_gp_state_dim, :self.
                    full_gp_state_dim] = block_diag_embed(gp_process_covar)
        trans_covar[
            self.full_gp_state_dim:,
            self.full_gp_state_dim:] = self.trans_noise_scale_sq.diag_embed()
        trans_dist = MultivariateNormal(
            trans_covar.new_zeros(self.full_state_dim), trans_covar)

        full_trans_mat = trans_covar.new_zeros(self.full_state_dim,
                                               self.full_state_dim)
        full_trans_mat[:self.full_gp_state_dim, :self.
                       full_gp_state_dim] = block_diag_embed(gp_trans_matrix)
        full_trans_mat[self.full_gp_state_dim:,
                       self.full_gp_state_dim:] = self.z_trans_matrix

        return dist.GaussianHMM(self._get_init_dist(), full_trans_mat,
                                trans_dist, self._get_obs_matrix(),
                                self._get_obs_dist())
Beispiel #9
0
 def _get_dist(self):
     """
     Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to :class:`GenericLGSSM`.
     """
     return dist.GaussianHMM(self._get_init_dist(), self.trans_matrix,
                             self._get_trans_dist(), self.obs_matrix,
                             self._get_obs_dist())
Beispiel #10
0
    def _forward_pyro(self, features, trip_counts):
        total_hours = len(features)
        observed_hours, num_origins, num_destins = trip_counts.shape
        assert observed_hours <= total_hours
        assert num_origins == self.num_stations
        assert num_destins == self.num_stations
        time_plate = pyro.plate("time", observed_hours, dim=-3)
        origins_plate = pyro.plate("origins", num_origins, dim=-2)
        destins_plate = pyro.plate("destins", num_destins, dim=-1)

        # The first half of the model performs exact inference over
        # the observed portion of the time series.
        hmm = dist.GaussianHMM(*self._dynamics(features[:observed_hours]))
        gate_rate = pyro.sample("gate_rate", hmm)
        gate, rate = self._unpack_gate_rate(gate_rate, event_dim=2)
        with time_plate, origins_plate, destins_plate:
            pyro.sample("trip_count",
                        dist.ZeroInflatedPoisson(gate, rate),
                        obs=trip_counts)

        # The second half of the model forecasts forward.
        if total_hours > observed_hours:
            state_dist = hmm.filter(gate_rate)
            return self._forward_pyro_forecast(features,
                                               trip_counts,
                                               origins_plate,
                                               destins_plate,
                                               state_dist=state_dist)
Beispiel #11
0
    def get_dist(self, duration=None):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to a :class:`DependentMaternGP`

        :param int duration: Optional size of the time axis ``event_shape[0]``.
            This is required when sampling from homogeneous HMMs whose parameters
            are not expanded along the time axis.
        """
        trans_matrix, trans_dist, stat_covar = self._trans_matrix_distribution_stat_covar(self.dt)
        return dist.GaussianHMM(self._get_init_dist(stat_covar), trans_matrix,
                                trans_dist, self._get_obs_matrix(), self._get_obs_dist(), duration=duration)
Beispiel #12
0
 def _get_dist(self):
     """
     Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
     to ``obs_dim``-many independent Matern GPs.
     """
     trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt)
     trans_dist = MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, 1, self.kernel.state_dim),
                                     process_covar.unsqueeze(-3))
     trans_matrix = trans_matrix.unsqueeze(-3)
     return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist,
                             self.obs_matrix, self._get_obs_dist())
Beispiel #13
0
def test_gaussian_hmm_log_prob(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim))
    obs_dist = random_mvn(batch_shape + (num_steps,), obs_dim)
    if diag:
        scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1)
        obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)
    if diag:
        obs_mvn = dist.MultivariateNormal(obs_dist.base_dist.loc,
                                          scale_tril=obs_dist.base_dist.scale.diag_embed())
    else:
        obs_mvn = obs_dist
    data = obs_dist.sample(sample_shape)
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussians with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    # and then combine these using gaussian_tensordot().
    T = num_steps
    init = mvn_to_gaussian(init_dist)
    trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist)
    obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn)

    unrolled_trans = reduce(operator.add, [
        trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim)
        for t in range(T)
    ])
    unrolled_obs = reduce(operator.add, [
        obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim())
        for t in range(T)
    ])
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat([torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
                     [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim,))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp = gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim)
    expected_log_prob = logp.log_density(unrolled_data)
    assert_close(actual_log_prob, expected_log_prob)
Beispiel #14
0
 def _get_dist(self):
     """
     Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
     to a :class:`LinearlyCoupledMaternGP`.
     """
     trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt)
     trans_matrix = block_diag_embed(trans_matrix)
     process_covar = block_diag_embed(process_covar)
     loc = self.A.new_zeros(self.full_state_dim)
     trans_dist = MultivariateNormal(loc, process_covar)
     return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist,
                             self._get_obs_matrix(), self._get_obs_dist())
Beispiel #15
0
    def get_dist(self, duration=None):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds to :class:`GenericLGSSM`.

        :param int duration: Optional size of the time axis ``event_shape[0]``.
            This is required when sampling from homogeneous HMMs whose parameters
            are not expanded along the time axis.
        """
        return dist.GaussianHMM(self._get_init_dist(),
                                self.trans_matrix,
                                self._get_trans_dist(),
                                self.obs_matrix,
                                self._get_obs_dist(),
                                duration=duration)
Beispiel #16
0
    def forward(self, features, trip_counts):
        pyro.module("model", self)
        total_hours = len(features)
        observed_hours, num_origins, num_destins = trip_counts.shape
        assert observed_hours <= total_hours
        assert num_origins == self.num_stations
        assert num_destins == self.num_stations
        time_plate = pyro.plate("time", observed_hours, dim=-3)
        origins_plate = pyro.plate("origins", num_origins, dim=-2)
        destins_plate = pyro.plate("destins", num_destins, dim=-1)

        # The first half of the model performs exact inference over
        # the observed portion of the time series.
        hmm = dist.GaussianHMM(*self._dynamics(features[:observed_hours]))
        gate_rate = pyro.sample("gate_rate", hmm)
        gate, rate = self._unpack_gate_rate(gate_rate, event_dim=2)
        with time_plate, origins_plate, destins_plate:
            pyro.sample("trip_count",
                        dist.ZeroInflatedPoisson(gate, rate),
                        obs=trip_counts)

        # The second half of the model forecasts forward.
        forecast = []
        forecast_hours = total_hours - observed_hours
        if forecast_hours > 0:
            _, trans_matrix, trans_dist, obs_matrix, obs_dist = \
                self._dynamics(features[observed_hours:])
        state = None
        for t in range(forecast_hours):
            if state is None:  # on first step
                state_dist = hmm.filter(gate_rate)
            else:
                loc = vm(state, trans_matrix) + trans_dist.loc
                scale_tril = trans_dist.scale_tril
                state_dist = dist.MultivariateNormal(loc,
                                                     scale_tril=scale_tril)
            state = pyro.sample("state_{}".format(t), state_dist)

            loc = vm(state, obs_matrix) + obs_dist.base_dist.loc[..., t, :]
            scale = obs_dist.base_dist.scale[..., t, :]
            gate_rate = pyro.sample("gate_rate_{}".format(t),
                                    dist.Normal(loc, scale).to_event(1))
            gate, rate = self._unpack_gate_rate(gate_rate, event_dim=1)

            with origins_plate, destins_plate:
                forecast.append(
                    pyro.sample("trip_count_{}".format(t),
                                dist.ZeroInflatedPoisson(gate, rate)))
        return forecast
Beispiel #17
0
    def get_dist(self, duration=None):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
        to ``obs_dim``-many independent Matern GPs.

        :param int duration: Optional size of the time axis ``event_shape[0]``.
            This is required when sampling from homogeneous HMMs whose parameters
            are not expanded along the time axis.
        """
        trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt)
        trans_dist = MultivariateNormal(self.obs_matrix.new_zeros(self.obs_dim, 1, self.kernel.state_dim),
                                        process_covar.unsqueeze(-3))
        trans_matrix = trans_matrix.unsqueeze(-3)
        return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist,
                                self.obs_matrix, self._get_obs_dist(), duration=duration)
Beispiel #18
0
    def get_dist(self, duration=None):
        """
        Get the :class:`~pyro.distributions.GaussianHMM` distribution that corresponds
        to a :class:`LinearlyCoupledMaternGP`.

        :param int duration: Optional size of the time axis ``event_shape[0]``.
            This is required when sampling from homogeneous HMMs whose parameters
            are not expanded along the time axis.
        """
        trans_matrix, process_covar = self.kernel.transition_matrix_and_covariance(dt=self.dt)
        trans_matrix = block_diag_embed(trans_matrix)
        process_covar = block_diag_embed(process_covar)
        loc = self.A.new_zeros(self.full_state_dim)
        trans_dist = MultivariateNormal(loc, process_covar)
        return dist.GaussianHMM(self._get_init_dist(), trans_matrix, trans_dist,
                                self._get_obs_matrix(), self._get_obs_dist(), duration=duration)
Beispiel #19
0
 def model(self, zero_data, covariates):
     duration = zero_data.size(-2)
     init_dist = dist.Normal(0, 1).expand([1]).to_event(1)
     obs_dist = dist.Normal(0, 2).expand([1]).to_event(1)
     obs_matrix = torch.tensor([[1.]])
     trans_dist = dist.Normal(0, 1).expand([1]).to_event(1)
     trans_matrix = torch.tensor([[1.]])
     pre_dist = dist.GaussianHMM(init_dist,
                                 trans_matrix,
                                 trans_dist,
                                 obs_matrix,
                                 obs_dist,
                                 duration=duration)
     prediction = periodic_repeat(torch.zeros(1, 1), duration,
                                  dim=-1).unsqueeze(-1)
     self.predict(pre_dist, prediction)
Beispiel #20
0
def test_independent_hmm_shape(init_shape, trans_mat_shape, trans_mvn_shape,
                               obs_mat_shape, obs_mvn_shape, hidden_dim,
                               obs_dim):
    base_init_shape = init_shape + (obs_dim, )
    base_trans_mat_shape = trans_mat_shape[:-1] + (obs_dim, trans_mat_shape[-1]
                                                   if trans_mat_shape else 6)
    base_trans_mvn_shape = trans_mvn_shape[:-1] + (obs_dim, trans_mvn_shape[-1]
                                                   if trans_mvn_shape else 6)
    base_obs_mat_shape = obs_mat_shape[:-1] + (obs_dim, obs_mat_shape[-1]
                                               if obs_mat_shape else 6)
    base_obs_mvn_shape = obs_mvn_shape[:-1] + (obs_dim, obs_mvn_shape[-1]
                                               if obs_mvn_shape else 6)

    init_dist = random_mvn(base_init_shape, hidden_dim)
    trans_mat = torch.randn(base_trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(base_trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(base_obs_mat_shape + (hidden_dim, 1))
    obs_dist = random_mvn(base_obs_mvn_shape, 1)
    d = dist.GaussianHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=6)
    d = dist.IndependentHMM(d)

    shape = broadcast_shape(init_shape + (6, ), trans_mat_shape,
                            trans_mvn_shape, obs_mat_shape, obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    data = torch.randn(shape + (obs_dim, ))
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)

    x = d.rsample()
    assert x.shape == d.shape()
    x = d.rsample((6, ))
    assert x.shape == (6, ) + d.shape()
    x = d.expand((6, 5)).rsample()
    assert x.shape == (6, 5) + d.event_shape
Beispiel #21
0
    def model(self, zero_data, covariates):
        with pyro.plate_stack("batch", zero_data.shape[:-2], rightmost_dim=-2):
            loc = zero_data[..., :1, :]
            scale = pyro.sample("scale", dist.LogNormal(loc, 1).to_event(1))

            with self.time_plate:
                jumps = pyro.sample("jumps", dist.Normal(0, scale).to_event(1))
            prediction = jumps.cumsum(-2)

            duration, obs_dim = zero_data.shape[-2:]
            noise_dist = dist.GaussianHMM(
                dist.Normal(0, 1).expand([obs_dim]).to_event(1),
                torch.eye(obs_dim),
                dist.Normal(0, 1).expand([obs_dim]).to_event(1),
                torch.eye(obs_dim),
                dist.Normal(0, 1).expand([obs_dim]).to_event(1),
                duration=duration,
            )
            self.predict(noise_dist, prediction)
Beispiel #22
0
def test_gaussian_hmm_elbo(batch_shape, num_steps, hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim),
                            requires_grad=True)
    trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim),
                          requires_grad=True)
    obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim)

    data = obs_dist.sample()
    assert data.shape == batch_shape + (num_steps, obs_dim)
    prior = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                             obs_dist)
    likelihood = dist.Normal(data, 1).to_event(2)
    posterior, log_normalizer = prior.conjugate_update(likelihood)

    def model(data):
        with pyro.plate_stack("plates", batch_shape):
            z = pyro.sample("z", prior)
            pyro.sample("x", dist.Normal(z, 1).to_event(2), obs=data)

    def guide(data):
        with pyro.plate_stack("plates", batch_shape):
            pyro.sample("z", posterior)

    reparam_model = poutine.reparam(model, {"z": ConjugateReparam(likelihood)})

    def reparam_guide(data):
        pass

    elbo = Trace_ELBO(num_particles=1000, vectorize_particles=True)
    expected_loss = elbo.differentiable_loss(model, guide, data)
    actual_loss = elbo.differentiable_loss(reparam_model, reparam_guide, data)
    assert_close(actual_loss, expected_loss, atol=0.01)

    params = [trans_mat, obs_mat]
    expected_grads = torch.autograd.grad(expected_loss,
                                         params,
                                         retain_graph=True)
    actual_grads = torch.autograd.grad(actual_loss, params, retain_graph=True)
    for a, e in zip(actual_grads, expected_grads):
        assert_close(a, e, rtol=0.01)
Beispiel #23
0
def test_gaussian_hmm_high_obs_dim():
    hidden_dim = 1
    obs_dim = 1000
    duration = 10
    sample_shape = (100, )
    init_dist = random_mvn((), hidden_dim)
    trans_mat = torch.randn((duration, ) + (hidden_dim, hidden_dim))
    trans_dist = random_mvn((duration, ), hidden_dim)
    obs_mat = torch.randn((duration, ) + (hidden_dim, obs_dim))
    loc = torch.randn((duration, obs_dim))
    scale = torch.randn((duration, obs_dim)).exp()
    obs_dist = dist.Normal(loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=duration)
    x = d.rsample(sample_shape)
    assert x.shape == sample_shape + (duration, obs_dim)
Beispiel #24
0
def test_gaussian_hmm_log_prob_null_dynamics(init_shape, trans_mat_shape,
                                             trans_mvn_shape, obs_mvn_shape,
                                             hidden_dim):
    obs_dim = hidden_dim
    init_dist = random_mvn(init_shape, hidden_dim)

    # impose null dynamics
    trans_mat = torch.zeros(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim, diag=True)

    # trivial observation matrix (hidden_dim = obs_dim)
    obs_mat = torch.eye(hidden_dim)
    obs_dist = random_mvn(obs_mvn_shape, obs_dim, diag=True)

    actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                              obs_dist)
    expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                                     obs_dist)
    assert actual_dist.batch_shape == expected_dist.batch_shape
    assert actual_dist.event_shape == expected_dist.event_shape

    shape = broadcast_shape(init_shape + (1, ), trans_mat_shape,
                            trans_mvn_shape, obs_mvn_shape)
    data = obs_dist.expand(shape).sample()
    assert data.shape == actual_dist.shape()

    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)

    assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
    check_expand(actual_dist, data)

    obs_cov = obs_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2)
    trans_cov = trans_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2)
    sum_scale = (obs_cov + trans_cov).sqrt()
    sum_loc = trans_dist.loc + obs_dist.loc

    analytic_log_prob = dist.Normal(sum_loc,
                                    sum_scale).log_prob(data).sum(-1).sum(-1)
    assert_close(analytic_log_prob, actual_log_prob, atol=1.0e-5)
Beispiel #25
0
def test_gaussian_filter():
    dim = 4
    init_dist = dist.MultivariateNormal(torch.zeros(dim),
                                        scale_tril=torch.eye(dim) * 10)
    trans_mat = torch.eye(dim)
    trans_dist = dist.MultivariateNormal(torch.zeros(dim),
                                         scale_tril=torch.eye(dim))
    obs_mat = torch.eye(dim)
    obs_dist = dist.MultivariateNormal(torch.zeros(dim),
                                       scale_tril=torch.eye(dim) * 2)
    hmm = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    class Model:
        def init(self, state):
            state["z"] = pyro.sample("z_init", init_dist)
            self.t = 0

        def step(self, state, datum=None):
            state["z"] = pyro.sample(
                "z_{}".format(self.t),
                dist.MultivariateNormal(state["z"],
                                        scale_tril=trans_dist.scale_tril))
            datum = pyro.sample(
                "obs_{}".format(self.t),
                dist.MultivariateNormal(state["z"],
                                        scale_tril=obs_dist.scale_tril),
                obs=datum)
            self.t += 1
            return datum

    class Guide:
        def init(self, state):
            pyro.sample("z_init", init_dist)
            self.t = 0

        def step(self, state, datum):
            pyro.sample(
                "z_{}".format(self.t),
                dist.MultivariateNormal(state["z"],
                                        scale_tril=trans_dist.scale_tril * 2))
            self.t += 1

    # Generate data.
    num_steps = 20
    model = Model()
    state = {}
    model.init(state)
    data = torch.stack([model.step(state) for _ in range(num_steps)])

    # Perform inference.
    model = Model()
    guide = Guide()
    smc = SMCFilter(model, guide, num_particles=1000, max_plate_nesting=0)
    smc.init()
    for t, datum in enumerate(data):
        smc.step(datum)
        expected = hmm.filter(data[:1 + t])
        actual = smc.get_empirical()["z"]
        assert_close(actual.variance**0.5,
                     expected.variance**0.5,
                     atol=0.1,
                     rtol=0.5)
        sigma = actual.variance.max().item()**0.5
        assert_close(actual.mean, expected.mean, atol=3 * sigma)
Beispiel #26
0
    def model(self, zero_data, covariates):
        period = 24 * 7
        duration, dim = zero_data.shape[-2:]
        assert dim == 2  # Data is bivariate: (arrivals, departures).

        # Sample global parameters.
        noise_scale = pyro.sample(
            "noise_scale",
            dist.LogNormal(torch.full((dim, ), -3.), 1.).to_event(1))
        assert noise_scale.shape[-1:] == (dim, )
        trans_timescale = pyro.sample(
            "trans_timescale",
            dist.LogNormal(torch.zeros(dim), 1).to_event(1))
        assert trans_timescale.shape[-1:] == (dim, )

        trans_loc = pyro.sample("trans_loc", dist.Cauchy(0, 1 / period))
        trans_loc = trans_loc.unsqueeze(-1).expand(trans_loc.shape + (dim, ))
        assert trans_loc.shape[-1:] == (dim, )
        trans_scale = pyro.sample(
            "trans_scale",
            dist.LogNormal(torch.zeros(dim), 0.1).to_event(1))
        trans_corr = pyro.sample("trans_corr",
                                 dist.LKJCorrCholesky(dim, torch.ones(())))
        trans_scale_tril = trans_scale.unsqueeze(-1) * trans_corr
        assert trans_scale_tril.shape[-2:] == (dim, dim)

        obs_scale = pyro.sample(
            "obs_scale",
            dist.LogNormal(torch.zeros(dim), 0.1).to_event(1))
        obs_corr = pyro.sample("obs_corr",
                               dist.LKJCorrCholesky(dim, torch.ones(())))
        obs_scale_tril = obs_scale.unsqueeze(-1) * obs_corr
        assert obs_scale_tril.shape[-2:] == (dim, dim)

        # Note the initial seasonality should be sampled in a plate with the
        # same dim as the time_plate, dim=-1. That way we can repeat the dim
        # below using periodic_repeat().
        with pyro.plate("season_plate", period, dim=-1):
            season_init = pyro.sample(
                "season_init",
                dist.Normal(torch.zeros(dim), 1).to_event(1))
            assert season_init.shape[-2:] == (period, dim)

        # Sample independent noise at each time step.
        with self.time_plate:
            season_noise = pyro.sample("season_noise",
                                       dist.Normal(0, noise_scale).to_event(1))
            assert season_noise.shape[-2:] == (duration, dim)

        # Construct a prediction. This prediction has an exactly repeated
        # seasonal part plus slow seasonal drift. We use two deterministic,
        # linear functions to transform our diagonal Normal noise to nontrivial
        # samples from a Gaussian process.
        prediction = (periodic_repeat(season_init, duration, dim=-2) +
                      periodic_cumsum(season_noise, period, dim=-2))
        assert prediction.shape[-2:] == (duration, dim)

        # Construct a joint noise model. This model is a GaussianHMM, whose
        # .rsample() and .log_prob() methods are parallelized over time; this
        # this entire model is parallelized over time.
        init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1)
        trans_mat = trans_timescale.neg().exp().diag_embed()
        trans_dist = dist.MultivariateNormal(trans_loc,
                                             scale_tril=trans_scale_tril)
        obs_mat = torch.eye(dim)
        obs_dist = dist.MultivariateNormal(torch.zeros(dim),
                                           scale_tril=obs_scale_tril)
        noise_model = dist.GaussianHMM(init_dist,
                                       trans_mat,
                                       trans_dist,
                                       obs_mat,
                                       obs_dist,
                                       duration=duration)
        assert noise_model.event_shape == (duration, dim)

        # The final statement registers our noise model and prediction.
        self.predict(noise_model, prediction)
Beispiel #27
0
def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps,
                                   hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim))
    obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim)
    if diag:
        scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1)
        obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=num_steps)
    if diag:
        obs_mvn = dist.MultivariateNormal(
            obs_dist.base_dist.loc,
            scale_tril=obs_dist.base_dist.scale.diag_embed())
    else:
        obs_mvn = obs_dist
    data = obs_dist.sample(sample_shape)
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussians with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    #    like |         O O O
    # and then combine these using gaussian_tensordot().
    T = num_steps
    init = mvn_to_gaussian(init_dist)
    trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist)
    obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn)
    like_dist = dist.Normal(torch.randn(data.shape), 1).to_event(2)
    like = mvn_to_gaussian(like_dist)

    unrolled_trans = reduce(operator.add, [
        trans[..., t].event_pad(left=t * hidden_dim,
                                right=(T - t - 1) * hidden_dim)
        for t in range(T)
    ])
    unrolled_obs = reduce(operator.add, [
        obs[..., t].event_pad(left=t * obs.dim(),
                              right=(T - t - 1) * obs.dim()) for t in range(T)
    ])
    unrolled_like = reduce(operator.add, [
        like[..., t].event_pad(left=t * obs_dim, right=(T - t - 1) * obs_dim)
        for t in range(T)
    ])
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat(
        [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
        [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, ))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp = gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim)
    expected_log_prob = logp.log_density(unrolled_data)
    assert_close(actual_log_prob, expected_log_prob)

    d_posterior, log_normalizer = d.conjugate_update(like_dist)
    assert_close(
        d.log_prob(data) + like_dist.log_prob(data),
        d_posterior.log_prob(data) + log_normalizer)

    if batch_shape or sample_shape:
        return

    # Test mean and covariance.
    prior = "prior", d, logp
    posterior = "posterior", d_posterior, logp + unrolled_like
    for name, d, g in [prior, posterior]:
        logging.info("testing {} moments".format(name))
        with torch.no_grad():
            num_samples = 100000
            samples = d.sample([num_samples]).reshape(num_samples, T * obs_dim)
            actual_mean = samples.mean(0)
            delta = samples - actual_mean
            actual_cov = (delta.unsqueeze(-1) * delta.unsqueeze(-2)).mean(0)
            actual_std = actual_cov.diagonal(dim1=-2, dim2=-1).sqrt()
            actual_corr = actual_cov / (actual_std.unsqueeze(-1) *
                                        actual_std.unsqueeze(-2))

            expected_cov = g.precision.cholesky().cholesky_inverse()
            expected_mean = expected_cov.matmul(
                g.info_vec.unsqueeze(-1)).squeeze(-1)
            expected_std = expected_cov.diagonal(dim1=-2, dim2=-1).sqrt()
            expected_corr = expected_cov / (expected_std.unsqueeze(-1) *
                                            expected_std.unsqueeze(-2))

            assert_close(actual_mean, expected_mean, atol=0.05, rtol=0.02)
            assert_close(actual_std, expected_std, atol=0.05, rtol=0.02)
            assert_close(actual_corr, expected_corr, atol=0.02)
Beispiel #28
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]

        fn, event_dim = self._unwrap(fn)
        assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM))
        if fn.duration is None:
            raise ValueError(
                "LinearHMMReparam requires duration to be specified "
                "on targeted LinearHMM distributions")

        # Unwrap IndependentHMM.
        if isinstance(fn, dist.IndependentHMM):
            indep_value = None
            if value is not None:
                indep_value = value.transpose(-1, -2).unsqueeze(-1)
            msg = self.apply({
                "name": name,
                "fn": fn.base_dist.to_event(1),
                "value": indep_value,
                "is_observed": is_observed,
            })
            hmm = msg["fn"]
            hmm = dist.IndependentHMM(hmm.to_event(-1))
            if msg["value"] is not indep_value:
                value = msg["value"].squeeze(-1).transpose(-1, -2)
            return {"fn": hmm, "value": value, "is_observed": is_observed}

        # Reparameterize the initial distribution as conditionally Gaussian.
        init_dist = fn.initial_dist
        if self.init is not None:
            msg = self.init.apply({
                "name": f"{name}_init",
                "fn": self._wrap(init_dist, event_dim - 1),
                "value": None,
                "is_observed": False,
            })
            init_dist = msg["fn"]
            init_dist = init_dist.to_event(1 - init_dist.event_dim)

        # Reparameterize the transition distribution as conditionally Gaussian.
        trans_dist = fn.transition_dist
        if self.trans is not None:
            if trans_dist.batch_shape[-1] != fn.duration:
                trans_dist = trans_dist.expand(trans_dist.batch_shape[:-1] +
                                               (fn.duration, ))
            msg = self.trans.apply({
                "name": f"{name}_trans",
                "fn": self._wrap(trans_dist, event_dim),
                "value": None,
                "is_observed": False,
            })
            trans_dist = msg["fn"]
            trans_dist = trans_dist.to_event(1 - trans_dist.event_dim)

        # Reparameterize the observation distribution as conditionally Gaussian.
        obs_dist = fn.observation_dist
        if self.obs is not None:
            if obs_dist.batch_shape[-1] != fn.duration:
                obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] +
                                           (fn.duration, ))
            msg = self.obs.apply({
                "name": f"{name}_obs",
                "fn": self._wrap(obs_dist, event_dim),
                "value": value,
                "is_observed": is_observed,
            })
            obs_dist = msg["fn"]
            obs_dist = obs_dist.to_event(1 - obs_dist.event_dim)
            value = msg["value"]
            is_observed = msg["is_observed"]

        # Reparameterize the entire HMM as conditionally Gaussian.
        hmm = dist.GaussianHMM(
            init_dist,
            fn.transition_matrix,
            trans_dist,
            fn.observation_matrix,
            obs_dist,
            duration=fn.duration,
        )
        hmm = self._wrap(hmm, event_dim)

        # Apply any observation transforms.
        if fn.transforms:
            hmm = dist.TransformedDistribution(hmm, fn.transforms)

        return {"fn": hmm, "value": value, "is_observed": is_observed}
Beispiel #29
0
def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape,
                            obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)
    if diag:
        scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1)
        obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=6)

    shape = broadcast_shape(init_shape + (6, ), trans_mat_shape,
                            trans_mvn_shape, obs_mat_shape, obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    data = obs_dist.expand(shape).sample()
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)

    x = d.rsample()
    assert x.shape == d.shape()
    x = d.rsample((6, ))
    assert x.shape == (6, ) + d.shape()
    x = d.expand((6, 5)).rsample()
    assert x.shape == (6, 5) + d.event_shape

    likelihood = dist.Normal(data, 1).to_event(2)
    p, log_normalizer = d.conjugate_update(likelihood)
    assert p.batch_shape == d.batch_shape
    assert p.event_shape == d.event_shape
    x = p.rsample()
    assert x.shape == d.shape()
    x = p.rsample((6, ))
    assert x.shape == (6, ) + d.shape()
    x = p.expand((6, 5)).rsample()
    assert x.shape == (6, 5) + d.event_shape

    final = d.filter(data)
    assert isinstance(final, dist.MultivariateNormal)
    assert final.batch_shape == d.batch_shape
    assert final.event_shape == (hidden_dim, )

    z = d.rsample_posterior(data)
    assert z.shape == expected_batch_shape + time_shape + (hidden_dim, )

    for t in range(1, d.duration - 1):
        f = d.duration - t
        d2 = d.prefix_condition(data[..., :t, :])
        assert d2.batch_shape == d.batch_shape
        assert d2.event_shape == (f, obs_dim)