def KL(nn_state, target_psi, bases=None): psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) KL = 0.0 unitary_dict = unitaries.create_dict() target_psi = target_psi.to(nn_state.device) space = nn_state.generate_Hilbert_space(nn_state.num_visible) nn_state.compute_normalization() if bases is None: num_bases = 1 for i in range(len(space)): KL += cplx.norm(target_psi[:, i]) * cplx.norm(target_psi[:, i]).log() KL -= cplx.norm(target_psi[:, i]) * cplx.norm( nn_state.psi(space[i])).log() KL += cplx.norm(target_psi[:, i]) * nn_state.Z.log() else: num_bases = len(bases) for b in range(1, len(bases)): psi_r = rotate_psi(nn_state, bases[b], unitary_dict) target_psi_r = rotate_psi(nn_state, bases[b], unitary_dict, target_psi) for ii in range(len(space)): if (cplx.norm(target_psi_r[:, ii]) > 0.0): KL += cplx.norm(target_psi_r[:, ii]) * cplx.norm( target_psi_r[:, ii]).log() KL -= cplx.norm(target_psi_r[:, ii]) * cplx.norm( psi_r[:, ii]).log().item() KL += cplx.norm(target_psi_r[:, ii]) * nn_state.Z.log() return KL / float(num_bases)
def test_rotate_rho_probs(num_visible, state_type, precompute_rho): nn_state = state_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() space = nn_state.generate_hilbert_space() rho = nn_state.rho(space, expand=True) if precompute_rho else None rho_r = rotate_rho(nn_state, basis, space, unitary_dict, rho=rho) rho_r_probs = torch.diagonal(cplx.real(rho_r)) rho_r_probs_fast = rotate_rho_probs(nn_state, basis, space, unitary_dict, rho=rho) # use different tolerance as this sometimes just barely breaks through the # smaller TOL value from test_grads.py assertAlmostEqual( rho_r_probs, rho_r_probs_fast, tol=(TOL * 10), msg="Fast rho probs rotation failed!", )
def __init__(self, num_visible, num_hidden, num_aux, unitary_dict=None, gpu=False): self.rbm_am = PurificationRBM(int(num_visible), int(num_hidden), int(num_aux), gpu=gpu) self.rbm_ph = PurificationRBM(int(num_visible), int(num_hidden), int(num_aux), gpu=gpu) self.num_visible = int(num_visible) self.num_hidden = int(num_hidden) self.num_aux = int(num_aux) self.device = self.rbm_am.device self.unitary_dict = unitary_dict if unitary_dict else unitaries.create_dict( ) self.unitary_dict = {k: v for k, v in self.unitary_dict.items()}
def complex_wavefunction_data(gpu, num_hidden): with open( os.path.join(__tests_location__, "data", "test_grad_data.pkl"), "rb" ) as f: test_data = pickle.load(f) qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True) data_bases = test_data["2qubits"]["train_bases"] data_samples = torch.tensor( test_data["2qubits"]["train_samples"], dtype=torch.double ) bases_data = test_data["2qubits"]["bases"] target_psi_tmp = torch.tensor( test_data["2qubits"]["target_psi"], dtype=torch.double ) num_visible = data_samples.shape[-1] unitary_dict = unitaries.create_dict() nn_state = ComplexWaveFunction( num_visible, num_hidden, unitary_dict=unitary_dict, gpu=gpu ) CGU = ComplexGradsUtils(nn_state) bases = CGU.transform_bases(bases_data) psi_dict = CGU.load_target_psi(bases, target_psi_tmp) vis = nn_state.generate_hilbert_space(num_visible) data_samples = data_samples.to(device=nn_state.device) unitary_dict = {b: v.to(device=nn_state.device) for b, v in unitary_dict.items()} psi_dict = {b: v.to(device=nn_state.device) for b, v in psi_dict.items()} ComplexWaveFunctionFixture = namedtuple( "ComplexWaveFunctionFixture", [ "data_samples", "data_bases", "grad_utils", "bases", "psi_dict", "vis", "nn_state", "unitary_dict", ], ) return ComplexWaveFunctionFixture( data_samples=data_samples, data_bases=data_bases, grad_utils=CGU, bases=bases, psi_dict=psi_dict, vis=vis, nn_state=nn_state, unitary_dict=unitary_dict, )
def __init__( self, num_visible, num_hidden=None, unitary_dict=None, gpu=False, module=None ): if gpu and torch.cuda.is_available(): warnings.warn( "Using ComplexWaveFunction on GPU is not recommended due to poor performance compared to CPU.", ResourceWarning, 2, ) self.device = torch.device("cuda") else: self.device = torch.device("cpu") if module is None: self.rbm_am = BinaryRBM(num_visible, num_hidden, gpu=gpu) self.rbm_ph = BinaryRBM(num_visible, num_hidden, gpu=gpu) else: _warn_on_missing_gpu(gpu) self.rbm_am = module.to(self.device) self.rbm_am.device = self.device self.rbm_ph = module.to(self.device).clone() self.rbm_ph.device = self.device self.num_visible = self.rbm_am.num_visible self.num_hidden = self.rbm_am.num_hidden self.device = self.rbm_am.device self.unitary_dict = unitary_dict if unitary_dict else unitaries.create_dict() self.unitary_dict = { k: v.to(device=self.device) for k, v in self.unitary_dict.items() }
def test_adding_unitaries(): unitary_dict = unitaries.create_dict( A=( 0.5 * torch.tensor([[[1, 1], [1, 1]], [[1, -1], [-1, 1]]], dtype=torch.double) ) ) msg = "Unitary dictionary has the wrong keys!" assert {"X", "Y", "Z", "A"} == set(unitary_dict.keys()), msg
def NLL(nn_state, samples, space, train_bases=None, **kwargs): r"""A function for calculating the negative log-likelihood. :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: WaveFunction :param samples: Samples to compute the NLL on. :type samples: torch.Tensor :param space: The hilbert space of the system. :type space: torch.Tensor :param train_bases: An array of bases where measurements were taken. :type train_bases: np.array(dtype=str) :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The Negative Log-Likelihood. :rtype: float """ psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) NLL = 0.0 unitary_dict = unitaries.create_dict() Z = nn_state.compute_normalization(space) eps = 0.000001 if train_bases is None: for i in range(len(samples)): NLL -= (cplx.norm_sqr(nn_state.psi(samples[i])) + eps).log() NLL += Z.log() else: for i in range(len(samples)): # Check whether the sample was measured the reference basis is_reference_basis = True # b_ID = 0 for j in range(nn_state.num_visible): if train_bases[i][j] != "Z": is_reference_basis = False break if is_reference_basis is True: NLL -= (cplx.norm_sqr(nn_state.psi(samples[i])) + eps).log() NLL += Z.log() else: psi_r = rotate_psi(nn_state, train_bases[i], space, unitary_dict) # Get the index value of the sample state ind = 0 for j in range(nn_state.num_visible): if samples[i, nn_state.num_visible - j - 1] == 1: ind += pow(2, j) NLL -= cplx.norm_sqr(psi_r[:, ind]).log().item() NLL += Z.log() return (NLL / float(len(samples))).item()
def test_rotate_psi_inner_prod(num_visible, state_type, precompute_psi): nn_state = state_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() space = nn_state.generate_hilbert_space() psi = nn_state.psi(space) if precompute_psi else None psi_r = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi) psi_r_ip = rotate_psi_inner_prod(nn_state, basis, space, unitary_dict, psi=psi) assertAlmostEqual(psi_r, psi_r_ip, msg="Fast psi inner product rotation failed!")
def test_rotate_psi(num_visible, wvfn_type): nn_state = wvfn_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() space = nn_state.generate_hilbert_space() psi = nn_state.psi(space) psi_r_fast = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi) U = reduce(cplx.kronecker_prod, [unitary_dict[b] for b in basis]) psi_r_correct = cplx.matmul(U, psi) assertAlmostEqual(psi_r_fast, psi_r_correct, msg="Fast psi rotation failed!")
def test_rotate_rho(num_visible, state_type): nn_state = state_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() space = nn_state.generate_hilbert_space() rho = nn_state.rho(space, space) rho_r_fast = rotate_rho(nn_state, basis, space, unitary_dict, rho=rho) U = reduce(cplx.kronecker_prod, [unitary_dict[b] for b in basis]) rho_r_correct = cplx.matmul(U, rho) rho_r_correct = cplx.matmul(rho_r_correct, cplx.conjugate(U)) assertAlmostEqual(rho_r_fast, rho_r_correct, msg="Fast rho rotation failed!")
def KL(nn_state, target_psi, space, bases=None, **kwargs): r"""A function for calculating the total KL divergence. :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: WaveFunction :param target_psi: The true wavefunction of the system. :type target_psi: torch.Tensor :param space: The hilbert space of the system. :type space: torch.Tensor :param bases: An array of unique bases. :type bases: np.array(dtype=str) :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The KL divergence. :rtype: float """ psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) KL = 0.0 unitary_dict = unitaries.create_dict() target_psi = target_psi.to(nn_state.device) Z = nn_state.compute_normalization(space) eps = 0.000001 if bases is None: num_bases = 1 for i in range(len(space)): KL += (cplx.norm_sqr(target_psi[:, i]) * (cplx.norm_sqr(target_psi[:, i]) + eps).log()) KL -= (cplx.norm_sqr(target_psi[:, i]) * (cplx.norm_sqr(nn_state.psi(space[i])) + eps).log()) KL += cplx.norm_sqr(target_psi[:, i]) * Z.log() else: num_bases = len(bases) for b in range(1, len(bases)): psi_r = rotate_psi(nn_state, bases[b], space, unitary_dict) target_psi_r = rotate_psi(nn_state, bases[b], space, unitary_dict, target_psi) for ii in range(len(space)): if cplx.norm_sqr(target_psi_r[:, ii]) > 0.0: KL += (cplx.norm_sqr(target_psi_r[:, ii]) * cplx.norm_sqr(target_psi_r[:, ii]).log()) KL -= (cplx.norm_sqr(target_psi_r[:, ii]) * cplx.norm_sqr(psi_r[:, ii]).log().item()) KL += cplx.norm_sqr(target_psi_r[:, ii]) * Z.log() return (KL / float(num_bases)).item()
def __init__(self, num_visible, num_hidden=None, unitary_dict=None, gpu=True, module=None): if gpu and torch.cuda.is_available(): warnings.warn( ("Using ComplexWaveFunction on GPU is not recommended due to poor " "performance compared to CPU. In the future, ComplexWaveFunction " "will default to using CPU, even if a GPU is available."), ResourceWarning, 2, ) self.device = torch.device("cuda") else: self.device = torch.device("cpu") if module is None: self.rbm_am = BinaryRBM( int(num_visible), int(num_hidden) if num_hidden else int(num_visible), gpu=gpu, ) self.rbm_ph = BinaryRBM( int(num_visible), int(num_hidden) if num_hidden else int(num_visible), gpu=gpu, ) else: _warn_on_missing_gpu(gpu) self.rbm_am = module.to(self.device) self.rbm_am.device = self.device self.rbm_ph = module.to(self.device).clone() self.rbm_ph.device = self.device self.num_visible = int(num_visible) self.num_hidden = int(num_hidden) if num_hidden else self.num_visible self.unitary_dict = unitary_dict if unitary_dict else unitaries.create_dict( ) self.unitary_dict = { k: v.to(device=self.device) for k, v in self.unitary_dict.items() }
def __init__(self, num_visible, num_hidden=None, unitary_dict=None, gpu=True): self.num_visible = int(num_visible) self.num_hidden = int(num_hidden) if num_hidden else self.num_visible self.rbm_am = BinaryRBM(self.num_visible, self.num_hidden, gpu=gpu) self.rbm_ph = BinaryRBM(self.num_visible, self.num_hidden, gpu=gpu) self.device = self.rbm_am.device self.unitary_dict = unitary_dict if unitary_dict else unitaries.create_dict( ) self.unitary_dict = { k: v.to(device=self.device) for k, v in self.unitary_dict.items() }
def test_trainingcomplex(self): print("Complex Wavefunction") print("--------------------") train_samples_path = os.path.join( __tests_location__, "..", "examples", "Tutorial2_TrainComplexWavefunction", "qubits_train.txt", ) train_bases_path = os.path.join( __tests_location__, "..", "examples", "Tutorial2_TrainComplexWavefunction", "qubits_train_bases.txt", ) bases_path = os.path.join( __tests_location__, "..", "examples", "Tutorial2_TrainComplexWavefunction", "qubits_bases.txt", ) psi_path = os.path.join( __tests_location__, "..", "examples", "Tutorial2_TrainComplexWavefunction", "qubits_psi.txt", ) train_samples, target_psi, train_bases, bases = data.load_data( train_samples_path, psi_path, train_bases_path, bases_path) unitary_dict = unitaries.create_dict() nv = nh = train_samples.shape[-1] fidelities = [] KLs = [] epochs = 5 batch_size = 50 num_chains = 10 CD = 10 lr = 0.1 log_every = 5 print( "Training 10 times and checking fidelity and KL at 5 epochs...\n") for i in range(10): print("Iteration: ", i + 1) nn_state = ComplexWavefunction(unitary_dict=unitary_dict, num_visible=nv, num_hidden=nh, gpu=False) space = nn_state.generate_hilbert_space(nv) callbacks = [ MetricEvaluator( log_every, { "Fidelity": ts.fidelity, "KL": ts.KL }, target_psi=target_psi, bases=bases, space=space, verbose=True, ) ] self.initialize_complex_params(nn_state) nn_state.fit( data=train_samples, epochs=epochs, pos_batch_size=batch_size, neg_batch_size=num_chains, k=CD, lr=lr, time=True, input_bases=train_bases, progbar=False, callbacks=callbacks, ) fidelities.append(ts.fidelity(nn_state, target_psi, space).item()) KLs.append(ts.KL(nn_state, target_psi, space, bases=bases).item()) print("\nStatistics") print("----------") print( "Fidelity: ", np.average(fidelities), "+/-", np.std(fidelities) / np.sqrt(len(fidelities)), "\n", ) print("KL: ", np.average(KLs), "+/-", np.std(KLs) / np.sqrt(len(KLs)), "\n") self.assertTrue(abs(np.average(fidelities) - 0.38) < 0.05) self.assertTrue(abs(np.average(KLs) - 0.33) < 0.05) self.assertTrue((np.std(fidelities) / np.sqrt(len(fidelities))) < 0.01) self.assertTrue((np.std(KLs) / np.sqrt(len(KLs))) < 0.01)
def density_matrix_data(request, gpu, num_hidden): with open( os.path.join(request.fspath.dirname, "data", "test_grad_data.pkl"), "rb") as f: test_data = pickle.load(f) qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True) data_bases = test_data["density_matrix"]["train_bases"] data_samples = torch.tensor(test_data["density_matrix"]["train_samples"], dtype=torch.double) bases_data = test_data["density_matrix"]["bases"] target_matrix = torch.tensor(test_data["density_matrix"]["density_matrix"], dtype=torch.double) num_visible = data_samples.shape[-1] num_aux = num_hidden + 1 # this is not a rule, will change with data unitary_dict = unitaries.create_dict() nn_state = DensityMatrix(num_visible, num_hidden, num_aux, unitary_dict=unitary_dict, gpu=gpu) DGU = DensityGradsUtils(nn_state) bases = DGU.transform_bases(bases_data) v_space = nn_state.generate_hilbert_space(num_visible) a_space = nn_state.generate_hilbert_space(num_aux) data_samples = data_samples.to(device=nn_state.device) unitary_dict = { b: v.to(device=nn_state.device) for b, v in unitary_dict.items() } DensityMatrixFixture = namedtuple( "DensityMatrixFixture", [ "data_samples", "data_bases", "grad_utils", "bases", "target", "v_space", "a_space", "nn_state", "unitary_dict", ], ) return DensityMatrixFixture( data_samples=data_samples, data_bases=data_bases, grad_utils=DGU, bases=bases, target=target_matrix, v_space=v_space, a_space=a_space, nn_state=nn_state, unitary_dict=unitary_dict, )
def test_rotate_psi(num_visible, wvfn_type): nn_state = wvfn_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() rotate_psi(nn_state, basis, nn_state.generate_hilbert_space(), unitary_dict)
print('') if __name__ == '__main__': k = 2 num_chains = 10 seed = 1234 with open('test_data.pkl', 'rb') as fin: test_data = pickle.load(fin) qucumber.set_random_seed(seed) train_bases = test_data['2qubits']['train_bases'] train_samples = torch.tensor(test_data['2qubits']['train_samples'], dtype=torch.double) bases_data = test_data['2qubits']['bases'] target_psi_tmp = torch.tensor(test_data['2qubits']['target_psi'], dtype=torch.double) nh = train_samples.shape[-1] bases = transform_bases(bases_data) unitary_dict = unitaries.create_dict() psi_dict = load_target_psi(bases, target_psi_tmp) vis = generate_visible_space(train_samples.shape[-1]) nn_state = ComplexWavefunction(num_visible=train_samples.shape[-1], num_hidden=nh, unitary_dict=unitary_dict) qr = QuantumReconstruction(nn_state) eps = 1.e-6 run(qr, psi_dict, train_samples, train_bases, unitary_dict, bases, vis, eps, k)
def test_default_unitary_dict(): unitary_dict = unitaries.create_dict() msg = "Default Unitary dictionary has the wrong keys!" assert {"X", "Y", "Z"} == set(unitary_dict.keys()), msg