Ejemplo n.º 1
0
def eager_gamma_gamma(red_op, bin_op, reduced_vars, x, y):
    gamma_reduction = frozenset(x.inputs).intersection(reduced_vars)
    if gamma_reduction:
        unnormalized = (y.concentration - 1) * ops.log(y.value) \
            - (y.concentration + x.concentration) * ops.log(y.value + x.rate)
        const = -x.concentration * ops.log(x.rate) + _log_beta(
            y.concentration, x.concentration)
        return unnormalized - const
    else:
        return eager(Contraction, red_op, bin_op, reduced_vars, (x, y))
Ejemplo n.º 2
0
def test_reduce_moment_matching_multivariate():
    int_inputs = [('i', bint(4))]
    real_inputs = [('x', reals(2))]
    inputs = OrderedDict(int_inputs + real_inputs)
    int_inputs = OrderedDict(int_inputs)
    real_inputs = OrderedDict(real_inputs)

    loc = numeric_array([[-10., -1.], [+10., -1.], [+10., +1.], [-10., +1.]])
    precision = zeros(4, 1, 1) + ops.new_eye(loc, (2, ))
    discrete = Tensor(zeros(4), int_inputs)
    gaussian = Gaussian(loc, precision, inputs)
    gaussian -= gaussian.log_normalizer
    joint = discrete + gaussian
    with interpretation(moment_matching):
        actual = joint.reduce(ops.logaddexp, 'i')
    assert_close(actual.reduce(ops.logaddexp), joint.reduce(ops.logaddexp))

    expected_loc = zeros(2)
    expected_covariance = numeric_array([[101., 0.], [0., 2.]])
    expected_precision = _inverse(expected_covariance)
    expected_gaussian = Gaussian(expected_loc, expected_precision, real_inputs)
    expected_gaussian -= expected_gaussian.log_normalizer
    expected_discrete = Tensor(ops.log(numeric_array(4.)))
    expected = expected_discrete + expected_gaussian
    assert_close(actual, expected, atol=1e-5, rtol=None)
Ejemplo n.º 3
0
def test_reduce_moment_matching_univariate():
    int_inputs = [('i', bint(2))]
    real_inputs = [('x', reals())]
    inputs = OrderedDict(int_inputs + real_inputs)
    int_inputs = OrderedDict(int_inputs)
    real_inputs = OrderedDict(real_inputs)

    p = 0.8
    t = 1.234
    s1, s2, s3 = 2.0, 3.0, 4.0
    loc = numeric_array([[-s1], [s1]])
    precision = numeric_array([[[s2**-2]], [[s3**-2]]])
    info_vec = (precision @ ops.unsqueeze(loc, -1)).squeeze(-1)
    discrete = Tensor(ops.log(numeric_array([1 - p, p])) + t, int_inputs)
    gaussian = Gaussian(info_vec, precision, inputs)
    gaussian -= gaussian.log_normalizer
    joint = discrete + gaussian
    with interpretation(moment_matching):
        actual = joint.reduce(ops.logaddexp, 'i')
    assert_close(actual.reduce(ops.logaddexp), joint.reduce(ops.logaddexp))

    expected_loc = numeric_array([(2 * p - 1) * s1])
    expected_variance = (4 * p * (1 - p) * s1**2 + (1 - p) * s2**2 + p * s3**2)
    expected_precision = numeric_array([[1 / expected_variance]])
    expected_info_vec = (
        expected_precision @ ops.unsqueeze(expected_loc, -1)).squeeze(-1)
    expected_gaussian = Gaussian(expected_info_vec, expected_precision,
                                 real_inputs)
    expected_gaussian -= expected_gaussian.log_normalizer
    expected_discrete = Tensor(numeric_array(t))
    expected = expected_discrete + expected_gaussian
    assert_close(actual, expected, atol=1e-5, rtol=None)
Ejemplo n.º 4
0
    def eager_reduce(self, op, reduced_vars):
        if op is ops.logaddexp:
            # Keep mixture parameters lazy.
            mixture_vars = frozenset(k
                                     for k, d in self.gaussian.inputs.items()
                                     if d.dtype != 'real')
            mixture_vars = mixture_vars.union(*(x.point.inputs
                                                for x in self.deltas))
            lazy_vars = reduced_vars & mixture_vars
            reduced_vars -= lazy_vars

            # Integrate out degenerate variables, i.e. drop selected delta.
            deltas = []
            remaining_vars = set(reduced_vars)
            for d in self.deltas:
                if d.name in reduced_vars:
                    remaining_vars.remove(d.name)
                else:
                    deltas.append(d)
            deltas = tuple(deltas)
            reduced_vars = frozenset(remaining_vars)

            # Integrate out delayed discrete variables.
            discrete_vars = reduced_vars.intersection(self.discrete.inputs)
            discrete = self.discrete.reduce(op, discrete_vars)
            reduced_vars -= discrete_vars

            # Integrate out delayed gaussian variables.
            gaussian_vars = reduced_vars.intersection(self.gaussian.inputs)
            gaussian = self.gaussian.reduce(ops.logaddexp, gaussian_vars)
            reduced_vars -= gaussian_vars

            # Scale to account for remaining reduced_vars that were inputs to dropped deltas.
            eager_result = Joint(deltas, discrete)
            if gaussian is not Number(0):
                eager_result += gaussian
            reduced_vars |= lazy_vars.difference(eager_result.inputs)
            lazy_vars = lazy_vars.intersection(eager_result.inputs)
            if reduced_vars:
                eager_result += ops.log(
                    reduce(ops.mul,
                           [self.inputs[v].dtype for v in reduced_vars]))

            # Return a value only if progress has been made.
            if eager_result is self:
                return None  # defer to default implementation
            else:
                return eager_result.reduce(ops.logaddexp, lazy_vars)

        if op is ops.add:
            terms = list(self.deltas) + [self.discrete, self.gaussian]
            for i, term in enumerate(terms):
                terms[i] = term.reduce(ops.add,
                                       reduced_vars.intersection(term.inputs))
            return reduce(ops.add, terms)

        return None  # defer to default implementation
Ejemplo n.º 5
0
def eager_normal(loc, scale, value):
    assert loc.output == Real
    assert scale.output == Real
    assert value.output == Real
    if not is_affine(loc) or not is_affine(value):
        return None  # lazy

    info_vec = ops.new_zeros(scale.data, scale.data.shape + (1, ))
    precision = ops.pow(scale.data, -2).reshape(scale.data.shape + (1, 1))
    log_prob = -0.5 * math.log(2 * math.pi) - ops.log(scale).sum()
    inputs = scale.inputs.copy()
    var = gensym('value')
    inputs[var] = Real
    gaussian = log_prob + Gaussian(info_vec, precision, inputs)
    return gaussian(**{var: value - loc})
Ejemplo n.º 6
0
def eager_mvn(loc, scale_tril, value):
    assert len(loc.shape) == 1
    assert len(scale_tril.shape) == 2
    assert value.output == loc.output
    if not is_affine(loc) or not is_affine(value):
        return None  # lazy

    info_vec = ops.new_zeros(scale_tril.data, scale_tril.data.shape[:-1])
    precision = ops.cholesky_inverse(scale_tril.data)
    scale_diag = Tensor(ops.diagonal(scale_tril.data, -1, -2),
                        scale_tril.inputs)
    log_prob = -0.5 * scale_diag.shape[0] * math.log(
        2 * math.pi) - ops.log(scale_diag).sum()
    inputs = scale_tril.inputs.copy()
    var = gensym('value')
    inputs[var] = Reals[scale_diag.shape[0]]
    gaussian = log_prob + Gaussian(info_vec, precision, inputs)
    return gaussian(**{var: value - loc})
Ejemplo n.º 7
0
def einsum(equation, *operands):
    """
    Log-sum-exp implementation of einsum.
    """
    if get_backend() != "jax":
        # NB: rename symbols to support NumPy, which allow only symbols a-z.
        symbols = sorted(set(equation) - set(',->'))
        rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz'))
        equation = ''.join(rename.get(s, s) for s in equation)

    inputs, output = equation.split('->')
    if inputs == output:
        return operands[0][...]  # create a new object
    inputs = inputs.split(',')

    shifts = []
    exp_operands = []
    for dims, operand in zip(inputs, operands):
        shift = operand
        for i, dim in enumerate(dims):
            if dim not in output:
                shift = ops.amax(shift, i, keepdims=True)
        # avoid nan due to -inf - -inf
        shift = ops.clamp(shift, ops.finfo(shift).min, None)
        exp_operands.append(ops.exp(operand - shift))

        # permute shift to match output
        shift = shift.reshape(
            [size for size, dim in zip(operand.shape, dims) if dim in output])
        if len(shift.shape) > 0:
            shift = shift.reshape((1, ) * (len(output) - shift.ndim) +
                                  shift.shape)
            dims = [dim for dim in dims if dim in output]
            dims = [dim for dim in output if dim not in dims] + dims
            shift = ops.permute(shift, [dims.index(dim) for dim in output])
        shifts.append(shift)

    result = ops.log(ops.einsum(equation, *exp_operands))
    return sum(shifts + [result])
Ejemplo n.º 8
0
def test_transform_log(shape):
    point = Tensor(randn(shape))
    x = Variable('x', reals(*shape))
    actual = Delta('y', point)(y=ops.log(x))
    expected = Delta('x', point.exp(), -point.sum())
    assert_close(actual, expected)
Ejemplo n.º 9
0
def _log_det_tri(x):
    return ops.log(ops.diagonal(x, -1, -2)).sum(-1)
Ejemplo n.º 10
0
 def delta(v: Reals[event_shape], log_density: Real, value: Reals[event_shape]) -> Real:
     eq = (v == value)
     for _ in range(len(event_shape)):
         eq = ops.all(eq, -1)
     return ops.log(ops.astype(eq, 'float32')) + log_density
Ejemplo n.º 11
0
 def delta(v, log_density, value):
     eq = (v == value)
     for _ in range(len(event_shape)):
         eq = ops.all(eq, -1)
     return ops.log(ops.astype(eq, 'float32')) + log_density