def eig_cpu_translation_rule(c, operand, *, compute_left_eigenvectors, compute_right_eigenvectors): shape = c.get_shape(operand) batch_dims = shape.dimensions()[:-2] w, vl, vr, info = _cpu_geev(c, operand, jobvl=compute_left_eigenvectors, jobvr=compute_right_eigenvectors) ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32))) w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, )), w, _nan_like(c, w)) output = [w] if compute_left_eigenvectors: vl = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vl, _nan_like(c, vl)) output.append(vl) if compute_right_eigenvectors: vr = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vr, _nan_like(c, vr)) output.append(vr) return xops.Tuple(c, output)
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand, full_matrices): shape = c.get_shape(operand) dims = shape.dimensions() m, n = dims[-2:] batch_dims = dims[:-2] r, tau, info_geqrf = geqrf_impl(c, operand) if m < n: q = xops.Slice(r, [0] * len(dims), list(batch_dims) + [m, m], [1] * len(dims)) q, info_orgqr = orgqr_impl(c, q, tau) elif not full_matrices: q, info_orgqr = orgqr_impl(c, r, tau) r = xops.Slice(r, [0] * len(dims), list(batch_dims) + [n, n], [1] * len(dims)) else: padding_config = [(0, 0, 0)] * len(dims) padding_config[-1] = (0, m - n, 0) q = xops.Pad(r, xops.Constant(c, np.array(0, dtype=shape.element_type())), xla_client.make_padding_config(padding_config)) q, info_orgqr = orgqr_impl(c, q, tau) ok = xops.And( xops.Eq(info_geqrf, xops.ConstantLiteral(c, np.array(0, np.int32))), xops.Eq(info_orgqr, xops.ConstantLiteral(c, np.array(0, np.int32)))) q = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), q, _nan_like(c, q)) r = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), r, _nan_like(c, r)) r = xla.lower_fun(jnp.triu, multiple_results=False)(c, r) return xops.Tuple(c, [q, r])
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv): shape = c.get_shape(operand).dimensions() m, n = shape[-2:] batch_dims = shape[:-2] if m == 0 or n == 0: return xla.lower_fun(_empty_svd, multiple_results=True)( c, operand, full_matrices=full_matrices, compute_uv=compute_uv) s, u, vt, info = gesvd_impl(c, operand, full_matrices=full_matrices, compute_uv=compute_uv) ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32))) s = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), s, _nan_like(c, s)) result = [s] if compute_uv: u = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), u, _nan_like(c, u)) vt = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vt, _nan_like(c, vt)) result += [u, vt] return xops.Tuple(c, result)
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower): shape = c.get_shape(operand) batch_dims = shape.dimensions()[:-2] v, w, info = syevd_impl(c, operand, lower=lower) ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32))) v = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), v, _nan_like(c, v)) w = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, )), w, _nan_like(c, w)) return xops.Tuple(c, [v, w])
def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand): shape = c.get_shape(operand) batch_dims = shape.dimensions()[:-2] result, info = potrf_impl(c, operand, lower=True) ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32))) return _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), result, _nan_like(c, result))
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv): shape = c.get_shape(operand) batch_dims = shape.dimensions()[:-2] s, u, vt, info = gesvd_impl(c, operand, full_matrices=full_matrices, compute_uv=compute_uv) ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32))) s = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), s, _nan_like(c, s)) result = [s] if compute_uv: u = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), u, _nan_like(c, u)) vt = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vt, _nan_like(c, vt)) result += [u, vt] return xops.Tuple(c, result)
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand): shape = c.get_shape(operand) batch_dims = shape.dimensions()[:-2] m = shape.dimensions()[-2] lu, pivot, info = getrf_impl(c, operand) # Subtract 1 from the pivot to get 0-based indices. pivot = xops.Sub(pivot, xops.ConstantLiteral(c, np.array(1, np.int32))) ok = xops.Ge(info, xops.ConstantLiteral(c, np.array(0, np.int32))) lu = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), lu, _nan_like(c, lu)) perm = xla.lower_fun(lambda x: lu_pivots_to_permutation(x, m), multiple_results=False)(c, pivot) return xops.Tuple(c, [lu, pivot, perm])