def partial_trace_wf_keep_first_cupy(iwf: cupy.ndarray, n): # TODO: improve the cuda version of partial_trace_wf. # For example, cleverly adjust the blockDim to deal with the other case assert iwf.flags.c_contiguous assert iwf.dtype == default_dtype iwf_conj = iwf.conj() nqb = int(math_log2(iwf.shape[0])) m = nqb - n m_idx = 2 ** m n_idx = 2 ** n rho = zeros(shape=(n_idx, n_idx), dtype=default_dtype, order="C") # Here we simply use the threadDim for i, j in the cuda code. threads_per_bloch = 32 threadDim = (threads_per_bloch, threads_per_bloch) x = (n_idx + (threads_per_bloch - 1)) // threads_per_bloch blockDim = (x, x) partial_trace_wf_keep_first_cuda( grid=blockDim, block=threadDim, args=( iwf, iwf_conj, rho, m, m_idx, n_idx, ), ) return rho
def partial_trace_wf_cupy(iwf: cupy.ndarray, retain_qubits): nqb = int(math_log2(iwf.shape[0])) if len(retain_qubits) == nqb: return outer(iwf, iwf.conj()) iwf = iwf.reshape([2] * nqb, order="C") retain_qubits = sorted(retain_qubits) for idx in range(len(retain_qubits)): r = retain_qubits[idx] if idx != r: iwf = iwf.swapaxes(idx, r) iwf = iwf.reshape((2 ** nqb,)) return partial_trace_wf_keep_first_cupy(iwf, len(retain_qubits))