Ejemplo n.º 1
0
def eager_cat_homogeneous(name, part_name, *parts):
    assert parts
    output = parts[0].output
    inputs = OrderedDict([(part_name, None)])
    for part in parts:
        assert part.output == output
        assert part_name in part.inputs
        inputs.update(part.inputs)

    int_inputs = OrderedDict(
        (k, v) for k, v in inputs.items() if v.dtype != "real")
    real_inputs = OrderedDict(
        (k, v) for k, v in inputs.items() if v.dtype == "real")
    inputs = int_inputs.copy()
    inputs.update(real_inputs)
    discretes = []
    info_vecs = []
    precisions = []
    for part in parts:
        inputs[part_name] = part.inputs[part_name]
        int_inputs[part_name] = inputs[part_name]
        shape = tuple(d.size for d in int_inputs.values())
        if isinstance(part, Gaussian):
            discrete = None
            gaussian = part
        elif issubclass(type(part), GaussianMixture
                        ):  # TODO figure out why isinstance isn't working
            discrete, gaussian = part.terms[0], part.terms[1]
            discrete = ops.expand(align_tensor(int_inputs, discrete), shape)
        else:
            raise NotImplementedError("TODO")
        discretes.append(discrete)
        info_vec, precision = align_gaussian(inputs, gaussian)
        info_vecs.append(ops.expand(info_vec, shape + (-1, )))
        precisions.append(ops.expand(precision, shape + (-1, -1)))
    if part_name != name:
        del inputs[part_name]
        del int_inputs[part_name]

    dim = 0
    info_vec = ops.cat(dim, *info_vecs)
    precision = ops.cat(dim, *precisions)
    inputs[name] = Bint[info_vec.shape[dim]]
    int_inputs[name] = inputs[name]
    result = Gaussian(info_vec, precision, inputs)
    if any(d is not None for d in discretes):
        for i, d in enumerate(discretes):
            if d is None:
                discretes[i] = ops.new_zeros(info_vecs[i],
                                             info_vecs[i].shape[:-1])
        discrete = ops.cat(dim, *discretes)
        result = result + Tensor(discrete, int_inputs)
    return result
Ejemplo n.º 2
0
    def as_tensor(self):
        # Fill gaps with zeros.
        prototype = next(iter(self.parts.values()))
        for i in _find_intervals(self.parts.keys(), self.shape[-1]):
            if i not in self.parts:
                self.parts[i] = ops.new_zeros(
                    prototype, self.shape[:-1] + (i[1] - i[0], ))

        # Concatenate parts.
        parts = [v for k, v in sorted(self.parts.items())]
        result = ops.cat(-1, *parts)
        if not get_tracing_state():
            assert result.shape == self.shape
        return result
Ejemplo n.º 3
0
def eager_normal(loc, scale, value):
    assert loc.output == Real
    assert scale.output == Real
    assert value.output == Real
    if not is_affine(loc) or not is_affine(value):
        return None  # lazy

    info_vec = ops.new_zeros(scale.data, scale.data.shape + (1, ))
    precision = ops.pow(scale.data, -2).reshape(scale.data.shape + (1, 1))
    log_prob = -0.5 * math.log(2 * math.pi) - ops.log(scale).sum()
    inputs = scale.inputs.copy()
    var = gensym('value')
    inputs[var] = Real
    gaussian = log_prob + Gaussian(info_vec, precision, inputs)
    return gaussian(**{var: value - loc})
Ejemplo n.º 4
0
def _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=1e-2,
                  num_samples=100000, statistic="mean", skip_grad=False, with_lazy=None):
    """utility that compares a Monte Carlo estimate of a distribution mean with the true mean"""
    samples_per_dim = int(num_samples ** (1./max(1, len(sample_inputs))))
    sample_inputs = OrderedDict((k, Bint[samples_per_dim]) for k in sample_inputs)
    _get_stat_diff_fn = functools.partial(
        _get_stat_diff, funsor_dist_class, sample_inputs, inputs, num_samples, statistic, with_lazy)

    if get_backend() == "torch":
        import torch

        for param in params:
            param.requires_grad_()

        res = _get_stat_diff_fn(params)
        if sample_inputs:
            diff_sum, diff = res
            assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None)
            if not skip_grad:
                diff_grads = torch.autograd.grad(diff_sum, params, allow_unused=True)
                for diff_grad in diff_grads:
                    assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None)
    elif get_backend() == "jax":
        import jax

        if sample_inputs:
            if skip_grad:
                _, diff = _get_stat_diff_fn(params)
                assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None)
            else:
                (_, diff), diff_grads = jax.value_and_grad(_get_stat_diff_fn, has_aux=True)(params)
                assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None)
                for diff_grad in diff_grads:
                    assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None)
        else:
            _get_stat_diff_fn(params)
Ejemplo n.º 5
0
def eager_mvn(loc, scale_tril, value):
    assert len(loc.shape) == 1
    assert len(scale_tril.shape) == 2
    assert value.output == loc.output
    if not is_affine(loc) or not is_affine(value):
        return None  # lazy

    info_vec = ops.new_zeros(scale_tril.data, scale_tril.data.shape[:-1])
    precision = ops.cholesky_inverse(scale_tril.data)
    scale_diag = Tensor(ops.diagonal(scale_tril.data, -1, -2),
                        scale_tril.inputs)
    log_prob = -0.5 * scale_diag.shape[0] * math.log(
        2 * math.pi) - ops.log(scale_diag).sum()
    inputs = scale_tril.inputs.copy()
    var = gensym('value')
    inputs[var] = Reals[scale_diag.shape[0]]
    gaussian = log_prob + Gaussian(info_vec, precision, inputs)
    return gaussian(**{var: value - loc})
Ejemplo n.º 6
0
def test_tensor_distribution(event_inputs, batch_inputs, test_grad):
    num_samples = 50000
    sample_inputs = OrderedDict(n=bint(num_samples))
    be_inputs = OrderedDict(batch_inputs + event_inputs)
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    sampled_vars = frozenset(event_inputs)
    p_data = random_tensor(be_inputs).data
    rng_key = None if get_backend() == "torch" else np.array([0, 0],
                                                             dtype=np.uint32)
    probe = randn(p_data.shape)

    def diff_fn(p_data):
        p = Tensor(p_data, be_inputs)
        q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key)
        mq = p.materialize(q).reduce(ops.logaddexp, 'n')
        mq = mq.align(tuple(p.inputs))

        _, (p_data, mq_data) = align_tensors(p, mq)
        assert p_data.shape == mq_data.shape
        return (ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) *
                                                   probe).sum(), mq

    if test_grad:
        if get_backend() == "jax":
            import jax

            diff_grad, mq = jax.grad(diff_fn, has_aux=True)(p_data)
        else:
            import torch

            p_data.requires_grad_(True)
            diff_grad = torch.autograd.grad(diff_fn(p_data)[0], [p_data])[0]

        assert_close(diff_grad,
                     ops.new_zeros(diff_grad, diff_grad.shape),
                     atol=0.1,
                     rtol=None)
    else:
        _, mq = diff_fn(p_data)
        assert_close(mq, Tensor(p_data, be_inputs), atol=0.1, rtol=None)
Ejemplo n.º 7
0
    def as_tensor(self):
        # Fill gaps with zeros.
        arbitrary_row = next(iter(self.parts.values()))
        prototype = next(iter(arbitrary_row.values()))
        js = set().union(*(part.keys() for part in self.parts.values()))
        rows = _find_intervals(self.parts.keys(), self.shape[-2])
        cols = _find_intervals(js, self.shape[-1])
        for i in rows:
            for j in cols:
                if j not in self.parts[i]:
                    shape = self.shape[:-2] + (i[1] - i[0], j[1] - j[0])
                    self.parts[i][j] = ops.new_zeros(prototype, shape)

        # Concatenate parts.
        # TODO This could be optimized into a single .reshape().cat().reshape() if
        #   all inputs are contiguous, thereby saving a memcopy.
        columns = {
            i: ops.cat(-1, *[v for j, v in sorted(part.items())])
            for i, part in self.parts.items()
        }
        result = ops.cat(-2, *[v for i, v in sorted(columns.items())])
        if not get_tracing_state():
            assert result.shape == self.shape
        return result