Пример #1
0
    def __init__(self,
                 initial_dist,
                 transition_dist,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_dist, torch.distributions.MultivariateNormal)
        assert isinstance(transition_dist,
                          torch.distributions.MultivariateNormal)
        assert isinstance(observation_dist,
                          torch.distributions.MultivariateNormal)
        hidden_dim = initial_dist.event_shape[0]
        assert transition_dist.event_shape[0] == hidden_dim + hidden_dim
        obs_dim = observation_dist.event_shape[0] - hidden_dim
        shape = broadcast_shape(initial_dist.batch_shape + (1, ),
                                transition_dist.batch_shape,
                                observation_dist.batch_shape)
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + (obs_dim, )

        # Convert distributions to funsors.
        init = dist_to_funsor(initial_dist)(value="state")
        trans = mvn_to_funsor(
            transition_dist, ("time", ),
            OrderedDict([("state", Reals[hidden_dim]),
                         ("state(time=1)", Reals[hidden_dim])]))
        obs = mvn_to_funsor(
            observation_dist, ("time", ),
            OrderedDict([("state(time=1)", Reals[hidden_dim]),
                         ("value", Reals[obs_dim])]))

        # Construct the joint funsor.
        # Compare with pyro.distributions.hmm.GaussianMRF.log_prob().
        with interpretation(lazy):
            time = Variable("time", Bint[time_shape[0]])
            value = Variable("value", Reals[time_shape[0], obs_dim])
            logp_oh = trans + obs(value=value["time"])
            logp_oh = MarkovProduct(ops.logaddexp, ops.add, logp_oh, time,
                                    {"state": "state(time=1)"})
            logp_oh += init
            logp_oh = logp_oh.reduce(ops.logaddexp,
                                     frozenset({"state", "state(time=1)"}))
            logp_h = trans + obs.reduce(ops.logaddexp, "value")
            logp_h = MarkovProduct(ops.logaddexp, ops.add, logp_h, time,
                                   {"state": "state(time=1)"})
            logp_h += init
            logp_h = logp_h.reduce(ops.logaddexp,
                                   frozenset({"state", "state(time=1)"}))
            funsor_dist = logp_oh - logp_h

        dtype = "real"
        super(GaussianMRF, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
        self.hidden_dim = hidden_dim
        self.obs_dim = obs_dim
Пример #2
0
def test_mvn_to_funsor(batch_shape, event_shape, event_sizes):
    event_size = sum(event_sizes)
    mvn = random_mvn(batch_shape + event_shape, event_size)
    int_inputs = OrderedDict(
        (k, bint(size)) for k, size in zip("abc", event_shape))
    real_inputs = OrderedDict(
        (k, reals(size)) for k, size in zip("xyz", event_sizes))

    f = mvn_to_funsor(mvn, tuple(int_inputs), real_inputs)
    assert isinstance(f, Funsor)
    for k, d in int_inputs.items():
        if d.num_elements == 1:
            assert d not in f.inputs
        else:
            assert k in f.inputs
            assert f.inputs[k] == d
    for k, d in real_inputs.items():
        assert k in f.inputs
        assert f.inputs[k] == d

    value = mvn.sample()
    subs = {}
    beg = 0
    for k, d in real_inputs.items():
        end = beg + d.num_elements
        subs[k] = tensor_to_funsor(value[..., beg:end], tuple(int_inputs), 1)
        beg = end
    actual_log_prob = f(**subs)
    expected_log_prob = tensor_to_funsor(mvn.log_prob(value),
                                         tuple(int_inputs))
    assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
Пример #3
0
def test_matrix_and_mvn_to_funsor(batch_shape, event_shape, x_size, y_size):
    matrix = torch.randn(batch_shape + event_shape + (x_size, y_size))
    y_mvn = random_mvn(batch_shape + event_shape, y_size)
    xy_mvn = random_mvn(batch_shape + event_shape, x_size + y_size)
    int_inputs = OrderedDict(
        (k, bint(size)) for k, size in zip("abc", event_shape))
    real_inputs = OrderedDict([("x", reals(x_size)), ("y", reals(y_size))])

    f = (matrix_and_mvn_to_funsor(matrix, y_mvn, tuple(int_inputs), "x", "y") +
         mvn_to_funsor(xy_mvn, tuple(int_inputs), real_inputs))
    assert isinstance(f, Funsor)
    for k, d in int_inputs.items():
        if d.num_elements == 1:
            assert d not in f.inputs
        else:
            assert k in f.inputs
            assert f.inputs[k] == d
    assert f.inputs["x"] == reals(x_size)
    assert f.inputs["y"] == reals(y_size)

    xy = torch.randn(x_size + y_size)
    x, y = xy[:x_size], xy[x_size:]
    y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2)
    actual_log_prob = f(x=x, y=y)
    expected_log_prob = tensor_to_funsor(
        xy_mvn.log_prob(xy) + y_mvn.log_prob(y - y_pred), tuple(int_inputs))
    assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4)
Пример #4
0
    def __init__(self,
                 num_components,   # the number of switching states K
                 hidden_dim,       # the dimension of the continuous latent space
                 obs_dim,          # the dimension of the continuous outputs
                 fine_transition_matrix=True,    # controls whether the transition matrix depends on s_t
                 fine_transition_noise=False,    # controls whether the transition noise depends on s_t
                 fine_observation_matrix=False,  # controls whether the observation matrix depends on s_t
                 fine_observation_noise=False,   # controls whether the observation noise depends on s_t
                 moment_matching_lag=1):         # controls the expense of the moment matching approximation

        self.num_components = num_components
        self.hidden_dim = hidden_dim
        self.obs_dim = obs_dim
        self.moment_matching_lag = moment_matching_lag
        self.fine_transition_noise = fine_transition_noise
        self.fine_observation_matrix = fine_observation_matrix
        self.fine_observation_noise = fine_observation_noise
        self.fine_transition_matrix = fine_transition_matrix

        assert moment_matching_lag > 0
        assert fine_transition_noise or fine_observation_matrix or fine_observation_noise or fine_transition_matrix, \
            "The continuous dynamics need to be coupled to the discrete dynamics in at least one way [use at " + \
            "least one of the arguments --ftn --ftm --fon --fom]"

        super(SLDS, self).__init__()

        # initialize the various parameters of the model
        self.transition_logits = nn.Parameter(0.1 * torch.randn(num_components, num_components))
        if fine_transition_matrix:
            transition_matrix = torch.eye(hidden_dim) + 0.05 * torch.randn(num_components, hidden_dim, hidden_dim)
        else:
            transition_matrix = torch.eye(hidden_dim) + 0.05 * torch.randn(hidden_dim, hidden_dim)
        self.transition_matrix = nn.Parameter(transition_matrix)
        if fine_transition_noise:
            self.log_transition_noise = nn.Parameter(0.1 * torch.randn(num_components, hidden_dim))
        else:
            self.log_transition_noise = nn.Parameter(0.1 * torch.randn(hidden_dim))
        if fine_observation_matrix:
            self.observation_matrix = nn.Parameter(0.3 * torch.randn(num_components, hidden_dim, obs_dim))
        else:
            self.observation_matrix = nn.Parameter(0.3 * torch.randn(hidden_dim, obs_dim))
        if fine_observation_noise:
            self.log_obs_noise = nn.Parameter(0.1 * torch.randn(num_components, obs_dim))
        else:
            self.log_obs_noise = nn.Parameter(0.1 * torch.randn(obs_dim))

        # define the prior distribution p(x_0) over the continuous latent at the initial time step t=0
        x_init_mvn = pyro.distributions.MultivariateNormal(torch.zeros(self.hidden_dim), torch.eye(self.hidden_dim))
        self.x_init_mvn = mvn_to_funsor(x_init_mvn, real_inputs=OrderedDict([('x_0', funsor.Reals[self.hidden_dim])]))