def model(data): log_prob = funsor.to_funsor(0.) trans = dist.Categorical(probs=funsor.Tensor( trans_probs, inputs=OrderedDict([('prev', funsor.bint(args.hidden_dim))]), )) emit = dist.Categorical(probs=funsor.Tensor( emit_probs, inputs=OrderedDict([('latent', funsor.bint(args.hidden_dim))]), )) x_curr = funsor.Number(0, args.hidden_dim) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.bint(args.hidden_dim)) log_prob += trans(prev=x_prev, value=x_curr) if not args.lazy and isinstance(x_prev, funsor.Variable): log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2)) log_prob = log_prob.reduce(ops.logaddexp) return log_prob
def model(data): log_prob = funsor.Number(0.) # s is the discrete latent state, # x is the continuous latent state, # y is the observed state. s_curr = funsor.Tensor(torch.tensor(0), dtype=2) x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): s_prev = s_curr x_prev = x_curr # A delayed sample statement. s_curr = funsor.Variable('s_{}'.format(t), funsor.bint(2)) log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr) # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr) # Marginalize out previous delayed sample statements. if t > 0: log_prob = log_prob.reduce(ops.logaddexp, {s_prev.name, x_prev.name}) # An observe statement. log_prob += dist.Normal(x_curr, emit_noise, value=y) log_prob = log_prob.reduce(ops.logaddexp) return log_prob
def log_prob(self, data): trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists( ) log_prob = funsor.Number(0.) s_vars = { -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components) } x_vars = {} for t, y in enumerate(data): # construct free variables for s_t and x_t s_vars[t] = funsor.Variable(f's_{t}', funsor.bint(self.num_components)) x_vars[t] = funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) # incorporate the discrete switching dynamics log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]), value=s_vars[t]) # incorporate the prior term p(x_t | x_{t-1}) if t == 0: log_prob += self.x_init_mvn(value=x_vars[t]) else: log_prob += x_trans_dist(s=s_vars[t], x=x_vars[t - 1], y=x_vars[t]) # do a moment-matching reduction. at this point log_prob depends on (moment_matching_lag + 1)-many # pairs of free variables. if t > self.moment_matching_lag - 1: log_prob = log_prob.reduce( ops.logaddexp, frozenset([ s_vars[t - self.moment_matching_lag].name, x_vars[t - self.moment_matching_lag].name ])) # incorporate the observation p(y_t | x_t, s_t) log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y) T = data.shape[0] # reduce any remaining free variables for t in range(self.moment_matching_lag): log_prob = log_prob.reduce( ops.logaddexp, frozenset([ s_vars[T - self.moment_matching_lag + t].name, x_vars[T - self.moment_matching_lag + t].name ])) # assert that we've reduced all the free variables in log_prob assert not log_prob.inputs, 'unexpected free variables remain' # return the PyTorch tensor behind log_prob (which we can directly differentiate) return log_prob.data
def filter_and_predict(self, data, smoothing=False): trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists( ) log_prob = funsor.Number(0.) s_vars = { -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components) } x_vars = {-1: None} predictive_x_dists, predictive_y_dists, filtering_dists = [], [], [] test_LLs = [] for t, y in enumerate(data): s_vars[t] = funsor.Variable(f's_{t}', funsor.bint(self.num_components)) x_vars[t] = funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]), value=s_vars[t]) if t == 0: log_prob += self.x_init_mvn(value=x_vars[t]) else: log_prob += x_trans_dist(s=s_vars[t], x=x_vars[t - 1], y=x_vars[t]) if t > 0: log_prob = log_prob.reduce( ops.logaddexp, frozenset([s_vars[t - 1].name, x_vars[t - 1].name])) # do 1-step prediction and compute test LL if t > 0: predictive_x_dists.append(log_prob) _log_prob = log_prob - log_prob.reduce(ops.logaddexp) predictive_y_dist = y_dist(s=s_vars[t], x=x_vars[t]) + _log_prob test_LLs.append( predictive_y_dist(y=y).reduce(ops.logaddexp).data.item()) predictive_y_dist = predictive_y_dist.reduce( ops.logaddexp, frozenset([f"x_{t}", f"s_{t}"])) predictive_y_dists.append( funsor_to_mvn(predictive_y_dist, 0, ())) log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y) # save filtering dists for forward-backward smoothing if smoothing: filtering_dists.append(log_prob) # do the backward recursion using previously computed ingredients if smoothing: # seed the backward recursion with the filtering distribution at t=T smoothing_dists = [filtering_dists[-1]] T = data.size(0) s_vars = { t: funsor.Variable(f's_{t}', funsor.bint(self.num_components)) for t in range(T) } x_vars = { t: funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) for t in range(T) } # do the backward recursion. # let p[t|t-1] be the predictive distribution at time step t. # let p[t|t] be the filtering distribution at time step t. # let f[t] denote the prior (transition) density at time step t. # then the smoothing distribution p[t|T] at time step t is # given by the following recursion. # p[t-1|T] = p[t-1|t-1] <p[t|T] f[t] / p[t|t-1]> # where <...> denotes integration of the latent variables at time step t. for t in reversed(range(T - 1)): integral = smoothing_dists[-1] - predictive_x_dists[t] integral += dist.Categorical(trans_probs(s=s_vars[t]), value=s_vars[t + 1]) integral += x_trans_dist(s=s_vars[t], x=x_vars[t], y=x_vars[t + 1]) integral = integral.reduce( ops.logaddexp, frozenset([s_vars[t + 1].name, x_vars[t + 1].name])) smoothing_dists.append(filtering_dists[t] + integral) # compute predictive test MSE and predictive variances predictive_means = torch.stack([d.mean for d in predictive_y_dists ]) # T-1 ydim predictive_vars = torch.stack([ d.covariance_matrix.diagonal(dim1=-1, dim2=-2) for d in predictive_y_dists ]) predictive_mse = (predictive_means - data[1:, :]).pow(2.0).mean(-1) if smoothing: # compute smoothed mean function smoothing_dists = [ funsor_to_cat_and_mvn(d, 0, (f"s_{t}", )) for t, d in enumerate(reversed(smoothing_dists)) ] means = torch.stack([d[1].mean for d in smoothing_dists]) # T 2 xdim means = torch.matmul(means.unsqueeze(-2), self.observation_matrix).squeeze( -2) # T 2 ydim probs = torch.stack([d[0].logits for d in smoothing_dists]).exp() probs = probs / probs.sum(-1, keepdim=True) # T 2 smoothing_means = (probs.unsqueeze(-1) * means).sum(-2) # T ydim smoothing_probs = probs[:, 1] return predictive_mse, torch.tensor(np.array(test_LLs)), predictive_means, predictive_vars, \ smoothing_means, smoothing_probs else: return predictive_mse, torch.tensor(np.array(test_LLs))