def squeeze_as_einsum(x: JaxExpression, params: Params) -> Einsum: """Converts a squeeze into an `Einsum`.""" dimensions = params['dimensions'] x_ndim = len(x.shape) x_dims = ''.join(it.islice(einsum.einsum_letters(), x_ndim)) out_dims = ''.join([x_dims[i] for i in range(x_ndim) if i not in dimensions]) return Einsum(f'{x_dims}->{out_dims}', (x,))
def reduce_sum_as_einsum(x: JaxExpression, params: Params) -> Einsum: """Converts a reduce sum into an `Einsum`.""" axis = params['axes'] x_shape = x.shape x_dims = ''.join(it.islice(einsum.einsum_letters(), len(x_shape))) out_dims = ''.join([x_dims[i] for i in range(len(x_shape)) if i not in axis]) formula = f'{x_dims}->{out_dims}' return Einsum(formula, (x,))
def dot_as_einsum(x: JaxExpression, y: JaxExpression, params: Params) -> Einsum: """Converts a dot product into an `Einsum`.""" dimension_numbers = params['dimension_numbers'] (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim, y_ndim = len(x.shape), len(y.shape) letter_iter = einsum.einsum_letters() x_dims = ''.join(it.islice(letter_iter, x_ndim)) y_dims = list(it.islice(letter_iter, y_ndim)) for x_dim, y_dim in zip(x_contract + x_batch, y_contract + y_batch): y_dims[y_dim] = x_dims[x_dim] y_dims = ''.join(y_dims) out_batch_dims = [x_dims[dim] for dim in x_batch] out_dims = out_batch_dims + ([xd for xd in x_dims if xd not in y_dims] + [yd for yd in y_dims if yd not in x_dims]) out_dims = ''.join(out_dims) return Einsum(f'{x_dims},{y_dims}->{out_dims}', (x, y))
def transpose_as_einsum(x: JaxExpression, params: Params) -> Einsum: """Converts a transpose into an `Einsum`.""" x_ndim = len(x.shape) x_dims = ''.join(it.islice(einsum.einsum_letters(), x_ndim)) out_dims = ''.join([x_dims[dim] for dim in params['permutation']]) return Einsum(f'{x_dims}->{out_dims}', (x, ))