def eager_einsum(equation, operands): if all(isinstance(x, Tensor) for x in operands): # Make new symbols for inputs of operands. inputs = OrderedDict() for x in operands: inputs.update(x.inputs) symbols = set(equation) get_symbol = iter(map(opt_einsum.get_symbol, itertools.count())) new_symbols = {} for k in inputs: symbol = next(get_symbol) while symbol in symbols: symbol = next(get_symbol) symbols.add(symbol) new_symbols[k] = symbol # Manually broadcast using einsum symbols. assert '.' not in equation ins, out = equation.split('->') ins = ins.split(',') ins = [ ''.join(new_symbols[k] for k in x.inputs) + x_out for x, x_out in zip(operands, ins) ] out = ''.join(new_symbols[k] for k in inputs) + out equation = ','.join(ins) + '->' + out data = ops.einsum(equation, *[x.data for x in operands]) return Tensor(data, inputs) return None # defer to default implementation
def test_einsum(equation): sizes = dict(a=2, b=3, c=4) inputs, outputs = equation.split('->') inputs = inputs.split(',') tensors = [randn(tuple(sizes[d] for d in dims)) for dims in inputs] funsors = [Tensor(x) for x in tensors] expected = Tensor(ops.einsum(equation, *tensors)) actual = Einsum(equation, tuple(funsors)) assert_close(actual, expected, atol=1e-5, rtol=None)
def test_batched_einsum(equation, batch1, batch2): inputs, output = equation.split('->') inputs = inputs.split(',') sizes = dict(a=2, b=3, c=4, i=5, j=6) batch1 = OrderedDict([(k, bint(sizes[k])) for k in batch1]) batch2 = OrderedDict([(k, bint(sizes[k])) for k in batch2]) funsors = [ random_tensor(batch, reals(*(sizes[d] for d in dims))) for batch, dims in zip([batch1, batch2], inputs) ] actual = Einsum(equation, tuple(funsors)) _equation = ','.join('...' + i for i in inputs) + '->...' + output inputs, tensors = align_tensors(*funsors) batch = tuple(v.size for v in inputs.values()) tensors = [ ops.expand(x, batch + f.shape) for (x, f) in zip(tensors, funsors) ] expected = Tensor(ops.einsum(_equation, *tensors), inputs) assert_close(actual, expected, atol=1e-5, rtol=None)
def einsum(equation, *operands): """ Log-sum-exp implementation of einsum. """ if get_backend() != "jax": # NB: rename symbols to support NumPy, which allow only symbols a-z. symbols = sorted(set(equation) - set(',->')) rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz')) equation = ''.join(rename.get(s, s) for s in equation) inputs, output = equation.split('->') if inputs == output: return operands[0][...] # create a new object inputs = inputs.split(',') shifts = [] exp_operands = [] for dims, operand in zip(inputs, operands): shift = operand for i, dim in enumerate(dims): if dim not in output: shift = ops.amax(shift, i, keepdims=True) # avoid nan due to -inf - -inf shift = ops.clamp(shift, ops.finfo(shift).min, None) exp_operands.append(ops.exp(operand - shift)) # permute shift to match output shift = shift.reshape( [size for size, dim in zip(operand.shape, dims) if dim in output]) if len(shift.shape) > 0: shift = shift.reshape((1, ) * (len(output) - shift.ndim) + shift.shape) dims = [dim for dim in dims if dim in output] dims = [dim for dim in output if dim not in dims] + dims shift = ops.permute(shift, [dims.index(dim) for dim in output]) shifts.append(shift) result = ops.log(ops.einsum(equation, *exp_operands)) return sum(shifts + [result])