def _fill_defaults(concentration, value='value'): concentration = to_funsor(concentration) assert concentration.dtype == "real" assert len(concentration.output.shape) == 1 dim = concentration.output.shape[0] value = to_funsor(value, reals(dim)) return concentration, value
def adjoint(self, red_op, bin_op, root, targets): bin_unit = to_funsor(ops.UNITS[bin_op]) adjoint_values = defaultdict(lambda: bin_unit) reached_root = False while self.tape: output, fn, inputs = self.tape.pop() if not reached_root: if output is root: reached_root = True else: continue # reverse the effects of alpha-renaming with interpretation(reflect): other_subs = tuple((name, to_funsor(name.split("__BOUND")[0], domain)) for name, domain in output.inputs.items() if "__BOUND" in name) inputs = _alpha_unmangle(substitute(fn(*inputs), other_subs)) output = type(output)(*_alpha_unmangle(substitute(output, other_subs))) in_adjs = adjoint_ops(fn, red_op, bin_op, adjoint_values[output], *inputs) for v, adjv in in_adjs.items(): adjoint_values[v] = bin_op(adjoint_values[v], adjv) target_adjs = {} for v in targets: target_adjs[v] = adjoint_values[v] if not isinstance(v, Variable): target_adjs[v] = bin_op(target_adjs[v], v) return target_adjs
def _fill_defaults(total_count, probs, value='value'): total_count = to_funsor(total_count, reals()) probs = to_funsor(probs) assert probs.dtype == "real" assert len(probs.output.shape) == 1 value = to_funsor(value, probs.output) return total_count, probs, value
def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): # note this should handle transforms correctly via distribution_to_data raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() for d, name in zip(range(len(sample_inputs), 0, -1), sample_inputs.keys()): dim_to_name[-d - len(raw_dist.batch_shape)] = name if value_name not in sampled_vars: return self sample_shape = tuple(v.size for v in sample_inputs.values()) sample_args = (sample_shape, ) if get_backend() == "torch" else ( rng_key, sample_shape) if self.has_rsample: raw_value = raw_dist.rsample(*sample_args) else: raw_value = ops.detach(raw_dist.sample(*sample_args)) funsor_value = to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) funsor_value = funsor_value.align( tuple(sample_inputs) + tuple(inp for inp in self.inputs if inp in funsor_value.inputs)) result = funsor.delta.Delta(value_name, funsor_value) if not self.has_rsample: # scaling of dice_factor by num samples should already be handled by Funsor.sample raw_log_prob = raw_dist.log_prob(raw_value) dice_factor = to_funsor(raw_log_prob - ops.detach(raw_log_prob), output=self.output, dim_to_name=dim_to_name) result = result + dice_factor return result
def _fill_defaults(loc, concentration, value='value'): loc = to_funsor(loc) assert loc.dtype == "real" concentration = to_funsor(concentration) assert concentration.dtype == "real" value = to_funsor(value, reals()) return loc, concentration, value
def _fill_defaults(concentration, rate, value='value'): concentration = to_funsor(concentration) assert concentration.dtype == "real" rate = to_funsor(rate) assert rate.dtype == "real" value = to_funsor(value, reals()) return concentration, rate, value
def _fill_defaults(concentration, total_count=1, value='value'): concentration = to_funsor(concentration) assert concentration.dtype == "real" assert len(concentration.output.shape) == 1 total_count = to_funsor(total_count, reals()) dim = concentration.output.shape[0] value = to_funsor(value, reals(dim)) # Should this be bint(total_count)? return concentration, total_count, value
def __call__(cls, *args): if len(args) > 1: assert len(args) == 2 or len(args) == 3 assert isinstance(args[0], str) and isinstance(args[1], Funsor) args = args + (Number(0.), ) if len(args) == 2 else args args = (((args[0], (to_funsor(args[1]), to_funsor(args[2]))), ), ) assert isinstance(args[0], tuple) return super().__call__(args[0])
def deltadist_to_funsor(pyro_dist, output=None, dim_to_name=None): v = to_funsor(pyro_dist.v, output=Reals[pyro_dist.event_shape], dim_to_name=dim_to_name) log_density = to_funsor(pyro_dist.log_density, output=Real, dim_to_name=dim_to_name) return Delta(v, log_density) # noqa: F821
def maskeddist_to_funsor(backend_dist, output=None, dim_to_name=None): mask = to_funsor(ops.astype(backend_dist._mask, 'float32'), output=output, dim_to_name=dim_to_name) funsor_base_dist = to_funsor(backend_dist.base_dist, output=output, dim_to_name=dim_to_name) return mask * funsor_base_dist
def __call__(cls, *args, **kwargs): kwargs.update(zip(cls._ast_fields, args)) value = kwargs.pop('value', 'value') kwargs = OrderedDict( (k, to_funsor(kwargs[k], output=cls._infer_param_domain(k, getattr(kwargs[k], "shape", ())))) for k in cls._ast_fields if k != 'value') value = to_funsor(value, output=cls._infer_value_domain(**{k: v.output for k, v in kwargs.items()})) args = numbers_to_tensors(*(tuple(kwargs.values()) + (value,))) return super(DistributionMeta, cls).__call__(*args)
def _fill_defaults(loc, scale_tril, value='value'): loc = to_funsor(loc) scale_tril = to_funsor(scale_tril) assert loc.dtype == 'real' assert scale_tril.dtype == 'real' assert len(loc.output.shape) == 1 dim = loc.output.shape[0] assert scale_tril.output.shape == (dim, dim) value = to_funsor(value, loc.output) return loc, scale_tril, value
def LogNormal(loc, scale, value='value'): """ Wraps :class:`pyro.distributions.LogNormal` . :param Funsor loc: Mean of the untransformed Normal distribution. :param Funsor scale: Standard deviation of the untransformed Normal distribution. :param Funsor value: Optional real observation. """ loc, scale = to_funsor(loc), to_funsor(scale) y = to_funsor(value, output=loc.output) t = ops.exp x = t.inv(y) log_abs_det_jacobian = t.log_abs_det_jacobian(x, y) return Normal(loc, scale, x) - log_abs_det_jacobian
def test_mvn_affine_one_var(): x = Variable('x', Reals[2]) data = dict(x=Tensor(randn(2))) with interpretation(lazy): d = to_funsor(random_mvn((), 2), Real) d = d(value=2 * x + 1) _check_mvn_affine(d, data)
def mvn_to_funsor(pyro_dist, event_inputs=(), real_inputs=OrderedDict()): """ Convert a joint :class:`torch.distributions.MultivariateNormal` distribution into a :class:`~funsor.terms.Funsor` with multiple real inputs. This should satisfy:: sum(d.num_elements for d in real_inputs.values()) == pyro_dist.event_shape[0] :param torch.distributions.MultivariateNormal pyro_dist: A multivariate normal distribution over one or more variables of real or vector or tensor type. :param tuple event_inputs: A tuple of names for rightmost dimensions. These will be assigned to ``result.inputs`` of type ``Bint``. :param OrderedDict real_inputs: A dict mapping real variable name to appropriately sized ``Real``. The sum of all ``.numel()`` of all real inputs should be equal to the ``pyro_dist`` dimension. :return: A funsor with given ``real_inputs`` and possibly additional Bint inputs. :rtype: funsor.terms.Funsor """ assert isinstance(pyro_dist, torch.distributions.MultivariateNormal) assert isinstance(event_inputs, tuple) assert isinstance(real_inputs, OrderedDict) dim_to_name = default_dim_to_name(pyro_dist.batch_shape, event_inputs) funsor_dist = to_funsor(pyro_dist, Real, dim_to_name) if len(real_inputs) == 0: return funsor_dist discrete, gaussian = funsor_dist(value="value").terms inputs = OrderedDict( (k, v) for k, v in gaussian.inputs.items() if v.dtype != 'real') inputs.update(real_inputs) return discrete + Gaussian(gaussian.info_vec, gaussian.precision, inputs)
def test_generic_distribution_to_funsor(case): with xfail_if_not_found(): raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) with interpretation(normalize_with_subs): funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) assert funsor_dist.inputs["value"] == expected_value_domain while isinstance(funsor_dist, funsor.cnf.Contraction): funsor_dist = [term for term in funsor_dist.terms if isinstance(term, (funsor.distribution.Distribution, funsor.terms.Independent))][0] actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) assert isinstance(actual_dist, backend_dist.Distribution) assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers while isinstance(raw_dist, backend_dist.Independent) or type(raw_dist) == backend_dist.TransformedDistribution: raw_dist = raw_dist.base_dist actual_dist = actual_dist.base_dist assert isinstance(actual_dist, backend_dist.Distribution) assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers for param_name, _ in case.raw_params: assert hasattr(raw_dist, param_name) assert_close(getattr(actual_dist, param_name), getattr(raw_dist, param_name))
def test_mvn_affine_getitem(): x = Variable('x', Reals[2, 2]) data = dict(x=Tensor(randn(2, 2))) with interpretation(lazy): d = to_funsor(random_mvn((), 2), Real) d = d(value=x[0] - x[1]) _check_mvn_affine(d, data)
def adjoint(self, red_op, bin_op, root, targets): bin_unit = to_funsor(ops.UNITS[bin_op]) adjoint_values = defaultdict(lambda: bin_unit) multiplicities = defaultdict(lambda: 0) reached_root = False while self.tape: output, fn, inputs = self.tape.pop() if not reached_root: if output is root: reached_root = True else: continue # reverse the effects of alpha-renaming with interpretation(lazy): other_subs = {name: name.split("__BOUND")[0] for name in output.inputs if "__BOUND" in name} inputs = _alpha_unmangle(fn(*inputs)(**other_subs)) output = type(output)(*_alpha_unmangle(output(**other_subs))) in_adjs = adjoint_ops(fn, red_op, bin_op, adjoint_values[output], *inputs) for v, adjv in in_adjs.items(): multiplicities[v] += 1 adjoint_values[v] = bin_op(adjoint_values[v], adjv) target_adjs = {} for v in targets: target_adjs[v] = adjoint_values[v] / multiplicities[v] # TODO use correct op here with bin_op if not isinstance(v, Variable): target_adjs[v] = bin_op(target_adjs[v], v) return target_adjs
def transformeddist_to_funsor(backend_dist, output=None, dim_to_name=None): dist_module = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist base_dist, transforms = backend_dist, [] while isinstance(base_dist, dist_module.TransformedDistribution): transforms = base_dist.transforms + transforms base_dist = base_dist.base_dist funsor_base_dist = to_funsor(base_dist, output=output, dim_to_name=dim_to_name) # TODO make this work with transforms that change the output type transform = to_funsor(dist_module.transforms.ComposeTransform(transforms), funsor_base_dist.inputs["value"], dim_to_name) _, inv_transform, ldj = funsor.delta.solve( transform, to_funsor("value", funsor_base_dist.inputs["value"])) return -ldj + funsor_base_dist(value=inv_transform)
def mvn_to_funsor(pyro_dist, event_inputs=(), real_inputs=OrderedDict()): """ Convert a joint :class:`torch.distributions.MultivariateNormal` distribution into a :class:`~funsor.terms.Funsor` with multiple real inputs. This should satisfy:: sum(d.num_elements for d in real_inputs.values()) == pyro_dist.event_shape[0] :param torch.distributions.MultivariateNormal pyro_dist: A multivariate normal distribution over one or more variables of real or vector or tensor type. :param tuple event_inputs: A tuple of names for rightmost dimensions. These will be assigned to ``result.inputs`` of type ``bint``. :param OrderedDict real_inputs: A dict mapping real variable name to appropriately sized ``reals()``. The sum of all ``.numel()`` of all real inputs should be equal to the ``pyro_dist`` dimension. :return: A funsor with given ``real_inputs`` and possibly additional bint inputs. :rtype: funsor.terms.Funsor """ assert isinstance(pyro_dist, torch.distributions.MultivariateNormal) assert isinstance(event_inputs, tuple) assert isinstance(real_inputs, OrderedDict) dim_to_name = default_dim_to_name(pyro_dist.batch_shape, event_inputs) return to_funsor(pyro_dist, reals(), dim_to_name, real_inputs=real_inputs)
def tensor_to_funsor(tensor, event_inputs=(), event_output=0, dtype="real"): """ Convert a :class:`torch.Tensor` to a :class:`funsor.tensor.Tensor` . Note this should not touch data, but may trigger a :meth:`torch.Tensor.reshape` op. :param torch.Tensor tensor: A PyTorch tensor. :param tuple event_inputs: A tuple of names for rightmost tensor dimensions. If ``tensor`` has these names, they will be converted to ``result.inputs``. :param int event_output: The number of tensor dimensions assigned to ``result.output``. These must be on the right of any ``event_input`` dimensions. :return: A funsor. :rtype: funsor.tensor.Tensor """ assert isinstance(tensor, torch.Tensor) assert isinstance(event_inputs, tuple) assert isinstance(event_output, int) and event_output >= 0 inputs_shape = tensor.shape[:tensor.dim() - event_output] output = Domain(dtype=dtype, shape=tensor.shape[tensor.dim() - event_output:]) dim_to_name = default_dim_to_name(inputs_shape, event_inputs) return to_funsor(tensor, output, dim_to_name)
def test_mvn_affine_two_vars(): x = Variable('x', Reals[2]) y = Variable('y', Reals[2]) data = dict(x=Tensor(randn(2)), y=Tensor(randn(2))) with interpretation(lazy): d = to_funsor(random_mvn((), 2), Real) d = d(value=x - y) _check_mvn_affine(d, data)
def enumerate_support(self, expand=False): assert self.has_enumerate_support and isinstance(self.value, Variable) raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() raw_value = raw_dist.enumerate_support(expand=expand) dim_to_name[min(dim_to_name.keys(), default=0) - 1] = value_name return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)
def test_mvn_affine_reshape(): x = Variable('x', Reals[2, 2]) y = Variable('y', Reals[4]) data = dict(x=Tensor(randn(2, 2)), y=Tensor(randn(4))) with interpretation(lazy): d = to_funsor(random_mvn((), 4), Real) d = d(value=x.reshape((4,)) - y) _check_mvn_affine(d, data)
def _alpha_convert(self, alpha_subs): assert self.bound.issuperset(alpha_subs) reduced_vars = frozenset(alpha_subs.get(k, k) for k in self.reduced_vars) alpha_subs = {k: to_funsor(v, self.integrand.inputs.get(k, self.log_measure.inputs.get(k))) for k, v in alpha_subs.items()} log_measure = substitute(self.log_measure, alpha_subs) integrand = substitute(self.integrand, alpha_subs) return log_measure, integrand, reduced_vars
def eager_subs(self, subs): assert isinstance(subs, tuple) subs = { k: materialize(to_funsor(v, self.inputs[k])) for k, v in subs if k in self.inputs } if not subs: return self # Compute result shapes. inputs = OrderedDict() for k, domain in self.inputs.items(): if k in subs: inputs.update(subs[k].inputs) else: inputs[k] = domain # Construct a dict with each input's positional dim, # counting from the right so as to support broadcasting. total_size = len(inputs) + len( self.output.shape) # Assumes only scalar indices. new_dims = {} for k, domain in inputs.items(): assert not domain.shape new_dims[k] = len(new_dims) - total_size # Use advanced indexing to construct a simultaneous substitution. index = [] for k, domain in self.inputs.items(): if k in subs: v = subs.get(k) if isinstance(v, Number): index.append(int(v.data)) else: # Permute and expand v.data to end up at new_dims. assert isinstance(v, Tensor) v = v.align(tuple(k2 for k2 in inputs if k2 in v.inputs)) assert isinstance(v, Tensor) v_shape = [1] * total_size for k2, size in zip(v.inputs, v.data.shape): v_shape[new_dims[k2]] = size index.append(v.data.reshape(tuple(v_shape))) else: # Construct a [:] slice for this preserved input. offset_from_right = -1 - new_dims[k] index.append( torch.arange( domain.dtype).reshape((-1, ) + (1, ) * offset_from_right)) # Construct a [:] slice for the output. for i, size in enumerate(self.output.shape): offset_from_right = len(self.output.shape) - i - 1 index.append( torch.arange(size).reshape((-1, ) + (1, ) * offset_from_right)) data = self.data[tuple(index)] return Tensor(data, inputs, self.dtype)
def LogNormal(loc, scale, value='value'): """ Wraps backend `LogNormal` distributions. :param Funsor loc: Mean of the untransformed Normal distribution. :param Funsor scale: Standard deviation of the untransformed Normal distribution. :param Funsor value: Optional real observation. """ loc, scale = to_funsor(loc), to_funsor(scale) y = to_funsor(value, output=loc.output) t = ops.exp x = t.inv(y) log_abs_det_jacobian = t.log_abs_det_jacobian(x, y) backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Normal(loc, scale, x) - log_abs_det_jacobian # noqa: F821
def test_mvn_affine_einsum(): c = Tensor(randn(3, 2, 2)) x = Variable('x', Reals[2, 2]) y = Variable('y', Real) data = dict(x=Tensor(randn(2, 2)), y=Tensor(randn(()))) with interpretation(lazy): d = to_funsor(random_mvn((), 3), Real) d = d(value=Einsum("abc,bc->a", c, x) + y) _check_mvn_affine(d, data)
def indepdist_to_funsor(pyro_dist, output=None, dim_to_name=None): dim_to_name = OrderedDict((dim - pyro_dist.reinterpreted_batch_ndims, name) for dim, name in dim_to_name.items()) dim_to_name.update(OrderedDict((i, f"_pyro_event_dim_{i}") for i in range(-pyro_dist.reinterpreted_batch_ndims, 0))) result = to_funsor(pyro_dist.base_dist, dim_to_name=dim_to_name) for i in reversed(range(-pyro_dist.reinterpreted_batch_ndims, 0)): name = f"_pyro_event_dim_{i}" result = funsor.terms.Independent(result, "value", name, "value") return result
def test_mvn_affine_matmul_sub(): x = Variable('x', Reals[2]) y = Variable('y', Reals[3]) m = Tensor(randn(2, 3)) data = dict(x=Tensor(randn(2)), y=Tensor(randn(3))) with interpretation(lazy): d = to_funsor(random_mvn((), 3), Real) d = d(value=x @ m - y) _check_mvn_affine(d, data)