def test_grads(quantum_state_graddata): nn_state, alg_grads, num_grads, grad_type, test_tol = quantum_state_graddata print("\nTesting {} gradients for {} on {}.".format( grad_type, nn_state.__class__.__name__, nn_state.device)) for n, net in enumerate(nn_state.networks): print("\nRBM: %s" % net) rbm = getattr(nn_state, net) param_ranges = {} counter = 0 for param_name, param in rbm.named_parameters(): param_ranges[param_name] = range(counter, counter + param.numel()) counter += param.numel() for i, grad in enumerate(num_grads[n]): p_name, at_start = get_param_status(i, param_ranges) if at_start: print(f"\nTesting {p_name}...") print(f"Numerical {grad_type}\tAlg {grad_type}") print("{: 10.8f}\t{: 10.8f}\t\t".format(grad, alg_grads[n][i].item())) assertAlmostEqual( num_grads[n], alg_grads[n], test_tol, msg=f"{grad_type} grads are not close enough for {net}!", )
def test_rotate_rho_probs(num_visible, state_type, precompute_rho): nn_state = state_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() space = nn_state.generate_hilbert_space() rho = nn_state.rho(space, expand=True) if precompute_rho else None rho_r = rotate_rho(nn_state, basis, space, unitary_dict, rho=rho) rho_r_probs = torch.diagonal(cplx.real(rho_r)) rho_r_probs_fast = rotate_rho_probs(nn_state, basis, space, unitary_dict, rho=rho) # use different tolerance as this sometimes just barely breaks through the # smaller TOL value from test_grads.py assertAlmostEqual( rho_r_probs, rho_r_probs_fast, tol=(TOL * 10), msg="Fast rho probs rotation failed!", )
def test_density_matrix_tr1(): nn_state = DensityMatrix(5, gpu=False) space = nn_state.generate_hilbert_space(5) matrix = nn_state.rho(space, space) / nn_state.normalization(space) msg = f"Trace of density matrix is not within {TOL} of 1!" assertAlmostEqual(torch.trace(matrix[0]), torch.Tensor([1]), TOL, msg=msg)
def test_density_matrix_diagonal(): nn_state = DensityMatrix(5, gpu=False) v = nn_state.generate_hilbert_space(5) rho = nn_state.rho(v, expand=True) diag = nn_state.rho(v, expand=False) msg = "Diagonal of density matrix is wrong!" assertAlmostEqual(torch.einsum("cii...->ci...", rho), diag, TOL, msg=msg)
def test_rotate_psi_inner_prod(num_visible, state_type, precompute_psi): nn_state = state_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() space = nn_state.generate_hilbert_space() psi = nn_state.psi(space) if precompute_psi else None psi_r = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi) psi_r_ip = rotate_psi_inner_prod(nn_state, basis, space, unitary_dict, psi=psi) assertAlmostEqual(psi_r, psi_r_ip, msg="Fast psi inner product rotation failed!")
def test_rotate_psi(num_visible, wvfn_type): nn_state = wvfn_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() space = nn_state.generate_hilbert_space() psi = nn_state.psi(space) psi_r_fast = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi) U = reduce(cplx.kronecker_prod, [unitary_dict[b] for b in basis]) psi_r_correct = cplx.matmul(U, psi) assertAlmostEqual(psi_r_fast, psi_r_correct, msg="Fast psi rotation failed!")
def test_rotate_rho(num_visible, state_type): nn_state = state_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() space = nn_state.generate_hilbert_space() rho = nn_state.rho(space, space) rho_r_fast = rotate_rho(nn_state, basis, space, unitary_dict, rho=rho) U = reduce(cplx.kronecker_prod, [unitary_dict[b] for b in basis]) rho_r_correct = cplx.matmul(U, rho) rho_r_correct = cplx.matmul(rho_r_correct, cplx.conjugate(U)) assertAlmostEqual(rho_r_fast, rho_r_correct, msg="Fast rho rotation failed!")
def test_density_matrix_expansion(prop): qucumber.set_random_seed(INIT_SEED, cpu=True, gpu=False, quiet=True) nn_state = DensityMatrix(5, gpu=False) v = nn_state.generate_hilbert_space(5) vp = v[torch.randperm(v.shape[0]), :] prop_name = prop[0] is_complex = prop[1] args = prop[2:] fn = attrgetter(prop_name)(nn_state) matrix = fn(v, vp, *args, expand=True) diag = fn(v, vp, *args, expand=False) msg = f"Diagonal of matrix {prop_name} is wrong!" equation = "cii...->ci..." if is_complex else "ii->i" assertAlmostEqual(torch.einsum(equation, matrix), diag, TOL, msg=msg)