def test_smoke(expr, expected_type): dx = Delta('x', Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2))]))) assert isinstance(dx, Delta) dy = Delta('y', Tensor(torch.randn(3, 4), OrderedDict([('j', bint(3))]))) assert isinstance(dy, Delta) t = Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2)), ('j', bint(3))])) assert isinstance(t, Tensor) g = Gaussian(info_vec=torch.tensor([[0.0, 0.1, 0.2], [2.0, 3.0, 4.0]]), precision=torch.tensor([[[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]], [[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]]]), inputs=OrderedDict([('i', bint(2)), ('x', reals(3))])) assert isinstance(g, Gaussian) i0 = Number(1, 2) assert isinstance(i0, Number) x0 = Tensor(torch.tensor([0.5, 0.6, 0.7])) assert isinstance(x0, Tensor) result = eval(expr) assert isinstance(result, expected_type)
def model(size, position=0): if size == 1: name = str(position) return Uniform((Delta(name, Number(0, 2)), Delta(name, Number(1, 2)))) return Uniform( model(t, position) + model(size - t, t + position) for t in range(1, size))
def test_reduce_logaddexp_deltas_lazy(): a = Delta('a', Tensor(torch.randn(3, 2), OrderedDict(i=bint(3)))) b = Delta('b', Tensor(torch.randn(3), OrderedDict(i=bint(3)))) x = a + b assert isinstance(x, Delta) assert set(x.inputs) == {'a', 'b', 'i'} y = x.reduce(ops.logaddexp, 'i') # assert isinstance(y, Reduce) assert set(y.inputs) == {'a', 'b'} assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp))
def test_reduce_logaddexp_deltas_discrete_lazy(): a = Delta('a', Tensor(randn(3, 2), OrderedDict(i=bint(3)))) b = Delta('b', Tensor(randn(3), OrderedDict(i=bint(3)))) c = Tensor(randn(3), OrderedDict(i=bint(3))) x = a + b + c assert isinstance(x, Contraction) assert set(x.inputs) == {'a', 'b', 'i'} y = x.reduce(ops.logaddexp, 'i') # assert isinstance(y, Reduce) assert set(y.inputs) == {'a', 'b'} assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp))
def test_reduce_logaddexp(int_inputs, real_inputs): int_inputs = OrderedDict(sorted(int_inputs.items())) real_inputs = OrderedDict(sorted(real_inputs.items())) inputs = int_inputs.copy() inputs.update(real_inputs) t = random_tensor(int_inputs) g = random_gaussian(inputs) truth = { name: random_tensor(int_inputs, domain) for name, domain in real_inputs.items() } state = 0 state += g state += t for name, point in truth.items(): with xfail_if_not_implemented(): state += Delta(name, point) actual = state.reduce(ops.logaddexp, frozenset(truth)) expected = t + g(**truth) assert_close(actual, expected, atol=1e-5, rtol=1e-4 if get_backend() == "jax" else 1e-5)
def unscaled_sample(self, sampled_vars, sample_inputs): assert self.output == reals() sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self # Partition inputs into sample_inputs + batch_inputs + event_inputs. sample_inputs = OrderedDict((k, d) for k, d in sample_inputs.items() if k not in self.inputs) sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) batch_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in sampled_vars) event_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k in sampled_vars) be_inputs = batch_inputs.copy() be_inputs.update(event_inputs) sb_inputs = sample_inputs.copy() sb_inputs.update(batch_inputs) # Sample all variables in a single Categorical call. logits = align_tensor(be_inputs, self) batch_shape = logits.shape[:len(batch_inputs)] flat_logits = logits.reshape(batch_shape + (-1,)) sample_shape = tuple(d.dtype for d in sample_inputs.values()) flat_sample = torch.distributions.Categorical(logits=flat_logits).sample(sample_shape) assert flat_sample.shape == sample_shape + batch_shape results = [] mod_sample = flat_sample for name, domain in reversed(list(event_inputs.items())): size = domain.dtype point = Tensor(mod_sample % size, sb_inputs, size) mod_sample = mod_sample / size results.append(Delta(name, point)) # Account for the log normalizer factor. # Derivation: Let f be a nonnormalized distribution (a funsor), and # consider operations in linear space (source code is in log space). # Let x0 ~ f/|f| be a monte carlo sample from a normalized f/|f|. # f(x0) / |f| # dice numerator # Let g = delta(x=x0) |f| ----------------- # detach(f(x0)/|f|) # dice denominator # |detach(f)| f(x0) # = delta(x=x0) ----------------- be a dice approximation of f. # detach(f(x0)) # Then g is an unbiased estimator of f in value and all derivatives. # In the special case f = detach(f), we can simplify to # g = delta(x=x0) |f|. if flat_logits.requires_grad: # Apply a dice factor to preserve differentiability. index = [torch.arange(n).reshape((n,) + (1,) * (flat_logits.dim() - i - 2)) for i, n in enumerate(flat_logits.shape[:-1])] index.append(flat_sample) log_prob = flat_logits[index] assert log_prob.shape == flat_sample.shape results.append(Tensor(flat_logits.detach().logsumexp(-1) + (log_prob - log_prob.detach()), sb_inputs)) else: # This is the special case f = detach(f). results.append(Tensor(flat_logits.logsumexp(-1), batch_inputs)) return reduce(ops.add, results)
def test_add_delta_funsor(): x = Variable('x', reals(3)) y = Variable('y', reals(3)) d = Delta('x', y) expr = -(1 + x**2).log() assert d + expr is d + expr(x=y) assert expr + d is expr(x=y) + d
def test_reduce_moment_matching_shape(interp): delta = Delta('x', random_tensor(OrderedDict([('h', bint(7))]))) discrete = random_tensor(OrderedDict( [('h', bint(7)), ('i', bint(6)), ('j', bint(5)), ('k', bint(4))])) gaussian = random_gaussian(OrderedDict( [('k', bint(4)), ('l', bint(3)), ('m', bint(2)), ('y', reals()), ('z', reals(2))])) reduced_vars = frozenset(['i', 'k', 'l']) joint = delta + discrete + gaussian with interpretation(interp): actual = joint.reduce(ops.logaddexp, reduced_vars) assert set(actual.inputs) == set(joint.inputs) - reduced_vars
def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self if any(self.inputs[k].dtype != 'real' for k in sampled_vars): raise ValueError( 'Sampling from non-normalized Gaussian mixtures is intentionally ' 'not implemented. You probably want to normalize. To work around, ' 'add a zero Tensor/Array with given inputs.') # Partition inputs into sample_inputs + int_inputs + real_inputs. sample_inputs = OrderedDict( (k, d) for k, d in sample_inputs.items() if k not in self.inputs) sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') inputs = sample_inputs.copy() inputs.update(int_inputs) if sampled_vars == frozenset(real_inputs): shape = sample_shape + self.info_vec.shape backend = get_backend() if backend != "numpy": from importlib import import_module dist = import_module(funsor.distribution. BACKEND_TO_DISTRIBUTIONS_BACKEND[backend]) sample_args = (shape, ) if rng_key is None else (rng_key, shape) white_noise = dist.Normal.dist_class(0, 1).sample(*sample_args) else: white_noise = np.random.randn(*shape) white_noise = ops.unsqueeze(white_noise, -1) white_vec = ops.triangular_solve(self.info_vec[..., None], self._precision_chol) sample = ops.triangular_solve(white_noise + white_vec, self._precision_chol, transpose=True)[..., 0] offsets, _ = _compute_offsets(real_inputs) results = [] for key, domain in real_inputs.items(): data = sample[..., offsets[key]:offsets[key] + domain.num_elements] data = data.reshape(shape[:-1] + domain.shape) point = Tensor(data, inputs) assert point.output == domain results.append(Delta(key, point)) results.append(self.log_normalizer) return reduce(ops.add, results) raise NotImplementedError( 'TODO implement partial sampling of real variables')
def test_reduce_moment_matching_finite(): delta = Delta('x', random_tensor(OrderedDict([('h', bint(7))]))) discrete = random_tensor( OrderedDict([('i', bint(6)), ('j', bint(5)), ('k', bint(3))])) gaussian = random_gaussian( OrderedDict([('k', bint(3)), ('l', bint(2)), ('y', reals()), ('z', reals(2))])) discrete.data[1:, :] = -float('inf') discrete.data[:, 1:] = -float('inf') reduced_vars = frozenset(['j', 'k']) joint = delta + discrete + gaussian with interpretation(moment_matching): joint.reduce(ops.logaddexp, reduced_vars)
def test_reduce_moment_matching_shape(interp): delta = Delta('x', random_tensor(OrderedDict([('h', Bint[7])]))) discrete = random_tensor( OrderedDict([('h', Bint[7]), ('i', Bint[6]), ('j', Bint[5]), ('k', Bint[4])])) gaussian = random_gaussian( OrderedDict([('k', Bint[4]), ('l', Bint[3]), ('m', Bint[2]), ('y', Real), ('z', Reals[2])])) reduced_vars = frozenset(['i', 'k', 'l']) real_vars = frozenset(k for k, d in gaussian.inputs.items() if d.dtype == "real") joint = delta + discrete + gaussian with interpretation(interp): actual = joint.reduce(ops.logaddexp, reduced_vars) assert set(actual.inputs) == set(joint.inputs) - reduced_vars assert_close(actual.reduce(ops.logaddexp, real_vars), joint.reduce(ops.logaddexp, real_vars | reduced_vars))
def unscaled_sample(self, sampled_vars, sample_inputs): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self if any(self.inputs[k].dtype != 'real' for k in sampled_vars): raise ValueError( 'Sampling from non-normalized Gaussian mixtures is intentionally ' 'not implemented. You probably want to normalize. To work around, ' 'add a zero Tensor/Array with given inputs.') # Partition inputs into sample_inputs + int_inputs + real_inputs. sample_inputs = OrderedDict( (k, d) for k, d in sample_inputs.items() if k not in self.inputs) sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') inputs = sample_inputs.copy() inputs.update(int_inputs) if sampled_vars == frozenset(real_inputs): shape = sample_shape + self.info_vec.shape # TODO: revise the logic here; `key` is required for JAX normal sampler white_noise = funsor.testing.randn(shape + (1, )) white_vec = ops.triangular_solve(self.info_vec[..., None], self._precision_chol) sample = ops.triangular_solve(white_noise + white_vec, self._precision_chol, transpose=True)[..., 0] offsets, _ = _compute_offsets(real_inputs) results = [] for key, domain in real_inputs.items(): data = sample[..., offsets[key]:offsets[key] + domain.num_elements] data = data.reshape(shape[:-1] + domain.shape) point = Tensor(data, inputs) assert point.output == domain results.append(Delta(key, point)) results.append(self.log_normalizer) return reduce(ops.add, results) raise NotImplementedError( 'TODO implement partial sampling of real variables')
def unscaled_sample(self, sampled_vars, sample_inputs): # Sample only the real variables. sampled_vars = frozenset(k for k, v in self.inputs.items() if k in sampled_vars if v.dtype == 'real') if not sampled_vars: return self # Partition inputs into sample_inputs + int_inputs + real_inputs. sample_inputs = OrderedDict( (k, d) for k, d in sample_inputs.items() if k not in self.inputs) sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') inputs = sample_inputs.copy() inputs.update(int_inputs) if sampled_vars == frozenset(real_inputs): scale_tri = torch.inverse(torch.cholesky( self.precision)).transpose(-1, -2) if not torch._C._get_tracing_state(): assert self.loc.shape == scale_tri.shape[:-1] shape = sample_shape + self.loc.shape white_noise = torch.randn(shape) sample = self.loc + _mv(scale_tri, white_noise) offsets, _ = _compute_offsets(real_inputs) results = [] for key, domain in real_inputs.items(): data = sample[..., offsets[key]:offsets[key] + domain.num_elements] data = data.reshape(shape[:-1] + domain.shape) point = Tensor(data, inputs) assert point.output == domain results.append(Delta(key, point)) results.append(self._log_normalizer) return reduce(ops.add, results) raise NotImplementedError( 'TODO implement partial sampling of real variables')
def test_reduce_density(log_density): point = Tensor(randn(3)) d = Delta('foo', point, log_density) # Note that log_density affects ground substitution but does not affect reduction. assert d.reduce(ops.logaddexp, frozenset(['foo'])) is Number(0)
def test_transform_log(shape): point = Tensor(randn(shape)) x = Variable('x', reals(*shape)) actual = Delta('y', point)(y=ops.log(x)) expected = Delta('x', point.exp(), -point.sum()) assert_close(actual, expected)
def test_delta_delta(): v = Variable('v', Reals[2]) point = Tensor(randn(2)) log_density = Tensor(numeric_array(0.5)) d = dist.Delta(point, log_density, v) assert d is Delta('v', point, log_density)
def test_reduce(): point = Tensor(randn(3)) d = Delta('foo', point) assert d.reduce(ops.logaddexp, frozenset(['foo'])) is Number(0)
def test_eager_subs_ground(log_density): point1 = Tensor(randn(3)) point2 = Tensor(randn(3)) d = Delta('foo', point1, log_density) check_funsor(d(foo=point1), {}, reals(), numeric_array(float(log_density))) check_funsor(d(foo=point2), {}, reals(), numeric_array(float('-inf')))
def test_eager_subs_variable(): v = Variable('v', reals(3)) point = Tensor(randn(3)) d = Delta('foo', v) assert d(v=point) is Delta('foo', point)
def test_delta_delta(): v = Variable('v', reals(2)) point = Tensor(torch.randn(2)) log_density = Tensor(torch.tensor(0.5)) d = dist.Delta(point, log_density, v) assert d is Delta('v', point, log_density)
def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): assert self.output == Real sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self # Partition inputs into sample_inputs + batch_inputs + event_inputs. sample_inputs = OrderedDict( (k, d) for k, d in sample_inputs.items() if k not in self.inputs) sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) batch_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if k not in sampled_vars) event_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if k in sampled_vars) be_inputs = batch_inputs.copy() be_inputs.update(event_inputs) sb_inputs = sample_inputs.copy() sb_inputs.update(batch_inputs) # Sample all variables in a single Categorical call. logits = align_tensor(be_inputs, self) batch_shape = logits.shape[:len(batch_inputs)] flat_logits = logits.reshape(batch_shape + (-1, )) sample_shape = tuple(d.dtype for d in sample_inputs.values()) backend = get_backend() if backend != "numpy": from importlib import import_module dist = import_module( funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend]) sample_args = (sample_shape, ) if rng_key is None else ( rng_key, sample_shape) flat_sample = dist.CategoricalLogits.dist_class( logits=flat_logits).sample(*sample_args) else: # default numpy backend assert backend == "numpy" shape = sample_shape + flat_logits.shape[:-1] logit_max = np.amax(flat_logits, -1, keepdims=True) probs = np.exp(flat_logits - logit_max) probs = probs / np.sum(probs, -1, keepdims=True) s = np.cumsum(probs, -1) r = np.random.rand(*shape) flat_sample = np.sum(s < np.expand_dims(r, -1), axis=-1) assert flat_sample.shape == sample_shape + batch_shape results = [] mod_sample = flat_sample for name, domain in reversed(list(event_inputs.items())): size = domain.dtype point = Tensor(mod_sample % size, sb_inputs, size) mod_sample = mod_sample // size results.append(Delta(name, point)) # Account for the log normalizer factor. # Derivation: Let f be a nonnormalized distribution (a funsor), and # consider operations in linear space (source code is in log space). # Let x0 ~ f/|f| be a monte carlo sample from a normalized f/|f|. # f(x0) / |f| # dice numerator # Let g = delta(x=x0) |f| ----------------- # detach(f(x0)/|f|) # dice denominator # |detach(f)| f(x0) # = delta(x=x0) ----------------- be a dice approximation of f. # detach(f(x0)) # Then g is an unbiased estimator of f in value and all derivatives. # In the special case f = detach(f), we can simplify to # g = delta(x=x0) |f|. if (backend == "torch" and flat_logits.requires_grad) or backend == "jax": # Apply a dice factor to preserve differentiability. index = [ ops.new_arange(self.data, n).reshape((n, ) + (1, ) * (len(flat_logits.shape) - i - 2)) for i, n in enumerate(flat_logits.shape[:-1]) ] index.append(flat_sample) log_prob = flat_logits[tuple(index)] assert log_prob.shape == flat_sample.shape results.append( Tensor( ops.logsumexp(ops.detach(flat_logits), -1) + (log_prob - ops.detach(log_prob)), sb_inputs)) else: # This is the special case f = detach(f). results.append(Tensor(ops.logsumexp(flat_logits, -1), batch_inputs)) return reduce(ops.add, results)
def test_transform_exp(shape): point = Tensor(torch.randn(shape).abs()) x = Variable('x', reals(*shape)) actual = Delta('y', point)(y=ops.exp(x)) expected = Delta('x', point.log(), point.log().sum()) assert_close(actual, expected)