コード例 #1
0
ファイル: lax_linalg.py プロジェクト: linan7788626/jax
def lu_impl(operand):
    lu, pivot = xla.apply_primitive(lu_p, operand)
    return core.pack((lu, pivot))
コード例 #2
0
ファイル: lax_linalg.py プロジェクト: linan7788626/jax
def qr_impl(operand, full_matrices):
    q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
    return core.pack((q, r))
コード例 #3
0
ファイル: lax_linalg.py プロジェクト: sgpohlj87/jax
def svd_impl(operand, full_matrices, compute_uv):
  s, u, vt = xla.apply_primitive(svd_p, operand, full_matrices=full_matrices,
                                 compute_uv=compute_uv)
  return s, u, vt
コード例 #4
0
ファイル: lax_linalg.py プロジェクト: linan7788626/jax
def eigh_impl(operand, lower):
    v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
    return core.pack((v, w))
コード例 #5
0
ファイル: lax_linalg.py プロジェクト: sgpohlj87/jax
def _lu_impl(operand):
  lu, pivot = xla.apply_primitive(lu_p, operand)
  return lu, pivot
コード例 #6
0
ファイル: lax_linalg.py プロジェクト: sgpohlj87/jax
def eigh_impl(operand, lower):
  v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
  return v, w
コード例 #7
0
ファイル: lax_linalg.py プロジェクト: sgpohlj87/jax
def eig_impl(operand):
  return xla.apply_primitive(eig_p, operand)
コード例 #8
0
def fft_impl(x, fft_type, fft_lengths):
    return xla.apply_primitive(fft_p,
                               x,
                               fft_type=fft_type,
                               fft_lengths=fft_lengths)
コード例 #9
0
def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors):
  return (
    xla.apply_primitive(eig_p, operand,
                        compute_left_eigenvectors=compute_left_eigenvectors,
                        compute_right_eigenvectors=compute_right_eigenvectors))