def vdot(a, b): """Returns the dot product of 2 vectors (or anything that can be made into a vector). Note: this is not the same as `dot`, as it takes the conjugate of its first argument if complex and always returns a scalar.""" return dot(asarray(a).ravel().conj(), asarray(b).ravel())
def tensordot(a, b, axes=2): """tensordot returns the product for any (ndim >= 1) arrays. r_{xxx, yyy} = \sum_k a_{xxx,k} b_{k,yyy} where the axes to be summed over are given by the axes argument. the first element of the sequence determines the axis or axes in arr1 to sum over, and the second element in axes argument sequence determines the axis or axes in arr2 to sum over. When there is more than one axis to sum over, the corresponding arguments to axes should be sequences of the same length with the first axis to sum over given first in both sequences, the second axis second, and so forth. If the axes argument is an integer, N, then the last N dimensions of a and first N dimensions of b are summed over. """ try: iter(axes) except: axes_a = range(-axes,0) axes_b = range(0,axes) else: axes_a, axes_b = axes try: na = len(axes_a) axes_a = list(axes_a) except TypeError: axes_a = [axes_a] na = 1 try: nb = len(axes_b) axes_b = list(axes_b) except TypeError: axes_b = [axes_b] nb = 1 a, b = asarray(a), asarray(b) as_ = a.shape nda = len(a.shape) bs = b.shape ndb = len(b.shape) equal = 1 if (na != nb): equal = 0 else: for k in xrange(na): if as_[axes_a[k]] != bs[axes_b[k]]: equal = 0 break if axes_a[k] < 0: axes_a[k] += nda if axes_b[k] < 0: axes_b[k] += ndb if not equal: raise ValueError, "shape-mismatch for sum" # Move the axes to sum over to the end of "a" # and to the front of "b" notin = [k for k in range(nda) if k not in axes_a] newaxes_a = notin + axes_a N2 = 1 for axis in axes_a: N2 *= as_[axis] newshape_a = (-1, N2) olda = [as_[axis] for axis in notin] notin = [k for k in range(ndb) if k not in axes_b] newaxes_b = axes_b + notin N2 = 1 for axis in axes_b: N2 *= bs[axis] newshape_b = (N2, -1) oldb = [bs[axis] for axis in notin] at = a.transpose(newaxes_a).reshape(newshape_a) bt = b.transpose(newaxes_b).reshape(newshape_b) res = dot(at, bt) return res.reshape(olda + oldb)
def tensordot(a, b, axes=2): """tensordot returns the product for any (ndim >= 1) arrays. r_{xxx, yyy} = \sum_k a_{xxx,k} b_{k,yyy} where the axes to be summed over are given by the axes argument. the first element of the sequence determines the axis or axes in arr1 to sum over, and the second element in axes argument sequence determines the axis or axes in arr2 to sum over. When there is more than one axis to sum over, the corresponding arguments to axes should be sequences of the same length with the first axis to sum over given first in both sequences, the second axis second, and so forth. If the axes argument is an integer, N, then the last N dimensions of a and first N dimensions of b are summed over. """ try: iter(axes) except: axes_a = range(-axes, 0) axes_b = range(0, axes) else: axes_a, axes_b = axes try: na = len(axes_a) axes_a = list(axes_a) except TypeError: axes_a = [axes_a] na = 1 try: nb = len(axes_b) axes_b = list(axes_b) except TypeError: axes_b = [axes_b] nb = 1 a, b = asarray(a), asarray(b) as_ = a.shape nda = len(a.shape) bs = b.shape ndb = len(b.shape) equal = 1 if (na != nb): equal = 0 else: for k in xrange(na): if as_[axes_a[k]] != bs[axes_b[k]]: equal = 0 break if axes_a[k] < 0: axes_a[k] += nda if axes_b[k] < 0: axes_b[k] += ndb if not equal: raise ValueError, "shape-mismatch for sum" # Move the axes to sum over to the end of "a" # and to the front of "b" notin = [k for k in range(nda) if k not in axes_a] newaxes_a = notin + axes_a N2 = 1 for axis in axes_a: N2 *= as_[axis] newshape_a = (-1, N2) olda = [as_[axis] for axis in notin] notin = [k for k in range(ndb) if k not in axes_b] newaxes_b = axes_b + notin N2 = 1 for axis in axes_b: N2 *= bs[axis] newshape_b = (N2, -1) oldb = [bs[axis] for axis in notin] at = a.transpose(newaxes_a).reshape(newshape_a) bt = b.transpose(newaxes_b).reshape(newshape_b) res = dot(at, bt) return res.reshape(olda + oldb)