def test_matrix_and_mvn_to_funsor(batch_shape, event_shape, x_size, y_size): matrix = torch.randn(batch_shape + event_shape + (x_size, y_size)) y_mvn = random_mvn(batch_shape + event_shape, y_size) xy_mvn = random_mvn(batch_shape + event_shape, x_size + y_size) int_inputs = OrderedDict( (k, bint(size)) for k, size in zip("abc", event_shape)) real_inputs = OrderedDict([("x", reals(x_size)), ("y", reals(y_size))]) f = (matrix_and_mvn_to_funsor(matrix, y_mvn, tuple(int_inputs), "x", "y") + mvn_to_funsor(xy_mvn, tuple(int_inputs), real_inputs)) assert isinstance(f, Funsor) for k, d in int_inputs.items(): if d.num_elements == 1: assert d not in f.inputs else: assert k in f.inputs assert f.inputs[k] == d assert f.inputs["x"] == reals(x_size) assert f.inputs["y"] == reals(y_size) xy = torch.randn(x_size + y_size) x, y = xy[:x_size], xy[x_size:] y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2) actual_log_prob = f(x=x, y=y) expected_log_prob = tensor_to_funsor( xy_mvn.log_prob(xy) + y_mvn.log_prob(y - y_pred), tuple(int_inputs)) assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4)
def test_gaussian_hmm_log_prob(init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) assert actual_dist.batch_shape == expected_dist.batch_shape assert actual_dist.event_shape == expected_dist.event_shape shape = broadcast_shape(init_shape + (1, ), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) data = obs_dist.expand(shape).sample() assert data.shape == actual_dist.shape() actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5) check_expand(actual_dist, data)
def test_switching_linear_hmm_log_prob(exact, num_steps, hidden_dim, obs_dim, num_components): # This tests agreement between an SLDS and an HMM when all components # are identical, i.e. so latent can be marginalized out. torch.manual_seed(2) init_logits = torch.rand(num_components) init_mvn = random_mvn((), hidden_dim) trans_logits = torch.rand(num_components) trans_matrix = torch.randn(hidden_dim, hidden_dim) trans_mvn = random_mvn((), hidden_dim) obs_matrix = torch.randn(hidden_dim, obs_dim) obs_mvn = random_mvn((), obs_dim) expected_dist = GaussianHMM(init_mvn, trans_matrix.expand(num_steps, -1, -1), trans_mvn, obs_matrix, obs_mvn) actual_dist = SwitchingLinearHMM(init_logits, init_mvn, trans_logits, trans_matrix.expand( num_steps, num_components, -1, -1), trans_mvn, obs_matrix, obs_mvn, exact=exact) assert actual_dist.batch_shape == expected_dist.batch_shape assert actual_dist.event_shape == expected_dist.event_shape data = obs_mvn.sample(expected_dist.batch_shape + (num_steps, )) assert data.shape == expected_dist.shape() expected_log_prob = expected_dist.log_prob(data) assert expected_log_prob.shape == expected_dist.batch_shape actual_log_prob = actual_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=None)
def test_pyro_convert(): data = Tensor(torch.randn(2, 2), OrderedDict([("time", bint(2))])) bias_dist = dist_to_funsor(random_mvn((), 2)) trans_mat = torch.randn(3, 3) trans_mvn = random_mvn((), 3) trans = matrix_and_mvn_to_funsor(trans_mat, trans_mvn, (), "prev", "curr") obs_mat = torch.randn(3, 2) obs_mvn = random_mvn((), 2) obs = matrix_and_mvn_to_funsor(obs_mat, obs_mvn, (), "state", "obs") log_prob = 0 bias = Variable("bias", reals(2)) log_prob += bias_dist(value=bias) state_0 = Variable("state_0", reals(3)) log_prob += obs(state=state_0, obs=bias + data(time=0)) state_1 = Variable("state_1", reals(3)) log_prob += trans(prev=state_0, curr=state_1) log_prob += obs(state=state_1, obs=bias + data(time=1)) log_prob = log_prob.reduce(ops.logaddexp) assert isinstance(log_prob, Tensor), log_prob.pretty()
def test_distributions(state_dim, obs_dim): data = Tensor(torch.randn(2, obs_dim))["time"] bias = Variable("bias", reals(obs_dim)) bias_dist = dist_to_funsor(random_mvn((), obs_dim))(value=bias) prev = Variable("prev", reals(state_dim)) curr = Variable("curr", reals(state_dim)) trans_mat = Tensor( torch.eye(state_dim) + 0.1 * torch.randn(state_dim, state_dim)) trans_mvn = random_mvn((), state_dim) trans_dist = dist.MultivariateNormal(loc=trans_mvn.loc, scale_tril=trans_mvn.scale_tril, value=curr - prev @ trans_mat) state = Variable("state", reals(state_dim)) obs = Variable("obs", reals(obs_dim)) obs_mat = Tensor(torch.randn(state_dim, obs_dim)) obs_mvn = random_mvn((), obs_dim) obs_dist = dist.MultivariateNormal(loc=obs_mvn.loc, scale_tril=obs_mvn.scale_tril, value=state @ obs_mat + bias - obs) log_prob = 0 log_prob += bias_dist state_0 = Variable("state_0", reals(state_dim)) log_prob += obs_dist(state=state_0, obs=data(time=0)) state_1 = Variable("state_1", reals(state_dim)) log_prob += trans_dist(prev=state_0, curr=state_1) log_prob += obs_dist(state=state_1, obs=data(time=1)) log_prob = log_prob.reduce(ops.logaddexp) assert isinstance(log_prob, Tensor), log_prob.pretty()
def test_mvn_affine_getitem(): x = Variable('x', reals(2, 2)) data = dict(x=Tensor(torch.randn(2, 2))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 2)) d = d(value=x[0] - x[1]) _check_mvn_affine(d, data)
def test_mvn_to_funsor(batch_shape, event_shape, event_sizes): event_size = sum(event_sizes) mvn = random_mvn(batch_shape + event_shape, event_size) int_inputs = OrderedDict( (k, bint(size)) for k, size in zip("abc", event_shape)) real_inputs = OrderedDict( (k, reals(size)) for k, size in zip("xyz", event_sizes)) f = mvn_to_funsor(mvn, tuple(int_inputs), real_inputs) assert isinstance(f, Funsor) for k, d in int_inputs.items(): if d.num_elements == 1: assert d not in f.inputs else: assert k in f.inputs assert f.inputs[k] == d for k, d in real_inputs.items(): assert k in f.inputs assert f.inputs[k] == d value = mvn.sample() subs = {} beg = 0 for k, d in real_inputs.items(): end = beg + d.num_elements subs[k] = tensor_to_funsor(value[..., beg:end], tuple(int_inputs), 1) beg = end actual_log_prob = f(**subs) expected_log_prob = tensor_to_funsor(mvn.log_prob(value), tuple(int_inputs)) assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
def test_mvn_affine_one_var(): x = Variable('x', reals(2)) data = dict(x=Tensor(torch.randn(2))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 2)) d = d(value=2 * x + 1) _check_mvn_affine(d, data)
def test_gaussian_mrf_log_prob(init_shape, trans_shape, obs_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_dist = random_mvn(trans_shape, hidden_dim + hidden_dim) obs_dist = random_mvn(obs_shape, hidden_dim + obs_dim) actual_dist = GaussianMRF(init_dist, trans_dist, obs_dist) expected_dist = dist.GaussianMRF(init_dist, trans_dist, obs_dist) assert actual_dist.event_shape == expected_dist.event_shape assert actual_dist.batch_shape == expected_dist.batch_shape batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape) data = obs_dist.expand(batch_shape).sample()[..., hidden_dim:] actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4) check_expand(actual_dist, data)
def test_mvn_affine_reshape(): x = Variable('x', reals(2, 2)) y = Variable('y', reals(4)) data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(4))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 4)) d = d(value=x.reshape((4, )) - y) _check_mvn_affine(d, data)
def test_mvn_affine_two_vars(): x = Variable('x', Reals[2]) y = Variable('y', Reals[2]) data = dict(x=Tensor(randn(2)), y=Tensor(randn(2))) with interpretation(lazy): d = to_funsor(random_mvn((), 2), Real) d = d(value=x - y) _check_mvn_affine(d, data)
def test_mvn_affine_matmul(): x = Variable('x', reals(2)) y = Variable('y', reals(3)) m = Tensor(torch.randn(2, 3)) data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(3))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 3)) d = d(value=x @ m - y) _check_mvn_affine(d, data)
def test_mvn_affine_matmul_sub(): x = Variable('x', Reals[2]) y = Variable('y', Reals[3]) m = Tensor(randn(2, 3)) data = dict(x=Tensor(randn(2)), y=Tensor(randn(3))) with interpretation(lazy): d = to_funsor(random_mvn((), 3), Real) d = d(value=x @ m - y) _check_mvn_affine(d, data)
def test_mvn_affine_einsum(): c = Tensor(torch.randn(3, 2, 2)) x = Variable('x', reals(2, 2)) y = Variable('y', reals()) data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(()))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 3)) d = d(value=Einsum("abc,bc->a", c, x) + y) _check_mvn_affine(d, data)
def test_mvn_affine_matmul(): x = Variable('x', Reals[2]) y = Variable('y', Reals[3]) m = Tensor(randn(2, 3)) data = dict(x=Tensor(randn(2)), y=Tensor(randn(3))) with interpretation(lazy): d = random_mvn((), 3) d = dist.MultivariateNormal(loc=y, scale_tril=d.scale_tril, value=x @ m) _check_mvn_affine(d, data)
def test_gaussian_hmm_log_prob_null_dynamics(init_shape, trans_mat_shape, trans_mvn_shape, obs_mvn_shape, hidden_dim): obs_dim = hidden_dim init_dist = random_mvn(init_shape, hidden_dim) # impose null dynamics trans_mat = torch.zeros(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim, diag=True) # trivial observation matrix (hidden_dim = obs_dim) obs_mat = torch.eye(hidden_dim) obs_dist = random_mvn(obs_mvn_shape, obs_dim, diag=True) actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) assert actual_dist.batch_shape == expected_dist.batch_shape assert actual_dist.event_shape == expected_dist.event_shape shape = broadcast_shape(init_shape + (1, ), trans_mat_shape, trans_mvn_shape, obs_mvn_shape) data = obs_dist.expand(shape).sample() assert data.shape == actual_dist.shape() actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5) check_expand(actual_dist, data) obs_cov = obs_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2) trans_cov = trans_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2) sum_scale = (obs_cov + trans_cov).sqrt() sum_loc = trans_dist.loc + obs_dist.loc analytic_log_prob = dist.Normal(sum_loc, sum_scale).log_prob(data).sum(-1).sum(-1) assert_close(analytic_log_prob, actual_log_prob, atol=1.0e-5)
def test_switching_linear_hmm_shape(init_cat_shape, init_mvn_shape, trans_cat_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape): hidden_dim, obs_dim = obs_mat_shape[-2:] assert trans_mat_shape[-2:] == (hidden_dim, hidden_dim) init_logits = torch.randn(init_cat_shape) init_mvn = random_mvn(init_mvn_shape, hidden_dim) trans_logits = torch.randn(trans_cat_shape) trans_matrix = torch.randn(trans_mat_shape) trans_mvn = random_mvn(trans_mvn_shape, hidden_dim) obs_matrix = torch.randn(obs_mat_shape) obs_mvn = random_mvn(obs_mvn_shape, obs_dim) init_shape = broadcast_shape(init_cat_shape, init_mvn_shape) shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]), trans_cat_shape[:-1], trans_mat_shape[:-2], trans_mvn_shape, obs_mat_shape[:-2], obs_mvn_shape) expected_batch_shape, time_shape = shape[:-2], shape[-2:-1] expected_event_shape = time_shape + (obs_dim,) actual_dist = SwitchingLinearHMM(init_logits, init_mvn, trans_logits, trans_matrix, trans_mvn, obs_matrix, obs_mvn) assert actual_dist.event_shape == expected_event_shape assert actual_dist.batch_shape == expected_batch_shape data = obs_mvn.expand(shape).sample()[..., 0, :] actual_log_prob = actual_dist.log_prob(data) assert actual_log_prob.shape == expected_batch_shape check_expand(actual_dist, data) final_cat, final_mvn = actual_dist.filter(data) assert isinstance(final_cat, dist.Categorical) assert isinstance(final_mvn, dist.MultivariateNormal) assert final_cat.batch_shape == actual_dist.batch_shape assert final_mvn.batch_shape == actual_dist.batch_shape + final_cat.logits.shape[-1:]
def test_funsor_to_mvn(batch_shape, event_shape, real_size): expected = random_mvn(batch_shape + event_shape, real_size) event_dims = tuple("abc"[:len(event_shape)]) ndims = len(expected.batch_shape) funsor_ = dist_to_funsor(expected, event_dims)(value="value") assert isinstance(funsor_, Funsor) actual = funsor_to_mvn(funsor_, ndims, event_dims) assert isinstance(actual, dist.MultivariateNormal) assert actual.batch_shape == expected.batch_shape assert_close(actual.loc, expected.loc, atol=1e-3, rtol=None) assert_close(actual.precision_matrix, expected.precision_matrix, atol=1e-3, rtol=None)
def test_funsor_to_cat_and_mvn(batch_shape, event_shape, int_size, real_size): logits = torch.randn(batch_shape + event_shape + (int_size, )) expected_cat = dist.Categorical(logits=logits) expected_mvn = random_mvn(batch_shape + event_shape + (int_size, ), real_size) event_dims = tuple("abc"[:len(event_shape)]) + ("component", ) ndims = len(expected_cat.batch_shape) funsor_ = (tensor_to_funsor(logits, event_dims) + dist_to_funsor(expected_mvn, event_dims)(value="value")) assert isinstance(funsor_, Funsor) actual_cat, actual_mvn = funsor_to_cat_and_mvn(funsor_, ndims, event_dims) assert isinstance(actual_cat, dist.Categorical) assert isinstance(actual_mvn, dist.MultivariateNormal) assert actual_cat.batch_shape == expected_cat.batch_shape assert actual_mvn.batch_shape == expected_mvn.batch_shape assert_close(actual_cat.logits, expected_cat.logits, atol=1e-4, rtol=None) assert_close(actual_mvn.loc, expected_mvn.loc, atol=1e-4, rtol=None) assert_close(actual_mvn.precision_matrix, expected_mvn.precision_matrix, atol=1e-4, rtol=None)
def test_switching_linear_hmm_log_prob_alternating(exact, num_steps, num_components): # This tests agreement between an SLDS and an HMM in the case that the two # SLDS discrete states alternate back and forth between 0 and 1 deterministically torch.manual_seed(0) hidden_dim = 4 obs_dim = 3 extra_components = num_components - 2 init_logits = torch.tensor([float("-inf"), 0.0] + extra_components * [float("-inf")]) init_mvn = random_mvn((num_components, ), hidden_dim) left_logits = torch.tensor([0.0, float("-inf")] + extra_components * [float("-inf")]) right_logits = torch.tensor([float("-inf"), 0.0] + extra_components * [float("-inf")]) trans_logits = torch.stack([ left_logits if t % 2 == 0 else right_logits for t in range(num_steps) ]) trans_logits = trans_logits.unsqueeze(-2) hmm_trans_matrix = torch.randn(num_steps, hidden_dim, hidden_dim) switching_trans_matrix = hmm_trans_matrix.unsqueeze(-3).expand( -1, num_components, -1, -1) trans_mvn = random_mvn(( num_steps, num_components, ), hidden_dim) hmm_obs_matrix = torch.randn(num_steps, hidden_dim, obs_dim) switching_obs_matrix = hmm_obs_matrix.unsqueeze(-3).expand( -1, num_components, -1, -1) obs_mvn = random_mvn((num_steps, num_components), obs_dim) hmm_trans_mvn_loc = torch.empty(num_steps, hidden_dim) hmm_trans_mvn_cov = torch.empty(num_steps, hidden_dim, hidden_dim) hmm_obs_mvn_loc = torch.empty(num_steps, obs_dim) hmm_obs_mvn_cov = torch.empty(num_steps, obs_dim, obs_dim) for t in range(num_steps): # select relevant bits for hmm given deterministic dynamics in discrete space s = t % 2 # 0, 1, 0, 1, ... hmm_trans_mvn_loc[t] = trans_mvn.loc[t, s] hmm_trans_mvn_cov[t] = trans_mvn.covariance_matrix[t, s] hmm_obs_mvn_loc[t] = obs_mvn.loc[t, s] hmm_obs_mvn_cov[t] = obs_mvn.covariance_matrix[t, s] # scramble matrices in places that should never be accessed given deterministic dynamics in discrete space s = 1 - (t % 2) # 1, 0, 1, 0, ... switching_trans_matrix[t, s, :, :] = torch.rand(hidden_dim, hidden_dim) switching_obs_matrix[t, s, :, :] = torch.rand(hidden_dim, obs_dim) expected_dist = GaussianHMM( dist.MultivariateNormal(init_mvn.loc[1], init_mvn.covariance_matrix[1]), hmm_trans_matrix, dist.MultivariateNormal(hmm_trans_mvn_loc, hmm_trans_mvn_cov), hmm_obs_matrix, dist.MultivariateNormal(hmm_obs_mvn_loc, hmm_obs_mvn_cov)) actual_dist = SwitchingLinearHMM(init_logits, init_mvn, trans_logits, switching_trans_matrix, trans_mvn, switching_obs_matrix, obs_mvn, exact=exact) assert actual_dist.batch_shape == expected_dist.batch_shape assert actual_dist.event_shape == expected_dist.event_shape data = obs_mvn.sample()[:, 0, :] assert data.shape == expected_dist.shape() expected_log_prob = expected_dist.log_prob(data) assert expected_log_prob.shape == expected_dist.batch_shape actual_log_prob = actual_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-2, rtol=None)