def partial_trace_cupy(rho: cupy.ndarray, retain_qubits) -> cupy.ndarray: """ Compute the partial trace of rho. Args: rho: input rho retain_qubits: the qubits which we want to keep after partial trace. """ if len(retain_qubits) == 0: return trace(rho) total_qb = int(math.log2(rho.shape[0])) assert min(retain_qubits) >= 0 and max(retain_qubits) < total_qb if total_qb == 1 or len(retain_qubits) == total_qb: return rho all_qbs = list(range(total_qb)) qbs_to_remove = list(filter(lambda x: x not in retain_qubits, all_qbs)) rho = rho.reshape([2] * (2 * total_qb)) for qid in reversed(qbs_to_remove): rho = trace(rho, axis1=qid, axis2=qid + total_qb) total_qb -= 1 # retain back to normal density matrix newshape = 2 ** total_qb return rho.reshape(newshape, newshape)
def partial_trace_wf_cupy(iwf: cupy.ndarray, retain_qubits): nqb = int(math_log2(iwf.shape[0])) if len(retain_qubits) == nqb: return outer(iwf, iwf.conj()) iwf = iwf.reshape([2] * nqb, order="C") retain_qubits = sorted(retain_qubits) for idx in range(len(retain_qubits)): r = retain_qubits[idx] if idx != r: iwf = iwf.swapaxes(idx, r) iwf = iwf.reshape((2 ** nqb,)) return partial_trace_wf_keep_first_cupy(iwf, len(retain_qubits))
def partial_trace_1d_cupy(rho: cupy.ndarray, retain_qubit: int): """ Compute the partial trace of rho. Returns a reduced density matrix in the Hilbert space of "retain_qubit"th qubit. """ total_qb = int(math.log2(rho.shape[0])) if retain_qubit >= total_qb or retain_qubit < 0: raise ValueError(retain_qubit) if total_qb == 1: return rho all_qbs = list(range(total_qb)) qbs_to_remove = list(filter(lambda x: x != retain_qubit, all_qbs)) assert qbs_to_remove == list(sorted(qbs_to_remove)) rho = rho.reshape([2] * (2 * total_qb)) # ret = np.empty(shape=(2,2), dtype=complex) ret = None for qid in reversed(qbs_to_remove): # remove the qubit with higher qubit count first, this is crucial # otherwise we will have indexing problems. if ret is None: ret = trace(rho, axis1=qid, axis2=qid + total_qb) total_qb -= 1 # removed one already else: ret = trace(ret, axis1=qid, axis2=qid + total_qb) total_qb -= 1 # removed one already assert ret.shape == (2, 2) return ret