def testing(): with markov(): v1 = to_data( Tensor(jnp.ones(2), OrderedDict([("1", bint(2))]), 'real')) print(1, v1.shape) # shapes should alternate assert v1.shape == (2, ) with markov(): v2 = to_data( Tensor(jnp.ones(2), OrderedDict([("2", bint(2))]), 'real')) print(2, v2.shape) # shapes should alternate assert v2.shape == (2, 1) with markov(): v3 = to_data( Tensor(jnp.ones(2), OrderedDict([("3", bint(2))]), 'real')) print(3, v3.shape) # shapes should alternate assert v3.shape == (2, ) with markov(): v4 = to_data( Tensor(jnp.ones(2), OrderedDict([("4", bint(2))]), 'real')) print(4, v4.shape) # shapes should alternate assert v4.shape == (2, 1)
def model(data): log_prob = funsor.to_funsor(0.) trans = dist.Categorical(probs=funsor.Tensor( trans_probs, inputs=OrderedDict([('prev', funsor.bint(args.hidden_dim))]), )) emit = dist.Categorical(probs=funsor.Tensor( emit_probs, inputs=OrderedDict([('latent', funsor.bint(args.hidden_dim))]), )) x_curr = funsor.Number(0, args.hidden_dim) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.bint(args.hidden_dim)) log_prob += trans(prev=x_prev, value=x_curr) if not args.lazy and isinstance(x_prev, funsor.Variable): log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2)) log_prob = log_prob.reduce(ops.logaddexp) return log_prob
def expand_inputs(self, name, size): if name in self.funsor_dist.inputs: assert self.funsor_dist.inputs[name] == funsor.bint(int(size)) return self inputs = OrderedDict([(name, funsor.bint(int(size)))]) if self.sample_inputs: inputs.update(self.sample_inputs) return Distribution(self.funsor_dist, sample_inputs=inputs)
def expand_inputs(self, name, size): if name in self.funsor_dist.inputs: assert self.funsor_dist.inputs[name] == funsor.bint(int(size)) return self inputs = OrderedDict([(name, funsor.bint(int(size)))]) funsor_dist = self.funsor_dist + funsor.torch.Tensor( torch.zeros(size), inputs) return Distribution(funsor_dist)
def test_align(): x = Array(np.random.randn(2, 3, 4), OrderedDict([ ('i', bint(2)), ('j', bint(3)), ('k', bint(4)), ])) y = x.align(('j', 'k', 'i')) assert isinstance(y, Array) assert tuple(y.inputs) == ('j', 'k', 'i') for i in range(2): for j in range(3): for k in range(4): assert x(i=i, j=j, k=k) == y(i=i, j=j, k=k)
def testing(): for i in markov(range(5)): v1 = to_data(Tensor(jnp.ones(2), OrderedDict([(str(i), bint(2))]), 'real')) v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real')) fv1 = to_funsor(v1, reals()) fv2 = to_funsor(v2, reals()) print(i, v1.shape) # shapes should alternate if i % 2 == 0: assert v1.shape == (2,) else: assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) print('a', v2.shape) # shapes should stay the same print('a', fv2.inputs)
def get_tensors_and_dists(self): # normalize the transition probabilities trans_logits = self.transition_logits - self.transition_logits.logsumexp( dim=-1, keepdim=True) trans_probs = funsor.Tensor( trans_logits, OrderedDict([("s", funsor.bint(self.num_components))])) trans_mvn = torch.distributions.MultivariateNormal( torch.zeros(self.hidden_dim), self.log_transition_noise.exp().diag_embed()) obs_mvn = torch.distributions.MultivariateNormal( torch.zeros(self.obs_dim), self.log_obs_noise.exp().diag_embed()) event_dims = ( "s", ) if self.fine_transition_matrix or self.fine_transition_noise else () x_trans_dist = matrix_and_mvn_to_funsor(self.transition_matrix, trans_mvn, event_dims, "x", "y") event_dims = ( "s", ) if self.fine_observation_matrix or self.fine_observation_noise else ( ) y_dist = matrix_and_mvn_to_funsor(self.observation_matrix, obs_mvn, event_dims, "x", "y") return trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist
def model(data): log_prob = funsor.Number(0.) # s is the discrete latent state, # x is the continuous latent state, # y is the observed state. s_curr = funsor.Tensor(torch.tensor(0), dtype=2) x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): s_prev = s_curr x_prev = x_curr # A delayed sample statement. s_curr = funsor.Variable('s_{}'.format(t), funsor.bint(2)) log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr) # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr) # Marginalize out previous delayed sample statements. if t > 0: log_prob = log_prob.reduce(ops.logaddexp, {s_prev.name, x_prev.name}) # An observe statement. log_prob += dist.Normal(x_curr, emit_noise, value=y) log_prob = log_prob.reduce(ops.logaddexp) return log_prob
def testing(): for i in markov(range(12)): if i % 4 == 0: v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real')) fv2 = to_funsor(v2, reals()) assert v2.shape == (2,) print('a', v2.shape) print('a', fv2.inputs)
def log_prob(self, data): trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists( ) log_prob = funsor.Number(0.) s_vars = { -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components) } x_vars = {} for t, y in enumerate(data): # construct free variables for s_t and x_t s_vars[t] = funsor.Variable(f's_{t}', funsor.bint(self.num_components)) x_vars[t] = funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) # incorporate the discrete switching dynamics log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]), value=s_vars[t]) # incorporate the prior term p(x_t | x_{t-1}) if t == 0: log_prob += self.x_init_mvn(value=x_vars[t]) else: log_prob += x_trans_dist(s=s_vars[t], x=x_vars[t - 1], y=x_vars[t]) # do a moment-matching reduction. at this point log_prob depends on (moment_matching_lag + 1)-many # pairs of free variables. if t > self.moment_matching_lag - 1: log_prob = log_prob.reduce( ops.logaddexp, frozenset([ s_vars[t - self.moment_matching_lag].name, x_vars[t - self.moment_matching_lag].name ])) # incorporate the observation p(y_t | x_t, s_t) log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y) T = data.shape[0] # reduce any remaining free variables for t in range(self.moment_matching_lag): log_prob = log_prob.reduce( ops.logaddexp, frozenset([ s_vars[T - self.moment_matching_lag + t].name, x_vars[T - self.moment_matching_lag + t].name ])) # assert that we've reduced all the free variables in log_prob assert not log_prob.inputs, 'unexpected free variables remain' # return the PyTorch tensor behind log_prob (which we can directly differentiate) return log_prob.data
def main(args): # Generate fake data. data = funsor.Tensor(torch.randn(100), inputs=OrderedDict([('data', funsor.bint(100))]), output=funsor.reals()) # Train. optim = pyro.Adam({'lr': args.learning_rate}) svi = pyro.SVI(model, pyro.deferred(guide), optim, pyro.elbo) for step in range(args.steps): svi.step(data)
def __init__(self, name, size, subsample_size=None, dim=None): self.name = name self.size = size self.subsample_size = size if subsample_size is None else subsample_size if dim is not None and dim >= 0: raise ValueError('dim arg must be negative.') self.dim = dim self._indices = funsor.Tensor( funsor.ops.new_arange(funsor.tensor.get_default_prototype(), self.size), OrderedDict([(self.name, funsor.bint(self.size))]), self.size) super(plate, self).__init__(None)
def __init__(self, name, size, subsample_size=None, dim=None): self.name = name self.size = size if dim is not None and dim >= 0: raise ValueError('dim arg must be negative.') self.dim, indices = OrigPlateMessenger._subsample( self.name, self.size, subsample_size, dim) self.subsample_size = indices.shape[0] self._indices = funsor.Tensor( indices, OrderedDict([(self.name, funsor.bint(self.subsample_size))]), self.subsample_size) super(plate, self).__init__(None)
def test_indexing(): data = np.random.normal(size=(4, 5)) inputs = OrderedDict([('i', bint(4)), ('j', bint(5))]) x = Array(data, inputs) check_funsor(x, inputs, reals(), data) assert x() is x assert x(k=3) is x check_funsor(x(1), {'j': bint(5)}, reals(), data[1]) check_funsor(x(1, 2), {}, reals(), data[1, 2]) check_funsor(x(1, 2, k=3), {}, reals(), data[1, 2]) check_funsor(x(1, j=2), {}, reals(), data[1, 2]) check_funsor(x(1, j=2, k=3), (), reals(), data[1, 2]) check_funsor(x(1, k=3), {'j': bint(5)}, reals(), data[1]) check_funsor(x(i=1), {'j': bint(5)}, reals(), data[1]) check_funsor(x(i=1, j=2), (), reals(), data[1, 2]) check_funsor(x(i=1, j=2, k=3), (), reals(), data[1, 2]) check_funsor(x(i=1, k=3), {'j': bint(5)}, reals(), data[1]) check_funsor(x(j=2), {'i': bint(4)}, reals(), data[:, 2]) check_funsor(x(j=2, k=3), {'i': bint(4)}, reals(), data[:, 2])
def tensor_to_funsor(value, cond_indep_stack, output): assert isinstance(value, torch.Tensor) event_shape = output.shape batch_shape = value.shape[:value.dim() - len(event_shape)] if torch._C._get_tracing_state(): with funsor.tensor.ignore_jit_warnings(): batch_shape = tuple(map(int, batch_shape)) inputs = OrderedDict() data = value for dim, size in enumerate(batch_shape): if size == 1: data = data.squeeze(dim - value.dim()) else: frame = cond_indep_stack[dim - len(batch_shape)] assert size == frame.size, (size, frame) inputs[frame.name] = funsor.bint(int(size)) value = funsor.tensor.Tensor(data, inputs, output.dtype) assert value.output == output return value
def generate_HMM_dataset(model, args): """ Generates a sequence of observations from a given funsor model """ data = [ funsor.Variable('y_{}'.format(t), funsor.bint(args.hidden_dim)) for t in range(args.time_steps) ] log_prob = model(data) var = [key for key, value in log_prob.inputs.items()] # TODO: move sample to model definition, to avoid memory explosion r = log_prob.sample(frozenset(var)) data = torch.tensor([ r.deltas[i].point.data for i in range(len(r.deltas)) if r.deltas[i].name.startswith('y') ]) return data
def test_advanced_indexing_lazy(output_shape): x = Array(np.random.normal(size=(2, 3, 4) + output_shape), OrderedDict([ ('i', bint(2)), ('j', bint(3)), ('k', bint(4)), ])) u = Variable('u', bint(2)) v = Variable('v', bint(3)) with interpretation(lazy): i = Number(1, 2) - u j = Number(2, 3) - v k = u + v expected_data = np.empty((2, 3) + output_shape) i_data = funsor.numpy.materialize(i).data.astype(np.int64) j_data = funsor.numpy.materialize(j).data.astype(np.int64) k_data = funsor.numpy.materialize(k).data.astype(np.int64) for u in range(2): for v in range(3): expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]] expected = Array(expected_data, OrderedDict([ ('u', bint(2)), ('v', bint(3)), ])) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) assert_equiv(expected, x(i=i, j=j)(k=k)) assert_equiv(expected, x(j=j, k=k)(i=i)) assert_equiv(expected, x(k=k, i=i)(j=j)) assert_equiv(expected, x(i=i)(j=j, k=k)) assert_equiv(expected, x(j=j)(k=k, i=i)) assert_equiv(expected, x(k=k)(i=i, j=j)) assert_equiv(expected, x(i=i)(j=j)(k=k)) assert_equiv(expected, x(i=i)(k=k)(j=j)) assert_equiv(expected, x(j=j)(i=i)(k=k)) assert_equiv(expected, x(j=j)(k=k)(i=i)) assert_equiv(expected, x(k=k)(i=i)(j=j)) assert_equiv(expected, x(k=k)(j=j)(i=i))
def process_message(self, msg): if msg["type"] != "sample" or \ msg.get("done", False) or msg["is_observed"] or msg["infer"].get("expand", False) or \ msg["infer"].get("enumerate") != "parallel" or (not msg["fn"].has_enumerate_support): if msg["type"] == "control_flow": msg["kwargs"]["enum"] = True return super().process_message(msg) if msg["infer"].get("num_samples", None) is not None: raise NotImplementedError("TODO implement multiple sampling") if msg["infer"].get("expand", False): raise NotImplementedError("expand=True not implemented") size = msg["fn"].enumerate_support(expand=False).shape[0] raw_value = jnp.arange(0, size) funsor_value = funsor.Tensor( raw_value, OrderedDict([(msg["name"], funsor.bint(size))]), size) msg["value"] = to_data(funsor_value) msg["done"] = True
def test_advanced_indexing_shape(): I, J, M, N = 4, 4, 2, 3 x = Array(np.random.normal(size=(I, J)), OrderedDict([ ('i', bint(I)), ('j', bint(J)), ])) m = Array(np.array([2, 3]), OrderedDict([('m', bint(M))]), I) n = Array(np.array([0, 1, 1]), OrderedDict([('n', bint(N))]), J) assert x.data.shape == (I, J) check_funsor(x(i=m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(i=m, j=n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(i=m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(i=m, k=m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(i=n), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(i=n, k=m), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(j=m), {'i': bint(I), 'm': bint(M)}, reals()) check_funsor(x(j=m, i=n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(j=m, i=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(j=m, k=m), {'i': bint(I), 'm': bint(M)}, reals()) check_funsor(x(j=n), {'i': bint(I), 'n': bint(N)}, reals()) check_funsor(x(j=n, k=m), {'i': bint(I), 'n': bint(N)}, reals()) check_funsor(x(m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(m, j=n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(m, k=m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(m, n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(m, n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(n), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(n, k=m), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(n, m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(n, m, k=m), {'m': bint(M), 'n': bint(N)}, reals())
def test_to_data_error(): data = np.zeros((3, 3)) x = Array(data, OrderedDict(i=bint(3))) with pytest.raises(ValueError): funsor.to_data(x)
def test_advanced_indexing_array(output_shape): # u v # / \ / \ # i j k # \ | / # \ | / # x output = reals(*output_shape) x = random_array( OrderedDict([ ('i', bint(2)), ('j', bint(3)), ('k', bint(4)), ]), output) i = random_array(OrderedDict([ ('u', bint(5)), ]), bint(2)) j = random_array(OrderedDict([ ('v', bint(6)), ('u', bint(5)), ]), bint(3)) k = random_array(OrderedDict([ ('v', bint(6)), ]), bint(4)) expected_data = np.empty((5, 6) + output_shape) for u in range(5): for v in range(6): expected_data[u, v] = x.data[i.data[u], j.data[v, u], k.data[v]] expected = Array(expected_data, OrderedDict([ ('u', bint(5)), ('v', bint(6)), ])) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) assert_equiv(expected, x(i=i, j=j)(k=k)) assert_equiv(expected, x(j=j, k=k)(i=i)) assert_equiv(expected, x(k=k, i=i)(j=j)) assert_equiv(expected, x(i=i)(j=j, k=k)) assert_equiv(expected, x(j=j)(k=k, i=i)) assert_equiv(expected, x(k=k)(i=i, j=j)) assert_equiv(expected, x(i=i)(j=j)(k=k)) assert_equiv(expected, x(i=i)(k=k)(j=j)) assert_equiv(expected, x(j=j)(i=i)(k=k)) assert_equiv(expected, x(j=j)(k=k)(i=i)) assert_equiv(expected, x(k=k)(i=i)(j=j)) assert_equiv(expected, x(k=k)(j=j)(i=i))
def filter_and_predict(self, data, smoothing=False): trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists( ) log_prob = funsor.Number(0.) s_vars = { -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components) } x_vars = {-1: None} predictive_x_dists, predictive_y_dists, filtering_dists = [], [], [] test_LLs = [] for t, y in enumerate(data): s_vars[t] = funsor.Variable(f's_{t}', funsor.bint(self.num_components)) x_vars[t] = funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]), value=s_vars[t]) if t == 0: log_prob += self.x_init_mvn(value=x_vars[t]) else: log_prob += x_trans_dist(s=s_vars[t], x=x_vars[t - 1], y=x_vars[t]) if t > 0: log_prob = log_prob.reduce( ops.logaddexp, frozenset([s_vars[t - 1].name, x_vars[t - 1].name])) # do 1-step prediction and compute test LL if t > 0: predictive_x_dists.append(log_prob) _log_prob = log_prob - log_prob.reduce(ops.logaddexp) predictive_y_dist = y_dist(s=s_vars[t], x=x_vars[t]) + _log_prob test_LLs.append( predictive_y_dist(y=y).reduce(ops.logaddexp).data.item()) predictive_y_dist = predictive_y_dist.reduce( ops.logaddexp, frozenset([f"x_{t}", f"s_{t}"])) predictive_y_dists.append( funsor_to_mvn(predictive_y_dist, 0, ())) log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y) # save filtering dists for forward-backward smoothing if smoothing: filtering_dists.append(log_prob) # do the backward recursion using previously computed ingredients if smoothing: # seed the backward recursion with the filtering distribution at t=T smoothing_dists = [filtering_dists[-1]] T = data.size(0) s_vars = { t: funsor.Variable(f's_{t}', funsor.bint(self.num_components)) for t in range(T) } x_vars = { t: funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim)) for t in range(T) } # do the backward recursion. # let p[t|t-1] be the predictive distribution at time step t. # let p[t|t] be the filtering distribution at time step t. # let f[t] denote the prior (transition) density at time step t. # then the smoothing distribution p[t|T] at time step t is # given by the following recursion. # p[t-1|T] = p[t-1|t-1] <p[t|T] f[t] / p[t|t-1]> # where <...> denotes integration of the latent variables at time step t. for t in reversed(range(T - 1)): integral = smoothing_dists[-1] - predictive_x_dists[t] integral += dist.Categorical(trans_probs(s=s_vars[t]), value=s_vars[t + 1]) integral += x_trans_dist(s=s_vars[t], x=x_vars[t], y=x_vars[t + 1]) integral = integral.reduce( ops.logaddexp, frozenset([s_vars[t + 1].name, x_vars[t + 1].name])) smoothing_dists.append(filtering_dists[t] + integral) # compute predictive test MSE and predictive variances predictive_means = torch.stack([d.mean for d in predictive_y_dists ]) # T-1 ydim predictive_vars = torch.stack([ d.covariance_matrix.diagonal(dim1=-1, dim2=-2) for d in predictive_y_dists ]) predictive_mse = (predictive_means - data[1:, :]).pow(2.0).mean(-1) if smoothing: # compute smoothed mean function smoothing_dists = [ funsor_to_cat_and_mvn(d, 0, (f"s_{t}", )) for t, d in enumerate(reversed(smoothing_dists)) ] means = torch.stack([d[1].mean for d in smoothing_dists]) # T 2 xdim means = torch.matmul(means.unsqueeze(-2), self.observation_matrix).squeeze( -2) # T 2 ydim probs = torch.stack([d[0].logits for d in smoothing_dists]).exp() probs = probs / probs.sum(-1, keepdim=True) # T 2 smoothing_means = (probs.unsqueeze(-1) * means).sum(-2) # T ydim smoothing_probs = probs[:, 1] return predictive_mse, torch.tensor(np.array(test_LLs)), predictive_means, predictive_vars, \ smoothing_means, smoothing_probs else: return predictive_mse, torch.tensor(np.array(test_LLs))