def test_delta_defaults(): v = Variable('v', reals()) log_density = Variable('log_density', reals()) assert isinstance(dist.Delta(v, log_density), dist.Delta) value = Variable('value', reals()) assert dist.Delta(v, log_density, 'value') is dist.Delta( v, log_density, value)
def test_delta_density(batch_shape, event_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @funsor.torch.function(reals(*event_shape), reals(), reals(*event_shape), reals()) def delta(v, log_density, value): eq = (v == value) for _ in range(len(event_shape)): eq = eq.all(dim=-1) return eq.type(v.dtype).log() + log_density check_funsor( delta, { 'v': reals(*event_shape), 'log_density': reals(), 'value': reals(*event_shape) }, reals()) v = Tensor(torch.randn(batch_shape + event_shape), inputs) log_density = Tensor(torch.randn(batch_shape).exp(), inputs) for value in [v, Tensor(torch.randn(batch_shape + event_shape), inputs)]: expected = delta(v, log_density, value) check_funsor(expected, inputs, reals()) actual = dist.Delta(v, log_density, value) check_funsor(actual, inputs, reals()) assert_close(actual, expected)
def _forward_funsor(self, features, trip_counts): total_hours = len(features) observed_hours, num_origins, num_destins = trip_counts.shape assert observed_hours == total_hours assert num_origins == self.num_stations assert num_destins == self.num_stations n = self.num_stations gate_rate = funsor.Variable("gate_rate_t", reals(observed_hours, 2 * n * n))["time"] @funsor.torch.function(reals(2 * n * n), (reals(n, n, 2), reals(n, n))) def unpack_gate_rate(gate_rate): batch_shape = gate_rate.shape[:-1] gate, rate = gate_rate.reshape(batch_shape + (2, n, n)).unbind(-3) gate = gate.sigmoid().clamp(min=0.01, max=0.99) rate = bounded_exp(rate, bound=1e4) gate = torch.stack((1 - gate, gate), dim=-1) return gate, rate # Create a Gaussian latent dynamical system. init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist = \ self._dynamics(features[:observed_hours]) init = dist_to_funsor(init_dist)(value="state") trans = matrix_and_mvn_to_funsor(trans_matrix, trans_dist, ("time", ), "state", "state(time=1)") obs = matrix_and_mvn_to_funsor(obs_matrix, obs_dist, ("time", ), "state(time=1)", "gate_rate") # Compute dynamic prior over gate_rate. prior = trans + obs(gate_rate=gate_rate) prior = MarkovProduct(ops.logaddexp, ops.add, prior, "time", {"state": "state(time=1)"}) prior += init prior = prior.reduce(ops.logaddexp, {"state", "state(time=1)"}) # Compute zero-inflated Poisson likelihood. gate, rate = unpack_gate_rate(gate_rate) likelihood = fdist.Categorical(gate["origin", "destin"], value="gated") trip_counts = tensor_to_funsor(trip_counts, ("time", "origin", "destin")) likelihood += funsor.Stack( "gated", (fdist.Poisson(rate["origin", "destin"], value=trip_counts), fdist.Delta(0, value=trip_counts))) likelihood = likelihood.reduce(ops.logaddexp, "gated") likelihood = likelihood.reduce(ops.add, {"time", "origin", "destin"}) assert set(prior.inputs) == {"gate_rate_t"}, prior.inputs assert set(likelihood.inputs) == {"gate_rate_t"}, likelihood.inputs return prior, likelihood
def test_delta_delta(): v = Variable('v', reals(2)) point = Tensor(torch.randn(2)) log_density = Tensor(torch.tensor(0.5)) d = dist.Delta(point, log_density, v) assert d is Delta('v', point, log_density)
def __call__(self): # calls pyro.param so that params are exposed and constraints applied # should not create any new torch.Tensors after __init__ self.initialize_params() N_state = self.config["sizes"]["state"] # initialize gamma to uniform gamma = Tensor( torch.zeros((N_state, N_state)), OrderedDict([("y_prev", bint(N_state))]), ) N_v = self.config["sizes"]["random"] N_c = self.config["sizes"]["group"] log_prob = [] plate_g = Tensor(torch.zeros(N_c), OrderedDict([("g", bint(N_c))])) # group-level random effects if self.config["group"]["random"] == "discrete": # group-level discrete effect e_g = Variable("e_g", bint(N_v)) e_g_dist = plate_g + dist.Categorical(**self.params["e_g"])(value=e_g) log_prob.append(e_g_dist) eps_g = (plate_g + self.params["eps_g"]["theta"])(e_g=e_g) elif self.config["group"]["random"] == "continuous": eps_g = Variable("eps_g", reals(N_state)) eps_g_dist = plate_g + dist.Normal(**self.params["eps_g"])(value=eps_g) log_prob.append(eps_g_dist) else: eps_g = to_funsor(0.) N_s = self.config["sizes"]["individual"] plate_i = Tensor(torch.zeros(N_s), OrderedDict([("i", bint(N_s))])) # individual-level random effects if self.config["individual"]["random"] == "discrete": # individual-level discrete effect e_i = Variable("e_i", bint(N_v)) e_i_dist = plate_g + plate_i + dist.Categorical( **self.params["e_i"] )(value=e_i) * self.raggedness_masks["individual"](t=0) log_prob.append(e_i_dist) eps_i = (plate_i + plate_g + self.params["eps_i"]["theta"](e_i=e_i)) elif self.config["individual"]["random"] == "continuous": eps_i = Variable("eps_i", reals(N_state)) eps_i_dist = plate_g + plate_i + dist.Normal(**self.params["eps_i"])(value=eps_i) log_prob.append(eps_i_dist) else: eps_i = to_funsor(0.) # add group-level and individual-level random effects to gamma gamma = gamma + eps_g + eps_i N_state = self.config["sizes"]["state"] # we've accounted for all effects, now actually compute gamma_y gamma_y = gamma(y_prev="y(t=1)") y = Variable("y", bint(N_state)) y_dist = plate_g + plate_i + dist.Categorical( probs=gamma_y.exp() / gamma_y.exp().sum() )(value=y) # observation 1: step size step_dist = plate_g + plate_i + dist.Gamma( **{k: v(y_curr=y) for k, v in self.params["step"].items()} )(value=self.observations["step"]) # step size zero-inflation if self.config["zeroinflation"]: step_zi = dist.Categorical(probs=self.params["zi_step"]["zi_param"](y_curr=y))( value="zi_step") step_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)( value=self.observations["step"]) step_dist = (step_zi + Stack("zi_step", (step_dist, step_zi_dist))).reduce(ops.logaddexp, "zi_step") # observation 2: step angle angle_dist = plate_g + plate_i + dist.VonMises( **{k: v(y_curr=y) for k, v in self.params["angle"].items()} )(value=self.observations["angle"]) # observation 3: dive activity omega_dist = plate_g + plate_i + dist.Beta( **{k: v(y_curr=y) for k, v in self.params["omega"].items()} )(value=self.observations["omega"]) # dive activity zero-inflation if self.config["zeroinflation"]: omega_zi = dist.Categorical(probs=self.params["zi_omega"]["zi_param"](y_curr=y))( value="zi_omega") omega_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)( value=self.observations["omega"]) omega_dist = (omega_zi + Stack("zi_omega", (omega_dist, omega_zi_dist))).reduce(ops.logaddexp, "zi_omega") # finally, construct the term for parallel scan reduction hmm_factor = step_dist + angle_dist + omega_dist hmm_factor = hmm_factor * self.raggedness_masks["individual"] hmm_factor = hmm_factor * self.raggedness_masks["timestep"] # copy masking behavior of pyro.infer.TraceEnum_ELBO._compute_model_factors hmm_factor = hmm_factor + y_dist log_prob.insert(0, hmm_factor) return log_prob
def test_bart(analytic_kl): global call_count call_count = 0 with interpretation(reflect): q = Independent( Independent( Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor( [[ -0.6077086925506592, -1.1546266078948975, -0.7021151781082153, -0.5303535461425781, -0.6365622282028198, -1.2423288822174072, -0.9941254258155823, -0.6287292242050171 ], [ -0.6987162828445435, -1.0875964164733887, -0.7337473630905151, -0.4713417589664459, -0.6674002408981323, -1.2478348016738892, -0.8939017057418823, -0.5238542556762695 ]], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ), 'real'), Gaussian( torch.tensor([ [[-0.3536059558391571], [-0.21779225766658783], [0.2840439975261688], [0.4531521499156952], [-0.1220812276005745], [-0.05519985035061836], [0.10932210087776184], [0.6656699776649475]], [[-0.39107921719551086], [ -0.20241987705230713 ], [0.2170514464378357], [0.4500560462474823], [0.27945515513420105], [-0.0490039587020874], [-0.06399798393249512], [0.846565842628479]] ], dtype=torch.float32), # noqa torch.tensor([ [[[1.984686255455017]], [[0.6699360013008118]], [[1.6215802431106567]], [[2.372016668319702]], [[1.77385413646698]], [[0.526767373085022]], [[0.8722561597824097]], [[2.1879124641418457]] ], [[[1.6996612548828125]], [[ 0.7535632252693176 ]], [[1.4946647882461548]], [[2.642792224884033]], [[1.7301604747772217]], [[0.5203893780708313]], [[1.055436372756958]], [[2.8370864391326904]]] ], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ( 'value_b1', reals(), ), )), )), 'gate_rate_b3', '_event_1_b2', 'value_b1'), 'gate_rate_t', 'time_b4', 'gate_rate_b3') p_prior = Contraction( ops.logaddexp, ops.add, frozenset({'state(time=1)_b11', 'state_b10'}), ( MarkovProduct( ops.logaddexp, ops.add, Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor(2.7672932147979736, dtype=torch.float32), (), 'real'), Gaussian( torch.tensor([-0.0, -0.0, 0.0, 0.0], dtype=torch.float32), torch.tensor([[ 98.01002502441406, 0.0, -99.0000228881836, -0.0 ], [ 0.0, 98.01002502441406, -0.0, -99.0000228881836 ], [ -99.0000228881836, -0.0, 100.0000228881836, 0.0 ], [ -0.0, -99.0000228881836, 0.0, 100.0000228881836 ]], dtype=torch.float32), # noqa ( ( 'state_b7', reals(2, ), ), ( 'state(time=1)_b8', reals(2, ), ), )), Subs( AffineNormal( Tensor( torch.tensor( [[ 0.03488487750291824, 0.07356668263673782, 0.19946961104869843, 0.5386509299278259, -0.708323061466217, 0.24411526322364807, -0.20855577290058136, -0.2421337217092514 ], [ 0.41762110590934753, 0.5272183418273926, -0.49835553765296936, -0.0363837406039238, -0.0005282597267068923, 0.2704298794269562, -0.155222088098526, -0.44802337884902954 ]], dtype=torch.float32), # noqa (), 'real'), Tensor( torch.tensor( [[ -0.003566693514585495, -0.2848514914512634, 0.037103548645973206, 0.12648648023605347, -0.18501518666744232, -0.20899859070777893, 0.04121830314397812, 0.0054807960987091064 ], [ 0.0021788496524095535, -0.18700894713401794, 0.08187370002269745, 0.13554862141609192, -0.10477752983570099, -0.20848378539085388, -0.01393645629286766, 0.011670656502246857 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Tensor( torch.tensor( [[ 0.5974780917167664, 0.864071786403656, 1.0236268043518066, 0.7147538065910339, 0.7423890233039856, 0.9462157487869263, 1.2132389545440674, 1.0596832036972046 ], [ 0.5787821412086487, 0.9178534150123596, 0.9074794054031372, 0.6600189208984375, 0.8473222255706787, 0.8426999449729919, 1.194266438484192, 1.0471148490905762 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Variable('state(time=1)_b8', reals(2, )), Variable('gate_rate_b6', reals(8, ))), (( 'gate_rate_b6', Binary( ops.GetitemOp(0), Variable('gate_rate_t', reals(2, 8)), Variable('time_b9', bint(2))), ), )), )), Variable('time_b9', bint(2)), frozenset({('state_b7', 'state(time=1)_b8')}), frozenset({('state(time=1)_b8', 'state(time=1)_b11'), ('state_b7', 'state_b10')})), # noqa Subs( dist.MultivariateNormal( Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32), (), 'real'), Tensor( torch.tensor([[10.0, 0.0], [0.0, 10.0]], dtype=torch.float32), (), 'real'), Variable('value_b5', reals(2, ))), (( 'value_b5', Variable('state_b10', reals(2, )), ), )), )) p_likelihood = Contraction( ops.add, ops.nullop, frozenset({'time_b17', 'destin_b16', 'origin_b15'}), ( Contraction( ops.logaddexp, ops.add, frozenset({'gated_b14'}), ( dist.Categorical( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_0, reals(2, 2, 2), (Variable('gate_rate_b12', reals(8, )), )), (( 'gate_rate_b12', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable('time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Variable('gated_b14', bint(2))), Stack( 'gated_b14', ( dist.Poisson( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_1, reals(2, 2), (Variable( 'gate_rate_b13', reals(8, )), )), (( 'gate_rate_b13', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable( 'time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), dist.Delta( Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), )), )), )) if analytic_kl: exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t") with interpretation(monte_carlo): approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t") elbo = exact_part + approx_part else: p = p_prior + p_likelihood with interpretation(monte_carlo): elbo = Integrate(q, p - q, "gate_rate_t") assert isinstance(elbo, Tensor), elbo.pretty() assert call_count == 1