def main(args): # Declare parameters. trans_probs = funsor.Tensor( torch.tensor([[0.9, 0.1], [0.1, 0.9]], requires_grad=True)) trans_noise = funsor.Tensor( torch.tensor( [ 0.1, # low noise component 1.0, # high noisy component ], requires_grad=True)) emit_noise = funsor.Tensor(torch.tensor(0.5, requires_grad=True)) params = [trans_probs.data, trans_noise.data, emit_noise.data] # A Gaussian HMM model. @funsor.interpreter.interpretation(funsor.terms.moment_matching) def model(data): log_prob = funsor.Number(0.) # s is the discrete latent state, # x is the continuous latent state, # y is the observed state. s_curr = funsor.Tensor(torch.tensor(0), dtype=2) x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): s_prev = s_curr x_prev = x_curr # A delayed sample statement. s_curr = funsor.Variable('s_{}'.format(t), funsor.bint(2)) log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr) # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr) # Marginalize out previous delayed sample statements. if t > 0: log_prob = log_prob.reduce(ops.logaddexp, {s_prev.name, x_prev.name}) # An observe statement. log_prob += dist.Normal(x_curr, emit_noise, value=y) log_prob = log_prob.reduce(ops.logaddexp) return log_prob # Train model parameters. torch.manual_seed(0) data = torch.randn(args.time_steps) optim = torch.optim.Adam(params, lr=args.learning_rate) for step in range(args.train_steps): optim.zero_grad() log_prob = model(data) assert not log_prob.inputs, 'free variables remain' loss = -log_prob.data loss.backward() optim.step() if args.verbose and step % 10 == 0: print('step {} loss = {}'.format(step, loss.item()))
def model(data): log_prob = funsor.Number(0.) # s is the discrete latent state, # x is the continuous latent state, # y is the observed state. s_curr = funsor.Tensor(torch.tensor(0), dtype=2) x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): s_prev = s_curr x_prev = x_curr # A delayed sample statement. s_curr = funsor.Variable('s_{}'.format(t), funsor.bint(2)) log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr) # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr) # Marginalize out previous delayed sample statements. if t > 0: log_prob = log_prob.reduce(ops.logaddexp, {s_prev.name, x_prev.name}) # An observe statement. log_prob += dist.Normal(x_curr, emit_noise, value=y) log_prob = log_prob.reduce(ops.logaddexp) return log_prob
def model(data): log_prob = funsor.to_funsor(0.) trans = dist.Categorical(probs=funsor.Tensor( trans_probs, inputs=OrderedDict([('prev', funsor.bint(args.hidden_dim))]), )) emit = dist.Categorical(probs=funsor.Tensor( emit_probs, inputs=OrderedDict([('latent', funsor.bint(args.hidden_dim))]), )) x_curr = funsor.Number(0, args.hidden_dim) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.bint(args.hidden_dim)) log_prob += trans(prev=x_prev, value=x_curr) if not args.lazy and isinstance(x_prev, funsor.Variable): log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2)) log_prob = log_prob.reduce(ops.logaddexp) return log_prob
def _get_support_value_tensor(funsor_dist, name, **kwargs): assert name in funsor_dist.inputs return funsor.Tensor( funsor.ops.new_arange(funsor_dist.data, funsor_dist.inputs[name].size), OrderedDict([(name, funsor_dist.inputs[name])]), funsor_dist.inputs[name].size, )
def get_tensors_and_dists(self): # normalize the transition probabilities trans_logits = self.transition_logits - self.transition_logits.logsumexp( dim=-1, keepdim=True) trans_probs = funsor.Tensor( trans_logits, OrderedDict([("s", funsor.bint(self.num_components))])) trans_mvn = torch.distributions.MultivariateNormal( torch.zeros(self.hidden_dim), self.log_transition_noise.exp().diag_embed()) obs_mvn = torch.distributions.MultivariateNormal( torch.zeros(self.obs_dim), self.log_obs_noise.exp().diag_embed()) event_dims = ( "s", ) if self.fine_transition_matrix or self.fine_transition_noise else () x_trans_dist = matrix_and_mvn_to_funsor(self.transition_matrix, trans_mvn, event_dims, "x", "y") event_dims = ( "s", ) if self.fine_observation_matrix or self.fine_observation_noise else ( ) y_dist = matrix_and_mvn_to_funsor(self.observation_matrix, obs_mvn, event_dims, "x", "y") return trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist
def model(data): log_prob = funsor.to_funsor(0.) xs_curr = [funsor.Tensor(torch.tensor(0.)) for var in var_names] for t, y in enumerate(data): xs_prev = xs_curr # A delayed sample statement. xs_curr = [ funsor.Variable(name + '_{}'.format(t), funsor.reals()) for name in var_names ] for i, x_curr in enumerate(xs_curr): log_prob += dist.Normal(trans_eqs[var_names[i]](xs_prev), torch.exp(trans_noises[i]), value=x_curr) if t > 0: log_prob = log_prob.reduce( ops.logaddexp, frozenset([x_prev.name for x_prev in xs_prev])) # An observe statement. log_prob += dist.Normal(emit_eq(xs_curr), torch.exp(emit_noise), value=y) # Marginalize out all remaining delayed variables. return log_prob.reduce(ops.logaddexp), log_prob.gaussian
def _enum_strategy_mixture(dist, msg): sample_dim_name = "{}__PARTICLES".format(msg["name"]) sample_inputs = OrderedDict( {sample_dim_name: funsor.Bint[msg['infer']['num_samples']]}) plate_names = frozenset(f.name for f in msg["cond_indep_stack"] if f.vectorized) ancestor_names = frozenset( k for k, v in dist.inputs.items() if v.dtype != 'real' and k != msg["name"] and k not in plate_names) plate_inputs = OrderedDict((k, dist.inputs[k]) for k in plate_names) # TODO should the ancestor_indices be pyro.sampled? ancestor_indices = { # TODO make this comprehension less gross name: _get_support_value( funsor.torch.distributions.CategoricalLogits( # sample different ancestors for each plate slice logits=funsor.Tensor( # TODO avoid use of torch.zeros here in favor of funsor.ops.new_zeros torch.zeros((1, )).expand( tuple(v.dtype for v in plate_inputs.values()) + (dist.inputs[name].dtype, )), plate_inputs), )(value=name).sample(name, sample_inputs), name) for name in ancestor_names } sampled_dist = dist(**ancestor_indices).sample( msg["name"], sample_inputs if not ancestor_indices else None) if ancestor_indices: # XXX is there a better way to account for this in funsor? sampled_dist = sampled_dist - math.log(msg["infer"]["num_samples"]) return sampled_dist
def main(args): funsor.set_backend("torch") # XXX Temporary fix after https://github.com/pyro-ppl/pyro/pull/2701 import pyro pyro.enable_validation(False) encoder = Encoder() decoder = Decoder() encode = funsor.function(Reals[28, 28], (Reals[20], Reals[20]))(encoder) decode = funsor.function(Reals[20], Reals[28, 28])(decoder) @funsor.interpretation(funsor.montecarlo.MonteCarlo()) def loss_function(data, subsample_scale): # Lazily sample from the guide. loc, scale = encode(data) q = funsor.Independent(dist.Normal(loc['i'], scale['i'], value='z_i'), 'z', 'i', 'z_i') # Evaluate the model likelihood at the lazy value z. probs = decode('z') p = dist.Bernoulli(probs['x', 'y'], value=data['x', 'y']) p = p.reduce(ops.add, {'x', 'y'}) # Construct an elbo. This is where sampling happens. elbo = funsor.Integrate(q, p - q, 'z') elbo = elbo.reduce(ops.add, 'batch') * subsample_scale loss = -elbo return loss train_loader = torch.utils.data.DataLoader(datasets.MNIST( DATA_PATH, train=True, download=True, transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True) encoder.train() decoder.train() optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3) for epoch in range(args.num_epochs): train_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): subsample_scale = float(len(train_loader.dataset) / len(data)) data = data[:, 0, :, :] data = funsor.Tensor(data, OrderedDict(batch=Bint[len(data)])) optimizer.zero_grad() loss = loss_function(data, subsample_scale) assert isinstance(loss, funsor.Tensor), loss.pretty() loss.data.backward() train_loss += loss.item() optimizer.step() if batch_idx % 50 == 0: print(' loss = {}'.format(loss.item())) if batch_idx and args.smoke_test: return print('epoch {} train_loss = {}'.format(epoch, train_loss))
def log_prob(self, data): trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists( ) log_prob = funsor.Number(0.) s_vars = { -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components) } x_vars = {} for t, y in enumerate(data): # construct free variables for s_t and x_t s_vars[t] = funsor.Variable(f's_{t}', funsor.bint(self.num_components)) x_vars[t] = funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) # incorporate the discrete switching dynamics log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]), value=s_vars[t]) # incorporate the prior term p(x_t | x_{t-1}) if t == 0: log_prob += self.x_init_mvn(value=x_vars[t]) else: log_prob += x_trans_dist(s=s_vars[t], x=x_vars[t - 1], y=x_vars[t]) # do a moment-matching reduction. at this point log_prob depends on (moment_matching_lag + 1)-many # pairs of free variables. if t > self.moment_matching_lag - 1: log_prob = log_prob.reduce( ops.logaddexp, frozenset([ s_vars[t - self.moment_matching_lag].name, x_vars[t - self.moment_matching_lag].name ])) # incorporate the observation p(y_t | x_t, s_t) log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y) T = data.shape[0] # reduce any remaining free variables for t in range(self.moment_matching_lag): log_prob = log_prob.reduce( ops.logaddexp, frozenset([ s_vars[T - self.moment_matching_lag + t].name, x_vars[T - self.moment_matching_lag + t].name ])) # assert that we've reduced all the free variables in log_prob assert not log_prob.inputs, 'unexpected free variables remain' # return the PyTorch tensor behind log_prob (which we can directly differentiate) return log_prob.data
def main(args): # Generate fake data. data = funsor.Tensor(torch.randn(100), inputs=OrderedDict([('data', funsor.bint(100))]), output=funsor.reals()) # Train. optim = pyro.Adam({'lr': args.learning_rate}) svi = pyro.SVI(model, pyro.deferred(guide), optim, pyro.elbo) for step in range(args.steps): svi.step(data)
def test_bernoullilogits_enumerate_support(expand, batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) logits = funsor.Tensor(rand(batch_shape), inputs, 'real') with interpretation(lazy): d = dist.BernoulliLogits(logits) x = d.enumerate_support(expand=expand) actual_log_prob = d(value='value2')(value2=x).reduce(ops.logaddexp, 'value') raw_dist = d.dist_class(logits=logits.data) raw_value = raw_dist.enumerate_support(expand=expand) expected_inputs = OrderedDict([('value', Bint[raw_value.shape[0]])]) expected_inputs.update(inputs) expected_log_prob = funsor.Tensor(raw_dist.log_prob(raw_value), expected_inputs).reduce(ops.logaddexp, 'value') assert d.has_enumerate_support assert x.output == d.value.output assert set(x.inputs) == {'value'} | (set(batch_dims) if expand else set()) assert_close(expected_log_prob, actual_log_prob)
def main(args): encoder = Encoder() decoder = Decoder() encode = funsor.torch.function(reals(28, 28), (reals(20), reals(20)))(encoder) decode = funsor.torch.function(reals(20), reals(28, 28))(decoder) @funsor.interpreter.interpretation(funsor.montecarlo.monte_carlo) def loss_function(data, subsample_scale): # Lazily sample from the guide. loc, scale = encode(data) q = funsor.Independent(dist.Normal(loc['i'], scale['i'], value='z'), 'z', 'i') # Evaluate the model likelihood at the lazy value z. probs = decode('z') p = dist.Bernoulli(probs['x', 'y'], value=data['x', 'y']) p = p.reduce(ops.add, frozenset(['x', 'y'])) # Construct an elbo. This is where sampling happens. elbo = funsor.Integrate(q, p - q, frozenset(['z'])) elbo = elbo.reduce(ops.add, 'batch') * subsample_scale loss = -elbo return loss train_loader = torch.utils.data.DataLoader(datasets.MNIST( DATA_PATH, train=True, download=True, transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True) encoder.train() decoder.train() optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3) for epoch in range(args.num_epochs): train_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): subsample_scale = float(len(train_loader.dataset) / len(data)) data = data[:, 0, :, :] data = funsor.Tensor(data, OrderedDict(batch=bint(len(data)))) optimizer.zero_grad() loss = loss_function(data, subsample_scale) assert isinstance(loss, funsor.torch.Tensor), loss.pretty() loss.data.backward() train_loss += loss.item() optimizer.step() if batch_idx % 50 == 0: print(' loss = {}'.format(loss.item())) if batch_idx and args.smoke_test: return print('epoch {} train_loss = {}'.format(epoch, train_loss))
def __init__(self, name, size, subsample_size=None, dim=None): self.name = name self.size = size self.subsample_size = size if subsample_size is None else subsample_size if dim is not None and dim >= 0: raise ValueError('dim arg must be negative.') self.dim = dim self._indices = funsor.Tensor( funsor.ops.new_arange(funsor.tensor.get_default_prototype(), self.size), OrderedDict([(self.name, funsor.bint(self.size))]), self.size) super(plate, self).__init__(None)
def __init__(self, name, size, subsample_size=None, dim=None): self.name = name self.size = size if dim is not None and dim >= 0: raise ValueError('dim arg must be negative.') self.dim, indices = OrigPlateMessenger._subsample( self.name, self.size, subsample_size, dim) self.subsample_size = indices.shape[0] self._indices = funsor.Tensor( indices, OrderedDict([(self.name, funsor.bint(self.subsample_size))]), self.subsample_size) super(plate, self).__init__(None)
def __init__(self, name=None, size=None, dim=None, indices=None): assert dim is None or dim < 0 super().__init__() # without a name or dim, treat as a "vectorize" effect and allocate a non-visible dim self.dim_type = DimType.GLOBAL if name is None and dim is None else DimType.VISIBLE self.name = name if name is not None else funsor.interpreter.gensym( "PLATE") self.size = size self.dim = dim if not hasattr(self, "_full_size"): self._full_size = size if indices is None: indices = funsor.ops.new_arange( funsor.tensor.get_default_prototype(), self.size) assert len(indices) == size self._indices = funsor.Tensor( indices, OrderedDict([(self.name, funsor.Bint[self.size])]), self._full_size)
def model(data): log_prob = funsor.to_funsor(0.) x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr) # Optionally marginalize out the previous state. if t > 0 and not args.lazy: log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) # An observe statement. log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y) # Marginalize out all remaining delayed variables. log_prob = log_prob.reduce(ops.logaddexp) return log_prob
def process_message(self, msg): if msg["type"] != "sample" or \ msg.get("done", False) or msg["is_observed"] or msg["infer"].get("expand", False) or \ msg["infer"].get("enumerate") != "parallel" or (not msg["fn"].has_enumerate_support): if msg["type"] == "control_flow": msg["kwargs"]["enum"] = True return super().process_message(msg) if msg["infer"].get("num_samples", None) is not None: raise NotImplementedError("TODO implement multiple sampling") if msg["infer"].get("expand", False): raise NotImplementedError("expand=True not implemented") size = msg["fn"].enumerate_support(expand=False).shape[0] raw_value = jnp.arange(0, size) funsor_value = funsor.Tensor( raw_value, OrderedDict([(msg["name"], funsor.bint(size))]), size) msg["value"] = to_data(funsor_value) msg["done"] = True
def differentiable_loss(self, model, guide, *args, **kwargs): # get batched, enumerated, to_funsor-ed traces from the guide and model with plate( size=self.num_particles ) if self.num_particles > 1 else contextlib.ExitStack(), enum( first_available_dim=(-self.max_plate_nesting - 1) if self.max_plate_nesting else None): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace( *args, **kwargs) # extract from traces all metadata that we will need to compute the elbo guide_terms = terms_from_trace(guide_tr) model_terms = terms_from_trace(model_tr) # build up a lazy expression for the elbo with funsor.terms.lazy: # identify and contract out auxiliary variables in the model with partial_sum_product contracted_factors, uncontracted_factors = [], [] for f in model_terms["log_factors"]: if model_terms["measure_vars"].intersection(f.inputs): contracted_factors.append(f) else: uncontracted_factors.append(f) # incorporate the effects of subsampling and handlers.scale through a common scale factor contracted_costs = [ model_terms["scale"] * f for f in funsor.sum_product.partial_sum_product( funsor.ops.logaddexp, funsor.ops.add, model_terms["log_measures"] + contracted_factors, plates=model_terms["plate_vars"], eliminate=model_terms["measure_vars"], ) ] # accumulate costs from model (logp) and guide (-logq) costs = contracted_costs + uncontracted_factors # model costs: logp costs += [-f for f in guide_terms["log_factors"] ] # guide costs: -logq # compute expected cost # Cf. pyro.infer.util.Dice.compute_expectation() # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212 # TODO Replace this with funsor.Expectation plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] # compute the marginal logq in the guide corresponding to each cost term targets = dict() for cost in costs: input_vars = frozenset(cost.inputs) if input_vars not in targets: targets[input_vars] = funsor.Tensor( funsor.ops.new_zeros( funsor.tensor.get_default_prototype(), tuple(v.size for v in cost.inputs.values()), ), cost.inputs, cost.dtype, ) with AdjointTape() as tape: logzq = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, guide_terms["log_measures"] + list(targets.values()), plates=plate_vars, eliminate=(plate_vars | guide_terms["measure_vars"]), ) marginals = tape.adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values())) # finally, integrate out guide variables in the elbo and all plates elbo = to_funsor(0, output=funsor.Real) for cost in costs: target = targets[frozenset(cost.inputs)] logzq_local = marginals[target].reduce( funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars) log_prob = marginals[target] - logzq_local elbo_term = funsor.Integrate( log_prob, cost, guide_terms["measure_vars"] & frozenset(log_prob.inputs), ) elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) # evaluate the elbo, using memoize to share tensor computation where possible with funsor.interpretations.memoize(): return -to_data(apply_optimizer(elbo))
def filter_and_predict(self, data, smoothing=False): trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists( ) log_prob = funsor.Number(0.) s_vars = { -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components) } x_vars = {-1: None} predictive_x_dists, predictive_y_dists, filtering_dists = [], [], [] test_LLs = [] for t, y in enumerate(data): s_vars[t] = funsor.Variable(f's_{t}', funsor.bint(self.num_components)) x_vars[t] = funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]), value=s_vars[t]) if t == 0: log_prob += self.x_init_mvn(value=x_vars[t]) else: log_prob += x_trans_dist(s=s_vars[t], x=x_vars[t - 1], y=x_vars[t]) if t > 0: log_prob = log_prob.reduce( ops.logaddexp, frozenset([s_vars[t - 1].name, x_vars[t - 1].name])) # do 1-step prediction and compute test LL if t > 0: predictive_x_dists.append(log_prob) _log_prob = log_prob - log_prob.reduce(ops.logaddexp) predictive_y_dist = y_dist(s=s_vars[t], x=x_vars[t]) + _log_prob test_LLs.append( predictive_y_dist(y=y).reduce(ops.logaddexp).data.item()) predictive_y_dist = predictive_y_dist.reduce( ops.logaddexp, frozenset([f"x_{t}", f"s_{t}"])) predictive_y_dists.append( funsor_to_mvn(predictive_y_dist, 0, ())) log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y) # save filtering dists for forward-backward smoothing if smoothing: filtering_dists.append(log_prob) # do the backward recursion using previously computed ingredients if smoothing: # seed the backward recursion with the filtering distribution at t=T smoothing_dists = [filtering_dists[-1]] T = data.size(0) s_vars = { t: funsor.Variable(f's_{t}', funsor.bint(self.num_components)) for t in range(T) } x_vars = { t: funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) for t in range(T) } # do the backward recursion. # let p[t|t-1] be the predictive distribution at time step t. # let p[t|t] be the filtering distribution at time step t. # let f[t] denote the prior (transition) density at time step t. # then the smoothing distribution p[t|T] at time step t is # given by the following recursion. # p[t-1|T] = p[t-1|t-1] <p[t|T] f[t] / p[t|t-1]> # where <...> denotes integration of the latent variables at time step t. for t in reversed(range(T - 1)): integral = smoothing_dists[-1] - predictive_x_dists[t] integral += dist.Categorical(trans_probs(s=s_vars[t]), value=s_vars[t + 1]) integral += x_trans_dist(s=s_vars[t], x=x_vars[t], y=x_vars[t + 1]) integral = integral.reduce( ops.logaddexp, frozenset([s_vars[t + 1].name, x_vars[t + 1].name])) smoothing_dists.append(filtering_dists[t] + integral) # compute predictive test MSE and predictive variances predictive_means = torch.stack([d.mean for d in predictive_y_dists ]) # T-1 ydim predictive_vars = torch.stack([ d.covariance_matrix.diagonal(dim1=-1, dim2=-2) for d in predictive_y_dists ]) predictive_mse = (predictive_means - data[1:, :]).pow(2.0).mean(-1) if smoothing: # compute smoothed mean function smoothing_dists = [ funsor_to_cat_and_mvn(d, 0, (f"s_{t}", )) for t, d in enumerate(reversed(smoothing_dists)) ] means = torch.stack([d[1].mean for d in smoothing_dists]) # T 2 xdim means = torch.matmul(means.unsqueeze(-2), self.observation_matrix).squeeze( -2) # T 2 ydim probs = torch.stack([d[0].logits for d in smoothing_dists]).exp() probs = probs / probs.sum(-1, keepdim=True) # T 2 smoothing_means = (probs.unsqueeze(-1) * means).sum(-2) # T ydim smoothing_probs = probs[:, 1] return predictive_mse, torch.tensor(np.array(test_LLs)), predictive_means, predictive_vars, \ smoothing_means, smoothing_probs else: return predictive_mse, torch.tensor(np.array(test_LLs))
def test_gaussian_funsor(batch_shape): # This tests sample distribution, rsample gradients, log_prob, and log_prob # gradients for both Pyro's and Funsor's Gaussian. import funsor funsor.set_backend("torch") num_samples = 100000 # Declare unconstrained parameters. loc = torch.randn(batch_shape + (3, )).requires_grad_() t = transform_to(constraints.positive_definite) m = torch.randn(batch_shape + (3, 3)) precision_unconstrained = t.inv(m @ m.transpose(-1, -2)).requires_grad_() # Transform to constrained space. log_normalizer = torch.zeros(batch_shape) precision = t(precision_unconstrained) info_vec = (precision @ loc[..., None])[..., 0] def check_equal(actual, expected, atol=0.01, rtol=0): assert_close(actual.data, expected.data, atol=atol, rtol=rtol) grads = torch.autograd.grad( (actual - expected).abs().sum(), [loc, precision_unconstrained], retain_graph=True, ) for grad in grads: assert grad.abs().max() < atol entropy = dist.MultivariateNormal(loc, precision_matrix=precision).entropy() # Monte carlo estimate entropy via pyro. p_gaussian = Gaussian(log_normalizer, info_vec, precision) p_log_Z = p_gaussian.event_logsumexp() p_rsamples = p_gaussian.rsample((num_samples, )) pp_entropy = (p_log_Z - p_gaussian.log_density(p_rsamples)).mean(0) check_equal(pp_entropy, entropy) # Monte carlo estimate entropy via funsor. inputs = OrderedDict([(k, funsor.Bint[v]) for k, v in zip("ij", batch_shape)]) inputs["x"] = funsor.Reals[3] f_gaussian = funsor.gaussian.Gaussian(mean=loc, precision=precision, inputs=inputs) f_log_Z = f_gaussian.reduce(funsor.ops.logaddexp, "x") sample_inputs = OrderedDict(particle=funsor.Bint[num_samples]) deltas = f_gaussian.sample("x", sample_inputs) f_rsamples = funsor.montecarlo.extract_samples(deltas)["x"] ff_entropy = (f_log_Z - f_gaussian(x=f_rsamples)).reduce( funsor.ops.mean, "particle") check_equal(ff_entropy.data, entropy) # Check Funsor's .rsample against Pyro's .log_prob. pf_entropy = (p_log_Z - p_gaussian.log_density(f_rsamples.data)).mean(0) check_equal(pf_entropy, entropy) # Check Pyro's .rsample against Funsor's .log_prob. fp_rsamples = funsor.Tensor(p_rsamples)["particle"] for i in "ij"[:len(batch_shape)]: fp_rsamples = fp_rsamples[i] fp_entropy = (f_log_Z - f_gaussian(x=fp_rsamples)).reduce( funsor.ops.mean, "particle") check_equal(fp_entropy.data, entropy)
def compute_probs(self) -> torch.Tensor: theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F, self.Q) nbatch_size = self.nbatch_size N = sum(self.data.is_ontarget) for ndx in torch.split(torch.arange(N), nbatch_size): self.n = ndx self.nbatch_size = len(ndx) with torch.no_grad(), pyro.plate( "particles", size=5, dim=-4), handlers.enum(first_available_dim=-5): guide_tr = handlers.trace(self.guide).get_trace() model_tr = handlers.trace( handlers.replay(self.model, trace=guide_tr)).get_trace() model_tr.compute_log_prob() guide_tr.compute_log_prob() logp = {} result = {} for fsx in ("0", f"slice(1, {self.data.F}, None)"): logp[fsx] = 0 # collect log_prob terms p(z, theta, phi) for name in [ "z", "theta", "m_k0", "m_k1", "x_k0", "x_k1", "y_k0", "y_k1", ]: logp[fsx] += model_tr.nodes[f"{name}_f{fsx}"]["funsor"][ "log_prob"] if fsx == "0": # substitute MAP values of z into p(z=z_map, theta, phi) z_map = funsor.Tensor(self.z_map[ndx, 0].long(), dtype=2)["aois", "channels"] logp[fsx] = logp[fsx](**{f"z_f{fsx}": z_map}) # compute log_measure q for given z_map log_measure = ( guide_tr.nodes[f"m_k0_f{fsx}"]["funsor"]["log_measure"] + guide_tr.nodes[f"m_k1_f{fsx}"]["funsor"]["log_measure"] ) log_measure = log_measure(**{f"z_f{fsx}": z_map}) else: # substitute MAP values of z into p(z=z_map, theta, phi) z_map = funsor.Tensor(self.z_map[ndx, 1:].long(), dtype=2)["aois", "frames", "channels"] z_map_prev = funsor.Tensor(self.z_map[ndx, :-1].long(), dtype=2)["aois", "frames", "channels"] fsx_prev = f"slice(0, {self.data.F-1}, None)" logp[fsx] = logp[fsx](**{ f"z_f{fsx}": z_map, f"z_f{fsx_prev}": z_map_prev }) # compute log_measure q for given z_map log_measure = ( guide_tr.nodes[f"m_k0_f{fsx}"]["funsor"]["log_measure"] + guide_tr.nodes[f"m_k1_f{fsx}"]["funsor"]["log_measure"] ) log_measure = log_measure(**{ f"z_f{fsx}": z_map, f"z_f{fsx_prev}": z_map_prev }) # compute p(z_map, theta | phi) = p(z_map, theta, phi) - p(z_map, phi) logp[fsx] = logp[fsx] - logp[fsx].reduce( funsor.ops.logaddexp, f"theta_f{fsx}") # average over m in p * q result[fsx] = (logp[fsx] + log_measure).reduce( funsor.ops.logaddexp, frozenset({f"m_k0_f{fsx}", f"m_k1_f{fsx}"})) # average over particles result[fsx] = result[fsx].exp().reduce(funsor.ops.mean, "particles") theta_probs[:, ndx, 0] = result["0"].data[..., 1:].permute(2, 0, 1) theta_probs[:, ndx, 1:] = ( result[f"slice(1, {self.data.F}, None)"].data[..., 1:].permute( 3, 0, 1, 2)) self.n = None self.nbatch_size = nbatch_size return theta_probs