def test_density_matrix(gpu): qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True) np.random.seed(SEED) nn_state = DensityMatrix(2, 1, 1, gpu=gpu) old_params = torch.cat(( parameters_to_vector(nn_state.rbm_am.parameters()), parameters_to_vector(nn_state.rbm_ph.parameters()), )) data = torch.ones(100, 2) # generate sample bases randomly, with probability 0.9 of being 'Z', otherwise 'X' bases = np.where(np.random.binomial(1, 0.9, size=(100, 2)), "Z", "X") nn_state.fit(data, epochs=1, pos_batch_size=10, input_bases=bases) new_params = torch.cat(( parameters_to_vector(nn_state.rbm_am.parameters()), parameters_to_vector(nn_state.rbm_ph.parameters()), )) msg = "DensityMatrix's parameters did not change!" assert not torch.equal(old_params, new_params), msg
def test_density_matrix_sizes(): nn_state = DensityMatrix(5, gpu=False) v = nn_state.generate_hilbert_space(5) vp = v[:4, :] rho = nn_state.rho(v, vp) assert rho.shape == (2, v.shape[0], vp.shape[0])
def test_density_matrix_tr1(): nn_state = DensityMatrix(5, gpu=False) space = nn_state.generate_hilbert_space(5) matrix = nn_state.rho(space, space) / nn_state.normalization(space) msg = f"Trace of density matrix is not within {TOL} of 1!" assertAlmostEqual(torch.trace(matrix[0]), torch.Tensor([1]), TOL, msg=msg)
def test_density_matrix_hermiticity(): nn_state = DensityMatrix(5, 5, 5, gpu=False) space = nn_state.generate_hilbert_space(5) Z = nn_state.normalization(space) rho = nn_state.rho(space, space) / Z assert torch.equal(rho, cplx.conjugate(rho)), "DensityMatrix should be Hermitian!"
def test_density_matrix_diagonal(): nn_state = DensityMatrix(5, gpu=False) v = nn_state.generate_hilbert_space(5) rho = nn_state.rho(v, expand=True) diag = nn_state.rho(v, expand=False) msg = "Diagonal of density matrix is wrong!" assertAlmostEqual(torch.einsum("cii...->ci...", rho), diag, TOL, msg=msg)
def density_matrix_data(request, gpu, num_hidden): with open( os.path.join(request.fspath.dirname, "data", "test_grad_data.pkl"), "rb") as f: test_data = pickle.load(f) qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True) data_bases = test_data["density_matrix"]["train_bases"] data_samples = torch.tensor(test_data["density_matrix"]["train_samples"], dtype=torch.double) all_bases = test_data["density_matrix"]["bases"] target = torch.tensor(test_data["density_matrix"]["density_matrix"], dtype=torch.double) num_visible = data_samples.shape[-1] num_aux = num_visible + 1 nn_state = DensityMatrix(num_visible, num_hidden, num_aux, gpu=gpu) unitary_dict = nn_state.unitary_dict DGU = DensityGradsUtils(nn_state) all_bases = DGU.transform_bases(all_bases) space = nn_state.generate_hilbert_space() data_samples = data_samples.to(device=nn_state.device) target = target.to(device=nn_state.device) DensityMatrixFixture = namedtuple( "DensityMatrixFixture", [ "data_samples", "data_bases", "grad_utils", "all_bases", "target", "space", "nn_state", "unitary_dict", ], ) return DensityMatrixFixture( data_samples=data_samples, data_bases=data_bases, grad_utils=DGU, all_bases=all_bases, target=target, space=space, nn_state=nn_state, unitary_dict=unitary_dict, )
def test_density_matrix_expansion(prop): qucumber.set_random_seed(INIT_SEED, cpu=True, gpu=False, quiet=True) nn_state = DensityMatrix(5, gpu=False) v = nn_state.generate_hilbert_space(5) vp = v[torch.randperm(v.shape[0]), :] prop_name = prop[0] is_complex = prop[1] args = prop[2:] fn = attrgetter(prop_name)(nn_state) matrix = fn(v, vp, *args, expand=True) diag = fn(v, vp, *args, expand=False) msg = f"Diagonal of matrix {prop_name} is wrong!" equation = "cii...->ci..." if is_complex else "ii->i" assertAlmostEqual(torch.einsum(equation, matrix), diag, TOL, msg=msg)
def test_density_matrix_hermiticity(): nn_state = DensityMatrix(5, 5, 5, gpu=False) v_space = nn_state.generate_hilbert_space(5) matrix = nn_state.rhoRBM(v_space, v_space) # Pick 10 random elements to sample, row and column index elements = torch.randint(0, 2**5, (2, 10)) real_reg_elements = torch.zeros(10) real_dag_elements = torch.zeros(10) imag_reg_elements = torch.zeros(10) imag_dag_elements = torch.zeros(10) for i in range(10): real_reg_elements[i] = matrix[0, elements[0][i], elements[1][i]] real_dag_elements[i] = matrix[0, elements[1][i], elements[0][i]] imag_reg_elements[i] = matrix[1, elements[0][i], elements[1][i]] imag_dag_elements[i] = -matrix[1, elements[1][i], elements[0][i]] assert torch.equal(real_reg_elements, real_dag_elements) assert torch.equal(imag_reg_elements, imag_dag_elements)
def density_matrix_data(request, gpu, num_hidden): with open( os.path.join(request.fspath.dirname, "data", "test_grad_data.pkl"), "rb") as f: test_data = pickle.load(f) qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True) data_bases = test_data["density_matrix"]["train_bases"] data_samples = torch.tensor(test_data["density_matrix"]["train_samples"], dtype=torch.double) bases_data = test_data["density_matrix"]["bases"] target_matrix = torch.tensor(test_data["density_matrix"]["density_matrix"], dtype=torch.double) num_visible = data_samples.shape[-1] num_aux = num_hidden + 1 # this is not a rule, will change with data unitary_dict = unitaries.create_dict() nn_state = DensityMatrix(num_visible, num_hidden, num_aux, unitary_dict=unitary_dict, gpu=gpu) DGU = DensityGradsUtils(nn_state) bases = DGU.transform_bases(bases_data) v_space = nn_state.generate_hilbert_space(num_visible) a_space = nn_state.generate_hilbert_space(num_aux) data_samples = data_samples.to(device=nn_state.device) unitary_dict = { b: v.to(device=nn_state.device) for b, v in unitary_dict.items() } DensityMatrixFixture = namedtuple( "DensityMatrixFixture", [ "data_samples", "data_bases", "grad_utils", "bases", "target", "v_space", "a_space", "nn_state", "unitary_dict", ], ) return DensityMatrixFixture( data_samples=data_samples, data_bases=data_bases, grad_utils=DGU, bases=bases, target=target_matrix, v_space=v_space, a_space=a_space, nn_state=nn_state, unitary_dict=unitary_dict, )
def quantum_state_training_data(request): nn_state_type = request.param if nn_state_type == PositiveWaveFunction: root = os.path.join( request.fspath.dirname, "..", "examples", "Tutorial1_TrainPosRealWaveFunction", ) train_samples, target = data.load_data( tr_samples_path=os.path.join(root, "tfim1d_data.txt"), tr_psi_path=os.path.join(root, "tfim1d_psi.txt"), ) train_bases, bases = None, None nn_state = PositiveWaveFunction(num_visible=train_samples.shape[-1], gpu=False) batch_size, num_chains = 100, 200 fid_target, kl_target = 0.85, 0.29 reinit_params_fn = initialize_posreal_params elif nn_state_type == ComplexWaveFunction: root = os.path.join( request.fspath.dirname, "..", "examples", "Tutorial2_TrainComplexWaveFunction", ) train_samples, target, train_bases, bases = data.load_data( tr_samples_path=os.path.join(root, "qubits_train.txt"), tr_psi_path=os.path.join(root, "qubits_psi.txt"), tr_bases_path=os.path.join(root, "qubits_train_bases.txt"), bases_path=os.path.join(root, "qubits_bases.txt"), ) nn_state = ComplexWaveFunction(num_visible=train_samples.shape[-1], gpu=False) batch_size, num_chains = 50, 10 fid_target, kl_target = 0.38, 0.33 reinit_params_fn = initialize_complex_params elif nn_state_type == DensityMatrix: root = os.path.join(request.fspath.dirname, "..", "examples", "Tutorial3_TrainDensityMatrix") train_samples, target, train_bases, bases = data.load_data_DM( tr_samples_path=os.path.join(root, "N2_W_state_100_samples_data.txt"), tr_mtx_real_path=os.path.join(root, "N2_W_state_target_real.txt"), tr_mtx_imag_path=os.path.join(root, "N2_W_state_target_imag.txt"), tr_bases_path=os.path.join(root, "N2_W_state_100_samples_bases.txt"), bases_path=os.path.join(root, "N2_IC_bases.txt"), ) nn_state = DensityMatrix(num_visible=train_samples.shape[-1], gpu=False) batch_size, num_chains = 100, 10 fid_target, kl_target = 0.45, 0.42 def reinit_params_fn(request, nn_state): nn_state.reinitialize_parameters() else: raise ValueError( f"invalid test config: {nn_state_type} is not a valid quantum state type" ) return { "nn_state": nn_state, "data": train_samples, "input_bases": train_bases, "target": target, "bases": bases, "epochs": 5, "pos_batch_size": batch_size, "neg_batch_size": num_chains, "k": 10, "lr": 0.1, "space": nn_state.generate_hilbert_space(), "fid_target": fid_target, "kl_target": kl_target, "reinit_params_fn": reinit_params_fn, }