Exemple #1
0
def dot( A, B, axis=-1 ):
  '''Transform axis of A by contraction with first axis of B and inserting
     remaining axes. Note: with default axis=-1 this leads to multiplication of
     vectors and matrices following linear algebra conventions.'''

  A = asarray( A )
  B = asarray( B )

  if axis < 0:
    axis += A.ndim
  assert 0 <= axis < A.ndim

  for j in range( B.ndim-2 ):
    B = B.swapaxes( j, j+1 ) # move first axis to 2nd-to-last

  if axis == A.ndim-1:
    return numpy.dot( A, B )

  for i in range( A.ndim-1, axis, -1 ):
    A = A.swapaxes( i-1, i ) # move 'axis' to last

  AB = numpy.dot( A, B )
  return AB.transpose( range(axis) + range(A.ndim-1,AB.ndim) + range(axis,A.ndim-1) )

  raise NotImplementedError

  A = asarray( A, dtype=float )
  B = asarray( B, dtype=float )

  if axis < 0:
    axis += A.ndim
  assert 0 <= axis < A.ndim

  if A.shape[axis] == 1 or B.shape[0] == 1:
    return A.sum(axis)[(slice(None),)*axis+(newaxis,)*(B.ndim-1)] \
         * B.sum(0)[(Ellipsis,)+(newaxis,)*(A.ndim-1-axis)]

  assert A.shape[axis] == B.shape[0]

  if B.ndim != 1 or axis != A.ndim-1:
    shape = A.shape[:axis] + B.shape[1:] + A.shape[axis+1:] + A.shape[axis:axis+1]
    Astrides = A.strides[:axis] + (0,) * (B.ndim-1) + A.strides[axis+1:] + A.strides[axis:axis+1]
    A = as_strided( A, shape, Astrides )

  if A.ndim > 1:
    Bstrides = (0,) * axis + B.strides[1:] + (0,) * (A.ndim-B.ndim-axis) + B.strides[:1]
    B = as_strided( B, A.shape, Bstrides )

  if not A.size:
    return zeros( A.shape[:-1] )

  return _contract( A, B, 1 ).view( NumericArray )
Exemple #2
0
def contract( A, B, axis=-1 ):
  'contract'

  A = asarray( A, dtype=float )
  B = asarray( B, dtype=float )

  n = B.ndim - A.ndim
  if n > 0:
    Ashape = list(B.shape[:n]) + list(A.shape)
    Astrides = [0]*n + list(A.strides)
    Bshape = list(B.shape)
    Bstrides = list(B.strides)
  elif n < 0:
    n = -n
    Ashape = list(A.shape)
    Astrides = list(A.strides)
    Bshape = list(A.shape[:n]) + list(B.shape)
    Bstrides = [0]*n + list(B.strides)
  else:
    Ashape = list(A.shape)
    Astrides = list(A.strides)
    Bshape = list(B.shape)
    Bstrides = list(B.strides)

  shape = Ashape
  nd = len(Ashape)
  for i in range( n, nd ):
    if Ashape[i] == 1:
      shape[i] = Bshape[i]
      Astrides[i] = 0
    elif Bshape[i] == 1:
      Bstrides[i] = 0
    else:
      assert Ashape[i] == Bshape[i]

  if isinstance( axis, int ):
    axis = axis,
  axis = sorted( [ ax+nd if ax < 0 else ax for ax in axis ], reverse=True )
  for ax in axis:
    assert 0 <= ax < nd, 'invalid contraction axis'
    shape.append( shape.pop(ax) )
    Astrides.append( Astrides.pop(ax) )
    Bstrides.append( Bstrides.pop(ax) )

  A = as_strided( A, shape, Astrides )
  B = as_strided( B, shape, Bstrides )

  if not A.size:
    return zeros( A.shape[:-len(axis)] )

  return _contract( A, B, len(axis) )
Exemple #3
0
def contract_fast( A, B, naxes ):
  'contract last n axes'

  A = asarray( A, dtype=float )
  B = asarray( B, dtype=float )

  n = B.ndim - A.ndim
  if n > 0:
    Ashape = list(B.shape[:n]) + list(A.shape)
    Astrides = [0]*n + list(A.strides)
    Bshape = list(B.shape)
    Bstrides = list(B.strides)
  elif n < 0:
    n = -n
    Ashape = list(A.shape)
    Astrides = list(A.strides)
    Bshape = list(A.shape[:n]) + list(B.shape)
    Bstrides = [0]*n + list(B.strides)
  else:
    Ashape = list(A.shape)
    Astrides = list(A.strides)
    Bshape = list(B.shape)
    Bstrides = list(B.strides)

  shape = list(Ashape)
  for i in range( len(Ashape) ):
    if Ashape[i] == 1:
      shape[i] = Bshape[i]
      Astrides[i] = 0
    elif Bshape[i] == 1:
      Bstrides[i] = 0
    else:
      assert Ashape[i] == Bshape[i]

  A = as_strided( A, shape, Astrides )
  B = as_strided( B, shape, Bstrides )

  if not A.size:
    return zeros( shape[:-naxes] )

  return _contract( A, B, naxes )