Esempio n. 1
0
def distribution_to_data(funsor_dist, name_to_dim=None):
    params = [
        to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim)
        for param_name in funsor_dist._ast_fields if param_name != 'value'
    ]
    pyro_dist = funsor_dist.dist_class(
        **dict(zip(funsor_dist._ast_fields[:-1], params)))
    funsor_event_shape = funsor_dist.value.output.shape
    pyro_dist = pyro_dist.to_event(
        max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0))

    # TODO get this working for all backends
    if not isinstance(funsor_dist.value, Variable):
        if get_backend() != "torch":
            raise NotImplementedError(
                "transformed distributions not yet supported under this backend,"
                "try set_backend('torch')")
        inv_value = funsor.delta.solve(
            funsor_dist.value, Variable("value", funsor_dist.value.output))[1]
        transforms = to_data(inv_value, name_to_dim=name_to_dim)
        backend_dist = import_module(
            BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist
        pyro_dist = backend_dist.TransformedDistribution(pyro_dist, transforms)

    if pyro_dist.event_shape != funsor_event_shape:
        raise ValueError("Event shapes don't match, something went wrong")
    return pyro_dist
Esempio n. 2
0
def backend_to_einsum_backends(backend):
    backends = [
        BACKEND_TO_EINSUM_BACKEND[get_backend()],
        BACKEND_TO_LOGSUMEXP_BACKEND[get_backend()]
    ]
    map_backend = BACKEND_TO_MAP_BACKEND[get_backend()]
    if backend == "jax":
        map_backend = pytest.param(
            map_backend,
            marks=pytest.mark.xfail(
                reason="Can't set attribute '_pyro_dims' to DeviceArray"))
    backends.append(map_backend)
    return backends
Esempio n. 3
0
def test_delta_defaults():
    v = Variable('v', Real)
    log_density = Variable('log_density', Real)
    backend_dist_module = BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]
    assert isinstance(dist.Delta(v, log_density), import_module(backend_dist_module).Delta)
    value = Variable('value', Real)
    assert dist.Delta(v, log_density, 'value') is dist.Delta(v, log_density, value)
Esempio n. 4
0
def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs):
    be_inputs = OrderedDict(batch_inputs + event_inputs)
    expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs)
    sample_inputs = OrderedDict(sample_inputs)
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    x = random_gaussian(be_inputs)
    rng_key = subkey = None if get_backend() == "torch" else np.array(
        [0, 0], dtype=np.uint32)

    xfail = False
    for num_sampled in range(len(event_inputs) + 1):
        for sampled_vars in itertools.combinations(list(event_inputs),
                                                   num_sampled):
            sampled_vars = frozenset(sampled_vars)
            print('sampled_vars: {}'.format(', '.join(sampled_vars)))
            try:
                if rng_key is not None:
                    import jax
                    rng_key, subkey = jax.random.split(rng_key)

                y = x.sample(sampled_vars, sample_inputs, rng_key=subkey)
            except NotImplementedError:
                xfail = True
                continue
            if num_sampled == len(event_inputs):
                assert isinstance(y, (Delta, Contraction))
            if sampled_vars:
                assert dict(y.inputs) == dict(expected_inputs), sampled_vars
            else:
                assert y is x
    if xfail:
        pytest.xfail(reason='Not implemented')
Esempio n. 5
0
def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs):
    event_inputs = int_event_inputs + real_event_inputs
    discrete_inputs = OrderedDict(int_event_inputs)
    gaussian_inputs = OrderedDict(event_inputs)
    expected_inputs = OrderedDict(sample_inputs + event_inputs)
    sample_inputs = OrderedDict(sample_inputs)
    event_inputs = OrderedDict(event_inputs)
    t = random_tensor(discrete_inputs)
    g = random_gaussian(gaussian_inputs)
    x = t + g  # Joint(discrete=t, gaussian=g)

    rng_key = subkey = None if get_backend() == "torch" else np.array(
        [0, 0], dtype=np.uint32)
    xfail = False
    for num_sampled in range(len(event_inputs)):
        for sampled_vars in itertools.combinations(list(event_inputs),
                                                   num_sampled):
            sampled_vars = frozenset(sampled_vars)
            print('sampled_vars: {}'.format(', '.join(sampled_vars)))
            try:
                if rng_key is not None:
                    import jax
                    rng_key, subkey = jax.random.split(rng_key)

                y = x.sample(sampled_vars, sample_inputs, rng_key=subkey)
            except NotImplementedError:
                xfail = True
                continue
            if sampled_vars:
                assert dict(y.inputs) == dict(expected_inputs), sampled_vars
            else:
                assert y is x
    if xfail:
        pytest.xfail(reason='Not implemented')
Esempio n. 6
0
def _numeric_max_and_argmax(x):
    if get_backend() == "torch":
        import torch

        return torch.max(x, dim=-1)
    else:
        return np.max(x, axis=-1), np.argmax(x, axis=-1)
Esempio n. 7
0
def test_gaussian_distribution(event_inputs, batch_inputs):
    num_samples = 100000
    sample_inputs = OrderedDict(particle=bint(num_samples))
    be_inputs = OrderedDict(batch_inputs + event_inputs)
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    sampled_vars = frozenset(event_inputs)
    p = random_gaussian(be_inputs)

    rng_key = None if get_backend() == "torch" else np.array([0, 0],
                                                             dtype=np.uint32)
    q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key)
    p_vars = sampled_vars
    q_vars = sampled_vars | frozenset(['particle'])
    # Check zeroth moment.
    assert_close(q.reduce(ops.logaddexp, q_vars),
                 p.reduce(ops.logaddexp, p_vars),
                 atol=1e-6)
    for k1, d1 in event_inputs.items():
        x = Variable(k1, d1)
        # Check first moments.
        assert_close(Integrate(q, x, q_vars),
                     Integrate(p, x, p_vars),
                     atol=0.5,
                     rtol=0.2)
        for k2, d2 in event_inputs.items():
            y = Variable(k2, d2)
            # Check second moments.
            continue  # FIXME: Quadratic integration is not supported:
            assert_close(Integrate(q, x * y, q_vars),
                         Integrate(p, x * y, p_vars),
                         atol=1e-2)
Esempio n. 8
0
    def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):

        # note this should handle transforms correctly via distribution_to_data
        raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
        for d, name in zip(range(len(sample_inputs), 0, -1),
                           sample_inputs.keys()):
            dim_to_name[-d - len(raw_dist.batch_shape)] = name

        if value_name not in sampled_vars:
            return self

        sample_shape = tuple(v.size for v in sample_inputs.values())
        sample_args = (sample_shape, ) if get_backend() == "torch" else (
            rng_key, sample_shape)
        if self.has_rsample:
            raw_value = raw_dist.rsample(*sample_args)
        else:
            raw_value = ops.detach(raw_dist.sample(*sample_args))

        funsor_value = to_funsor(raw_value,
                                 output=value_output,
                                 dim_to_name=dim_to_name)
        funsor_value = funsor_value.align(
            tuple(sample_inputs) +
            tuple(inp for inp in self.inputs if inp in funsor_value.inputs))
        result = funsor.delta.Delta(value_name, funsor_value)
        if not self.has_rsample:
            # scaling of dice_factor by num samples should already be handled by Funsor.sample
            raw_log_prob = raw_dist.log_prob(raw_value)
            dice_factor = to_funsor(raw_log_prob - ops.detach(raw_log_prob),
                                    output=self.output,
                                    dim_to_name=dim_to_name)
            result = result + dice_factor
        return result
Esempio n. 9
0
def test_reduce_logaddexp(int_inputs, real_inputs):
    int_inputs = OrderedDict(sorted(int_inputs.items()))
    real_inputs = OrderedDict(sorted(real_inputs.items()))
    inputs = int_inputs.copy()
    inputs.update(real_inputs)

    t = random_tensor(int_inputs)
    g = random_gaussian(inputs)
    truth = {
        name: random_tensor(int_inputs, domain)
        for name, domain in real_inputs.items()
    }

    state = 0
    state += g
    state += t
    for name, point in truth.items():
        with xfail_if_not_implemented():
            state += Delta(name, point)
    actual = state.reduce(ops.logaddexp, frozenset(truth))

    expected = t + g(**truth)
    assert_close(actual,
                 expected,
                 atol=1e-5,
                 rtol=1e-4 if get_backend() == "jax" else 1e-5)
Esempio n. 10
0
def ignore_jit_warnings():
    with warnings.catch_warnings():
        if get_backend() == "torch":
            import torch

            warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
        yield
Esempio n. 11
0
def test_normalize_einsum(equation, plates, backend, einsum_impl):
    if get_backend() == "torch":
        import torch  # noqa: F401

    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)

    with interpretation(reflect):
        expr = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates)

    with interpretation(normalize):
        transformed_expr = reinterpret(expr)

    assert isinstance(transformed_expr, Contraction)
    check_funsor(transformed_expr, expr.inputs, expr.output)

    assert all(isinstance(v, (Number, Tensor, Contraction)) for v in transformed_expr.terms)

    with interpretation(normalize):
        transformed_expr2 = reinterpret(transformed_expr)

    assert transformed_expr2 is transformed_expr  # check normalization

    with interpretation(eager):
        actual = reinterpret(transformed_expr)
        expected = reinterpret(expr)

    assert_close(actual, expected, rtol=1e-4)

    actual = eval(quote(expected))  # requires torch, bint
    assert_close(actual, expected)
Esempio n. 12
0
def test_tensor_shape(sample_inputs, batch_inputs, event_inputs):
    be_inputs = OrderedDict(batch_inputs + event_inputs)
    expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs)
    sample_inputs = OrderedDict(sample_inputs)
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    x = random_tensor(be_inputs)
    rng_key = subkey = None if get_backend() == "torch" else np.array(
        [0, 0], dtype=np.uint32)

    for num_sampled in range(len(event_inputs) + 1):
        for sampled_vars in itertools.combinations(list(event_inputs),
                                                   num_sampled):
            sampled_vars = frozenset(sampled_vars)
            print('sampled_vars: {}'.format(', '.join(sampled_vars)))
            if rng_key is not None:
                import jax
                rng_key, subkey = jax.random.split(rng_key)

            y = x.sample(sampled_vars, sample_inputs, rng_key=subkey)
            if num_sampled == len(event_inputs):
                assert isinstance(y, (Delta, Contraction))
            if sampled_vars:
                assert dict(y.inputs) == dict(expected_inputs), sampled_vars
            else:
                assert y is x
Esempio n. 13
0
    def __getitem__(cls, dtype_shape):
        dtype, shape = dtype_shape
        assert dtype is not None
        assert shape is not None

        # in some JAX versions, shape can be np.int64 type
        if get_tracing_state() or get_backend() == "jax":
            if dtype not in (None, "real"):
                dtype = int(dtype)
            if shape is not None:
                shape = tuple(map(int, shape))

        assert cls.dtype in (None, dtype)
        assert cls.shape in (None, shape)
        key = dtype, shape
        result = ArrayType._type_cache.get(key, None)
        if result is None:
            if dtype == "real":
                assert all(
                    isinstance(size, int) and size >= 0 for size in shape)
                name = "Reals[{}]".format(",".join(map(
                    str, shape))) if shape else "Real"
                result = RealsType(name, (), {"shape": shape})
            elif isinstance(dtype, int):
                assert dtype >= 0
                name = "Bint[{}, {}]".format(dtype, ",".join(map(str, shape)))
                result = BintType(name, (), {"dtype": dtype, "shape": shape})
            else:
                raise ValueError("invalid dtype: {}".format(dtype))
            ArrayType._type_cache[key] = result
        return result
Esempio n. 14
0
def _numeric_tensordot(x, y, dim):
    if get_backend() == "torch":
        import torch

        return torch.tensordot(x, y, dim)
    else:
        return np.tensordot(x, y, axes=dim)
Esempio n. 15
0
def eager_binomial(total_count, probs, value):
    probs = stack((1 - probs, probs))
    value = stack((total_count - value, value))
    backend_dist = import_module(
        BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
    return backend_dist.Multinomial(total_count, probs,
                                    value=value)  # noqa: F821
Esempio n. 16
0
def _check_mvn_affine(d1, data):
    backend_module = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
    assert isinstance(d1, backend_module.MultivariateNormal)
    d2 = reinterpret(d1)
    assert issubclass(type(d2), GaussianMixture)
    actual = d2(**data)
    expected = d1(**data)
    assert_close(actual, expected)
Esempio n. 17
0
def gaussianmixture_to_data(funsor_dist, name_to_dim=None):
    discrete, gaussian = funsor_dist.terms
    backend_dist = import_module(
        BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
    cat = backend_dist.CategoricalLogits.dist_class(logits=to_data(
        discrete + gaussian.log_normalizer, name_to_dim=name_to_dim))
    mvn = to_data(gaussian, name_to_dim=name_to_dim)
    return cat, mvn
Esempio n. 18
0
def numeric_array(x):
    backend = get_backend()
    if backend == "torch":
        import torch

        return torch.tensor(x)
    else:
        return np.array(x)
Esempio n. 19
0
def numeric_array(x, dtype=None, device=None):
    backend = get_backend()
    if backend == "torch":
        import torch

        return torch.tensor(x, dtype=dtype, device=device)
    else:
        return np.array(x, dtype=dtype)
Esempio n. 20
0
def get_default_prototype():
    backend = get_backend()
    if backend == "torch":
        import torch

        return torch.tensor([])
    else:
        return np.array([])
Esempio n. 21
0
def randint(low, high, size):
    backend = get_backend()
    if backend == "torch":
        import torch

        return torch.randint(low, high, size=size)
    else:
        return np.random.randint(low, high, size=size)
Esempio n. 22
0
def astype(x, dtype):
    backend = get_backend()
    if backend == "torch":
        if dtype == 'uint8':
            return x.byte()
        return x.type(dtype)
    else:
        return x.astype(dtype)
Esempio n. 23
0
def _skip_for_numpyro_version(version="0.2.4"):
    if get_backend() == "jax":
        import numpyro

        if numpyro.__version__ <= version:
            return True

    return False
Esempio n. 24
0
 def __init__(self, raw_dist, raw_params, expected_value_domain, xfail_reason=""):
     self.raw_dist = raw_dist
     self.raw_params = raw_params
     self.expected_value_domain = expected_value_domain
     for name, raw_param in self.raw_params:
         if get_backend() != "numpy":
             # we need direct access to these tensors for gradient tests
             setattr(self, name, eval(raw_param))
     TEST_CASES.append(self if not xfail_reason else xfail_param(self, reason=xfail_reason))
Esempio n. 25
0
def _random_scale_tril(shape):
    if get_backend() == "torch":
        data = randn(shape)
        return backend_dist.transforms.transform_to(
            backend_dist.constraints.lower_cholesky)(data)
    else:
        data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2, ))
        return backend_dist.biject_to(
            backend_dist.constraints.lower_cholesky)(data)
Esempio n. 26
0
def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, y):
    dirichlet_reduction = frozenset(x.inputs).intersection(reduced_vars)
    if dirichlet_reduction:
        backend_dist = import_module(
            BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
        return backend_dist.DirichletMultinomial(concentration=x.concentration,
                                                 total_count=y.total_count,
                                                 value=y.value)
    else:
        return eager(Contraction, red_op, bin_op, reduced_vars, (x, y))
Esempio n. 27
0
def eager_gamma_poisson(red_op, bin_op, reduced_vars, x, y):
    gamma_reduction = frozenset(x.inputs).intersection(reduced_vars)
    if gamma_reduction:
        backend_dist = import_module(
            BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
        return backend_dist.GammaPoisson(concentration=x.concentration,
                                         rate=x.rate,
                                         value=y.value)
    else:
        return eager(Contraction, red_op, bin_op, reduced_vars, (x, y))
Esempio n. 28
0
def eager_dirichlet_posterior(op, c, z):
    if (z.concentration is c.terms[0].concentration) and (
            c.terms[1].total_count is z.total_count):
        backend_dist = import_module(
            BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
        return backend_dist.Dirichlet(concentration=z.concentration +
                                      c.terms[1].value,
                                      value=c.terms[0].value)
    else:
        return None
Esempio n. 29
0
def random_scale_tril(*args):
    if isinstance(args[0], tuple):
        assert len(args) == 1
        shape = args[0]
    else:
        shape = args

    from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND
    backend_dist = importlib.import_module(
        BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist

    if get_backend() == "torch":
        data = randn(shape)
        return backend_dist.transforms.transform_to(
            backend_dist.constraints.lower_cholesky)(data)
    else:
        data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2, ))
        return backend_dist.biject_to(
            backend_dist.constraints.lower_cholesky)(data)
Esempio n. 30
0
def test_quote(output_shape, inputs):
    if get_backend() == "torch":
        import torch  # noqa: F401

    sizes = {'a': 4, 'b': 5, 'c': 6}
    inputs = OrderedDict((k, bint(sizes[k])) for k in inputs)
    x = random_tensor(inputs, reals(*output_shape))
    s = funsor.quote(x)
    assert isinstance(s, str)
    assert_close(eval(s), x)