def distribution_to_data(funsor_dist, name_to_dim=None): params = [ to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim) for param_name in funsor_dist._ast_fields if param_name != 'value' ] pyro_dist = funsor_dist.dist_class( **dict(zip(funsor_dist._ast_fields[:-1], params))) funsor_event_shape = funsor_dist.value.output.shape pyro_dist = pyro_dist.to_event( max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0)) # TODO get this working for all backends if not isinstance(funsor_dist.value, Variable): if get_backend() != "torch": raise NotImplementedError( "transformed distributions not yet supported under this backend," "try set_backend('torch')") inv_value = funsor.delta.solve( funsor_dist.value, Variable("value", funsor_dist.value.output))[1] transforms = to_data(inv_value, name_to_dim=name_to_dim) backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist pyro_dist = backend_dist.TransformedDistribution(pyro_dist, transforms) if pyro_dist.event_shape != funsor_event_shape: raise ValueError("Event shapes don't match, something went wrong") return pyro_dist
def backend_to_einsum_backends(backend): backends = [ BACKEND_TO_EINSUM_BACKEND[get_backend()], BACKEND_TO_LOGSUMEXP_BACKEND[get_backend()] ] map_backend = BACKEND_TO_MAP_BACKEND[get_backend()] if backend == "jax": map_backend = pytest.param( map_backend, marks=pytest.mark.xfail( reason="Can't set attribute '_pyro_dims' to DeviceArray")) backends.append(map_backend) return backends
def test_delta_defaults(): v = Variable('v', Real) log_density = Variable('log_density', Real) backend_dist_module = BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()] assert isinstance(dist.Delta(v, log_density), import_module(backend_dist_module).Delta) value = Variable('value', Real) assert dist.Delta(v, log_density, 'value') is dist.Delta(v, log_density, value)
def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs): be_inputs = OrderedDict(batch_inputs + event_inputs) expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs) sample_inputs = OrderedDict(sample_inputs) batch_inputs = OrderedDict(batch_inputs) event_inputs = OrderedDict(event_inputs) x = random_gaussian(be_inputs) rng_key = subkey = None if get_backend() == "torch" else np.array( [0, 0], dtype=np.uint32) xfail = False for num_sampled in range(len(event_inputs) + 1): for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): sampled_vars = frozenset(sampled_vars) print('sampled_vars: {}'.format(', '.join(sampled_vars))) try: if rng_key is not None: import jax rng_key, subkey = jax.random.split(rng_key) y = x.sample(sampled_vars, sample_inputs, rng_key=subkey) except NotImplementedError: xfail = True continue if num_sampled == len(event_inputs): assert isinstance(y, (Delta, Contraction)) if sampled_vars: assert dict(y.inputs) == dict(expected_inputs), sampled_vars else: assert y is x if xfail: pytest.xfail(reason='Not implemented')
def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): event_inputs = int_event_inputs + real_event_inputs discrete_inputs = OrderedDict(int_event_inputs) gaussian_inputs = OrderedDict(event_inputs) expected_inputs = OrderedDict(sample_inputs + event_inputs) sample_inputs = OrderedDict(sample_inputs) event_inputs = OrderedDict(event_inputs) t = random_tensor(discrete_inputs) g = random_gaussian(gaussian_inputs) x = t + g # Joint(discrete=t, gaussian=g) rng_key = subkey = None if get_backend() == "torch" else np.array( [0, 0], dtype=np.uint32) xfail = False for num_sampled in range(len(event_inputs)): for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): sampled_vars = frozenset(sampled_vars) print('sampled_vars: {}'.format(', '.join(sampled_vars))) try: if rng_key is not None: import jax rng_key, subkey = jax.random.split(rng_key) y = x.sample(sampled_vars, sample_inputs, rng_key=subkey) except NotImplementedError: xfail = True continue if sampled_vars: assert dict(y.inputs) == dict(expected_inputs), sampled_vars else: assert y is x if xfail: pytest.xfail(reason='Not implemented')
def _numeric_max_and_argmax(x): if get_backend() == "torch": import torch return torch.max(x, dim=-1) else: return np.max(x, axis=-1), np.argmax(x, axis=-1)
def test_gaussian_distribution(event_inputs, batch_inputs): num_samples = 100000 sample_inputs = OrderedDict(particle=bint(num_samples)) be_inputs = OrderedDict(batch_inputs + event_inputs) batch_inputs = OrderedDict(batch_inputs) event_inputs = OrderedDict(event_inputs) sampled_vars = frozenset(event_inputs) p = random_gaussian(be_inputs) rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key) p_vars = sampled_vars q_vars = sampled_vars | frozenset(['particle']) # Check zeroth moment. assert_close(q.reduce(ops.logaddexp, q_vars), p.reduce(ops.logaddexp, p_vars), atol=1e-6) for k1, d1 in event_inputs.items(): x = Variable(k1, d1) # Check first moments. assert_close(Integrate(q, x, q_vars), Integrate(p, x, p_vars), atol=0.5, rtol=0.2) for k2, d2 in event_inputs.items(): y = Variable(k2, d2) # Check second moments. continue # FIXME: Quadratic integration is not supported: assert_close(Integrate(q, x * y, q_vars), Integrate(p, x * y, p_vars), atol=1e-2)
def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): # note this should handle transforms correctly via distribution_to_data raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() for d, name in zip(range(len(sample_inputs), 0, -1), sample_inputs.keys()): dim_to_name[-d - len(raw_dist.batch_shape)] = name if value_name not in sampled_vars: return self sample_shape = tuple(v.size for v in sample_inputs.values()) sample_args = (sample_shape, ) if get_backend() == "torch" else ( rng_key, sample_shape) if self.has_rsample: raw_value = raw_dist.rsample(*sample_args) else: raw_value = ops.detach(raw_dist.sample(*sample_args)) funsor_value = to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) funsor_value = funsor_value.align( tuple(sample_inputs) + tuple(inp for inp in self.inputs if inp in funsor_value.inputs)) result = funsor.delta.Delta(value_name, funsor_value) if not self.has_rsample: # scaling of dice_factor by num samples should already be handled by Funsor.sample raw_log_prob = raw_dist.log_prob(raw_value) dice_factor = to_funsor(raw_log_prob - ops.detach(raw_log_prob), output=self.output, dim_to_name=dim_to_name) result = result + dice_factor return result
def test_reduce_logaddexp(int_inputs, real_inputs): int_inputs = OrderedDict(sorted(int_inputs.items())) real_inputs = OrderedDict(sorted(real_inputs.items())) inputs = int_inputs.copy() inputs.update(real_inputs) t = random_tensor(int_inputs) g = random_gaussian(inputs) truth = { name: random_tensor(int_inputs, domain) for name, domain in real_inputs.items() } state = 0 state += g state += t for name, point in truth.items(): with xfail_if_not_implemented(): state += Delta(name, point) actual = state.reduce(ops.logaddexp, frozenset(truth)) expected = t + g(**truth) assert_close(actual, expected, atol=1e-5, rtol=1e-4 if get_backend() == "jax" else 1e-5)
def ignore_jit_warnings(): with warnings.catch_warnings(): if get_backend() == "torch": import torch warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) yield
def test_normalize_einsum(equation, plates, backend, einsum_impl): if get_backend() == "torch": import torch # noqa: F401 inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) with interpretation(reflect): expr = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates) with interpretation(normalize): transformed_expr = reinterpret(expr) assert isinstance(transformed_expr, Contraction) check_funsor(transformed_expr, expr.inputs, expr.output) assert all(isinstance(v, (Number, Tensor, Contraction)) for v in transformed_expr.terms) with interpretation(normalize): transformed_expr2 = reinterpret(transformed_expr) assert transformed_expr2 is transformed_expr # check normalization with interpretation(eager): actual = reinterpret(transformed_expr) expected = reinterpret(expr) assert_close(actual, expected, rtol=1e-4) actual = eval(quote(expected)) # requires torch, bint assert_close(actual, expected)
def test_tensor_shape(sample_inputs, batch_inputs, event_inputs): be_inputs = OrderedDict(batch_inputs + event_inputs) expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs) sample_inputs = OrderedDict(sample_inputs) batch_inputs = OrderedDict(batch_inputs) event_inputs = OrderedDict(event_inputs) x = random_tensor(be_inputs) rng_key = subkey = None if get_backend() == "torch" else np.array( [0, 0], dtype=np.uint32) for num_sampled in range(len(event_inputs) + 1): for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): sampled_vars = frozenset(sampled_vars) print('sampled_vars: {}'.format(', '.join(sampled_vars))) if rng_key is not None: import jax rng_key, subkey = jax.random.split(rng_key) y = x.sample(sampled_vars, sample_inputs, rng_key=subkey) if num_sampled == len(event_inputs): assert isinstance(y, (Delta, Contraction)) if sampled_vars: assert dict(y.inputs) == dict(expected_inputs), sampled_vars else: assert y is x
def __getitem__(cls, dtype_shape): dtype, shape = dtype_shape assert dtype is not None assert shape is not None # in some JAX versions, shape can be np.int64 type if get_tracing_state() or get_backend() == "jax": if dtype not in (None, "real"): dtype = int(dtype) if shape is not None: shape = tuple(map(int, shape)) assert cls.dtype in (None, dtype) assert cls.shape in (None, shape) key = dtype, shape result = ArrayType._type_cache.get(key, None) if result is None: if dtype == "real": assert all( isinstance(size, int) and size >= 0 for size in shape) name = "Reals[{}]".format(",".join(map( str, shape))) if shape else "Real" result = RealsType(name, (), {"shape": shape}) elif isinstance(dtype, int): assert dtype >= 0 name = "Bint[{}, {}]".format(dtype, ",".join(map(str, shape))) result = BintType(name, (), {"dtype": dtype, "shape": shape}) else: raise ValueError("invalid dtype: {}".format(dtype)) ArrayType._type_cache[key] = result return result
def _numeric_tensordot(x, y, dim): if get_backend() == "torch": import torch return torch.tensordot(x, y, dim) else: return np.tensordot(x, y, axes=dim)
def eager_binomial(total_count, probs, value): probs = stack((1 - probs, probs)) value = stack((total_count - value, value)) backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Multinomial(total_count, probs, value=value) # noqa: F821
def _check_mvn_affine(d1, data): backend_module = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) assert isinstance(d1, backend_module.MultivariateNormal) d2 = reinterpret(d1) assert issubclass(type(d2), GaussianMixture) actual = d2(**data) expected = d1(**data) assert_close(actual, expected)
def gaussianmixture_to_data(funsor_dist, name_to_dim=None): discrete, gaussian = funsor_dist.terms backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) cat = backend_dist.CategoricalLogits.dist_class(logits=to_data( discrete + gaussian.log_normalizer, name_to_dim=name_to_dim)) mvn = to_data(gaussian, name_to_dim=name_to_dim) return cat, mvn
def numeric_array(x): backend = get_backend() if backend == "torch": import torch return torch.tensor(x) else: return np.array(x)
def numeric_array(x, dtype=None, device=None): backend = get_backend() if backend == "torch": import torch return torch.tensor(x, dtype=dtype, device=device) else: return np.array(x, dtype=dtype)
def get_default_prototype(): backend = get_backend() if backend == "torch": import torch return torch.tensor([]) else: return np.array([])
def randint(low, high, size): backend = get_backend() if backend == "torch": import torch return torch.randint(low, high, size=size) else: return np.random.randint(low, high, size=size)
def astype(x, dtype): backend = get_backend() if backend == "torch": if dtype == 'uint8': return x.byte() return x.type(dtype) else: return x.astype(dtype)
def _skip_for_numpyro_version(version="0.2.4"): if get_backend() == "jax": import numpyro if numpyro.__version__ <= version: return True return False
def __init__(self, raw_dist, raw_params, expected_value_domain, xfail_reason=""): self.raw_dist = raw_dist self.raw_params = raw_params self.expected_value_domain = expected_value_domain for name, raw_param in self.raw_params: if get_backend() != "numpy": # we need direct access to these tensors for gradient tests setattr(self, name, eval(raw_param)) TEST_CASES.append(self if not xfail_reason else xfail_param(self, reason=xfail_reason))
def _random_scale_tril(shape): if get_backend() == "torch": data = randn(shape) return backend_dist.transforms.transform_to( backend_dist.constraints.lower_cholesky)(data) else: data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2, )) return backend_dist.biject_to( backend_dist.constraints.lower_cholesky)(data)
def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, y): dirichlet_reduction = frozenset(x.inputs).intersection(reduced_vars) if dirichlet_reduction: backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.DirichletMultinomial(concentration=x.concentration, total_count=y.total_count, value=y.value) else: return eager(Contraction, red_op, bin_op, reduced_vars, (x, y))
def eager_gamma_poisson(red_op, bin_op, reduced_vars, x, y): gamma_reduction = frozenset(x.inputs).intersection(reduced_vars) if gamma_reduction: backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.GammaPoisson(concentration=x.concentration, rate=x.rate, value=y.value) else: return eager(Contraction, red_op, bin_op, reduced_vars, (x, y))
def eager_dirichlet_posterior(op, c, z): if (z.concentration is c.terms[0].concentration) and ( c.terms[1].total_count is z.total_count): backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Dirichlet(concentration=z.concentration + c.terms[1].value, value=c.terms[0].value) else: return None
def random_scale_tril(*args): if isinstance(args[0], tuple): assert len(args) == 1 shape = args[0] else: shape = args from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND backend_dist = importlib.import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist if get_backend() == "torch": data = randn(shape) return backend_dist.transforms.transform_to( backend_dist.constraints.lower_cholesky)(data) else: data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2, )) return backend_dist.biject_to( backend_dist.constraints.lower_cholesky)(data)
def test_quote(output_shape, inputs): if get_backend() == "torch": import torch # noqa: F401 sizes = {'a': 4, 'b': 5, 'c': 6} inputs = OrderedDict((k, bint(sizes[k])) for k in inputs) x = random_tensor(inputs, reals(*output_shape)) s = funsor.quote(x) assert isinstance(s, str) assert_close(eval(s), x)