Example #1
0
def squeeze_as_einsum(x: JaxExpression, params: Params) -> Einsum:
  """Converts a squeeze into an `Einsum`."""
  dimensions = params['dimensions']
  x_ndim = len(x.shape)
  x_dims = ''.join(it.islice(einsum.einsum_letters(), x_ndim))
  out_dims = ''.join([x_dims[i] for i in range(x_ndim) if i not in dimensions])
  return Einsum(f'{x_dims}->{out_dims}', (x,))
Example #2
0
def reduce_sum_as_einsum(x: JaxExpression, params: Params) -> Einsum:
  """Converts a reduce sum into an `Einsum`."""
  axis = params['axes']
  x_shape = x.shape
  x_dims = ''.join(it.islice(einsum.einsum_letters(), len(x_shape)))
  out_dims = ''.join([x_dims[i] for i in range(len(x_shape)) if i not in axis])
  formula = f'{x_dims}->{out_dims}'
  return Einsum(formula, (x,))
Example #3
0
def dot_as_einsum(x: JaxExpression, y: JaxExpression, params: Params) -> Einsum:
  """Converts a dot product into an `Einsum`."""
  dimension_numbers = params['dimension_numbers']
  (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
  x_ndim, y_ndim = len(x.shape), len(y.shape)
  letter_iter = einsum.einsum_letters()
  x_dims = ''.join(it.islice(letter_iter, x_ndim))
  y_dims = list(it.islice(letter_iter, y_ndim))
  for x_dim, y_dim in zip(x_contract + x_batch, y_contract + y_batch):
    y_dims[y_dim] = x_dims[x_dim]
  y_dims = ''.join(y_dims)
  out_batch_dims = [x_dims[dim] for dim in x_batch]
  out_dims = out_batch_dims + ([xd for xd in x_dims if xd not in y_dims] +
                               [yd for yd in y_dims if yd not in x_dims])
  out_dims = ''.join(out_dims)
  return Einsum(f'{x_dims},{y_dims}->{out_dims}', (x, y))
Example #4
0
def transpose_as_einsum(x: JaxExpression, params: Params) -> Einsum:
    """Converts a transpose into an `Einsum`."""
    x_ndim = len(x.shape)
    x_dims = ''.join(it.islice(einsum.einsum_letters(), x_ndim))
    out_dims = ''.join([x_dims[dim] for dim in params['permutation']])
    return Einsum(f'{x_dims}->{out_dims}', (x, ))