Exemple #1
0
def test_distributions(state_dim, obs_dim):
    data = Tensor(torch.randn(2, obs_dim))["time"]

    bias = Variable("bias", reals(obs_dim))
    bias_dist = dist_to_funsor(random_mvn((), obs_dim))(value=bias)

    prev = Variable("prev", reals(state_dim))
    curr = Variable("curr", reals(state_dim))
    trans_mat = Tensor(
        torch.eye(state_dim) + 0.1 * torch.randn(state_dim, state_dim))
    trans_mvn = random_mvn((), state_dim)
    trans_dist = dist.MultivariateNormal(loc=trans_mvn.loc,
                                         scale_tril=trans_mvn.scale_tril,
                                         value=curr - prev @ trans_mat)

    state = Variable("state", reals(state_dim))
    obs = Variable("obs", reals(obs_dim))
    obs_mat = Tensor(torch.randn(state_dim, obs_dim))
    obs_mvn = random_mvn((), obs_dim)
    obs_dist = dist.MultivariateNormal(loc=obs_mvn.loc,
                                       scale_tril=obs_mvn.scale_tril,
                                       value=state @ obs_mat + bias - obs)

    log_prob = 0
    log_prob += bias_dist

    state_0 = Variable("state_0", reals(state_dim))
    log_prob += obs_dist(state=state_0, obs=data(time=0))

    state_1 = Variable("state_1", reals(state_dim))
    log_prob += trans_dist(prev=state_0, curr=state_1)
    log_prob += obs_dist(state=state_1, obs=data(time=1))

    log_prob = log_prob.reduce(ops.logaddexp)
    assert isinstance(log_prob, Tensor), log_prob.pretty()
def test_mvn_affine_one_var():
    x = Variable('x', reals(2))
    data = dict(x=Tensor(torch.randn(2)))
    with interpretation(lazy):
        d = dist_to_funsor(random_mvn((), 2))
        d = d(value=2 * x + 1)
    _check_mvn_affine(d, data)
Exemple #3
0
def test_pyro_convert():
    data = Tensor(torch.randn(2, 2), OrderedDict([("time", bint(2))]))

    bias_dist = dist_to_funsor(random_mvn((), 2))

    trans_mat = torch.randn(3, 3)
    trans_mvn = random_mvn((), 3)
    trans = matrix_and_mvn_to_funsor(trans_mat, trans_mvn, (), "prev", "curr")

    obs_mat = torch.randn(3, 2)
    obs_mvn = random_mvn((), 2)
    obs = matrix_and_mvn_to_funsor(obs_mat, obs_mvn, (), "state", "obs")

    log_prob = 0
    bias = Variable("bias", reals(2))
    log_prob += bias_dist(value=bias)

    state_0 = Variable("state_0", reals(3))
    log_prob += obs(state=state_0, obs=bias + data(time=0))

    state_1 = Variable("state_1", reals(3))
    log_prob += trans(prev=state_0, curr=state_1)
    log_prob += obs(state=state_1, obs=bias + data(time=1))

    log_prob = log_prob.reduce(ops.logaddexp)
    assert isinstance(log_prob, Tensor), log_prob.pretty()
def test_mvn_affine_getitem():
    x = Variable('x', reals(2, 2))
    data = dict(x=Tensor(torch.randn(2, 2)))
    with interpretation(lazy):
        d = dist_to_funsor(random_mvn((), 2))
        d = d(value=x[0] - x[1])
    _check_mvn_affine(d, data)
def test_mvn_affine_two_vars():
    x = Variable('x', reals(2))
    y = Variable('y', reals(2))
    data = dict(x=Tensor(randn(2)), y=Tensor(randn(2)))
    with interpretation(lazy):
        d = dist_to_funsor(random_mvn((), 2))
        d = d(value=x - y)
    _check_mvn_affine(d, data)
def test_mvn_affine_reshape():
    x = Variable('x', reals(2, 2))
    y = Variable('y', reals(4))
    data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(4)))
    with interpretation(lazy):
        d = dist_to_funsor(random_mvn((), 4))
        d = d(value=x.reshape((4, )) - y)
    _check_mvn_affine(d, data)
Exemple #7
0
def test_dist_to_funsor_categorical(batch_shape, cardinality):
    logits = torch.randn(batch_shape + (cardinality, ))
    logits -= logits.logsumexp(dim=-1, keepdim=True)
    d = dist.Categorical(logits=logits)
    f = dist_to_funsor(d)
    assert isinstance(f, Tensor)
    expected = tensor_to_funsor(logits, ("value", ))
    assert_close(f, expected)
def test_mvn_affine_einsum():
    c = Tensor(torch.randn(3, 2, 2))
    x = Variable('x', reals(2, 2))
    y = Variable('y', reals())
    data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(())))
    with interpretation(lazy):
        d = dist_to_funsor(random_mvn((), 3))
        d = d(value=Einsum("abc,bc->a", c, x) + y)
    _check_mvn_affine(d, data)
def test_mvn_affine_matmul():
    x = Variable('x', reals(2))
    y = Variable('y', reals(3))
    m = Tensor(torch.randn(2, 3))
    data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(3)))
    with interpretation(lazy):
        d = dist_to_funsor(random_mvn((), 3))
        d = d(value=x @ m - y)
    _check_mvn_affine(d, data)
Exemple #10
0
def test_dist_to_funsor_bernoulli(batch_shape):
    logits = torch.randn(batch_shape)
    d = dist.Bernoulli(logits=logits)
    f = dist_to_funsor(d)
    assert isinstance(f, Funsor)

    value = d.sample()
    actual_log_prob = f(value=tensor_to_funsor(value))
    expected_log_prob = tensor_to_funsor(d.log_prob(value))
    assert_close(actual_log_prob, expected_log_prob)
Exemple #11
0
def test_dist_to_funsor_normal(batch_shape):
    loc = torch.randn(batch_shape)
    scale = torch.randn(batch_shape).exp()
    d = dist.Normal(loc, scale)
    f = dist_to_funsor(d)
    assert isinstance(f, Funsor)

    value = d.sample()
    actual_log_prob = f(value=tensor_to_funsor(value))
    expected_log_prob = tensor_to_funsor(d.log_prob(value))
    assert_close(actual_log_prob, expected_log_prob, rtol=1e-5)
Exemple #12
0
    def __init__(self,
                 initial_dist,
                 transition_dist,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_dist, torch.distributions.MultivariateNormal)
        assert isinstance(transition_dist,
                          torch.distributions.MultivariateNormal)
        assert isinstance(observation_dist,
                          torch.distributions.MultivariateNormal)
        hidden_dim = initial_dist.event_shape[0]
        assert transition_dist.event_shape[0] == hidden_dim + hidden_dim
        obs_dim = observation_dist.event_shape[0] - hidden_dim
        shape = broadcast_shape(initial_dist.batch_shape + (1, ),
                                transition_dist.batch_shape,
                                observation_dist.batch_shape)
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + (obs_dim, )

        # Convert distributions to funsors.
        init = dist_to_funsor(initial_dist)(value="state")
        trans = mvn_to_funsor(
            transition_dist, ("time", ),
            OrderedDict([("state", Reals[hidden_dim]),
                         ("state(time=1)", Reals[hidden_dim])]))
        obs = mvn_to_funsor(
            observation_dist, ("time", ),
            OrderedDict([("state(time=1)", Reals[hidden_dim]),
                         ("value", Reals[obs_dim])]))

        # Construct the joint funsor.
        # Compare with pyro.distributions.hmm.GaussianMRF.log_prob().
        with interpretation(lazy):
            time = Variable("time", Bint[time_shape[0]])
            value = Variable("value", Reals[time_shape[0], obs_dim])
            logp_oh = trans + obs(value=value["time"])
            logp_oh = MarkovProduct(ops.logaddexp, ops.add, logp_oh, time,
                                    {"state": "state(time=1)"})
            logp_oh += init
            logp_oh = logp_oh.reduce(ops.logaddexp,
                                     frozenset({"state", "state(time=1)"}))
            logp_h = trans + obs.reduce(ops.logaddexp, "value")
            logp_h = MarkovProduct(ops.logaddexp, ops.add, logp_h, time,
                                   {"state": "state(time=1)"})
            logp_h += init
            logp_h = logp_h.reduce(ops.logaddexp,
                                   frozenset({"state", "state(time=1)"}))
            funsor_dist = logp_oh - logp_h

        dtype = "real"
        super(GaussianMRF, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
        self.hidden_dim = hidden_dim
        self.obs_dim = obs_dim
Exemple #13
0
def test_dist_to_funsor_independent(batch_shape, event_shape):
    loc = torch.randn(batch_shape + event_shape)
    scale = torch.randn(batch_shape + event_shape).exp()
    d = dist.Normal(loc, scale).to_event(len(event_shape))
    f = dist_to_funsor(d)
    assert isinstance(f, Funsor)

    value = d.sample()
    funsor_value = tensor_to_funsor(value, event_output=len(event_shape))
    actual_log_prob = f(value=funsor_value)
    expected_log_prob = tensor_to_funsor(d.log_prob(value))
    assert_close(actual_log_prob, expected_log_prob, rtol=1e-5)
Exemple #14
0
def test_dist_to_funsor_masked(batch_shape):
    loc = torch.randn(batch_shape)
    scale = torch.randn(batch_shape).exp()
    mask = torch.bernoulli(torch.full(batch_shape, 0.5)).byte()
    d = dist.Normal(loc, scale).mask(mask)
    assert isinstance(d, MaskedDistribution)
    f = dist_to_funsor(d)
    assert isinstance(f, Funsor)

    value = d.sample()
    actual_log_prob = f(value=tensor_to_funsor(value))
    expected_log_prob = tensor_to_funsor(d.log_prob(value))
    assert_close(actual_log_prob, expected_log_prob)
Exemple #15
0
def test_dist_to_funsor_mvn(batch_shape, event_size):
    loc = torch.randn(batch_shape + (event_size, ))
    cov = torch.randn(batch_shape + (event_size, 2 * event_size))
    cov = cov.matmul(cov.transpose(-1, -2))
    scale_tril = torch.cholesky(cov)
    d = dist.MultivariateNormal(loc, scale_tril=scale_tril)
    f = dist_to_funsor(d)
    assert isinstance(f, Funsor)

    value = d.sample()
    actual_log_prob = f(value=tensor_to_funsor(value, event_output=1))
    expected_log_prob = tensor_to_funsor(d.log_prob(value))
    assert_close(actual_log_prob, expected_log_prob)
Exemple #16
0
    def _forward_funsor(self, features, trip_counts):
        total_hours = len(features)
        observed_hours, num_origins, num_destins = trip_counts.shape
        assert observed_hours == total_hours
        assert num_origins == self.num_stations
        assert num_destins == self.num_stations
        n = self.num_stations
        gate_rate = funsor.Variable("gate_rate_t",
                                    reals(observed_hours, 2 * n * n))["time"]

        @funsor.torch.function(reals(2 * n * n), (reals(n, n, 2), reals(n, n)))
        def unpack_gate_rate(gate_rate):
            batch_shape = gate_rate.shape[:-1]
            gate, rate = gate_rate.reshape(batch_shape + (2, n, n)).unbind(-3)
            gate = gate.sigmoid().clamp(min=0.01, max=0.99)
            rate = bounded_exp(rate, bound=1e4)
            gate = torch.stack((1 - gate, gate), dim=-1)
            return gate, rate

        # Create a Gaussian latent dynamical system.
        init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist = \
            self._dynamics(features[:observed_hours])
        init = dist_to_funsor(init_dist)(value="state")
        trans = matrix_and_mvn_to_funsor(trans_matrix, trans_dist, ("time", ),
                                         "state", "state(time=1)")
        obs = matrix_and_mvn_to_funsor(obs_matrix, obs_dist, ("time", ),
                                       "state(time=1)", "gate_rate")

        # Compute dynamic prior over gate_rate.
        prior = trans + obs(gate_rate=gate_rate)
        prior = MarkovProduct(ops.logaddexp, ops.add, prior, "time",
                              {"state": "state(time=1)"})
        prior += init
        prior = prior.reduce(ops.logaddexp, {"state", "state(time=1)"})

        # Compute zero-inflated Poisson likelihood.
        gate, rate = unpack_gate_rate(gate_rate)
        likelihood = fdist.Categorical(gate["origin", "destin"], value="gated")
        trip_counts = tensor_to_funsor(trip_counts,
                                       ("time", "origin", "destin"))
        likelihood += funsor.Stack(
            "gated",
            (fdist.Poisson(rate["origin", "destin"], value=trip_counts),
             fdist.Delta(0, value=trip_counts)))
        likelihood = likelihood.reduce(ops.logaddexp, "gated")
        likelihood = likelihood.reduce(ops.add, {"time", "origin", "destin"})

        assert set(prior.inputs) == {"gate_rate_t"}, prior.inputs
        assert set(likelihood.inputs) == {"gate_rate_t"}, likelihood.inputs
        return prior, likelihood
Exemple #17
0
    def __init__(self,
                 initial_dist,
                 transition_matrix,
                 transition_dist,
                 observation_matrix,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_dist, torch.distributions.MultivariateNormal)
        assert isinstance(transition_matrix, torch.Tensor)
        assert isinstance(transition_dist,
                          torch.distributions.MultivariateNormal)
        assert isinstance(observation_matrix, torch.Tensor)
        assert isinstance(observation_dist,
                          torch.distributions.MultivariateNormal)
        hidden_dim, obs_dim = observation_matrix.shape[-2:]
        assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim"
        assert initial_dist.event_shape == (hidden_dim, )
        assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
        assert transition_dist.event_shape == (hidden_dim, )
        assert observation_dist.event_shape == (obs_dim, )
        shape = broadcast_shape(initial_dist.batch_shape + (1, ),
                                transition_matrix.shape[:-2],
                                transition_dist.batch_shape,
                                observation_matrix.shape[:-2],
                                observation_dist.batch_shape)
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + (obs_dim, )

        # Convert distributions to funsors.
        init = dist_to_funsor(initial_dist)(value="state")
        trans = matrix_and_mvn_to_funsor(transition_matrix, transition_dist,
                                         ("time", ), "state", "state(time=1)")
        obs = matrix_and_mvn_to_funsor(observation_matrix, observation_dist,
                                       ("time", ), "state(time=1)", "value")
        dtype = "real"

        # Construct the joint funsor.
        with interpretation(lazy):
            value = Variable("value", Reals[time_shape[0], obs_dim])
            result = trans + obs(value=value["time"])
            result = MarkovProduct(ops.logaddexp, ops.add, result, "time",
                                   {"state": "state(time=1)"})
            result = init + result.reduce(ops.logaddexp, "state(time=1)")
            funsor_dist = result.reduce(ops.logaddexp, "state")

        super(GaussianHMM, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
        self.hidden_dim = hidden_dim
        self.obs_dim = obs_dim
Exemple #18
0
def test_funsor_to_mvn(batch_shape, event_shape, real_size):
    expected = random_mvn(batch_shape + event_shape, real_size)
    event_dims = tuple("abc"[:len(event_shape)])
    ndims = len(expected.batch_shape)

    funsor_ = dist_to_funsor(expected, event_dims)(value="value")
    assert isinstance(funsor_, Funsor)

    actual = funsor_to_mvn(funsor_, ndims, event_dims)
    assert isinstance(actual, dist.MultivariateNormal)
    assert actual.batch_shape == expected.batch_shape
    assert_close(actual.loc, expected.loc, atol=1e-3, rtol=None)
    assert_close(actual.precision_matrix,
                 expected.precision_matrix,
                 atol=1e-3,
                 rtol=None)
Exemple #19
0
    def __init__(self,
                 initial_logits,
                 transition_logits,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_logits, torch.Tensor)
        assert isinstance(transition_logits, torch.Tensor)
        assert isinstance(observation_dist, torch.distributions.Distribution)
        assert initial_logits.dim() >= 1
        assert transition_logits.dim() >= 2
        assert len(observation_dist.batch_shape) >= 1
        shape = broadcast_shape(initial_logits.shape[:-1] + (1, ),
                                transition_logits.shape[:-2],
                                observation_dist.batch_shape[:-1])
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + observation_dist.event_shape
        self._has_rsample = observation_dist.has_rsample

        # Normalize.
        initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
        transition_logits = transition_logits - transition_logits.logsumexp(
            -1, True)

        # Convert tensors and distributions to funsors.
        init = tensor_to_funsor(initial_logits, ("state", ))
        trans = tensor_to_funsor(transition_logits,
                                 ("time", "state", "state(time=1)"))
        obs = dist_to_funsor(observation_dist, ("time", "state(time=1)"))
        dtype = obs.inputs["value"].dtype

        # Construct the joint funsor.
        with interpretation(lazy):
            # TODO perform math here once sequential_sum_product has been
            #   implemented as a first-class funsor.
            funsor_dist = Variable("value",
                                   obs.inputs["value"])  # a bogus value
            # Until funsor_dist is defined, we save factors for hand-computation in .log_prob().
            self._init = init
            self._trans = trans
            self._obs = obs

        super(DiscreteHMM, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
Exemple #20
0
def test_funsor_to_cat_and_mvn(batch_shape, event_shape, int_size, real_size):
    logits = torch.randn(batch_shape + event_shape + (int_size, ))
    expected_cat = dist.Categorical(logits=logits)
    expected_mvn = random_mvn(batch_shape + event_shape + (int_size, ),
                              real_size)
    event_dims = tuple("abc"[:len(event_shape)]) + ("component", )
    ndims = len(expected_cat.batch_shape)

    funsor_ = (tensor_to_funsor(logits, event_dims) +
               dist_to_funsor(expected_mvn, event_dims)(value="value"))
    assert isinstance(funsor_, Funsor)

    actual_cat, actual_mvn = funsor_to_cat_and_mvn(funsor_, ndims, event_dims)
    assert isinstance(actual_cat, dist.Categorical)
    assert isinstance(actual_mvn, dist.MultivariateNormal)
    assert actual_cat.batch_shape == expected_cat.batch_shape
    assert actual_mvn.batch_shape == expected_mvn.batch_shape
    assert_close(actual_cat.logits, expected_cat.logits, atol=1e-4, rtol=None)
    assert_close(actual_mvn.loc, expected_mvn.loc, atol=1e-4, rtol=None)
    assert_close(actual_mvn.precision_matrix,
                 expected_mvn.precision_matrix,
                 atol=1e-4,
                 rtol=None)
Exemple #21
0
    def forward(self, features, trip_counts):
        pyro.module("guide", self)
        assert features.dim() == 2
        assert trip_counts.dim() == 3
        observed_hours = len(trip_counts)
        log_counts = trip_counts.reshape(observed_hours, -1).log1p()
        loc_scale = ((self.diag_part * log_counts.unsqueeze(-2)).reshape(
            observed_hours, -1) + self.lowrank(
                torch.cat([features[:observed_hours], log_counts], dim=-1)))
        loc, scale = loc_scale.reshape(observed_hours, 2, -1).unbind(1)
        scale = bounded_exp(scale, bound=10.)

        if self.args.funsor:
            diag_normal = dist.Normal(loc, scale).to_event(2)
            return dist_to_funsor(diag_normal)(value="gate_rate_t")

        pyro.sample("gate_rate", dist.Normal(loc, scale).to_event(2))

        if self.args.mean_field:
            time = torch.arange(observed_hours,
                                dtype=features.dtype,
                                device=features.device)
            temp = torch.cat([
                time.unsqueeze(-1),
                (observed_hours - 1 - time).unsqueeze(-1),
                features[:observed_hours],
                log_counts,
            ],
                             dim=-1)
            temp = self.mf_layer_0(temp).sigmoid()
            temp = self.mf_layer_1(temp).sigmoid()
            temp = (self.mf_highpass * temp +
                    self.mf_lowpass * temp.mean(0, keepdim=True))
            temp = torch.cat([temp[:1], temp], dim=0)  # copy initial state.
            loc = temp[:, :self.args.state_dim]
            scale = bounded_exp(temp[:, self.args.state_dim:], bound=10.)
            pyro.sample("state", dist.Normal(loc, scale).to_event(2))
Exemple #22
0
    def forward(self, observations, add_bias=True):
        obs_dim = 2 * self.num_sensors
        bias_scale = self.log_bias_scale.exp()
        obs_noise = self.log_obs_noise.exp()
        trans_noise = self.log_trans_noise.exp()

        # bias distribution
        bias = Variable('bias', reals(obs_dim))
        assert not torch.isnan(bias_scale), "bias scales was nan"
        bias_dist = dist_to_funsor(
            dist.MultivariateNormal(
                torch.zeros(obs_dim),
                scale_tril=bias_scale *
                torch.eye(2 * self.num_sensors)))(value=bias)

        init_dist = torch.distributions.MultivariateNormal(torch.zeros(4),
                                                           scale_tril=100. *
                                                           torch.eye(4))
        self.init = dist_to_funsor(init_dist)(value="state")

        # hidden states
        prev = Variable("prev", reals(4))
        curr = Variable("curr", reals(4))
        self.trans_dist = f_dist.MultivariateNormal(
            loc=prev @ NCV_TRANSITION_MATRIX,
            scale_tril=trans_noise * NCV_PROCESS_NOISE.cholesky(),
            value=curr)

        state = Variable('state', reals(4))
        obs = Variable("obs", reals(obs_dim))
        observation_matrix = Tensor(
            torch.eye(4,
                      2).unsqueeze(-1).expand(-1, -1,
                                              self.num_sensors).reshape(4, -1))
        assert observation_matrix.output.shape == (
            4, obs_dim), observation_matrix.output.shape
        obs_loc = state @ observation_matrix
        if add_bias:
            obs_loc += bias
        self.observation_dist = f_dist.MultivariateNormal(
            loc=obs_loc, scale_tril=obs_noise * torch.eye(obs_dim), value=obs)

        logp = bias_dist
        curr = "state_init"
        logp += self.init(state=curr)
        for t, x in enumerate(observations):
            prev, curr = curr, f"state_{t}"
            logp += self.trans_dist(prev=prev, curr=curr)
            logp += self.observation_dist(state=curr, obs=x)
            # marginalize out previous state
            logp = logp.reduce(ops.logaddexp, prev)
        # marginalize out bias variable
        logp = logp.reduce(ops.logaddexp, "bias")

        # save posterior over the final state
        assert set(logp.inputs) == {f'state_{len(observations) - 1}'}
        posterior = funsor_to_mvn(logp, ndims=0)

        # marginalize out remaining variables
        logp = logp.reduce(ops.logaddexp)
        assert isinstance(logp, Tensor) and logp.shape == (), logp.pretty()
        return logp.data, posterior
Exemple #23
0
    def __init__(self,
                 initial_logits,
                 initial_mvn,
                 transition_logits,
                 transition_matrix,
                 transition_mvn,
                 observation_matrix,
                 observation_mvn,
                 exact=False,
                 validate_args=None):
        assert isinstance(initial_logits, torch.Tensor)
        assert isinstance(initial_mvn, torch.distributions.MultivariateNormal)
        assert isinstance(transition_logits, torch.Tensor)
        assert isinstance(transition_matrix, torch.Tensor)
        assert isinstance(transition_mvn,
                          torch.distributions.MultivariateNormal)
        assert isinstance(observation_matrix, torch.Tensor)
        assert isinstance(observation_mvn,
                          torch.distributions.MultivariateNormal)
        hidden_cardinality = initial_logits.size(-1)
        hidden_dim, obs_dim = observation_matrix.shape[-2:]
        assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim"
        assert initial_mvn.event_shape[0] == hidden_dim
        assert transition_logits.size(-1) == hidden_cardinality
        assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
        assert transition_mvn.event_shape[0] == hidden_dim
        assert observation_mvn.event_shape[0] == obs_dim
        init_shape = broadcast_shape(initial_logits.shape,
                                     initial_mvn.batch_shape)
        shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]),
                                transition_logits.shape[:-1],
                                transition_matrix.shape[:-2],
                                transition_mvn.batch_shape,
                                observation_matrix.shape[:-2],
                                observation_mvn.batch_shape)
        assert shape[-1] == hidden_cardinality
        batch_shape, time_shape = shape[:-2], shape[-2:-1]
        event_shape = time_shape + (obs_dim, )

        # Normalize.
        initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
        transition_logits = transition_logits - transition_logits.logsumexp(
            -1, True)

        # Convert tensors and distributions to funsors.
        init = (tensor_to_funsor(initial_logits, ("class", )) +
                dist_to_funsor(initial_mvn, ("class", ))(value="state"))
        trans = (tensor_to_funsor(transition_logits,
                                  ("time", "class", "class(time=1)")) +
                 matrix_and_mvn_to_funsor(transition_matrix, transition_mvn,
                                          ("time", "class(time=1)"), "state",
                                          "state(time=1)"))
        obs = matrix_and_mvn_to_funsor(observation_matrix, observation_mvn,
                                       ("time", "class(time=1)"),
                                       "state(time=1)", "value")
        if "class(time=1)" not in set(trans.inputs).union(obs.inputs):
            raise ValueError(
                "neither transition nor observation depend on discrete state")
        dtype = "real"

        # Construct the joint funsor.
        with interpretation(lazy):
            # TODO perform math here once sequential_sum_product has been
            #   implemented as a first-class funsor.
            funsor_dist = Variable("value",
                                   obs.inputs["value"])  # a bogus value
            # Until funsor_dist is defined, we save factors for hand-computation in .log_prob().
            self._init = init
            self._trans = trans
            self._obs = obs

        super(SwitchingLinearHMM,
              self).__init__(funsor_dist, batch_shape, event_shape, dtype,
                             validate_args)
        self.exact = exact