示例#1
0
文件: linalg.py 项目: dev-fennek/jax
def tensorsolve(a, b, axes=None):
  a = np.asarray(a)
  b = np.asarray(b)
  an = a.ndim
  if axes is not None:
    allaxes = list(range(0, an))
    for k in axes:
      allaxes.remove(k)
      allaxes.insert(an, k)

    a = a.transpose(allaxes)
  
  Q = a.shape[-(an - b.ndim):]

  prod = 1
  for k in Q:
    prod *= k

  a = a.reshape(-1, prod)
  b = b.ravel()
  
  res = np.asarray(la.solve(a, b))
  res = res.reshape(Q)
  
  return res
示例#2
0
def multi_dot(arrays, *, precision=None):
    n = len(arrays)
    # optimization only makes sense for len(arrays) > 2
    if n < 2:
        raise ValueError("Expecting at least two arrays.")
    elif n == 2:
        return jnp.dot(arrays[0], arrays[1], precision=precision)

    arrays = [jnp.asarray(a) for a in arrays]

    # save original ndim to reshape the result array into the proper form later
    ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
    # Explicitly convert vectors to 2D arrays to keep the logic of the internal
    # _multi_dot_* functions as simple as possible.
    if arrays[0].ndim == 1:
        arrays[0] = jnp.atleast_2d(arrays[0])
    if arrays[-1].ndim == 1:
        arrays[-1] = jnp.atleast_2d(arrays[-1]).T
    _assert2d(*arrays)

    # _multi_dot_three is much faster than _multi_dot_matrix_chain_order
    if n == 3:
        result = _multi_dot_three(*arrays, precision)
    else:
        order = _multi_dot_matrix_chain_order(arrays)
        result = _multi_dot(arrays, order, 0, n - 1, precision)

    # return proper shape
    if ndim_first == 1 and ndim_last == 1:
        return result[0, 0]  # scalar
    elif ndim_first == 1 or ndim_last == 1:
        return result.ravel()  # 1-D
    else:
        return result
示例#3
0
文件: linalg.py 项目: dev-fennek/jax
def tensorinv(a, ind=2):
  a = np.asarray(a)
  oldshape = a.shape
  prod = 1
  if ind > 0:
    invshape = oldshape[ind:] + oldshape[:ind]
    for k in oldshape[ind:]:
      prod *= k
  else:
    raise ValueError("Invalid ind argument.")
  a = a.reshape(prod, -1)
  ia = la.inv(a)
  return ia.reshape(*invshape)