Beispiel #1
0
def eager_einsum(equation, operands):
    if all(isinstance(x, Tensor) for x in operands):
        # Make new symbols for inputs of operands.
        inputs = OrderedDict()
        for x in operands:
            inputs.update(x.inputs)
        symbols = set(equation)
        get_symbol = iter(map(opt_einsum.get_symbol, itertools.count()))
        new_symbols = {}
        for k in inputs:
            symbol = next(get_symbol)
            while symbol in symbols:
                symbol = next(get_symbol)
            symbols.add(symbol)
            new_symbols[k] = symbol

        # Manually broadcast using einsum symbols.
        assert '.' not in equation
        ins, out = equation.split('->')
        ins = ins.split(',')
        ins = [
            ''.join(new_symbols[k] for k in x.inputs) + x_out
            for x, x_out in zip(operands, ins)
        ]
        out = ''.join(new_symbols[k] for k in inputs) + out
        equation = ','.join(ins) + '->' + out

        data = ops.einsum(equation, *[x.data for x in operands])
        return Tensor(data, inputs)

    return None  # defer to default implementation
Beispiel #2
0
def test_einsum(equation):
    sizes = dict(a=2, b=3, c=4)
    inputs, outputs = equation.split('->')
    inputs = inputs.split(',')
    tensors = [randn(tuple(sizes[d] for d in dims)) for dims in inputs]
    funsors = [Tensor(x) for x in tensors]
    expected = Tensor(ops.einsum(equation, *tensors))
    actual = Einsum(equation, tuple(funsors))
    assert_close(actual, expected, atol=1e-5, rtol=None)
Beispiel #3
0
def test_batched_einsum(equation, batch1, batch2):
    inputs, output = equation.split('->')
    inputs = inputs.split(',')

    sizes = dict(a=2, b=3, c=4, i=5, j=6)
    batch1 = OrderedDict([(k, bint(sizes[k])) for k in batch1])
    batch2 = OrderedDict([(k, bint(sizes[k])) for k in batch2])
    funsors = [
        random_tensor(batch, reals(*(sizes[d] for d in dims)))
        for batch, dims in zip([batch1, batch2], inputs)
    ]
    actual = Einsum(equation, tuple(funsors))

    _equation = ','.join('...' + i for i in inputs) + '->...' + output
    inputs, tensors = align_tensors(*funsors)
    batch = tuple(v.size for v in inputs.values())
    tensors = [
        ops.expand(x, batch + f.shape) for (x, f) in zip(tensors, funsors)
    ]
    expected = Tensor(ops.einsum(_equation, *tensors), inputs)
    assert_close(actual, expected, atol=1e-5, rtol=None)
Beispiel #4
0
def einsum(equation, *operands):
    """
    Log-sum-exp implementation of einsum.
    """
    if get_backend() != "jax":
        # NB: rename symbols to support NumPy, which allow only symbols a-z.
        symbols = sorted(set(equation) - set(',->'))
        rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz'))
        equation = ''.join(rename.get(s, s) for s in equation)

    inputs, output = equation.split('->')
    if inputs == output:
        return operands[0][...]  # create a new object
    inputs = inputs.split(',')

    shifts = []
    exp_operands = []
    for dims, operand in zip(inputs, operands):
        shift = operand
        for i, dim in enumerate(dims):
            if dim not in output:
                shift = ops.amax(shift, i, keepdims=True)
        # avoid nan due to -inf - -inf
        shift = ops.clamp(shift, ops.finfo(shift).min, None)
        exp_operands.append(ops.exp(operand - shift))

        # permute shift to match output
        shift = shift.reshape(
            [size for size, dim in zip(operand.shape, dims) if dim in output])
        if len(shift.shape) > 0:
            shift = shift.reshape((1, ) * (len(output) - shift.ndim) +
                                  shift.shape)
            dims = [dim for dim in dims if dim in output]
            dims = [dim for dim in output if dim not in dims] + dims
            shift = ops.permute(shift, [dims.index(dim) for dim in output])
        shifts.append(shift)

    result = ops.log(ops.einsum(equation, *exp_operands))
    return sum(shifts + [result])