Пример #1
0
def broadcast_all(*values, **kwargs):
    """
    Packed broadcasting of multiple tensors.
    """
    inputs = kwargs.get('inputs')
    dims = kwargs.get('dims')
    sizes = {
        dim: size
        for value, old_dims in zip(values, inputs)
        for dim, size in zip(old_dims, value.shape)
    }
    if dims is None:
        dims = ''.join(sorted(sizes))
    else:
        assert set(dims) == set(sizes)
    shape = tuple(sizes[dim] for dim in dims)
    values = list(values)
    for i, (x, old_dims) in enumerate(zip(values, inputs)):
        if old_dims != dims:
            x = ops.permute(
                x,
                tuple(old_dims.index(dim) for dim in dims if dim in old_dims))
            x = x.reshape(
                tuple(sizes[dim] if dim in old_dims else 1 for dim in dims))
            x = ops.expand(x, shape)
            assert len(x.shape) == len(dims)
            values[i] = x
    return tuple(values)
Пример #2
0
def tensor_to_data(x, name_to_dim=None):
    if not name_to_dim or not x.inputs:
        if x.inputs:
            raise ValueError(
                "cannot convert Tensor to data due to lazy inputs: {}".format(
                    set(x.inputs)))
        return x.data
    else:
        assert all(
            isinstance(k, str) and isinstance(v, int) and v < 0
            for k, v in name_to_dim.items())
        # logic very similar to pyro.ops.packed.unpack
        # first collapse input domains into single dimensions
        data = x.data.reshape(
            tuple(d.dtype for d in x.inputs.values()) + x.output.shape)
        # permute packed dimensions to correct order
        unsorted_dims = [name_to_dim[name] for name in x.inputs]
        dims = sorted(unsorted_dims)
        permutation = [unsorted_dims.index(dim) for dim in dims] + \
            list(range(len(dims), len(dims) + len(x.output.shape)))
        data = ops.permute(data, permutation)
        # expand
        batch_shape = [1] * -min(dims)
        for dim, size in zip(dims, data.shape):
            batch_shape[dim] = size
        return data.reshape(tuple(batch_shape) + x.output.shape)
Пример #3
0
    def align(self, names):
        assert isinstance(names, tuple)
        assert all(name in self.inputs for name in names)
        if not names or names == tuple(self.inputs):
            return self

        inputs = OrderedDict((name, self.inputs[name]) for name in names)
        inputs.update(self.inputs)
        old_dims = tuple(self.inputs)
        new_dims = tuple(inputs)
        permutation = tuple(old_dims.index(d) for d in new_dims)
        permutation = permutation + tuple(
            range(len(permutation),
                  len(permutation) + len(self.output.shape)))
        data = ops.permute(self.data, permutation)
        return Tensor(data, inputs, self.dtype)
Пример #4
0
def eager_getitem_tensor_variable(op, lhs, rhs):
    assert op.offset < len(lhs.output.shape)
    assert rhs.output == Bint[lhs.output.shape[op.offset]]
    assert rhs.name not in lhs.inputs

    # Convert a positional event dimension to a named batch dimension.
    inputs = lhs.inputs.copy()
    inputs[rhs.name] = rhs.output
    data = lhs.data
    target_dim = len(lhs.inputs)
    source_dim = target_dim + op.offset
    if target_dim != source_dim:
        perm = list(range(len(data.shape)))
        del perm[source_dim]
        perm.insert(target_dim, source_dim)
        data = ops.permute(data, perm)
    return Tensor(data, inputs, lhs.dtype)
Пример #5
0
def align_tensor(new_inputs, x, expand=False):
    r"""
    Permute and add dims to a tensor to match desired ``new_inputs``.

    :param OrderedDict new_inputs: A target set of inputs.
    :param funsor.terms.Funsor x: A :class:`Tensor` or
        :class:`~funsor.terms.Number` .
    :param bool expand: If False (default), set result size to 1 for any input
        of ``x`` not in ``new_inputs``; if True expand to ``new_inputs`` size.
    :return: a number or :class:`torch.Tensor` or :class:`np.ndarray` that can be broadcast to other
        tensors with inputs ``new_inputs``.
    :rtype: int or float or torch.Tensor or np.ndarray
    """
    assert isinstance(new_inputs, OrderedDict)
    assert isinstance(x, (Number, Tensor))
    assert all(isinstance(d.dtype, int) for d in x.inputs.values())

    data = x.data
    if isinstance(x, Number):
        return data

    old_inputs = x.inputs
    if old_inputs == new_inputs:
        return data

    # Permute squashed input dims.
    x_keys = tuple(old_inputs)
    data = ops.permute(
        data,
        tuple(x_keys.index(k) for k in new_inputs if k in old_inputs) +
        tuple(range(len(old_inputs), len(data.shape))))

    # Unsquash multivariate input dims by filling in ones.
    data = data.reshape(
        tuple(old_inputs[k].dtype if k in old_inputs else 1
              for k in new_inputs) + x.output.shape)

    # Optionally expand new dims.
    if expand:
        data = ops.expand(
            data,
            tuple(d.dtype for d in new_inputs.values()) + x.output.shape)
    return data
Пример #6
0
def einsum(equation, *operands):
    """
    Log-sum-exp implementation of einsum.
    """
    if get_backend() != "jax":
        # NB: rename symbols to support NumPy, which allow only symbols a-z.
        symbols = sorted(set(equation) - set(',->'))
        rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz'))
        equation = ''.join(rename.get(s, s) for s in equation)

    inputs, output = equation.split('->')
    if inputs == output:
        return operands[0][...]  # create a new object
    inputs = inputs.split(',')

    shifts = []
    exp_operands = []
    for dims, operand in zip(inputs, operands):
        shift = operand
        for i, dim in enumerate(dims):
            if dim not in output:
                shift = ops.amax(shift, i, keepdims=True)
        # avoid nan due to -inf - -inf
        shift = ops.clamp(shift, ops.finfo(shift).min, None)
        exp_operands.append(ops.exp(operand - shift))

        # permute shift to match output
        shift = shift.reshape(
            [size for size, dim in zip(operand.shape, dims) if dim in output])
        if len(shift.shape) > 0:
            shift = shift.reshape((1, ) * (len(output) - shift.ndim) +
                                  shift.shape)
            dims = [dim for dim in dims if dim in output]
            dims = [dim for dim in output if dim not in dims] + dims
            shift = ops.permute(shift, [dims.index(dim) for dim in output])
        shifts.append(shift)

    result = ops.log(ops.einsum(equation, *exp_operands))
    return sum(shifts + [result])