Esempio n. 1
0
    def __call__(self, name, fn, obs):
        fn, event_dim = self._unwrap(fn)
        assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM))
        if fn.duration is None:
            raise ValueError(
                "LinearHMMReparam requires duration to be specified "
                "on targeted LinearHMM distributions")

        # Unwrap IndependentHMM.
        if isinstance(fn, dist.IndependentHMM):
            if obs is not None:
                obs = obs.transpose(-1, -2).unsqueeze(-1)
            hmm, obs = self(name, fn.base_dist.to_event(1), obs)
            hmm = dist.IndependentHMM(hmm.to_event(-1))
            if obs is not None:
                obs = obs.squeeze(-1).transpose(-1, -2)
            return hmm, obs

        # Reparameterize the initial distribution as conditionally Gaussian.
        init_dist = fn.initial_dist
        if self.init is not None:
            init_dist, _ = self.init("{}_init".format(name),
                                     self._wrap(init_dist, event_dim - 1),
                                     None)
            init_dist = init_dist.to_event(1 - init_dist.event_dim)

        # Reparameterize the transition distribution as conditionally Gaussian.
        trans_dist = fn.transition_dist
        if self.trans is not None:
            if trans_dist.batch_shape[-1] != fn.duration:
                trans_dist = trans_dist.expand(trans_dist.batch_shape[:-1] +
                                               (fn.duration, ))
            trans_dist, _ = self.trans("{}_trans".format(name),
                                       self._wrap(trans_dist, event_dim), None)
            trans_dist = trans_dist.to_event(1 - trans_dist.event_dim)

        # Reparameterize the observation distribution as conditionally Gaussian.
        obs_dist = fn.observation_dist
        if self.obs is not None:
            if obs_dist.batch_shape[-1] != fn.duration:
                obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] +
                                           (fn.duration, ))
            obs_dist, obs = self.obs("{}_obs".format(name),
                                     self._wrap(obs_dist, event_dim), obs)
            obs_dist = obs_dist.to_event(1 - obs_dist.event_dim)

        # Reparameterize the entire HMM as conditionally Gaussian.
        hmm = dist.GaussianHMM(init_dist,
                               fn.transition_matrix,
                               trans_dist,
                               fn.observation_matrix,
                               obs_dist,
                               duration=fn.duration)
        hmm = self._wrap(hmm, event_dim)

        # Apply any observation transforms.
        if fn.transforms:
            hmm = dist.TransformedDistribution(hmm, fn.transforms)

        return hmm, obs
Esempio n. 2
0
def test_independent_hmm_shape(init_shape, trans_mat_shape, trans_mvn_shape,
                               obs_mat_shape, obs_mvn_shape, hidden_dim,
                               obs_dim):
    base_init_shape = init_shape + (obs_dim, )
    base_trans_mat_shape = trans_mat_shape[:-1] + (obs_dim, trans_mat_shape[-1]
                                                   if trans_mat_shape else 6)
    base_trans_mvn_shape = trans_mvn_shape[:-1] + (obs_dim, trans_mvn_shape[-1]
                                                   if trans_mvn_shape else 6)
    base_obs_mat_shape = obs_mat_shape[:-1] + (obs_dim, obs_mat_shape[-1]
                                               if obs_mat_shape else 6)
    base_obs_mvn_shape = obs_mvn_shape[:-1] + (obs_dim, obs_mvn_shape[-1]
                                               if obs_mvn_shape else 6)

    init_dist = random_mvn(base_init_shape, hidden_dim)
    trans_mat = torch.randn(base_trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(base_trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(base_obs_mat_shape + (hidden_dim, 1))
    obs_dist = random_mvn(base_obs_mvn_shape, 1)
    d = dist.GaussianHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=6)
    d = dist.IndependentHMM(d)

    shape = broadcast_shape(init_shape + (6, ), trans_mat_shape,
                            trans_mvn_shape, obs_mat_shape, obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    data = torch.randn(shape + (obs_dim, ))
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)

    x = d.rsample()
    assert x.shape == d.shape()
    x = d.rsample((6, ))
    assert x.shape == (6, ) + d.shape()
    x = d.expand((6, 5)).rsample()
    assert x.shape == (6, 5) + d.event_shape
Esempio n. 3
0
def test_independent_hmm_shape(skew, batch_shape, duration, hidden_dim,
                               obs_dim):
    base_batch_shape = batch_shape + (obs_dim, )
    stability = dist.Uniform(0.5, 2).sample(base_batch_shape)
    init_dist = random_stable(base_batch_shape + (hidden_dim, ),
                              stability.unsqueeze(-1),
                              skew=skew).to_event(1)
    trans_mat = torch.randn(base_batch_shape +
                            (duration, hidden_dim, hidden_dim))
    trans_dist = random_stable(
        base_batch_shape + (duration, hidden_dim),
        stability.unsqueeze(-1).unsqueeze(-1),
        skew=skew,
    ).to_event(1)
    obs_mat = torch.randn(base_batch_shape + (duration, hidden_dim, 1))
    obs_dist = random_stable(
        base_batch_shape + (duration, 1),
        stability.unsqueeze(-1).unsqueeze(-1),
        skew=skew,
    ).to_event(1)
    hmm = dist.LinearHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=duration)
    assert hmm.batch_shape == base_batch_shape
    assert hmm.event_shape == (duration, 1)

    hmm = dist.IndependentHMM(hmm)
    assert hmm.batch_shape == batch_shape
    assert hmm.event_shape == (duration, obs_dim)

    def model(data=None):
        with pyro.plate_stack("plates", batch_shape):
            return pyro.sample("x", hmm, obs=data)

    data = torch.randn(duration, obs_dim)
    rep = SymmetricStableReparam() if skew == 0 else StableReparam()
    with poutine.trace() as tr:
        with poutine.reparam(config={"x": LinearHMMReparam(rep, rep, rep)}):
            model(data)
    assert isinstance(tr.trace.nodes["x"]["fn"], dist.IndependentHMM)
    tr.trace.compute_log_prob()  # smoke test only
Esempio n. 4
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]

        fn, event_dim = self._unwrap(fn)
        assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM))
        if fn.duration is None:
            raise ValueError(
                "LinearHMMReparam requires duration to be specified "
                "on targeted LinearHMM distributions")

        # Unwrap IndependentHMM.
        if isinstance(fn, dist.IndependentHMM):
            indep_value = None
            if value is not None:
                indep_value = value.transpose(-1, -2).unsqueeze(-1)
            msg = self.apply({
                "name": name,
                "fn": fn.base_dist.to_event(1),
                "value": indep_value,
                "is_observed": is_observed,
            })
            hmm = msg["fn"]
            hmm = dist.IndependentHMM(hmm.to_event(-1))
            if msg["value"] is not indep_value:
                value = msg["value"].squeeze(-1).transpose(-1, -2)
            return {"fn": hmm, "value": value, "is_observed": is_observed}

        # Reparameterize the initial distribution as conditionally Gaussian.
        init_dist = fn.initial_dist
        if self.init is not None:
            msg = self.init.apply({
                "name": f"{name}_init",
                "fn": self._wrap(init_dist, event_dim - 1),
                "value": None,
                "is_observed": False,
            })
            init_dist = msg["fn"]
            init_dist = init_dist.to_event(1 - init_dist.event_dim)

        # Reparameterize the transition distribution as conditionally Gaussian.
        trans_dist = fn.transition_dist
        if self.trans is not None:
            if trans_dist.batch_shape[-1] != fn.duration:
                trans_dist = trans_dist.expand(trans_dist.batch_shape[:-1] +
                                               (fn.duration, ))
            msg = self.trans.apply({
                "name": f"{name}_trans",
                "fn": self._wrap(trans_dist, event_dim),
                "value": None,
                "is_observed": False,
            })
            trans_dist = msg["fn"]
            trans_dist = trans_dist.to_event(1 - trans_dist.event_dim)

        # Reparameterize the observation distribution as conditionally Gaussian.
        obs_dist = fn.observation_dist
        if self.obs is not None:
            if obs_dist.batch_shape[-1] != fn.duration:
                obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] +
                                           (fn.duration, ))
            msg = self.obs.apply({
                "name": f"{name}_obs",
                "fn": self._wrap(obs_dist, event_dim),
                "value": value,
                "is_observed": is_observed,
            })
            obs_dist = msg["fn"]
            obs_dist = obs_dist.to_event(1 - obs_dist.event_dim)
            value = msg["value"]
            is_observed = msg["is_observed"]

        # Reparameterize the entire HMM as conditionally Gaussian.
        hmm = dist.GaussianHMM(
            init_dist,
            fn.transition_matrix,
            trans_dist,
            fn.observation_matrix,
            obs_dist,
            duration=fn.duration,
        )
        hmm = self._wrap(hmm, event_dim)

        # Apply any observation transforms.
        if fn.transforms:
            hmm = dist.TransformedDistribution(hmm, fn.transforms)

        return {"fn": hmm, "value": value, "is_observed": is_observed}
Esempio n. 5
0
def _(d, batch_shape):
    base_shape = batch_shape + d.event_shape[-1:]
    base_dist = reshape_batch(d.base_dist, base_shape)
    return dist.IndependentHMM(base_dist)
Esempio n. 6
0
def _(d, data):
    base_data = data.transpose(-1, -2).unsqueeze(-1)
    base_dist = prefix_condition(d.base_dist, base_data)
    return dist.IndependentHMM(base_dist)