예제 #1
0
def test_matrix_and_mvn_to_funsor(batch_shape, event_shape, x_size, y_size):
    matrix = torch.randn(batch_shape + event_shape + (x_size, y_size))
    y_mvn = random_mvn(batch_shape + event_shape, y_size)
    xy_mvn = random_mvn(batch_shape + event_shape, x_size + y_size)
    int_inputs = OrderedDict(
        (k, bint(size)) for k, size in zip("abc", event_shape))
    real_inputs = OrderedDict([("x", reals(x_size)), ("y", reals(y_size))])

    f = (matrix_and_mvn_to_funsor(matrix, y_mvn, tuple(int_inputs), "x", "y") +
         mvn_to_funsor(xy_mvn, tuple(int_inputs), real_inputs))
    assert isinstance(f, Funsor)
    for k, d in int_inputs.items():
        if d.num_elements == 1:
            assert d not in f.inputs
        else:
            assert k in f.inputs
            assert f.inputs[k] == d
    assert f.inputs["x"] == reals(x_size)
    assert f.inputs["y"] == reals(y_size)

    xy = torch.randn(x_size + y_size)
    x, y = xy[:x_size], xy[x_size:]
    y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2)
    actual_log_prob = f(x=x, y=y)
    expected_log_prob = tensor_to_funsor(
        xy_mvn.log_prob(xy) + y_mvn.log_prob(y - y_pred), tuple(int_inputs))
    assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4)
예제 #2
0
def test_gaussian_hmm_log_prob(init_shape, trans_mat_shape, trans_mvn_shape,
                               obs_mat_shape, obs_mvn_shape, hidden_dim,
                               obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)

    actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                              obs_dist)
    expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                                     obs_dist)
    assert actual_dist.batch_shape == expected_dist.batch_shape
    assert actual_dist.event_shape == expected_dist.event_shape

    shape = broadcast_shape(init_shape + (1, ), trans_mat_shape,
                            trans_mvn_shape, obs_mat_shape, obs_mvn_shape)
    data = obs_dist.expand(shape).sample()
    assert data.shape == actual_dist.shape()

    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
    check_expand(actual_dist, data)
예제 #3
0
def test_switching_linear_hmm_log_prob(exact, num_steps, hidden_dim, obs_dim,
                                       num_components):
    # This tests agreement between an SLDS and an HMM when all components
    # are identical, i.e. so latent can be marginalized out.
    torch.manual_seed(2)
    init_logits = torch.rand(num_components)
    init_mvn = random_mvn((), hidden_dim)
    trans_logits = torch.rand(num_components)
    trans_matrix = torch.randn(hidden_dim, hidden_dim)
    trans_mvn = random_mvn((), hidden_dim)
    obs_matrix = torch.randn(hidden_dim, obs_dim)
    obs_mvn = random_mvn((), obs_dim)

    expected_dist = GaussianHMM(init_mvn,
                                trans_matrix.expand(num_steps, -1, -1),
                                trans_mvn, obs_matrix, obs_mvn)
    actual_dist = SwitchingLinearHMM(init_logits,
                                     init_mvn,
                                     trans_logits,
                                     trans_matrix.expand(
                                         num_steps, num_components, -1, -1),
                                     trans_mvn,
                                     obs_matrix,
                                     obs_mvn,
                                     exact=exact)
    assert actual_dist.batch_shape == expected_dist.batch_shape
    assert actual_dist.event_shape == expected_dist.event_shape

    data = obs_mvn.sample(expected_dist.batch_shape + (num_steps, ))
    assert data.shape == expected_dist.shape()
    expected_log_prob = expected_dist.log_prob(data)
    assert expected_log_prob.shape == expected_dist.batch_shape
    actual_log_prob = actual_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=None)
예제 #4
0
def test_pyro_convert():
    data = Tensor(torch.randn(2, 2), OrderedDict([("time", bint(2))]))

    bias_dist = dist_to_funsor(random_mvn((), 2))

    trans_mat = torch.randn(3, 3)
    trans_mvn = random_mvn((), 3)
    trans = matrix_and_mvn_to_funsor(trans_mat, trans_mvn, (), "prev", "curr")

    obs_mat = torch.randn(3, 2)
    obs_mvn = random_mvn((), 2)
    obs = matrix_and_mvn_to_funsor(obs_mat, obs_mvn, (), "state", "obs")

    log_prob = 0
    bias = Variable("bias", reals(2))
    log_prob += bias_dist(value=bias)

    state_0 = Variable("state_0", reals(3))
    log_prob += obs(state=state_0, obs=bias + data(time=0))

    state_1 = Variable("state_1", reals(3))
    log_prob += trans(prev=state_0, curr=state_1)
    log_prob += obs(state=state_1, obs=bias + data(time=1))

    log_prob = log_prob.reduce(ops.logaddexp)
    assert isinstance(log_prob, Tensor), log_prob.pretty()
예제 #5
0
def test_distributions(state_dim, obs_dim):
    data = Tensor(torch.randn(2, obs_dim))["time"]

    bias = Variable("bias", reals(obs_dim))
    bias_dist = dist_to_funsor(random_mvn((), obs_dim))(value=bias)

    prev = Variable("prev", reals(state_dim))
    curr = Variable("curr", reals(state_dim))
    trans_mat = Tensor(
        torch.eye(state_dim) + 0.1 * torch.randn(state_dim, state_dim))
    trans_mvn = random_mvn((), state_dim)
    trans_dist = dist.MultivariateNormal(loc=trans_mvn.loc,
                                         scale_tril=trans_mvn.scale_tril,
                                         value=curr - prev @ trans_mat)

    state = Variable("state", reals(state_dim))
    obs = Variable("obs", reals(obs_dim))
    obs_mat = Tensor(torch.randn(state_dim, obs_dim))
    obs_mvn = random_mvn((), obs_dim)
    obs_dist = dist.MultivariateNormal(loc=obs_mvn.loc,
                                       scale_tril=obs_mvn.scale_tril,
                                       value=state @ obs_mat + bias - obs)

    log_prob = 0
    log_prob += bias_dist

    state_0 = Variable("state_0", reals(state_dim))
    log_prob += obs_dist(state=state_0, obs=data(time=0))

    state_1 = Variable("state_1", reals(state_dim))
    log_prob += trans_dist(prev=state_0, curr=state_1)
    log_prob += obs_dist(state=state_1, obs=data(time=1))

    log_prob = log_prob.reduce(ops.logaddexp)
    assert isinstance(log_prob, Tensor), log_prob.pretty()
예제 #6
0
def test_mvn_affine_getitem():
    x = Variable('x', reals(2, 2))
    data = dict(x=Tensor(torch.randn(2, 2)))
    with interpretation(lazy):
        d = dist_to_funsor(random_mvn((), 2))
        d = d(value=x[0] - x[1])
    _check_mvn_affine(d, data)
예제 #7
0
def test_mvn_to_funsor(batch_shape, event_shape, event_sizes):
    event_size = sum(event_sizes)
    mvn = random_mvn(batch_shape + event_shape, event_size)
    int_inputs = OrderedDict(
        (k, bint(size)) for k, size in zip("abc", event_shape))
    real_inputs = OrderedDict(
        (k, reals(size)) for k, size in zip("xyz", event_sizes))

    f = mvn_to_funsor(mvn, tuple(int_inputs), real_inputs)
    assert isinstance(f, Funsor)
    for k, d in int_inputs.items():
        if d.num_elements == 1:
            assert d not in f.inputs
        else:
            assert k in f.inputs
            assert f.inputs[k] == d
    for k, d in real_inputs.items():
        assert k in f.inputs
        assert f.inputs[k] == d

    value = mvn.sample()
    subs = {}
    beg = 0
    for k, d in real_inputs.items():
        end = beg + d.num_elements
        subs[k] = tensor_to_funsor(value[..., beg:end], tuple(int_inputs), 1)
        beg = end
    actual_log_prob = f(**subs)
    expected_log_prob = tensor_to_funsor(mvn.log_prob(value),
                                         tuple(int_inputs))
    assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
예제 #8
0
def test_mvn_affine_one_var():
    x = Variable('x', reals(2))
    data = dict(x=Tensor(torch.randn(2)))
    with interpretation(lazy):
        d = dist_to_funsor(random_mvn((), 2))
        d = d(value=2 * x + 1)
    _check_mvn_affine(d, data)
예제 #9
0
def test_gaussian_mrf_log_prob(init_shape, trans_shape, obs_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_dist = random_mvn(trans_shape, hidden_dim + hidden_dim)
    obs_dist = random_mvn(obs_shape, hidden_dim + obs_dim)

    actual_dist = GaussianMRF(init_dist, trans_dist, obs_dist)
    expected_dist = dist.GaussianMRF(init_dist, trans_dist, obs_dist)
    assert actual_dist.event_shape == expected_dist.event_shape
    assert actual_dist.batch_shape == expected_dist.batch_shape

    batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape)
    data = obs_dist.expand(batch_shape).sample()[..., hidden_dim:]
    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4)
    check_expand(actual_dist, data)
예제 #10
0
def test_mvn_affine_reshape():
    x = Variable('x', reals(2, 2))
    y = Variable('y', reals(4))
    data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(4)))
    with interpretation(lazy):
        d = dist_to_funsor(random_mvn((), 4))
        d = d(value=x.reshape((4, )) - y)
    _check_mvn_affine(d, data)
예제 #11
0
def test_mvn_affine_two_vars():
    x = Variable('x', Reals[2])
    y = Variable('y', Reals[2])
    data = dict(x=Tensor(randn(2)), y=Tensor(randn(2)))
    with interpretation(lazy):
        d = to_funsor(random_mvn((), 2), Real)
        d = d(value=x - y)
    _check_mvn_affine(d, data)
예제 #12
0
def test_mvn_affine_matmul():
    x = Variable('x', reals(2))
    y = Variable('y', reals(3))
    m = Tensor(torch.randn(2, 3))
    data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(3)))
    with interpretation(lazy):
        d = dist_to_funsor(random_mvn((), 3))
        d = d(value=x @ m - y)
    _check_mvn_affine(d, data)
예제 #13
0
def test_mvn_affine_matmul_sub():
    x = Variable('x', Reals[2])
    y = Variable('y', Reals[3])
    m = Tensor(randn(2, 3))
    data = dict(x=Tensor(randn(2)), y=Tensor(randn(3)))
    with interpretation(lazy):
        d = to_funsor(random_mvn((), 3), Real)
        d = d(value=x @ m - y)
    _check_mvn_affine(d, data)
예제 #14
0
def test_mvn_affine_einsum():
    c = Tensor(torch.randn(3, 2, 2))
    x = Variable('x', reals(2, 2))
    y = Variable('y', reals())
    data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(())))
    with interpretation(lazy):
        d = dist_to_funsor(random_mvn((), 3))
        d = d(value=Einsum("abc,bc->a", c, x) + y)
    _check_mvn_affine(d, data)
예제 #15
0
def test_mvn_affine_matmul():
    x = Variable('x', Reals[2])
    y = Variable('y', Reals[3])
    m = Tensor(randn(2, 3))
    data = dict(x=Tensor(randn(2)), y=Tensor(randn(3)))
    with interpretation(lazy):
        d = random_mvn((), 3)
        d = dist.MultivariateNormal(loc=y, scale_tril=d.scale_tril, value=x @ m)
    _check_mvn_affine(d, data)
예제 #16
0
def test_gaussian_hmm_log_prob_null_dynamics(init_shape, trans_mat_shape,
                                             trans_mvn_shape, obs_mvn_shape,
                                             hidden_dim):
    obs_dim = hidden_dim
    init_dist = random_mvn(init_shape, hidden_dim)

    # impose null dynamics
    trans_mat = torch.zeros(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim, diag=True)

    # trivial observation matrix (hidden_dim = obs_dim)
    obs_mat = torch.eye(hidden_dim)
    obs_dist = random_mvn(obs_mvn_shape, obs_dim, diag=True)

    actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                              obs_dist)
    expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                                     obs_dist)
    assert actual_dist.batch_shape == expected_dist.batch_shape
    assert actual_dist.event_shape == expected_dist.event_shape

    shape = broadcast_shape(init_shape + (1, ), trans_mat_shape,
                            trans_mvn_shape, obs_mvn_shape)
    data = obs_dist.expand(shape).sample()
    assert data.shape == actual_dist.shape()

    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)

    assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
    check_expand(actual_dist, data)

    obs_cov = obs_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2)
    trans_cov = trans_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2)
    sum_scale = (obs_cov + trans_cov).sqrt()
    sum_loc = trans_dist.loc + obs_dist.loc

    analytic_log_prob = dist.Normal(sum_loc,
                                    sum_scale).log_prob(data).sum(-1).sum(-1)
    assert_close(analytic_log_prob, actual_log_prob, atol=1.0e-5)
예제 #17
0
def test_switching_linear_hmm_shape(init_cat_shape, init_mvn_shape,
                                    trans_cat_shape, trans_mat_shape, trans_mvn_shape,
                                    obs_mat_shape, obs_mvn_shape):
    hidden_dim, obs_dim = obs_mat_shape[-2:]
    assert trans_mat_shape[-2:] == (hidden_dim, hidden_dim)

    init_logits = torch.randn(init_cat_shape)
    init_mvn = random_mvn(init_mvn_shape, hidden_dim)
    trans_logits = torch.randn(trans_cat_shape)
    trans_matrix = torch.randn(trans_mat_shape)
    trans_mvn = random_mvn(trans_mvn_shape, hidden_dim)
    obs_matrix = torch.randn(obs_mat_shape)
    obs_mvn = random_mvn(obs_mvn_shape, obs_dim)

    init_shape = broadcast_shape(init_cat_shape, init_mvn_shape)
    shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]),
                            trans_cat_shape[:-1],
                            trans_mat_shape[:-2],
                            trans_mvn_shape,
                            obs_mat_shape[:-2],
                            obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-2], shape[-2:-1]
    expected_event_shape = time_shape + (obs_dim,)

    actual_dist = SwitchingLinearHMM(init_logits, init_mvn,
                                     trans_logits, trans_matrix, trans_mvn,
                                     obs_matrix, obs_mvn)
    assert actual_dist.event_shape == expected_event_shape
    assert actual_dist.batch_shape == expected_batch_shape

    data = obs_mvn.expand(shape).sample()[..., 0, :]
    actual_log_prob = actual_dist.log_prob(data)
    assert actual_log_prob.shape == expected_batch_shape
    check_expand(actual_dist, data)

    final_cat, final_mvn = actual_dist.filter(data)
    assert isinstance(final_cat, dist.Categorical)
    assert isinstance(final_mvn, dist.MultivariateNormal)
    assert final_cat.batch_shape == actual_dist.batch_shape
    assert final_mvn.batch_shape == actual_dist.batch_shape + final_cat.logits.shape[-1:]
예제 #18
0
def test_funsor_to_mvn(batch_shape, event_shape, real_size):
    expected = random_mvn(batch_shape + event_shape, real_size)
    event_dims = tuple("abc"[:len(event_shape)])
    ndims = len(expected.batch_shape)

    funsor_ = dist_to_funsor(expected, event_dims)(value="value")
    assert isinstance(funsor_, Funsor)

    actual = funsor_to_mvn(funsor_, ndims, event_dims)
    assert isinstance(actual, dist.MultivariateNormal)
    assert actual.batch_shape == expected.batch_shape
    assert_close(actual.loc, expected.loc, atol=1e-3, rtol=None)
    assert_close(actual.precision_matrix,
                 expected.precision_matrix,
                 atol=1e-3,
                 rtol=None)
예제 #19
0
def test_funsor_to_cat_and_mvn(batch_shape, event_shape, int_size, real_size):
    logits = torch.randn(batch_shape + event_shape + (int_size, ))
    expected_cat = dist.Categorical(logits=logits)
    expected_mvn = random_mvn(batch_shape + event_shape + (int_size, ),
                              real_size)
    event_dims = tuple("abc"[:len(event_shape)]) + ("component", )
    ndims = len(expected_cat.batch_shape)

    funsor_ = (tensor_to_funsor(logits, event_dims) +
               dist_to_funsor(expected_mvn, event_dims)(value="value"))
    assert isinstance(funsor_, Funsor)

    actual_cat, actual_mvn = funsor_to_cat_and_mvn(funsor_, ndims, event_dims)
    assert isinstance(actual_cat, dist.Categorical)
    assert isinstance(actual_mvn, dist.MultivariateNormal)
    assert actual_cat.batch_shape == expected_cat.batch_shape
    assert actual_mvn.batch_shape == expected_mvn.batch_shape
    assert_close(actual_cat.logits, expected_cat.logits, atol=1e-4, rtol=None)
    assert_close(actual_mvn.loc, expected_mvn.loc, atol=1e-4, rtol=None)
    assert_close(actual_mvn.precision_matrix,
                 expected_mvn.precision_matrix,
                 atol=1e-4,
                 rtol=None)
예제 #20
0
def test_switching_linear_hmm_log_prob_alternating(exact, num_steps,
                                                   num_components):
    # This tests agreement between an SLDS and an HMM in the case that the two
    # SLDS discrete states alternate back and forth between 0 and 1 deterministically

    torch.manual_seed(0)

    hidden_dim = 4
    obs_dim = 3
    extra_components = num_components - 2

    init_logits = torch.tensor([float("-inf"), 0.0] +
                               extra_components * [float("-inf")])
    init_mvn = random_mvn((num_components, ), hidden_dim)

    left_logits = torch.tensor([0.0, float("-inf")] +
                               extra_components * [float("-inf")])
    right_logits = torch.tensor([float("-inf"), 0.0] +
                                extra_components * [float("-inf")])
    trans_logits = torch.stack([
        left_logits if t % 2 == 0 else right_logits for t in range(num_steps)
    ])
    trans_logits = trans_logits.unsqueeze(-2)

    hmm_trans_matrix = torch.randn(num_steps, hidden_dim, hidden_dim)
    switching_trans_matrix = hmm_trans_matrix.unsqueeze(-3).expand(
        -1, num_components, -1, -1)

    trans_mvn = random_mvn((
        num_steps,
        num_components,
    ), hidden_dim)
    hmm_obs_matrix = torch.randn(num_steps, hidden_dim, obs_dim)
    switching_obs_matrix = hmm_obs_matrix.unsqueeze(-3).expand(
        -1, num_components, -1, -1)
    obs_mvn = random_mvn((num_steps, num_components), obs_dim)

    hmm_trans_mvn_loc = torch.empty(num_steps, hidden_dim)
    hmm_trans_mvn_cov = torch.empty(num_steps, hidden_dim, hidden_dim)
    hmm_obs_mvn_loc = torch.empty(num_steps, obs_dim)
    hmm_obs_mvn_cov = torch.empty(num_steps, obs_dim, obs_dim)

    for t in range(num_steps):
        # select relevant bits for hmm given deterministic dynamics in discrete space
        s = t % 2  # 0, 1, 0, 1, ...
        hmm_trans_mvn_loc[t] = trans_mvn.loc[t, s]
        hmm_trans_mvn_cov[t] = trans_mvn.covariance_matrix[t, s]
        hmm_obs_mvn_loc[t] = obs_mvn.loc[t, s]
        hmm_obs_mvn_cov[t] = obs_mvn.covariance_matrix[t, s]

        # scramble matrices in places that should never be accessed given deterministic dynamics in discrete space
        s = 1 - (t % 2)  # 1, 0, 1, 0, ...
        switching_trans_matrix[t, s, :, :] = torch.rand(hidden_dim, hidden_dim)
        switching_obs_matrix[t, s, :, :] = torch.rand(hidden_dim, obs_dim)

    expected_dist = GaussianHMM(
        dist.MultivariateNormal(init_mvn.loc[1],
                                init_mvn.covariance_matrix[1]),
        hmm_trans_matrix,
        dist.MultivariateNormal(hmm_trans_mvn_loc,
                                hmm_trans_mvn_cov), hmm_obs_matrix,
        dist.MultivariateNormal(hmm_obs_mvn_loc, hmm_obs_mvn_cov))

    actual_dist = SwitchingLinearHMM(init_logits,
                                     init_mvn,
                                     trans_logits,
                                     switching_trans_matrix,
                                     trans_mvn,
                                     switching_obs_matrix,
                                     obs_mvn,
                                     exact=exact)

    assert actual_dist.batch_shape == expected_dist.batch_shape
    assert actual_dist.event_shape == expected_dist.event_shape

    data = obs_mvn.sample()[:, 0, :]
    assert data.shape == expected_dist.shape()
    expected_log_prob = expected_dist.log_prob(data)
    assert expected_log_prob.shape == expected_dist.batch_shape
    actual_log_prob = actual_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, atol=1e-2, rtol=None)