def test_distributions(state_dim, obs_dim): data = Tensor(torch.randn(2, obs_dim))["time"] bias = Variable("bias", reals(obs_dim)) bias_dist = dist_to_funsor(random_mvn((), obs_dim))(value=bias) prev = Variable("prev", reals(state_dim)) curr = Variable("curr", reals(state_dim)) trans_mat = Tensor( torch.eye(state_dim) + 0.1 * torch.randn(state_dim, state_dim)) trans_mvn = random_mvn((), state_dim) trans_dist = dist.MultivariateNormal(loc=trans_mvn.loc, scale_tril=trans_mvn.scale_tril, value=curr - prev @ trans_mat) state = Variable("state", reals(state_dim)) obs = Variable("obs", reals(obs_dim)) obs_mat = Tensor(torch.randn(state_dim, obs_dim)) obs_mvn = random_mvn((), obs_dim) obs_dist = dist.MultivariateNormal(loc=obs_mvn.loc, scale_tril=obs_mvn.scale_tril, value=state @ obs_mat + bias - obs) log_prob = 0 log_prob += bias_dist state_0 = Variable("state_0", reals(state_dim)) log_prob += obs_dist(state=state_0, obs=data(time=0)) state_1 = Variable("state_1", reals(state_dim)) log_prob += trans_dist(prev=state_0, curr=state_1) log_prob += obs_dist(state=state_1, obs=data(time=1)) log_prob = log_prob.reduce(ops.logaddexp) assert isinstance(log_prob, Tensor), log_prob.pretty()
def test_mvn_gaussian(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) loc = Tensor(torch.randn(batch_shape + (3, )), inputs) scale_tril = Tensor(_random_scale_tril(batch_shape + (3, 3)), inputs) value = Tensor(torch.randn(batch_shape + (3, )), inputs) expected = dist.MultivariateNormal(loc, scale_tril, value) assert isinstance(expected, Tensor) check_funsor(expected, inputs, reals()) g = dist.MultivariateNormal(loc, scale_tril, 'value') assert isinstance(g, Contraction) actual = g(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected, atol=1e-3, rtol=1e-4)
def test_mvn_affine_matmul(): x = Variable('x', reals(2)) y = Variable('y', reals(3)) m = Tensor(torch.randn(2, 3)) data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(3))) with interpretation(lazy): d = random_mvn((), 3) d = dist.MultivariateNormal(loc=y, scale_tril=d.scale_tril, value=x @ m) _check_mvn_affine(d, data)
def test_mvn_sample(with_lazy, batch_shape, sample_inputs, event_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) loc = Tensor(torch.randn(batch_shape + event_shape), inputs) scale_tril = Tensor(_random_scale_tril(batch_shape + event_shape * 2), inputs) with interpretation(lazy if with_lazy else eager): funsor_dist = dist.MultivariateNormal(loc, scale_tril) _check_sample(funsor_dist, sample_inputs, inputs, atol=5e-2, num_samples=200000)
def test_mvn_density(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @funsor.torch.function(reals(3), reals(3, 3), reals(3), reals()) def mvn(loc, scale_tril, value): return torch.distributions.MultivariateNormal( loc, scale_tril=scale_tril).log_prob(value) check_funsor(mvn, { 'loc': reals(3), 'scale_tril': reals(3, 3), 'value': reals(3) }, reals()) loc = Tensor(torch.randn(batch_shape + (3, )), inputs) scale_tril = Tensor(_random_scale_tril(batch_shape + (3, 3)), inputs) value = Tensor(torch.randn(batch_shape + (3, )), inputs) expected = mvn(loc, scale_tril, value) check_funsor(expected, inputs, reals()) actual = dist.MultivariateNormal(loc, scale_tril, value) check_funsor(actual, inputs, reals()) assert_close(actual, expected)
def test_mvn_defaults(): loc = Variable('loc', reals(3)) scale_tril = Variable('scale', reals(3, 3)) value = Variable('value', reals(3)) assert dist.MultivariateNormal(loc, scale_tril) is dist.MultivariateNormal( loc, scale_tril, value)
def test_bart(analytic_kl): global call_count call_count = 0 with interpretation(reflect): q = Independent( Independent( Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor( [[ -0.6077086925506592, -1.1546266078948975, -0.7021151781082153, -0.5303535461425781, -0.6365622282028198, -1.2423288822174072, -0.9941254258155823, -0.6287292242050171 ], [ -0.6987162828445435, -1.0875964164733887, -0.7337473630905151, -0.4713417589664459, -0.6674002408981323, -1.2478348016738892, -0.8939017057418823, -0.5238542556762695 ]], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ), 'real'), Gaussian( torch.tensor([ [[-0.3536059558391571], [-0.21779225766658783], [0.2840439975261688], [0.4531521499156952], [-0.1220812276005745], [-0.05519985035061836], [0.10932210087776184], [0.6656699776649475]], [[-0.39107921719551086], [ -0.20241987705230713 ], [0.2170514464378357], [0.4500560462474823], [0.27945515513420105], [-0.0490039587020874], [-0.06399798393249512], [0.846565842628479]] ], dtype=torch.float32), # noqa torch.tensor([ [[[1.984686255455017]], [[0.6699360013008118]], [[1.6215802431106567]], [[2.372016668319702]], [[1.77385413646698]], [[0.526767373085022]], [[0.8722561597824097]], [[2.1879124641418457]] ], [[[1.6996612548828125]], [[ 0.7535632252693176 ]], [[1.4946647882461548]], [[2.642792224884033]], [[1.7301604747772217]], [[0.5203893780708313]], [[1.055436372756958]], [[2.8370864391326904]]] ], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ( 'value_b1', reals(), ), )), )), 'gate_rate_b3', '_event_1_b2', 'value_b1'), 'gate_rate_t', 'time_b4', 'gate_rate_b3') p_prior = Contraction( ops.logaddexp, ops.add, frozenset({'state(time=1)_b11', 'state_b10'}), ( MarkovProduct( ops.logaddexp, ops.add, Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor(2.7672932147979736, dtype=torch.float32), (), 'real'), Gaussian( torch.tensor([-0.0, -0.0, 0.0, 0.0], dtype=torch.float32), torch.tensor([[ 98.01002502441406, 0.0, -99.0000228881836, -0.0 ], [ 0.0, 98.01002502441406, -0.0, -99.0000228881836 ], [ -99.0000228881836, -0.0, 100.0000228881836, 0.0 ], [ -0.0, -99.0000228881836, 0.0, 100.0000228881836 ]], dtype=torch.float32), # noqa ( ( 'state_b7', reals(2, ), ), ( 'state(time=1)_b8', reals(2, ), ), )), Subs( AffineNormal( Tensor( torch.tensor( [[ 0.03488487750291824, 0.07356668263673782, 0.19946961104869843, 0.5386509299278259, -0.708323061466217, 0.24411526322364807, -0.20855577290058136, -0.2421337217092514 ], [ 0.41762110590934753, 0.5272183418273926, -0.49835553765296936, -0.0363837406039238, -0.0005282597267068923, 0.2704298794269562, -0.155222088098526, -0.44802337884902954 ]], dtype=torch.float32), # noqa (), 'real'), Tensor( torch.tensor( [[ -0.003566693514585495, -0.2848514914512634, 0.037103548645973206, 0.12648648023605347, -0.18501518666744232, -0.20899859070777893, 0.04121830314397812, 0.0054807960987091064 ], [ 0.0021788496524095535, -0.18700894713401794, 0.08187370002269745, 0.13554862141609192, -0.10477752983570099, -0.20848378539085388, -0.01393645629286766, 0.011670656502246857 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Tensor( torch.tensor( [[ 0.5974780917167664, 0.864071786403656, 1.0236268043518066, 0.7147538065910339, 0.7423890233039856, 0.9462157487869263, 1.2132389545440674, 1.0596832036972046 ], [ 0.5787821412086487, 0.9178534150123596, 0.9074794054031372, 0.6600189208984375, 0.8473222255706787, 0.8426999449729919, 1.194266438484192, 1.0471148490905762 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Variable('state(time=1)_b8', reals(2, )), Variable('gate_rate_b6', reals(8, ))), (( 'gate_rate_b6', Binary( ops.GetitemOp(0), Variable('gate_rate_t', reals(2, 8)), Variable('time_b9', bint(2))), ), )), )), Variable('time_b9', bint(2)), frozenset({('state_b7', 'state(time=1)_b8')}), frozenset({('state(time=1)_b8', 'state(time=1)_b11'), ('state_b7', 'state_b10')})), # noqa Subs( dist.MultivariateNormal( Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32), (), 'real'), Tensor( torch.tensor([[10.0, 0.0], [0.0, 10.0]], dtype=torch.float32), (), 'real'), Variable('value_b5', reals(2, ))), (( 'value_b5', Variable('state_b10', reals(2, )), ), )), )) p_likelihood = Contraction( ops.add, ops.nullop, frozenset({'time_b17', 'destin_b16', 'origin_b15'}), ( Contraction( ops.logaddexp, ops.add, frozenset({'gated_b14'}), ( dist.Categorical( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_0, reals(2, 2, 2), (Variable('gate_rate_b12', reals(8, )), )), (( 'gate_rate_b12', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable('time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Variable('gated_b14', bint(2))), Stack( 'gated_b14', ( dist.Poisson( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_1, reals(2, 2), (Variable( 'gate_rate_b13', reals(8, )), )), (( 'gate_rate_b13', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable( 'time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), dist.Delta( Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), )), )), )) if analytic_kl: exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t") with interpretation(monte_carlo): approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t") elbo = exact_part + approx_part else: p = p_prior + p_likelihood with interpretation(monte_carlo): elbo = Integrate(q, p - q, "gate_rate_t") assert isinstance(elbo, Tensor), elbo.pretty() assert call_count == 1
def forward(self, observations, add_bias=True): obs_dim = 2 * self.num_sensors bias_scale = self.log_bias_scale.exp() obs_noise = self.log_obs_noise.exp() trans_noise = self.log_trans_noise.exp() # bias distribution bias = Variable('bias', reals(obs_dim)) assert not torch.isnan(bias_scale), "bias scales was nan" bias_dist = dist_to_funsor( dist.MultivariateNormal( torch.zeros(obs_dim), scale_tril=bias_scale * torch.eye(2 * self.num_sensors)))(value=bias) init_dist = torch.distributions.MultivariateNormal(torch.zeros(4), scale_tril=100. * torch.eye(4)) self.init = dist_to_funsor(init_dist)(value="state") # hidden states prev = Variable("prev", reals(4)) curr = Variable("curr", reals(4)) self.trans_dist = f_dist.MultivariateNormal( loc=prev @ NCV_TRANSITION_MATRIX, scale_tril=trans_noise * NCV_PROCESS_NOISE.cholesky(), value=curr) state = Variable('state', reals(4)) obs = Variable("obs", reals(obs_dim)) observation_matrix = Tensor( torch.eye(4, 2).unsqueeze(-1).expand(-1, -1, self.num_sensors).reshape(4, -1)) assert observation_matrix.output.shape == ( 4, obs_dim), observation_matrix.output.shape obs_loc = state @ observation_matrix if add_bias: obs_loc += bias self.observation_dist = f_dist.MultivariateNormal( loc=obs_loc, scale_tril=obs_noise * torch.eye(obs_dim), value=obs) logp = bias_dist curr = "state_init" logp += self.init(state=curr) for t, x in enumerate(observations): prev, curr = curr, f"state_{t}" logp += self.trans_dist(prev=prev, curr=curr) logp += self.observation_dist(state=curr, obs=x) # marginalize out previous state logp = logp.reduce(ops.logaddexp, prev) # marginalize out bias variable logp = logp.reduce(ops.logaddexp, "bias") # save posterior over the final state assert set(logp.inputs) == {f'state_{len(observations) - 1}'} posterior = funsor_to_mvn(logp, ndims=0) # marginalize out remaining variables logp = logp.reduce(ops.logaddexp) assert isinstance(logp, Tensor) and logp.shape == (), logp.pretty() return logp.data, posterior