Ejemplo n.º 1
0
def _einsum_grad(op, grad):
    equation = op.get_attr('equation')
    inputs, output = equation.split('->')
    left, right = inputs.split(',')

    return [
        gen_xla_ops.xla_einsum(grad,
                               op.inputs[1],
                               equation='{},{}->{}'.format(
                                   output, right, left),
                               name=None),
        gen_xla_ops.xla_einsum(grad,
                               op.inputs[0],
                               equation='{},{}->{}'.format(
                                   output, left, right),
                               name=None)
    ]
Ejemplo n.º 2
0
def _einsum_grad(op, grad):
  equation = op.get_attr('equation')
  inputs, output = equation.split('->')
  left, right = inputs.split(',')

  return [
      gen_xla_ops.xla_einsum(
          grad,
          op.inputs[1],
          equation='{},{}->{}'.format(output, right, left),
          name=None),
      gen_xla_ops.xla_einsum(
          grad,
          op.inputs[0],
          equation='{},{}->{}'.format(output, left, right),
          name=None)
  ]
Ejemplo n.º 3
0
def _einsum_v1(equation, *inputs, **kwargs):
    """Legacy implementation of einsum without using EinsumOp."""
    name = kwargs.pop('name', None)
    if kwargs:
        raise TypeError(
            'invalid keyword arguments for this function: ' +
            ', '.join([format(key) for key in sorted(list(kwargs.keys()))]))
    with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
        inputs = list(inputs)
        input_shapes = [x.shape for x in inputs]
        input_axis_labels, output_axis_labels = (
            _einsum_v1_parse_and_resolve_equation(equation, input_shapes))

        axis_labels = set(''.join(input_axis_labels) + output_axis_labels)

        for a in axis_labels:
            for input_labels in input_axis_labels:
                if (len(input_axis_labels) == 1 and input_labels.count(a) == 2
                        and input_labels == input_labels[::-1]
                        and '->' not in equation):
                    return math_ops.trace(inputs[0])
                if input_labels.count(a) > 1:
                    raise ValueError(
                        'Subscript not supported: an axis appears more than once: %s'
                        % input_labels)
        for a in axis_labels:
            input_count = sum(1 for s in input_axis_labels if a in s)
            if input_count > 2 and a not in output_axis_labels:
                logging.warn(
                    'Falling back to exponential-space implementation of einsum()'
                    ' because index "%s" is summed over more than two inputs.',
                    a)
                return _exponential_space_einsum_v1(equation, *inputs)

        # Use xla_einsum if executing on TPU and if the operation is a 2 input
        # einsum supported by XlaEinsumOp.
        if _enclosing_tpu_context() is not None and len(inputs) == 2:
            return gen_xla_ops.xla_einsum(
                inputs[0], inputs[1], input_axis_labels[0] + ',' +
                input_axis_labels[1] + '->' + output_axis_labels)
        temp = inputs[0]
        temp_axis_labels = input_axis_labels[0]
        for i in xrange(len(inputs) - 1):
            axes_to_sum = (
                set(temp_axis_labels)
                & set(input_axis_labels[i + 1]) - set(output_axis_labels))
            temp, temp_axis_labels = _einsum_v1_reduction(
                temp, temp_axis_labels, inputs[i + 1],
                input_axis_labels[i + 1], axes_to_sum)

        missing_indices = set(temp_axis_labels) - set(output_axis_labels)
        if missing_indices:
            axis = [
                i for i, a in enumerate(temp_axis_labels)
                if a not in output_axis_labels
            ]
            temp = math_ops.reduce_sum(temp, axis=axis)
            temp_axis_labels = ''.join(a for a in temp_axis_labels
                                       if a in output_axis_labels)
        if sorted(temp_axis_labels) != sorted(output_axis_labels):
            raise ValueError('Invalid equation: %s' % equation)

        perm = [temp_axis_labels.index(a) for a in output_axis_labels]
        return _transpose_if_necessary(temp, perm)
Ejemplo n.º 4
0
def einsum(equation, *inputs, **kwargs):
    """Tensor contraction over specified indices and outer product.

  This function returns a tensor whose elements are defined by `equation`,
  which is written in a shorthand form inspired by the Einstein summation
  convention.  As an example, consider multiplying two matrices
  A and B to form a matrix C.  The elements of C are given by:

  ```
    C[i,k] = sum_j A[i,j] * B[j,k]
  ```

  The corresponding `equation` is:

  ```
    ij,jk->ik
  ```

  In general, the `equation` is obtained from the more familiar element-wise
  equation by
    1. removing variable names, brackets, and commas,
    2. replacing "*" with ",",
    3. dropping summation signs, and
    4. moving the output to the right, and replacing "=" with "->".

  Many common operations can be expressed in this way.  For example:

  ```python
  # Matrix multiplication
  >>> einsum('ij,jk->ik', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]

  # Dot product
  >>> einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]

  # Outer product
  >>> einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]

  # Transpose
  >>> einsum('ij->ji', m)  # output[j,i] = m[i,j]

  # Trace
  >>> einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]

  # Batch matrix multiplication
  >>> einsum('aij,ajk->aik', s, t)  # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
  ```

  To enable and control broadcasting, use an ellipsis.  For example, to do
  batch matrix multiplication, you could use:

  ```python
  >>> einsum('...ij,...jk->...ik', u, v)
  ```

  This function behaves like `numpy.einsum`, but does not support:

  * Subscripts where an axis appears more than once for a single input
    (e.g. `ijj,k->ik`) unless it is a trace (e.g. `ijji`).

  Args:
    equation: a `str` describing the contraction, in the same format as
      `numpy.einsum`.
    *inputs: the inputs to contract (each one a `Tensor`), whose shapes should
      be consistent with `equation`.
    name: A name for the operation (optional).

  Returns:
    The contracted `Tensor`, with shape determined by `equation`.

  Raises:
    ValueError: If
      - the format of `equation` is incorrect,
      - the number of inputs implied by `equation` does not match `len(inputs)`,
      - an axis appears in the output subscripts but not in any of the inputs,
      - the number of dimensions of an input differs from the number of
        indices in its subscript, or
      - the input shapes are inconsistent along a particular axis.
  """
    name = kwargs.pop('name', None)
    if kwargs:
        raise TypeError(
            'invalid keyword arguments for this function: ' +
            ', '.join([format(key) for key in sorted(list(kwargs.keys()))]))
    with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
        inputs = list(inputs)
        input_shapes = [x.get_shape() for x in inputs]
        input_axis_labels, output_axis_labels = _einsum_parse_and_resolve_equation(
            equation, input_shapes)

        axis_labels = set(''.join(input_axis_labels) + output_axis_labels)

        for a in axis_labels:
            for input_labels in input_axis_labels:
                if (len(input_axis_labels) == 1 and input_labels.count(a) == 2
                        and input_labels == input_labels[::-1]
                        and '->' not in equation):
                    return math_ops.trace(inputs[0])
                if input_labels.count(a) > 1:
                    raise ValueError(
                        'Subscript not supported: an axis appears more than once: %s'
                        % input_labels)
        for a in axis_labels:
            input_count = sum(1 for s in input_axis_labels if a in s)
            if input_count > 2 and a not in output_axis_labels:
                logging.warn(
                    'Falling back to exponential-space implementation of einsum()'
                    ' because index "%s" is summed over more than two inputs.',
                    a)
                return _exponential_space_einsum(equation, *inputs)

        # Use xla_einsum if executing on TPU and if the operation is a 2 input
        # einsum supported by XlaEinsumOp.
        if _enclosing_tpu_context() is not None and len(inputs) == 2:
            return gen_xla_ops.xla_einsum(
                inputs[0], inputs[1], input_axis_labels[0] + ',' +
                input_axis_labels[1] + '->' + output_axis_labels)
        temp = inputs[0]
        temp_axis_labels = input_axis_labels[0]
        for i in xrange(len(inputs) - 1):
            axes_to_sum = (
                set(temp_axis_labels)
                & set(input_axis_labels[i + 1]) - set(output_axis_labels))
            temp, temp_axis_labels = _einsum_reduction(
                temp, temp_axis_labels, inputs[i + 1],
                input_axis_labels[i + 1], axes_to_sum)

        missing_indices = set(temp_axis_labels) - set(output_axis_labels)
        if missing_indices:
            axis = [
                i for i, a in enumerate(temp_axis_labels)
                if a not in output_axis_labels
            ]
            temp = math_ops.reduce_sum(temp, axis=axis)
            temp_axis_labels = ''.join(a for a in temp_axis_labels
                                       if a in output_axis_labels)
        if sorted(temp_axis_labels) != sorted(output_axis_labels):
            raise ValueError('Invalid equation: %s' % equation)

        perm = [temp_axis_labels.index(a) for a in output_axis_labels]
        return _transpose_if_necessary(temp, perm)
Ejemplo n.º 5
0
def einsum(equation, *inputs, **kwargs):
  """A generalized contraction between tensors of arbitrary dimension.

  This function returns a tensor whose elements are defined by `equation`,
  which is written in a shorthand form inspired by the Einstein summation
  convention.  As an example, consider multiplying two matrices
  A and B to form a matrix C.  The elements of C are given by:

  ```
    C[i,k] = sum_j A[i,j] * B[j,k]
  ```

  The corresponding `equation` is:

  ```
    ij,jk->ik
  ```

  In general, the `equation` is obtained from the more familiar element-wise
  equation by
    1. removing variable names, brackets, and commas,
    2. replacing "*" with ",",
    3. dropping summation signs, and
    4. moving the output to the right, and replacing "=" with "->".

  Many common operations can be expressed in this way.  For example:

  ```python
  # Matrix multiplication
  >>> einsum('ij,jk->ik', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]

  # Dot product
  >>> einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]

  # Outer product
  >>> einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]

  # Transpose
  >>> einsum('ij->ji', m)  # output[j,i] = m[i,j]

  # Trace
  >>> einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]

  # Batch matrix multiplication
  >>> einsum('aij,ajk->aik', s, t)  # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
  ```

  To enable and control broadcasting, use an ellipsis.  For example, to do
  batch matrix multiplication, you could use:

  ```python
  >>> einsum('...ij,...jk->...ik', u, v)
  ```

  This function behaves like `numpy.einsum`, but does not support:

  * Subscripts where an axis appears more than once for a single input
    (e.g. `ijj,k->ik`) unless it is a trace (e.g. `ijji`).

  Args:
    equation: a `str` describing the contraction, in the same format as
      `numpy.einsum`.
    *inputs: the inputs to contract (each one a `Tensor`), whose shapes should
      be consistent with `equation`.
    name: A name for the operation (optional).

  Returns:
    The contracted `Tensor`, with shape determined by `equation`.

  Raises:
    ValueError: If
      - the format of `equation` is incorrect,
      - the number of inputs implied by `equation` does not match `len(inputs)`,
      - an axis appears in the output subscripts but not in any of the inputs,
      - the number of dimensions of an input differs from the number of
        indices in its subscript, or
      - the input shapes are inconsistent along a particular axis.
  """
  name = kwargs.pop('name', None)
  if kwargs:
    raise TypeError('invalid keyword arguments for this function: ' + ', '.join(
        [format(key) for key in sorted(list(kwargs.keys()))]))
  with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
    inputs = list(inputs)
    input_shapes = [x.get_shape() for x in inputs]
    input_axis_labels, output_axis_labels = _einsum_parse_and_resolve_equation(
        equation, input_shapes)

    axis_labels = set(''.join(input_axis_labels) + output_axis_labels)

    for a in axis_labels:
      for input_labels in input_axis_labels:
        if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and
            input_labels == input_labels[::-1] and '->' not in equation):
          return math_ops.trace(inputs[0])
        if input_labels.count(a) > 1:
          raise ValueError(
              'Subscript not supported: an axis appears more than once: %s' %
              input_labels)
    for a in axis_labels:
      input_count = sum(1 for s in input_axis_labels if a in s)
      if input_count > 2 and a not in output_axis_labels:
        logging.warn(
            'Falling back to exponential-space implementation of einsum()'
            ' because index "%s" is summed over more than two inputs.', a)
        return _exponential_space_einsum(equation, *inputs)

    # Use xla_einsum if executing on TPU and if the operation is a 2 input
    # einsum supported by XlaEinsumOp.
    if _enclosing_tpu_context() is not None and len(inputs) == 2:
      return gen_xla_ops.xla_einsum(
          inputs[0], inputs[1], input_axis_labels[0] + ',' +
          input_axis_labels[1] + '->' + output_axis_labels)
    temp = inputs[0]
    temp_axis_labels = input_axis_labels[0]
    for i in xrange(len(inputs) - 1):
      axes_to_sum = (
          set(temp_axis_labels) &
          set(input_axis_labels[i + 1]) - set(output_axis_labels))
      temp, temp_axis_labels = _einsum_reduction(
          temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1],
          axes_to_sum)


    missing_indices = set(temp_axis_labels) - set(output_axis_labels)
    if missing_indices:
      axis = [
          i for i, a in enumerate(temp_axis_labels)
          if a not in output_axis_labels
      ]
      temp = math_ops.reduce_sum(temp, axis=axis)
      temp_axis_labels = ''.join(
          a for a in temp_axis_labels if a in output_axis_labels)
    if sorted(temp_axis_labels) != sorted(output_axis_labels):
      raise ValueError('Invalid equation: %s' % equation)

    perm = [temp_axis_labels.index(a) for a in output_axis_labels]
    return _transpose_if_necessary(temp, perm)
Ejemplo n.º 6
0
def _einsum_v2(equation, *inputs, **kwargs):
    """Implementation of einsum utilizing opt_einsum and EinsumOp."""
    name = kwargs.pop('name', None)
    optimize = kwargs.pop('optimize', 'greedy')
    if kwargs:
        msg = 'Invalid keyword arguments for einsum: {}'
        raise TypeError(msg.format(', '.join(kwargs)))

    with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
        inputs = list(inputs)
        input_shapes = []
        for operand in inputs:
            if isinstance(operand.shape, tensor_shape.TensorShape):
                input_shapes.append(
                    operand.shape.as_list() if operand.shape else None)
            else:
                input_shapes.append(list(operand.shape))
        # Validate and sanitize the equation and resolve static input shapes, as
        # opt_einsum requires that all shapes be a tuple of positive integers.
        # Also remove ellipsis from the equation as opt_einsum will replace them
        # with named labels. Then broadcasting between different shapes or ranks
        # wouldn't work. (E.g. [1, 1, 2] wouldn't broadcast with [3, 1]).
        resolved_equation, resolved_input_shapes, ellipsis_label = (
            _einsum_v2_parse_and_resolve_equation(equation, input_shapes))

        # Use xla_einsum if executing on TPU and if the operation is a 2 input
        # einsum supported by XlaEinsumOp.
        has_enclosing_tpu_context = _enclosing_tpu_context() is not None

        if len(inputs) <= 2:  # No need to call opt_einsum.
            # Replace back ellipses that were removed for opt_einsum.
            if ellipsis_label:
                resolved_equation = resolved_equation.replace(
                    ellipsis_label, '...')
            if has_enclosing_tpu_context and len(inputs) == 2:
                return gen_xla_ops.xla_einsum(inputs[0], inputs[1],
                                              resolved_equation)
            return gen_linalg_ops.einsum(inputs, resolved_equation)

        # Send fully specified shapes to opt_einsum, since it cannot handle unknown
        # dimensions. For unknown dimensions, we guess that the dimension equals 1.
        # Instead of creating Tensors or NumPy arrays with the specified shape,
        # create a dummy `shaped` object with a `shape` property.
        shaped = collections.namedtuple('shaped', ['shape'])
        shaped_inputs = tuple(
            [shaped(tuple(shape)) for shape in resolved_input_shapes])
        # opt_einsum breaks down an n-ary einsum operation into n-1 binary einsums.
        # Obtain the sequence of equations and the indices of operands involved in
        # each einsum operation.
        indices_and_equations = _get_opt_einsum_contract_path(
            resolved_equation, shaped_inputs, optimize)
        for operand_indices, binary_equation in indices_and_equations:
            if ellipsis_label:
                # Replace back ellipses that were removed for opt_einsum.
                binary_equation = binary_equation.replace(
                    ellipsis_label, '...')
            operands = list(map(inputs.pop, operand_indices))
            # Use xla_einsum if executing on TPU and if the operation is a 2 input
            # einsum supported by XlaEinsumOp.
            if has_enclosing_tpu_context and len(operands) == 2:
                inputs.append(
                    gen_xla_ops.xla_einsum(operands[0], operands[1],
                                           binary_equation))
            else:
                inputs.append(gen_linalg_ops.einsum(operands, binary_equation))
        return inputs[0]