def eager_cat_homogeneous(name, part_name, *parts): assert parts output = parts[0].output inputs = OrderedDict([(part_name, None)]) for part in parts: assert part.output == output assert part_name in part.inputs inputs.update(part.inputs) int_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype != "real") real_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype == "real") inputs = int_inputs.copy() inputs.update(real_inputs) discretes = [] info_vecs = [] precisions = [] for part in parts: inputs[part_name] = part.inputs[part_name] int_inputs[part_name] = inputs[part_name] shape = tuple(d.size for d in int_inputs.values()) if isinstance(part, Gaussian): discrete = None gaussian = part elif issubclass(type(part), GaussianMixture ): # TODO figure out why isinstance isn't working discrete, gaussian = part.terms[0], part.terms[1] discrete = ops.expand(align_tensor(int_inputs, discrete), shape) else: raise NotImplementedError("TODO") discretes.append(discrete) info_vec, precision = align_gaussian(inputs, gaussian) info_vecs.append(ops.expand(info_vec, shape + (-1, ))) precisions.append(ops.expand(precision, shape + (-1, -1))) if part_name != name: del inputs[part_name] del int_inputs[part_name] dim = 0 info_vec = ops.cat(dim, *info_vecs) precision = ops.cat(dim, *precisions) inputs[name] = Bint[info_vec.shape[dim]] int_inputs[name] = inputs[name] result = Gaussian(info_vec, precision, inputs) if any(d is not None for d in discretes): for i, d in enumerate(discretes): if d is None: discretes[i] = ops.new_zeros(info_vecs[i], info_vecs[i].shape[:-1]) discrete = ops.cat(dim, *discretes) result = result + Tensor(discrete, int_inputs) return result
def as_tensor(self): # Fill gaps with zeros. prototype = next(iter(self.parts.values())) for i in _find_intervals(self.parts.keys(), self.shape[-1]): if i not in self.parts: self.parts[i] = ops.new_zeros( prototype, self.shape[:-1] + (i[1] - i[0], )) # Concatenate parts. parts = [v for k, v in sorted(self.parts.items())] result = ops.cat(-1, *parts) if not get_tracing_state(): assert result.shape == self.shape return result
def eager_normal(loc, scale, value): assert loc.output == Real assert scale.output == Real assert value.output == Real if not is_affine(loc) or not is_affine(value): return None # lazy info_vec = ops.new_zeros(scale.data, scale.data.shape + (1, )) precision = ops.pow(scale.data, -2).reshape(scale.data.shape + (1, 1)) log_prob = -0.5 * math.log(2 * math.pi) - ops.log(scale).sum() inputs = scale.inputs.copy() var = gensym('value') inputs[var] = Real gaussian = log_prob + Gaussian(info_vec, precision, inputs) return gaussian(**{var: value - loc})
def _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=1e-2, num_samples=100000, statistic="mean", skip_grad=False, with_lazy=None): """utility that compares a Monte Carlo estimate of a distribution mean with the true mean""" samples_per_dim = int(num_samples ** (1./max(1, len(sample_inputs)))) sample_inputs = OrderedDict((k, Bint[samples_per_dim]) for k in sample_inputs) _get_stat_diff_fn = functools.partial( _get_stat_diff, funsor_dist_class, sample_inputs, inputs, num_samples, statistic, with_lazy) if get_backend() == "torch": import torch for param in params: param.requires_grad_() res = _get_stat_diff_fn(params) if sample_inputs: diff_sum, diff = res assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) if not skip_grad: diff_grads = torch.autograd.grad(diff_sum, params, allow_unused=True) for diff_grad in diff_grads: assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) elif get_backend() == "jax": import jax if sample_inputs: if skip_grad: _, diff = _get_stat_diff_fn(params) assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) else: (_, diff), diff_grads = jax.value_and_grad(_get_stat_diff_fn, has_aux=True)(params) assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) for diff_grad in diff_grads: assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) else: _get_stat_diff_fn(params)
def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy info_vec = ops.new_zeros(scale_tril.data, scale_tril.data.shape[:-1]) precision = ops.cholesky_inverse(scale_tril.data) scale_diag = Tensor(ops.diagonal(scale_tril.data, -1, -2), scale_tril.inputs) log_prob = -0.5 * scale_diag.shape[0] * math.log( 2 * math.pi) - ops.log(scale_diag).sum() inputs = scale_tril.inputs.copy() var = gensym('value') inputs[var] = Reals[scale_diag.shape[0]] gaussian = log_prob + Gaussian(info_vec, precision, inputs) return gaussian(**{var: value - loc})
def test_tensor_distribution(event_inputs, batch_inputs, test_grad): num_samples = 50000 sample_inputs = OrderedDict(n=bint(num_samples)) be_inputs = OrderedDict(batch_inputs + event_inputs) batch_inputs = OrderedDict(batch_inputs) event_inputs = OrderedDict(event_inputs) sampled_vars = frozenset(event_inputs) p_data = random_tensor(be_inputs).data rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) probe = randn(p_data.shape) def diff_fn(p_data): p = Tensor(p_data, be_inputs) q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key) mq = p.materialize(q).reduce(ops.logaddexp, 'n') mq = mq.align(tuple(p.inputs)) _, (p_data, mq_data) = align_tensors(p, mq) assert p_data.shape == mq_data.shape return (ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) * probe).sum(), mq if test_grad: if get_backend() == "jax": import jax diff_grad, mq = jax.grad(diff_fn, has_aux=True)(p_data) else: import torch p_data.requires_grad_(True) diff_grad = torch.autograd.grad(diff_fn(p_data)[0], [p_data])[0] assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=0.1, rtol=None) else: _, mq = diff_fn(p_data) assert_close(mq, Tensor(p_data, be_inputs), atol=0.1, rtol=None)
def as_tensor(self): # Fill gaps with zeros. arbitrary_row = next(iter(self.parts.values())) prototype = next(iter(arbitrary_row.values())) js = set().union(*(part.keys() for part in self.parts.values())) rows = _find_intervals(self.parts.keys(), self.shape[-2]) cols = _find_intervals(js, self.shape[-1]) for i in rows: for j in cols: if j not in self.parts[i]: shape = self.shape[:-2] + (i[1] - i[0], j[1] - j[0]) self.parts[i][j] = ops.new_zeros(prototype, shape) # Concatenate parts. # TODO This could be optimized into a single .reshape().cat().reshape() if # all inputs are contiguous, thereby saving a memcopy. columns = { i: ops.cat(-1, *[v for j, v in sorted(part.items())]) for i, part in self.parts.items() } result = ops.cat(-2, *[v for i, v in sorted(columns.items())]) if not get_tracing_state(): assert result.shape == self.shape return result