def _lu(input, output_idx_type=tf.int32, name=None): # pylint: disable=redefined-builtin """Returns Lu(lu, p), as TF does.""" del name if JAX_MODE: # But JAX uses XLA, which can do a batched factorization. lu_out, pivots = scipy_linalg.lu_factor(input) from jax import lax_linalg # pylint: disable=g-import-not-at-top return Lu(lu_out, lax_linalg.lu_pivots_to_permutation(pivots, lu_out.shape[-1])) # Scipy can't batch, so we must do so manually. nbatch = int(np.prod(input.shape[:-2])) dim = input.shape[-1] flat_mat = input.reshape(nbatch, dim, dim) flat_lu = np.empty((nbatch, dim, dim), dtype=input.dtype) flat_piv = np.empty((nbatch, dim), dtype=utils.numpy_dtype(output_idx_type)) if np.size(flat_lu): # Avoid non-empty batches of empty matrices. for i, mat in enumerate(flat_mat): lu_out, pivots = scipy_linalg.lu_factor(mat) flat_lu[i] = lu_out flat_piv[i] = _lu_pivot_to_permutation(pivots, flat_lu.shape[-1]) return Lu(flat_lu.reshape(*input.shape), flat_piv.reshape(*input.shape[:-1]))
def custom_assert(result_jax, result_tf): lu, pivots, perm = tuple(map(lambda t: t.numpy(), result_tf)) batch_dims = operand.shape[:-2] m, n = operand.shape[-2], operand.shape[-1] def _make_permutation_matrix(perm): result = [] for idx in itertools.product(*map(range, operand.shape[:-1])): result += [0 if c != perm[idx] else 1 for c in range(m)] result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m]) return result k = min(m, n) l = jnp.tril(lu, -1)[...,:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[...,:k, :] p_mat = _make_permutation_matrix(perm) self.assertArraysEqual(lax_linalg.lu_pivots_to_permutation(pivots, m), perm) self.assertAllClose(jnp.matmul(p_mat, operand), jnp.matmul(l, u), atol=tol, rtol=tol)
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): del overwrite_b, check_finite lu, pivots = lu_and_piv m, n = lu.shape[-2:] perm = lax_linalg.lu_pivots_to_permutation(pivots, m) return lax_linalg.lu_solve(lu, perm, b, trans)