def test_to_funsor(shape, dtype): t = np.random.normal(size=shape).astype(dtype) f = funsor.to_funsor(t) assert isinstance(f, Array) assert funsor.to_funsor(t, reals(*shape)) is f with pytest.raises(ValueError): funsor.to_funsor(t, reals(5, *shape))
def testing(): for i in markov(range(5)): v1 = to_data(Tensor(jnp.ones(2), OrderedDict([(str(i), bint(2))]), 'real')) v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real')) fv1 = to_funsor(v1, reals()) fv2 = to_funsor(v2, reals()) print(i, v1.shape) # shapes should alternate if i % 2 == 0: assert v1.shape == (2,) else: assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) print('a', v2.shape) # shapes should stay the same print('a', fv2.inputs)
def model(data): log_prob = funsor.Number(0.) # s is the discrete latent state, # x is the continuous latent state, # y is the observed state. s_curr = funsor.Tensor(torch.tensor(0), dtype=2) x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): s_prev = s_curr x_prev = x_curr # A delayed sample statement. s_curr = funsor.Variable('s_{}'.format(t), funsor.bint(2)) log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr) # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr) # Marginalize out previous delayed sample statements. if t > 0: log_prob = log_prob.reduce(ops.logaddexp, {s_prev.name, x_prev.name}) # An observe statement. log_prob += dist.Normal(x_curr, emit_noise, value=y) log_prob = log_prob.reduce(ops.logaddexp) return log_prob
def model(data): log_prob = funsor.to_funsor(0.) xs_curr = [funsor.Tensor(torch.tensor(0.)) for var in var_names] for t, y in enumerate(data): xs_prev = xs_curr # A delayed sample statement. xs_curr = [ funsor.Variable(name + '_{}'.format(t), funsor.reals()) for name in var_names ] for i, x_curr in enumerate(xs_curr): log_prob += dist.Normal(trans_eqs[var_names[i]](xs_prev), torch.exp(trans_noises[i]), value=x_curr) if t > 0: log_prob = log_prob.reduce( ops.logaddexp, frozenset([x_prev.name for x_prev in xs_prev])) # An observe statement. log_prob += dist.Normal(emit_eq(xs_curr), torch.exp(emit_noise), value=y) # Marginalize out all remaining delayed variables. return log_prob.reduce(ops.logaddexp), log_prob.gaussian
def one_step_prediction(p_x_tp1, t, var_names, emit_eq, emit_noise): """Computes p(y_{t+1}) from p(x_{t+1}). We assume y_t is scalar, so only one emit_eq""" log_prob = p_x_tp1 x_tp1s = [ funsor.Variable(name + '_{}'.format(t + 1), funsor.reals()) for name in var_names ] y_tp1 = funsor.Variable('y_{}'.format(t + 1), funsor.reals()) log_prob += dist.Normal(emit_eq(x_tp1s), torch.exp(emit_noise), value=y_tp1) log_prob = log_prob.reduce(ops.logaddexp, frozenset([x_tp1.name for x_tp1 in x_tp1s])) return log_prob
def testing(): for i in markov(range(12)): if i % 4 == 0: v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real')) fv2 = to_funsor(v2, reals()) assert v2.shape == (2,) print('a', v2.shape) print('a', fv2.inputs)
def log_prob(self, data): trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists( ) log_prob = funsor.Number(0.) s_vars = { -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components) } x_vars = {} for t, y in enumerate(data): # construct free variables for s_t and x_t s_vars[t] = funsor.Variable(f's_{t}', funsor.bint(self.num_components)) x_vars[t] = funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) # incorporate the discrete switching dynamics log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]), value=s_vars[t]) # incorporate the prior term p(x_t | x_{t-1}) if t == 0: log_prob += self.x_init_mvn(value=x_vars[t]) else: log_prob += x_trans_dist(s=s_vars[t], x=x_vars[t - 1], y=x_vars[t]) # do a moment-matching reduction. at this point log_prob depends on (moment_matching_lag + 1)-many # pairs of free variables. if t > self.moment_matching_lag - 1: log_prob = log_prob.reduce( ops.logaddexp, frozenset([ s_vars[t - self.moment_matching_lag].name, x_vars[t - self.moment_matching_lag].name ])) # incorporate the observation p(y_t | x_t, s_t) log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y) T = data.shape[0] # reduce any remaining free variables for t in range(self.moment_matching_lag): log_prob = log_prob.reduce( ops.logaddexp, frozenset([ s_vars[T - self.moment_matching_lag + t].name, x_vars[T - self.moment_matching_lag + t].name ])) # assert that we've reduced all the free variables in log_prob assert not log_prob.inputs, 'unexpected free variables remain' # return the PyTorch tensor behind log_prob (which we can directly differentiate) return log_prob.data
def compiled(*params_and_args): unconstrained_params = params_and_args[:len(self._param_trace)] args = params_and_args[len(self._param_trace):] for name, unconstrained_param in zip(self._param_trace, unconstrained_params): constrained_param = param(name) # assume param has been initialized assert constrained_param.data.unconstrained() is unconstrained_param self._param_trace[name]["value"] = constrained_param result = replay(self.fn, guide_trace=self._param_trace)(*args) assert not result.inputs assert result.output == funsor.reals() return funsor.to_data(result)
def main(args): # Generate fake data. data = funsor.Tensor(torch.randn(100), inputs=OrderedDict([('data', funsor.bint(100))]), output=funsor.reals()) # Train. optim = pyro.Adam({'lr': args.learning_rate}) svi = pyro.SVI(model, pyro.deferred(guide), optim, pyro.elbo) for step in range(args.steps): svi.step(data)
def next_state(p_x_t, t, var_names, trans_eqs, trans_noises): """Computes p(x_{t+1}) from p(x_t)""" log_prob = p_x_t x_ts = [ funsor.Variable(name + '_{}'.format(t), funsor.reals()) for name in var_names ] x_tp1s = [ funsor.Variable(name + '_{}'.format(t + 1), funsor.reals()) for name in var_names ] for i, x_tp1 in enumerate(x_tp1s): log_prob += dist.Normal(trans_eqs[var_names[i]](x_ts), torch.exp(trans_noises[i]), value=x_tp1) log_prob = log_prob.reduce(ops.logaddexp, frozenset([x_t.name for x_t in x_ts])) return log_prob
def test_advanced_indexing_array(output_shape): # u v # / \ / \ # i j k # \ | / # \ | / # x output = reals(*output_shape) x = random_array( OrderedDict([ ('i', bint(2)), ('j', bint(3)), ('k', bint(4)), ]), output) i = random_array(OrderedDict([ ('u', bint(5)), ]), bint(2)) j = random_array(OrderedDict([ ('v', bint(6)), ('u', bint(5)), ]), bint(3)) k = random_array(OrderedDict([ ('v', bint(6)), ]), bint(4)) expected_data = np.empty((5, 6) + output_shape) for u in range(5): for v in range(6): expected_data[u, v] = x.data[i.data[u], j.data[v, u], k.data[v]] expected = Array(expected_data, OrderedDict([ ('u', bint(5)), ('v', bint(6)), ])) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) assert_equiv(expected, x(i=i, j=j)(k=k)) assert_equiv(expected, x(j=j, k=k)(i=i)) assert_equiv(expected, x(k=k, i=i)(j=j)) assert_equiv(expected, x(i=i)(j=j, k=k)) assert_equiv(expected, x(j=j)(k=k, i=i)) assert_equiv(expected, x(k=k)(i=i, j=j)) assert_equiv(expected, x(i=i)(j=j)(k=k)) assert_equiv(expected, x(i=i)(k=k)(j=j)) assert_equiv(expected, x(j=j)(i=i)(k=k)) assert_equiv(expected, x(j=j)(k=k)(i=i)) assert_equiv(expected, x(k=k)(i=i)(j=j)) assert_equiv(expected, x(k=k)(j=j)(i=i))
def param(name, init_value=None, constraint=torch.distributions.constraints.real, event_dim=None): cond_indep_stack = {} output = None if init_value is not None: if event_dim is None: event_dim = init_value.dim() output = funsor.reals(*init_value.shape[init_value.dim() - event_dim:]) def fn(init_value, constraint): if name in PARAM_STORE: unconstrained_value, constraint = PARAM_STORE[name] else: # Initialize with a constrained value. assert init_value is not None with torch.no_grad(): constrained_value = init_value.detach() unconstrained_value = torch.distributions.transform_to( constraint).inv(constrained_value) unconstrained_value.requires_grad_() unconstrained_value._funsor_metadata = (cond_indep_stack, output) PARAM_STORE[name] = unconstrained_value, constraint # Transform from unconstrained space to constrained space. constrained_value = torch.distributions.transform_to(constraint)( unconstrained_value) constrained_value.unconstrained = weakref.ref(unconstrained_value) return tensor_to_funsor(constrained_value, *unconstrained_value._funsor_metadata) # if there are no active Messengers, we just draw a sample and return it as expected: if not PYRO_STACK: return fn(init_value, constraint) # Otherwise, we initialize a message... initial_msg = { "type": "param", "name": name, "fn": fn, "args": (init_value, constraint), "value": None, "cond_indep_stack": cond_indep_stack, # maps dim to CondIndepStackFrame "output": output, } # ...and use apply_stack to send it to the Messengers msg = apply_stack(initial_msg) assert isinstance(msg["value"], funsor.Funsor) return msg["value"]
def __init__(self, num_components, # the number of switching states K hidden_dim, # the dimension of the continuous latent space obs_dim, # the dimension of the continuous outputs fine_transition_matrix=True, # controls whether the transition matrix depends on s_t fine_transition_noise=False, # controls whether the transition noise depends on s_t fine_observation_matrix=False, # controls whether the observation matrix depends on s_t fine_observation_noise=False, # controls whether the observation noise depends on s_t moment_matching_lag=1): # controls the expense of the moment matching approximation self.num_components = num_components self.hidden_dim = hidden_dim self.obs_dim = obs_dim self.moment_matching_lag = moment_matching_lag self.fine_transition_noise = fine_transition_noise self.fine_observation_matrix = fine_observation_matrix self.fine_observation_noise = fine_observation_noise self.fine_transition_matrix = fine_transition_matrix assert moment_matching_lag > 0 assert fine_transition_noise or fine_observation_matrix or fine_observation_noise or fine_transition_matrix, \ "The continuous dynamics need to be coupled to the discrete dynamics in at least one way [use at " + \ "least one of the arguments --ftn --ftm --fon --fom]" super(SLDS, self).__init__() # initialize the various parameters of the model self.transition_logits = nn.Parameter(0.1 * torch.randn(num_components, num_components)) if fine_transition_matrix: transition_matrix = torch.eye(hidden_dim) + 0.05 * torch.randn(num_components, hidden_dim, hidden_dim) else: transition_matrix = torch.eye(hidden_dim) + 0.05 * torch.randn(hidden_dim, hidden_dim) self.transition_matrix = nn.Parameter(transition_matrix) if fine_transition_noise: self.log_transition_noise = nn.Parameter(0.1 * torch.randn(num_components, hidden_dim)) else: self.log_transition_noise = nn.Parameter(0.1 * torch.randn(hidden_dim)) if fine_observation_matrix: self.observation_matrix = nn.Parameter(0.3 * torch.randn(num_components, hidden_dim, obs_dim)) else: self.observation_matrix = nn.Parameter(0.3 * torch.randn(hidden_dim, obs_dim)) if fine_observation_noise: self.log_obs_noise = nn.Parameter(0.1 * torch.randn(num_components, obs_dim)) else: self.log_obs_noise = nn.Parameter(0.1 * torch.randn(obs_dim)) # define the prior distribution p(x_0) over the continuous latent at the initial time step t=0 x_init_mvn = torch.distributions.MultivariateNormal(torch.zeros(self.hidden_dim), torch.eye(self.hidden_dim)) self.x_init_mvn = mvn_to_funsor(x_init_mvn, real_inputs=OrderedDict([('x_0', funsor.reals(self.hidden_dim))]))
def update(p_x_tp1, t, y, var_names, emit_eq, emit_noise): """Computes p(x_{t+1} | y_{t+1}) from p(x_{t+1}). This is useful for iterating 1-step ahead predictions""" log_prob = p_x_tp1 x_tp1s = [ funsor.Variable(name + '_{}'.format(t + 1), funsor.reals()) for name in var_names ] log_p_x = log_prob log_prob += dist.Normal(emit_eq(x_tp1s), emit_noise, value=y) log_p_y = log_prob.reduce(ops.logaddexp, frozenset([x_tp1.name for x_tp1 in x_tp1s])) log_p_x_y = log_prob + log_p_x - log_p_y return log_p_x_y
def model(data): log_prob = funsor.to_funsor(0.) x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr) # Optionally marginalize out the previous state. if t > 0 and not args.lazy: log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) # An observe statement. log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y) # Marginalize out all remaining delayed variables. log_prob = log_prob.reduce(ops.logaddexp) return log_prob
def test_advanced_indexing_shape(): I, J, M, N = 4, 4, 2, 3 x = Array(np.random.normal(size=(I, J)), OrderedDict([ ('i', bint(I)), ('j', bint(J)), ])) m = Array(np.array([2, 3]), OrderedDict([('m', bint(M))]), I) n = Array(np.array([0, 1, 1]), OrderedDict([('n', bint(N))]), J) assert x.data.shape == (I, J) check_funsor(x(i=m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(i=m, j=n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(i=m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(i=m, k=m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(i=n), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(i=n, k=m), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(j=m), {'i': bint(I), 'm': bint(M)}, reals()) check_funsor(x(j=m, i=n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(j=m, i=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(j=m, k=m), {'i': bint(I), 'm': bint(M)}, reals()) check_funsor(x(j=n), {'i': bint(I), 'n': bint(N)}, reals()) check_funsor(x(j=n, k=m), {'i': bint(I), 'n': bint(N)}, reals()) check_funsor(x(m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(m, j=n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(m, k=m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(m, n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(m, n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(n), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(n, k=m), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(n, m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(n, m, k=m), {'m': bint(M), 'n': bint(N)}, reals())
def test_indexing(): data = np.random.normal(size=(4, 5)) inputs = OrderedDict([('i', bint(4)), ('j', bint(5))]) x = Array(data, inputs) check_funsor(x, inputs, reals(), data) assert x() is x assert x(k=3) is x check_funsor(x(1), {'j': bint(5)}, reals(), data[1]) check_funsor(x(1, 2), {}, reals(), data[1, 2]) check_funsor(x(1, 2, k=3), {}, reals(), data[1, 2]) check_funsor(x(1, j=2), {}, reals(), data[1, 2]) check_funsor(x(1, j=2, k=3), (), reals(), data[1, 2]) check_funsor(x(1, k=3), {'j': bint(5)}, reals(), data[1]) check_funsor(x(i=1), {'j': bint(5)}, reals(), data[1]) check_funsor(x(i=1, j=2), (), reals(), data[1, 2]) check_funsor(x(i=1, j=2, k=3), (), reals(), data[1, 2]) check_funsor(x(i=1, k=3), {'j': bint(5)}, reals(), data[1]) check_funsor(x(j=2), {'i': bint(4)}, reals(), data[:, 2]) check_funsor(x(j=2, k=3), {'i': bint(4)}, reals(), data[:, 2])
def filter_and_predict(self, data, smoothing=False): trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists( ) log_prob = funsor.Number(0.) s_vars = { -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components) } x_vars = {-1: None} predictive_x_dists, predictive_y_dists, filtering_dists = [], [], [] test_LLs = [] for t, y in enumerate(data): s_vars[t] = funsor.Variable(f's_{t}', funsor.bint(self.num_components)) x_vars[t] = funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]), value=s_vars[t]) if t == 0: log_prob += self.x_init_mvn(value=x_vars[t]) else: log_prob += x_trans_dist(s=s_vars[t], x=x_vars[t - 1], y=x_vars[t]) if t > 0: log_prob = log_prob.reduce( ops.logaddexp, frozenset([s_vars[t - 1].name, x_vars[t - 1].name])) # do 1-step prediction and compute test LL if t > 0: predictive_x_dists.append(log_prob) _log_prob = log_prob - log_prob.reduce(ops.logaddexp) predictive_y_dist = y_dist(s=s_vars[t], x=x_vars[t]) + _log_prob test_LLs.append( predictive_y_dist(y=y).reduce(ops.logaddexp).data.item()) predictive_y_dist = predictive_y_dist.reduce( ops.logaddexp, frozenset([f"x_{t}", f"s_{t}"])) predictive_y_dists.append( funsor_to_mvn(predictive_y_dist, 0, ())) log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y) # save filtering dists for forward-backward smoothing if smoothing: filtering_dists.append(log_prob) # do the backward recursion using previously computed ingredients if smoothing: # seed the backward recursion with the filtering distribution at t=T smoothing_dists = [filtering_dists[-1]] T = data.size(0) s_vars = { t: funsor.Variable(f's_{t}', funsor.bint(self.num_components)) for t in range(T) } x_vars = { t: funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) for t in range(T) } # do the backward recursion. # let p[t|t-1] be the predictive distribution at time step t. # let p[t|t] be the filtering distribution at time step t. # let f[t] denote the prior (transition) density at time step t. # then the smoothing distribution p[t|T] at time step t is # given by the following recursion. # p[t-1|T] = p[t-1|t-1] <p[t|T] f[t] / p[t|t-1]> # where <...> denotes integration of the latent variables at time step t. for t in reversed(range(T - 1)): integral = smoothing_dists[-1] - predictive_x_dists[t] integral += dist.Categorical(trans_probs(s=s_vars[t]), value=s_vars[t + 1]) integral += x_trans_dist(s=s_vars[t], x=x_vars[t], y=x_vars[t + 1]) integral = integral.reduce( ops.logaddexp, frozenset([s_vars[t + 1].name, x_vars[t + 1].name])) smoothing_dists.append(filtering_dists[t] + integral) # compute predictive test MSE and predictive variances predictive_means = torch.stack([d.mean for d in predictive_y_dists ]) # T-1 ydim predictive_vars = torch.stack([ d.covariance_matrix.diagonal(dim1=-1, dim2=-2) for d in predictive_y_dists ]) predictive_mse = (predictive_means - data[1:, :]).pow(2.0).mean(-1) if smoothing: # compute smoothed mean function smoothing_dists = [ funsor_to_cat_and_mvn(d, 0, (f"s_{t}", )) for t, d in enumerate(reversed(smoothing_dists)) ] means = torch.stack([d[1].mean for d in smoothing_dists]) # T 2 xdim means = torch.matmul(means.unsqueeze(-2), self.observation_matrix).squeeze( -2) # T 2 ydim probs = torch.stack([d[0].logits for d in smoothing_dists]).exp() probs = probs / probs.sum(-1, keepdim=True) # T 2 smoothing_means = (probs.unsqueeze(-1) * means).sum(-2) # T ydim smoothing_probs = probs[:, 1] return predictive_mse, torch.tensor(np.array(test_LLs)), predictive_means, predictive_vars, \ smoothing_means, smoothing_probs else: return predictive_mse, torch.tensor(np.array(test_LLs))
def log_density(model, model_args, model_kwargs, params): """ Similar to :func:`numpyro.infer.util.log_density` but works for models with discrete latent variables. Internally, this uses :mod:`funsor` to marginalize discrete latent sites and evaluate the joint log probability. :param model: Python callable containing NumPyro primitives. Typically, the model has been enumerated by using :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler:: def model(*args, **kwargs): ... log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params) :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, data=params) with plate_to_enum_plate(): model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_factors = [] time_to_factors = defaultdict(list) # log prob factors time_to_init_vars = defaultdict(frozenset) # _init/... variables time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites sum_vars, prod_vars = frozenset(), frozenset() for site in model_trace.values(): if site['type'] == 'sample': value = site['value'] intermediates = site['intermediates'] scale = site['scale'] if intermediates: log_prob = site['fn'].log_prob(value, intermediates) else: log_prob = site['fn'].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob dim_to_name = site["infer"]["dim_to_name"] log_prob = funsor.to_funsor(log_prob, output=funsor.reals(), dim_to_name=dim_to_name) time_dim = None for dim, name in dim_to_name.items(): if name.startswith("_time"): time_dim = funsor.Variable( name, funsor.domains.bint(site["value"].shape[dim])) time_to_factors[time_dim].append(log_prob) time_to_init_vars[time_dim] |= frozenset( s for s in dim_to_name.values() if s.startswith("_init")) break if time_dim is None: log_factors.append(log_prob) if not site['is_observed']: sum_vars |= frozenset({site['name']}) prod_vars |= frozenset(f.name for f in site['cond_indep_stack'] if f.dim is not None) for time_dim, init_vars in time_to_init_vars.items(): for var in init_vars: curr_var = "/".join(var.split("/")[1:]) dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"] if var in dim_to_name.values( ): # i.e. _init (i.e. prev) in dim_to_name time_to_markov_dims[time_dim] |= frozenset( name for name in dim_to_name.values()) if len(time_to_factors) > 0: markov_factors = compute_markov_factors(time_to_factors, time_to_init_vars, time_to_markov_dims, sum_vars, prod_vars) log_factors = log_factors + markov_factors with funsor.interpreter.interpretation(funsor.terms.lazy): lazy_result = funsor.sum_product.sum_product(funsor.ops.logaddexp, funsor.ops.add, log_factors, eliminate=sum_vars | prod_vars, plates=prod_vars) result = funsor.optimizer.apply_optimizer(lazy_result) if len(result.inputs) > 0: raise ValueError( "Expected the joint log density is a scalar, but got {}. " "There seems to be something wrong at the following sites: {}.". format(result.data.shape, {k.split("__BOUND")[0] for k in result.inputs})) return result.data, model_trace