Exemple #1
0
def eager_multinomial(total_count, probs, value):
    # Multinomial.log_prob() supports inhomogeneous total_count only by
    # avoiding passing total_count to the constructor.
    inputs, (total_count, probs,
             value) = align_tensors(total_count, probs, value)
    shape = broadcast_shape(total_count.shape + (1, ), probs.shape,
                            value.shape)
    probs = Tensor(ops.expand(probs, shape), inputs)
    value = Tensor(ops.expand(value, shape), inputs)
    if get_backend() == "torch":
        total_count = Number(
            ops.amax(total_count,
                     None).item())  # Used by distributions validation code.
    else:
        total_count = Tensor(ops.expand(total_count, shape[:-1]), inputs)
    backend_dist = import_module(
        BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
    return backend_dist.Multinomial.eager_log_prob(total_count, probs,
                                                   value)  # noqa: F821
Exemple #2
0
def einsum(equation, *operands):
    """
    Forward-max-sum backward-argmax implementation of einsum.
    This assumes all operands have a ``._pyro_dims`` attribute set.
    """
    inputs, output = equation.split('->')
    inputs = inputs.split(',')

    contract_dims = ''.join(sorted(set().union(*inputs) - set(output)))
    dims = output + contract_dims
    result = reduce(operator.add, broadcast_all(*operands, inputs=inputs, dims=dims))
    if contract_dims:
        output_shape = result.shape[:len(output)]
        result = ops.amax(result.reshape(output_shape + (-1,)), -1)
    elif result is operands[0]:
        result = result[...]  # create a new object
    assert len(result.shape) == len(output)

    return result
Exemple #3
0
def einsum(equation, *operands):
    """
    Log-sum-exp implementation of einsum.
    """
    if get_backend() != "jax":
        # NB: rename symbols to support NumPy, which allow only symbols a-z.
        symbols = sorted(set(equation) - set(',->'))
        rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz'))
        equation = ''.join(rename.get(s, s) for s in equation)

    inputs, output = equation.split('->')
    if inputs == output:
        return operands[0][...]  # create a new object
    inputs = inputs.split(',')

    shifts = []
    exp_operands = []
    for dims, operand in zip(inputs, operands):
        shift = operand
        for i, dim in enumerate(dims):
            if dim not in output:
                shift = ops.amax(shift, i, keepdims=True)
        # avoid nan due to -inf - -inf
        shift = ops.clamp(shift, ops.finfo(shift).min, None)
        exp_operands.append(ops.exp(operand - shift))

        # permute shift to match output
        shift = shift.reshape(
            [size for size, dim in zip(operand.shape, dims) if dim in output])
        if len(shift.shape) > 0:
            shift = shift.reshape((1, ) * (len(output) - shift.ndim) +
                                  shift.shape)
            dims = [dim for dim in dims if dim in output]
            dims = [dim for dim in output if dim not in dims] + dims
            shift = ops.permute(shift, [dims.index(dim) for dim in output])
        shifts.append(shift)

    result = ops.log(ops.einsum(equation, *exp_operands))
    return sum(shifts + [result])