예제 #1
0
def _lu(input, output_idx_type=tf.int32, name=None):  # pylint: disable=redefined-builtin
  """Returns Lu(lu, p), as TF does."""
  del name
  if JAX_MODE:  # But JAX uses XLA, which can do a batched factorization.
    lu_out, pivots = scipy_linalg.lu_factor(input)
    from jax import lax_linalg  # pylint: disable=g-import-not-at-top
    return Lu(lu_out,
              lax_linalg.lu_pivots_to_permutation(pivots, lu_out.shape[-1]))
  # Scipy can't batch, so we must do so manually.
  nbatch = int(np.prod(input.shape[:-2]))
  dim = input.shape[-1]
  flat_mat = input.reshape(nbatch, dim, dim)
  flat_lu = np.empty((nbatch, dim, dim), dtype=input.dtype)
  flat_piv = np.empty((nbatch, dim), dtype=utils.numpy_dtype(output_idx_type))
  if np.size(flat_lu):  # Avoid non-empty batches of empty matrices.
    for i, mat in enumerate(flat_mat):
      lu_out, pivots = scipy_linalg.lu_factor(mat)
      flat_lu[i] = lu_out
      flat_piv[i] = _lu_pivot_to_permutation(pivots, flat_lu.shape[-1])
  return Lu(flat_lu.reshape(*input.shape), flat_piv.reshape(*input.shape[:-1]))
예제 #2
0
    def custom_assert(result_jax, result_tf):
      lu, pivots, perm = tuple(map(lambda t: t.numpy(), result_tf))
      batch_dims = operand.shape[:-2]
      m, n = operand.shape[-2], operand.shape[-1]
      def _make_permutation_matrix(perm):
        result = []
        for idx in itertools.product(*map(range, operand.shape[:-1])):
          result += [0 if c != perm[idx] else 1 for c in range(m)]
        result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m])
        return result

      k = min(m, n)
      l = jnp.tril(lu, -1)[...,:, :k] + jnp.eye(m, k, dtype=dtype)
      u = jnp.triu(lu)[...,:k, :]
      p_mat = _make_permutation_matrix(perm)

      self.assertArraysEqual(lax_linalg.lu_pivots_to_permutation(pivots, m),
                             perm)
      self.assertAllClose(jnp.matmul(p_mat, operand), jnp.matmul(l, u),
                          atol=tol, rtol=tol)
예제 #3
0
파일: linalg.py 프로젝트: raj0088/jax
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
    del overwrite_b, check_finite
    lu, pivots = lu_and_piv
    m, n = lu.shape[-2:]
    perm = lax_linalg.lu_pivots_to_permutation(pivots, m)
    return lax_linalg.lu_solve(lu, perm, b, trans)