def _einsum_grad(op, grad): equation = op.get_attr('equation') inputs, output = equation.split('->') left, right = inputs.split(',') return [ gen_xla_ops.xla_einsum(grad, op.inputs[1], equation='{},{}->{}'.format( output, right, left), name=None), gen_xla_ops.xla_einsum(grad, op.inputs[0], equation='{},{}->{}'.format( output, left, right), name=None) ]
def _einsum_grad(op, grad): equation = op.get_attr('equation') inputs, output = equation.split('->') left, right = inputs.split(',') return [ gen_xla_ops.xla_einsum( grad, op.inputs[1], equation='{},{}->{}'.format(output, right, left), name=None), gen_xla_ops.xla_einsum( grad, op.inputs[0], equation='{},{}->{}'.format(output, left, right), name=None) ]
def _einsum_v1(equation, *inputs, **kwargs): """Legacy implementation of einsum without using EinsumOp.""" name = kwargs.pop('name', None) if kwargs: raise TypeError( 'invalid keyword arguments for this function: ' + ', '.join([format(key) for key in sorted(list(kwargs.keys()))])) with ops.name_scope(name, 'einsum', [equation, inputs]) as name: inputs = list(inputs) input_shapes = [x.shape for x in inputs] input_axis_labels, output_axis_labels = ( _einsum_v1_parse_and_resolve_equation(equation, input_shapes)) axis_labels = set(''.join(input_axis_labels) + output_axis_labels) for a in axis_labels: for input_labels in input_axis_labels: if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and input_labels == input_labels[::-1] and '->' not in equation): return math_ops.trace(inputs[0]) if input_labels.count(a) > 1: raise ValueError( 'Subscript not supported: an axis appears more than once: %s' % input_labels) for a in axis_labels: input_count = sum(1 for s in input_axis_labels if a in s) if input_count > 2 and a not in output_axis_labels: logging.warn( 'Falling back to exponential-space implementation of einsum()' ' because index "%s" is summed over more than two inputs.', a) return _exponential_space_einsum_v1(equation, *inputs) # Use xla_einsum if executing on TPU and if the operation is a 2 input # einsum supported by XlaEinsumOp. if _enclosing_tpu_context() is not None and len(inputs) == 2: return gen_xla_ops.xla_einsum( inputs[0], inputs[1], input_axis_labels[0] + ',' + input_axis_labels[1] + '->' + output_axis_labels) temp = inputs[0] temp_axis_labels = input_axis_labels[0] for i in xrange(len(inputs) - 1): axes_to_sum = ( set(temp_axis_labels) & set(input_axis_labels[i + 1]) - set(output_axis_labels)) temp, temp_axis_labels = _einsum_v1_reduction( temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1], axes_to_sum) missing_indices = set(temp_axis_labels) - set(output_axis_labels) if missing_indices: axis = [ i for i, a in enumerate(temp_axis_labels) if a not in output_axis_labels ] temp = math_ops.reduce_sum(temp, axis=axis) temp_axis_labels = ''.join(a for a in temp_axis_labels if a in output_axis_labels) if sorted(temp_axis_labels) != sorted(output_axis_labels): raise ValueError('Invalid equation: %s' % equation) perm = [temp_axis_labels.index(a) for a in output_axis_labels] return _transpose_if_necessary(temp, perm)
def einsum(equation, *inputs, **kwargs): """Tensor contraction over specified indices and outer product. This function returns a tensor whose elements are defined by `equation`, which is written in a shorthand form inspired by the Einstein summation convention. As an example, consider multiplying two matrices A and B to form a matrix C. The elements of C are given by: ``` C[i,k] = sum_j A[i,j] * B[j,k] ``` The corresponding `equation` is: ``` ij,jk->ik ``` In general, the `equation` is obtained from the more familiar element-wise equation by 1. removing variable names, brackets, and commas, 2. replacing "*" with ",", 3. dropping summation signs, and 4. moving the output to the right, and replacing "=" with "->". Many common operations can be expressed in this way. For example: ```python # Matrix multiplication >>> einsum('ij,jk->ik', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] # Dot product >>> einsum('i,i->', u, v) # output = sum_i u[i]*v[i] # Outer product >>> einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j] # Transpose >>> einsum('ij->ji', m) # output[j,i] = m[i,j] # Trace >>> einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] # Batch matrix multiplication >>> einsum('aij,ajk->aik', s, t) # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k] ``` To enable and control broadcasting, use an ellipsis. For example, to do batch matrix multiplication, you could use: ```python >>> einsum('...ij,...jk->...ik', u, v) ``` This function behaves like `numpy.einsum`, but does not support: * Subscripts where an axis appears more than once for a single input (e.g. `ijj,k->ik`) unless it is a trace (e.g. `ijji`). Args: equation: a `str` describing the contraction, in the same format as `numpy.einsum`. *inputs: the inputs to contract (each one a `Tensor`), whose shapes should be consistent with `equation`. name: A name for the operation (optional). Returns: The contracted `Tensor`, with shape determined by `equation`. Raises: ValueError: If - the format of `equation` is incorrect, - the number of inputs implied by `equation` does not match `len(inputs)`, - an axis appears in the output subscripts but not in any of the inputs, - the number of dimensions of an input differs from the number of indices in its subscript, or - the input shapes are inconsistent along a particular axis. """ name = kwargs.pop('name', None) if kwargs: raise TypeError( 'invalid keyword arguments for this function: ' + ', '.join([format(key) for key in sorted(list(kwargs.keys()))])) with ops.name_scope(name, 'einsum', [equation, inputs]) as name: inputs = list(inputs) input_shapes = [x.get_shape() for x in inputs] input_axis_labels, output_axis_labels = _einsum_parse_and_resolve_equation( equation, input_shapes) axis_labels = set(''.join(input_axis_labels) + output_axis_labels) for a in axis_labels: for input_labels in input_axis_labels: if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and input_labels == input_labels[::-1] and '->' not in equation): return math_ops.trace(inputs[0]) if input_labels.count(a) > 1: raise ValueError( 'Subscript not supported: an axis appears more than once: %s' % input_labels) for a in axis_labels: input_count = sum(1 for s in input_axis_labels if a in s) if input_count > 2 and a not in output_axis_labels: logging.warn( 'Falling back to exponential-space implementation of einsum()' ' because index "%s" is summed over more than two inputs.', a) return _exponential_space_einsum(equation, *inputs) # Use xla_einsum if executing on TPU and if the operation is a 2 input # einsum supported by XlaEinsumOp. if _enclosing_tpu_context() is not None and len(inputs) == 2: return gen_xla_ops.xla_einsum( inputs[0], inputs[1], input_axis_labels[0] + ',' + input_axis_labels[1] + '->' + output_axis_labels) temp = inputs[0] temp_axis_labels = input_axis_labels[0] for i in xrange(len(inputs) - 1): axes_to_sum = ( set(temp_axis_labels) & set(input_axis_labels[i + 1]) - set(output_axis_labels)) temp, temp_axis_labels = _einsum_reduction( temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1], axes_to_sum) missing_indices = set(temp_axis_labels) - set(output_axis_labels) if missing_indices: axis = [ i for i, a in enumerate(temp_axis_labels) if a not in output_axis_labels ] temp = math_ops.reduce_sum(temp, axis=axis) temp_axis_labels = ''.join(a for a in temp_axis_labels if a in output_axis_labels) if sorted(temp_axis_labels) != sorted(output_axis_labels): raise ValueError('Invalid equation: %s' % equation) perm = [temp_axis_labels.index(a) for a in output_axis_labels] return _transpose_if_necessary(temp, perm)
def einsum(equation, *inputs, **kwargs): """A generalized contraction between tensors of arbitrary dimension. This function returns a tensor whose elements are defined by `equation`, which is written in a shorthand form inspired by the Einstein summation convention. As an example, consider multiplying two matrices A and B to form a matrix C. The elements of C are given by: ``` C[i,k] = sum_j A[i,j] * B[j,k] ``` The corresponding `equation` is: ``` ij,jk->ik ``` In general, the `equation` is obtained from the more familiar element-wise equation by 1. removing variable names, brackets, and commas, 2. replacing "*" with ",", 3. dropping summation signs, and 4. moving the output to the right, and replacing "=" with "->". Many common operations can be expressed in this way. For example: ```python # Matrix multiplication >>> einsum('ij,jk->ik', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] # Dot product >>> einsum('i,i->', u, v) # output = sum_i u[i]*v[i] # Outer product >>> einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j] # Transpose >>> einsum('ij->ji', m) # output[j,i] = m[i,j] # Trace >>> einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] # Batch matrix multiplication >>> einsum('aij,ajk->aik', s, t) # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k] ``` To enable and control broadcasting, use an ellipsis. For example, to do batch matrix multiplication, you could use: ```python >>> einsum('...ij,...jk->...ik', u, v) ``` This function behaves like `numpy.einsum`, but does not support: * Subscripts where an axis appears more than once for a single input (e.g. `ijj,k->ik`) unless it is a trace (e.g. `ijji`). Args: equation: a `str` describing the contraction, in the same format as `numpy.einsum`. *inputs: the inputs to contract (each one a `Tensor`), whose shapes should be consistent with `equation`. name: A name for the operation (optional). Returns: The contracted `Tensor`, with shape determined by `equation`. Raises: ValueError: If - the format of `equation` is incorrect, - the number of inputs implied by `equation` does not match `len(inputs)`, - an axis appears in the output subscripts but not in any of the inputs, - the number of dimensions of an input differs from the number of indices in its subscript, or - the input shapes are inconsistent along a particular axis. """ name = kwargs.pop('name', None) if kwargs: raise TypeError('invalid keyword arguments for this function: ' + ', '.join( [format(key) for key in sorted(list(kwargs.keys()))])) with ops.name_scope(name, 'einsum', [equation, inputs]) as name: inputs = list(inputs) input_shapes = [x.get_shape() for x in inputs] input_axis_labels, output_axis_labels = _einsum_parse_and_resolve_equation( equation, input_shapes) axis_labels = set(''.join(input_axis_labels) + output_axis_labels) for a in axis_labels: for input_labels in input_axis_labels: if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and input_labels == input_labels[::-1] and '->' not in equation): return math_ops.trace(inputs[0]) if input_labels.count(a) > 1: raise ValueError( 'Subscript not supported: an axis appears more than once: %s' % input_labels) for a in axis_labels: input_count = sum(1 for s in input_axis_labels if a in s) if input_count > 2 and a not in output_axis_labels: logging.warn( 'Falling back to exponential-space implementation of einsum()' ' because index "%s" is summed over more than two inputs.', a) return _exponential_space_einsum(equation, *inputs) # Use xla_einsum if executing on TPU and if the operation is a 2 input # einsum supported by XlaEinsumOp. if _enclosing_tpu_context() is not None and len(inputs) == 2: return gen_xla_ops.xla_einsum( inputs[0], inputs[1], input_axis_labels[0] + ',' + input_axis_labels[1] + '->' + output_axis_labels) temp = inputs[0] temp_axis_labels = input_axis_labels[0] for i in xrange(len(inputs) - 1): axes_to_sum = ( set(temp_axis_labels) & set(input_axis_labels[i + 1]) - set(output_axis_labels)) temp, temp_axis_labels = _einsum_reduction( temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1], axes_to_sum) missing_indices = set(temp_axis_labels) - set(output_axis_labels) if missing_indices: axis = [ i for i, a in enumerate(temp_axis_labels) if a not in output_axis_labels ] temp = math_ops.reduce_sum(temp, axis=axis) temp_axis_labels = ''.join( a for a in temp_axis_labels if a in output_axis_labels) if sorted(temp_axis_labels) != sorted(output_axis_labels): raise ValueError('Invalid equation: %s' % equation) perm = [temp_axis_labels.index(a) for a in output_axis_labels] return _transpose_if_necessary(temp, perm)
def _einsum_v2(equation, *inputs, **kwargs): """Implementation of einsum utilizing opt_einsum and EinsumOp.""" name = kwargs.pop('name', None) optimize = kwargs.pop('optimize', 'greedy') if kwargs: msg = 'Invalid keyword arguments for einsum: {}' raise TypeError(msg.format(', '.join(kwargs))) with ops.name_scope(name, 'einsum', [equation, inputs]) as name: inputs = list(inputs) input_shapes = [] for operand in inputs: if isinstance(operand.shape, tensor_shape.TensorShape): input_shapes.append( operand.shape.as_list() if operand.shape else None) else: input_shapes.append(list(operand.shape)) # Validate and sanitize the equation and resolve static input shapes, as # opt_einsum requires that all shapes be a tuple of positive integers. # Also remove ellipsis from the equation as opt_einsum will replace them # with named labels. Then broadcasting between different shapes or ranks # wouldn't work. (E.g. [1, 1, 2] wouldn't broadcast with [3, 1]). resolved_equation, resolved_input_shapes, ellipsis_label = ( _einsum_v2_parse_and_resolve_equation(equation, input_shapes)) # Use xla_einsum if executing on TPU and if the operation is a 2 input # einsum supported by XlaEinsumOp. has_enclosing_tpu_context = _enclosing_tpu_context() is not None if len(inputs) <= 2: # No need to call opt_einsum. # Replace back ellipses that were removed for opt_einsum. if ellipsis_label: resolved_equation = resolved_equation.replace( ellipsis_label, '...') if has_enclosing_tpu_context and len(inputs) == 2: return gen_xla_ops.xla_einsum(inputs[0], inputs[1], resolved_equation) return gen_linalg_ops.einsum(inputs, resolved_equation) # Send fully specified shapes to opt_einsum, since it cannot handle unknown # dimensions. For unknown dimensions, we guess that the dimension equals 1. # Instead of creating Tensors or NumPy arrays with the specified shape, # create a dummy `shaped` object with a `shape` property. shaped = collections.namedtuple('shaped', ['shape']) shaped_inputs = tuple( [shaped(tuple(shape)) for shape in resolved_input_shapes]) # opt_einsum breaks down an n-ary einsum operation into n-1 binary einsums. # Obtain the sequence of equations and the indices of operands involved in # each einsum operation. indices_and_equations = _get_opt_einsum_contract_path( resolved_equation, shaped_inputs, optimize) for operand_indices, binary_equation in indices_and_equations: if ellipsis_label: # Replace back ellipses that were removed for opt_einsum. binary_equation = binary_equation.replace( ellipsis_label, '...') operands = list(map(inputs.pop, operand_indices)) # Use xla_einsum if executing on TPU and if the operation is a 2 input # einsum supported by XlaEinsumOp. if has_enclosing_tpu_context and len(operands) == 2: inputs.append( gen_xla_ops.xla_einsum(operands[0], operands[1], binary_equation)) else: inputs.append(gen_linalg_ops.einsum(operands, binary_equation)) return inputs[0]