def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand, full_matrices): shape = c.GetShape(operand) dims = shape.dimensions() m, n = dims[-2:] batch_dims = dims[:-2] cs = xb.computation_builder_shim(c) r, tau, info_geqrf = geqrf_impl(cs, operand) if m < n: q = xops.Slice(r, [0] * len(dims), list(batch_dims) + [m, m], [1] * len(dims)) q, info_orgqr = orgqr_impl(cs, q, tau) elif not full_matrices: q, info_orgqr = orgqr_impl(cs, r, tau) r = xops.Slice(r, [0] * len(dims), list(batch_dims) + [n, n], [1] * len(dims)) else: padding_config = [(0, 0, 0)] * len(dims) padding_config[-1] = (0, m - n, 0) q = xops.Pad( r, xops.Constant(c, onp.array(0, dtype=shape.element_type())), xla_client.make_padding_config(padding_config)) q, info_orgqr = orgqr_impl(cs, q, tau) ok = xops.And( xops.Eq(info_geqrf, xops.ConstantLiteral(c, onp.array(0, onp.int32))), xops.Eq(info_orgqr, xops.ConstantLiteral(c, onp.array(0, onp.int32)))) q = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), q, _nan_like(c, q)) r = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), r, _nan_like(c, r)) return xops.Tuple(c, [q, r])
def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand, full_matrices): shape = c.get_shape(operand) dims = shape.dimensions() m, n = dims[-2:] batch_dims = dims[:-2] r, tau, info_geqrf = geqrf_impl(c, operand) if m < n: q = xops.Slice(r, [0] * len(dims), list(batch_dims) + [m, m], [1] * len(dims)) q, info_orgqr = orgqr_impl(c, q, tau) elif not full_matrices: q, info_orgqr = orgqr_impl(c, r, tau) r = xops.Slice(r, [0] * len(dims), list(batch_dims) + [n, n], [1] * len(dims)) else: padding_config = [(0, 0, 0)] * len(dims) padding_config[-1] = (0, m - n, 0) q = xops.Pad(r, xops.Constant(c, np.array(0, dtype=shape.element_type())), xla_client.make_padding_config(padding_config)) q, info_orgqr = orgqr_impl(c, q, tau) if info_geqrf is not None: ok = xops.And( xops.Eq(info_geqrf, xops.ConstantLiteral(c, np.array(0, np.int32))), xops.Eq(info_orgqr, xops.ConstantLiteral(c, np.array(0, np.int32)))) q = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), q, _nan_like(c, q)) r = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), r, _nan_like(c, r)) else: pass # rocsolver does not return info r = xla.lower_fun(jnp.triu, multiple_results=False)(c, r) return xops.Tuple(c, [q, r])