def distribution_to_data(funsor_dist, name_to_dim=None): params = [ to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim) for param_name in funsor_dist._ast_fields if param_name != 'value' ] pyro_dist = funsor_dist.dist_class( **dict(zip(funsor_dist._ast_fields[:-1], params))) funsor_event_shape = funsor_dist.value.output.shape pyro_dist = pyro_dist.to_event( max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0)) # TODO get this working for all backends if not isinstance(funsor_dist.value, Variable): if get_backend() != "torch": raise NotImplementedError( "transformed distributions not yet supported under this backend," "try set_backend('torch')") inv_value = funsor.delta.solve( funsor_dist.value, Variable("value", funsor_dist.value.output))[1] transforms = to_data(inv_value, name_to_dim=name_to_dim) backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist pyro_dist = backend_dist.TransformedDistribution(pyro_dist, transforms) if pyro_dist.event_shape != funsor_event_shape: raise ValueError("Event shapes don't match, something went wrong") return pyro_dist
def gaussian_to_data(funsor_dist, name_to_dim=None, normalized=False): if normalized: return to_data(funsor_dist.log_normalizer + funsor_dist, name_to_dim=name_to_dim) loc = funsor_dist.info_vec.unsqueeze(-1).cholesky_solve(cholesky(funsor_dist.precision)).squeeze(-1) int_inputs = OrderedDict((k, d) for k, d in funsor_dist.inputs.items() if d.dtype != "real") loc = to_data(Tensor(loc, int_inputs), name_to_dim) precision = to_data(Tensor(funsor_dist.precision, int_inputs), name_to_dim) return dist.MultivariateNormal(loc, precision_matrix=precision)
def gaussianmixture_to_data(funsor_dist, name_to_dim=None): discrete, gaussian = funsor_dist.terms backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) cat = backend_dist.CategoricalLogits.dist_class(logits=to_data( discrete + gaussian.log_normalizer, name_to_dim=name_to_dim)) mvn = to_data(gaussian, name_to_dim=name_to_dim) return cat, mvn
def gaussian_to_data(funsor_dist, name_to_dim=None, normalized=False): if normalized: return to_data(funsor_dist.log_normalizer + funsor_dist, name_to_dim=name_to_dim) loc = ops.cholesky_solve(ops.unsqueeze(funsor_dist.info_vec, -1), ops.cholesky(funsor_dist.precision)).squeeze(-1) int_inputs = OrderedDict( (k, d) for k, d in funsor_dist.inputs.items() if d.dtype != "real") loc = to_data(Tensor(loc, int_inputs), name_to_dim) precision = to_data(Tensor(funsor_dist.precision, int_inputs), name_to_dim) backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.MultivariateNormal.dist_class( loc, precision_matrix=precision)
def test_generic_distribution_to_funsor(case): with xfail_if_not_found(): raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) with interpretation(normalize_with_subs): funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) assert funsor_dist.inputs["value"] == expected_value_domain while isinstance(funsor_dist, funsor.cnf.Contraction): funsor_dist = [term for term in funsor_dist.terms if isinstance(term, (funsor.distribution.Distribution, funsor.terms.Independent))][0] actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) assert isinstance(actual_dist, backend_dist.Distribution) assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers while isinstance(raw_dist, backend_dist.Independent) or type(raw_dist) == backend_dist.TransformedDistribution: raw_dist = raw_dist.base_dist actual_dist = actual_dist.base_dist assert isinstance(actual_dist, backend_dist.Distribution) assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers for param_name, _ in case.raw_params: assert hasattr(raw_dist, param_name) assert_close(getattr(actual_dist, param_name), getattr(raw_dist, param_name))
def transform_to_data(expr, name_to_dim=None): if isinstance(expr.op, ops.TransformOp): tfm = op_to_torch_transform(expr.op, name_to_dim=name_to_dim) if isinstance(expr.arg, Unary): tfm = torch.distributions.transforms.ComposeTransform( [to_data(expr.arg, name_to_dim=name_to_dim), tfm]) return tfm raise NotImplementedError("cannot convert to data: {}".format(expr))
def distribution_to_data(funsor_dist, name_to_dim=None): pyro_dist_class = funsor_dist.dist_class params = [to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim) for param_name in funsor_dist._ast_fields if param_name != 'value'] pyro_dist = pyro_dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params))) funsor_event_shape = funsor_dist.value.output.shape pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0)) if pyro_dist.event_shape != funsor_event_shape: raise ValueError("Event shapes don't match, something went wrong") return pyro_dist
def indep_to_data(funsor_dist, name_to_dim=None): if not isinstance(funsor_dist.fn, (Independent, Distribution, Gaussian)): raise NotImplementedError(f"cannot convert {funsor_dist} to data") name_to_dim = OrderedDict( (name, dim - 1) for name, dim in name_to_dim.items()) name_to_dim.update({funsor_dist.bint_var: -1}) backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist result = to_data(funsor_dist.fn, name_to_dim=name_to_dim) # collapse nested Independents into a single Independent for conversion reinterpreted_batch_ndims = 1 while isinstance(result, backend_dist.Independent): result = result.base_dist reinterpreted_batch_ndims += 1 return backend_dist.Independent(result, reinterpreted_batch_ndims)
def _get_raw_dist(self): """ Internal method for working with underlying distribution attributes """ value_name = [ name for name, domain in self.value.inputs.items() # TODO is this right? if domain == self.value.output ][0] # arbitrary name-dim mapping, since we're converting back to a funsor anyway name_to_dim = { name: -dim - 1 for dim, (name, domain) in enumerate(self.inputs.items()) if isinstance(domain.dtype, int) and name != value_name } raw_dist = to_data(self, name_to_dim=name_to_dim) dim_to_name = {dim: name for name, dim in name_to_dim.items()} # also return value output, dim_to_name for converting results back to funsor value_output = self.inputs[value_name] return raw_dist, value_name, value_output, dim_to_name
def funsor_to_tensor(funsor_, ndims, event_inputs=()): """ Convert a :class:`funsor.tensor.Tensor` to a :class:`torch.Tensor` . Note this should not touch data, but may trigger a :meth:`torch.Tensor.reshape` op. :param funsor.tensor.Tensor funsor_: A funsor. :param int ndims: The number of result dims, ``== result.dim()``. :param tuple event_inputs: Names assigned to rightmost dimensions. :return: A PyTorch tensor. :rtype: torch.Tensor """ assert isinstance(funsor_, Tensor) assert all(k.startswith("_pyro_dim_") or k in event_inputs for k in funsor_.inputs) tensor = to_data(funsor_, default_name_to_dim(event_inputs)) if ndims != tensor.dim(): tensor = tensor.reshape((1,) * (ndims - tensor.dim()) + tensor.shape) assert tensor.dim() == ndims return tensor
def funsor_to_cat_and_mvn(funsor_, ndims, event_inputs): """ Converts a labeled gaussian mixture model to a pair of distributions. :param funsor.joint.Joint funsor_: A Gaussian mixture funsor. :param int ndims: The number of batch dimensions in the result. :return: A pair ``(cat, mvn)``, where ``cat`` is a :class:`~pyro.distributions.Categorical` distribution over mixture components and ``mvn`` is a :class:`~pyro.distributions.MultivariateNormal` with rightmost batch dimension ranging over mixture components. """ assert isinstance(funsor_, Contraction), funsor_ assert sum(1 for d in funsor_.inputs.values() if d.dtype == "real") == 1 assert event_inputs, "no components name found" assert not any(isinstance(v, Delta) for v in funsor_.terms) cat, mvn = to_data(funsor_, name_to_dim=default_name_to_dim(event_inputs)) if ndims != len(cat.batch_shape): cat = cat.expand((1,) * (ndims - len(cat.batch_shape)) + cat.batch_shape) if ndims + 1 != len(mvn.batch_shape): mvn = mvn.expand((1,) * (ndims + 1 - len(mvn.batch_shape)) + mvn.batch_shape) return cat, mvn
def funsor_to_mvn(gaussian, ndims, event_inputs=()): """ Convert a :class:`~funsor.terms.Funsor` to a :class:`pyro.distributions.MultivariateNormal` , dropping the normalization constant. :param gaussian: A Gaussian funsor. :type gaussian: funsor.gaussian.Gaussian or funsor.joint.Joint :param int ndims: The number of batch dimensions in the result. :param tuple event_inputs: A tuple of names to assign to rightmost dimensions. :return: a multivariate normal distribution. :rtype: pyro.distributions.MultivariateNormal """ assert sum(1 for d in gaussian.inputs.values() if d.dtype == "real") == 1 if isinstance(gaussian, Contraction): gaussian = [v for v in gaussian.terms if isinstance(v, Gaussian)][0] assert isinstance(gaussian, Gaussian) result = to_data(gaussian, name_to_dim=default_name_to_dim(event_inputs)) if ndims != len(result.batch_shape): result = result.expand((1,) * (ndims - len(result.batch_shape)) + result.batch_shape) return result
def test_to_data_error(): with pytest.raises(ValueError): to_data(Variable('x', Real)) with pytest.raises(ValueError): to_data(Variable('y', Bint[12]))
def gaussianmixture_to_data(funsor_dist, name_to_dim=None): discrete, gaussian = funsor_dist.terms cat = dist.Categorical(logits=to_data( discrete + gaussian.log_normalizer, name_to_dim=name_to_dim)) mvn = to_data(gaussian, name_to_dim=name_to_dim) return cat, mvn
def test_generic_enumerate_support(case, expand): with xfail_if_not_found(): raw_dist = eval(case.raw_dist) dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) with interpretation(normalize_with_subs): funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) assert getattr(raw_dist, "has_enumerate_support", False) == getattr(funsor_dist, "has_enumerate_support", False) if getattr(funsor_dist, "has_enumerate_support", False): name_to_dim["value"] = -1 if not name_to_dim else min(name_to_dim.values()) - 1 with xfail_if_not_implemented("enumerate support not implemented"): raw_support = raw_dist.enumerate_support(expand=expand) funsor_support = funsor_dist.enumerate_support(expand=expand) assert_close(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) @pytest.mark.parametrize("case", TEST_CASES, ids=str) @pytest.mark.parametrize("sample_shape", [(), (2,), (4, 3)], ids=str) def test_generic_sample(case, sample_shape): with xfail_if_not_found(): raw_dist = eval(case.raw_dist) dim_to_name, name_to_dim = _default_dim_to_name(sample_shape + raw_dist.batch_shape) with interpretation(normalize_with_subs): funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) sample_inputs = OrderedDict((dim_to_name[dim - len(raw_dist.batch_shape)], funsor.Bint[sample_shape[dim]]) for dim in range(-len(sample_shape), 0))
def deltadist_to_data(funsor_dist, name_to_dim=None): v = to_data(funsor_dist.v, name_to_dim=name_to_dim) log_density = to_data(funsor_dist.log_density, name_to_dim=name_to_dim) return dist.Delta(v, log_density, event_dim=len(funsor_dist.v.output.shape))
def multinomial_to_data(funsor_dist, name_to_dim=None): probs = to_data(funsor_dist.probs, name_to_dim) total_count = to_data(funsor_dist.total_count, name_to_dim) if isinstance(total_count, numbers.Number) or len(total_count.shape) == 0: return dist.Multinomial(int(total_count), probs=probs) raise NotImplementedError("inhomogeneous total_count not supported")
def test_to_data_error(): with pytest.raises(ValueError): to_data(Variable('x', reals())) with pytest.raises(ValueError): to_data(Variable('y', bint(12)))
def test_to_data(): actual = to_data(Number(0.)) expected = 0. assert type(actual) == type(expected) assert actual == expected