def _forward_funsor(self, features, trip_counts): total_hours = len(features) observed_hours, num_origins, num_destins = trip_counts.shape assert observed_hours == total_hours assert num_origins == self.num_stations assert num_destins == self.num_stations n = self.num_stations gate_rate = funsor.Variable("gate_rate_t", reals(observed_hours, 2 * n * n))["time"] @funsor.torch.function(reals(2 * n * n), (reals(n, n, 2), reals(n, n))) def unpack_gate_rate(gate_rate): batch_shape = gate_rate.shape[:-1] gate, rate = gate_rate.reshape(batch_shape + (2, n, n)).unbind(-3) gate = gate.sigmoid().clamp(min=0.01, max=0.99) rate = bounded_exp(rate, bound=1e4) gate = torch.stack((1 - gate, gate), dim=-1) return gate, rate # Create a Gaussian latent dynamical system. init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist = \ self._dynamics(features[:observed_hours]) init = dist_to_funsor(init_dist)(value="state") trans = matrix_and_mvn_to_funsor(trans_matrix, trans_dist, ("time", ), "state", "state(time=1)") obs = matrix_and_mvn_to_funsor(obs_matrix, obs_dist, ("time", ), "state(time=1)", "gate_rate") # Compute dynamic prior over gate_rate. prior = trans + obs(gate_rate=gate_rate) prior = MarkovProduct(ops.logaddexp, ops.add, prior, "time", {"state": "state(time=1)"}) prior += init prior = prior.reduce(ops.logaddexp, {"state", "state(time=1)"}) # Compute zero-inflated Poisson likelihood. gate, rate = unpack_gate_rate(gate_rate) likelihood = fdist.Categorical(gate["origin", "destin"], value="gated") trip_counts = tensor_to_funsor(trip_counts, ("time", "origin", "destin")) likelihood += funsor.Stack( "gated", (fdist.Poisson(rate["origin", "destin"], value=trip_counts), fdist.Delta(0, value=trip_counts))) likelihood = likelihood.reduce(ops.logaddexp, "gated") likelihood = likelihood.reduce(ops.add, {"time", "origin", "destin"}) assert set(prior.inputs) == {"gate_rate_t"}, prior.inputs assert set(likelihood.inputs) == {"gate_rate_t"}, likelihood.inputs return prior, likelihood
def __init__(self, initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim" assert initial_dist.event_shape == (hidden_dim, ) assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim, ) assert observation_dist.event_shape == (obs_dim, ) shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_matrix.shape[:-2], transition_dist.batch_shape, observation_matrix.shape[:-2], observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = matrix_and_mvn_to_funsor(transition_matrix, transition_dist, ("time", ), "state", "state(time=1)") obs = matrix_and_mvn_to_funsor(observation_matrix, observation_dist, ("time", ), "state(time=1)", "value") dtype = "real" # Construct the joint funsor. with interpretation(lazy): value = Variable("value", Reals[time_shape[0], obs_dim]) result = trans + obs(value=value["time"]) result = MarkovProduct(ops.logaddexp, ops.add, result, "time", {"state": "state(time=1)"}) result = init + result.reduce(ops.logaddexp, "state(time=1)") funsor_dist = result.reduce(ops.logaddexp, "state") super(GaussianHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
def __init__(self, initial_dist, transition_dist, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim = initial_dist.event_shape[0] assert transition_dist.event_shape[0] == hidden_dim + hidden_dim obs_dim = observation_dist.event_shape[0] - hidden_dim shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_dist.batch_shape, observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = mvn_to_funsor( transition_dist, ("time", ), OrderedDict([("state", Reals[hidden_dim]), ("state(time=1)", Reals[hidden_dim])])) obs = mvn_to_funsor( observation_dist, ("time", ), OrderedDict([("state(time=1)", Reals[hidden_dim]), ("value", Reals[obs_dim])])) # Construct the joint funsor. # Compare with pyro.distributions.hmm.GaussianMRF.log_prob(). with interpretation(lazy): time = Variable("time", Bint[time_shape[0]]) value = Variable("value", Reals[time_shape[0], obs_dim]) logp_oh = trans + obs(value=value["time"]) logp_oh = MarkovProduct(ops.logaddexp, ops.add, logp_oh, time, {"state": "state(time=1)"}) logp_oh += init logp_oh = logp_oh.reduce(ops.logaddexp, frozenset({"state", "state(time=1)"})) logp_h = trans + obs.reduce(ops.logaddexp, "value") logp_h = MarkovProduct(ops.logaddexp, ops.add, logp_h, time, {"state": "state(time=1)"}) logp_h += init logp_h = logp_h.reduce(ops.logaddexp, frozenset({"state", "state(time=1)"})) funsor_dist = logp_oh - logp_h dtype = "real" super(GaussianMRF, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
def parallel_loss_fn(model, guide, parallel=True): # We're doing exact inference, so we don't use the guide here. factors = model() t_term, new_factors = factors[0], factors[1:] t = to_funsor("t", t_term.inputs["t"]) if parallel: result = MarkovProduct(ops.logaddexp, ops.add, t_term, t, {"y(t=1)": "y"}) else: result = naive_sequential_sum_product(ops.logaddexp, ops.add, t_term, t, {"y(t=1)": "y"}) new_factors = [result] + new_factors plates = frozenset(['g', 'i']) eliminate = frozenset().union(*(f.inputs for f in new_factors)) with interpretation(lazy): loss = sum_product(ops.logaddexp, ops.add, new_factors, eliminate, plates) loss = apply_optimizer(loss) assert not loss.inputs return -loss.data
def test_bart(analytic_kl): global call_count call_count = 0 with interpretation(reflect): q = Independent( Independent( Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor( [[ -0.6077086925506592, -1.1546266078948975, -0.7021151781082153, -0.5303535461425781, -0.6365622282028198, -1.2423288822174072, -0.9941254258155823, -0.6287292242050171 ], [ -0.6987162828445435, -1.0875964164733887, -0.7337473630905151, -0.4713417589664459, -0.6674002408981323, -1.2478348016738892, -0.8939017057418823, -0.5238542556762695 ]], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ), 'real'), Gaussian( torch.tensor([ [[-0.3536059558391571], [-0.21779225766658783], [0.2840439975261688], [0.4531521499156952], [-0.1220812276005745], [-0.05519985035061836], [0.10932210087776184], [0.6656699776649475]], [[-0.39107921719551086], [ -0.20241987705230713 ], [0.2170514464378357], [0.4500560462474823], [0.27945515513420105], [-0.0490039587020874], [-0.06399798393249512], [0.846565842628479]] ], dtype=torch.float32), # noqa torch.tensor([ [[[1.984686255455017]], [[0.6699360013008118]], [[1.6215802431106567]], [[2.372016668319702]], [[1.77385413646698]], [[0.526767373085022]], [[0.8722561597824097]], [[2.1879124641418457]] ], [[[1.6996612548828125]], [[ 0.7535632252693176 ]], [[1.4946647882461548]], [[2.642792224884033]], [[1.7301604747772217]], [[0.5203893780708313]], [[1.055436372756958]], [[2.8370864391326904]]] ], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ( 'value_b1', reals(), ), )), )), 'gate_rate_b3', '_event_1_b2', 'value_b1'), 'gate_rate_t', 'time_b4', 'gate_rate_b3') p_prior = Contraction( ops.logaddexp, ops.add, frozenset({'state(time=1)_b11', 'state_b10'}), ( MarkovProduct( ops.logaddexp, ops.add, Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor(2.7672932147979736, dtype=torch.float32), (), 'real'), Gaussian( torch.tensor([-0.0, -0.0, 0.0, 0.0], dtype=torch.float32), torch.tensor([[ 98.01002502441406, 0.0, -99.0000228881836, -0.0 ], [ 0.0, 98.01002502441406, -0.0, -99.0000228881836 ], [ -99.0000228881836, -0.0, 100.0000228881836, 0.0 ], [ -0.0, -99.0000228881836, 0.0, 100.0000228881836 ]], dtype=torch.float32), # noqa ( ( 'state_b7', reals(2, ), ), ( 'state(time=1)_b8', reals(2, ), ), )), Subs( AffineNormal( Tensor( torch.tensor( [[ 0.03488487750291824, 0.07356668263673782, 0.19946961104869843, 0.5386509299278259, -0.708323061466217, 0.24411526322364807, -0.20855577290058136, -0.2421337217092514 ], [ 0.41762110590934753, 0.5272183418273926, -0.49835553765296936, -0.0363837406039238, -0.0005282597267068923, 0.2704298794269562, -0.155222088098526, -0.44802337884902954 ]], dtype=torch.float32), # noqa (), 'real'), Tensor( torch.tensor( [[ -0.003566693514585495, -0.2848514914512634, 0.037103548645973206, 0.12648648023605347, -0.18501518666744232, -0.20899859070777893, 0.04121830314397812, 0.0054807960987091064 ], [ 0.0021788496524095535, -0.18700894713401794, 0.08187370002269745, 0.13554862141609192, -0.10477752983570099, -0.20848378539085388, -0.01393645629286766, 0.011670656502246857 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Tensor( torch.tensor( [[ 0.5974780917167664, 0.864071786403656, 1.0236268043518066, 0.7147538065910339, 0.7423890233039856, 0.9462157487869263, 1.2132389545440674, 1.0596832036972046 ], [ 0.5787821412086487, 0.9178534150123596, 0.9074794054031372, 0.6600189208984375, 0.8473222255706787, 0.8426999449729919, 1.194266438484192, 1.0471148490905762 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Variable('state(time=1)_b8', reals(2, )), Variable('gate_rate_b6', reals(8, ))), (( 'gate_rate_b6', Binary( ops.GetitemOp(0), Variable('gate_rate_t', reals(2, 8)), Variable('time_b9', bint(2))), ), )), )), Variable('time_b9', bint(2)), frozenset({('state_b7', 'state(time=1)_b8')}), frozenset({('state(time=1)_b8', 'state(time=1)_b11'), ('state_b7', 'state_b10')})), # noqa Subs( dist.MultivariateNormal( Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32), (), 'real'), Tensor( torch.tensor([[10.0, 0.0], [0.0, 10.0]], dtype=torch.float32), (), 'real'), Variable('value_b5', reals(2, ))), (( 'value_b5', Variable('state_b10', reals(2, )), ), )), )) p_likelihood = Contraction( ops.add, ops.nullop, frozenset({'time_b17', 'destin_b16', 'origin_b15'}), ( Contraction( ops.logaddexp, ops.add, frozenset({'gated_b14'}), ( dist.Categorical( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_0, reals(2, 2, 2), (Variable('gate_rate_b12', reals(8, )), )), (( 'gate_rate_b12', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable('time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Variable('gated_b14', bint(2))), Stack( 'gated_b14', ( dist.Poisson( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_1, reals(2, 2), (Variable( 'gate_rate_b13', reals(8, )), )), (( 'gate_rate_b13', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable( 'time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), dist.Delta( Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), )), )), )) if analytic_kl: exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t") with interpretation(monte_carlo): approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t") elbo = exact_part + approx_part else: p = p_prior + p_likelihood with interpretation(monte_carlo): elbo = Integrate(q, p - q, "gate_rate_t") assert isinstance(elbo, Tensor), elbo.pretty() assert call_count == 1