コード例 #1
0
ファイル: test_grads.py プロジェクト: yangyuan16/QuCumber
def test_grads(quantum_state_graddata):
    nn_state, alg_grads, num_grads, grad_type, test_tol = quantum_state_graddata

    print("\nTesting {} gradients for {} on {}.".format(
        grad_type, nn_state.__class__.__name__, nn_state.device))

    for n, net in enumerate(nn_state.networks):
        print("\nRBM: %s" % net)
        rbm = getattr(nn_state, net)

        param_ranges = {}
        counter = 0
        for param_name, param in rbm.named_parameters():
            param_ranges[param_name] = range(counter, counter + param.numel())
            counter += param.numel()

        for i, grad in enumerate(num_grads[n]):
            p_name, at_start = get_param_status(i, param_ranges)
            if at_start:
                print(f"\nTesting {p_name}...")
                print(f"Numerical {grad_type}\tAlg {grad_type}")

            print("{: 10.8f}\t{: 10.8f}\t\t".format(grad,
                                                    alg_grads[n][i].item()))

        assertAlmostEqual(
            num_grads[n],
            alg_grads[n],
            test_tol,
            msg=f"{grad_type} grads are not close enough for {net}!",
        )
コード例 #2
0
def test_rotate_rho_probs(num_visible, state_type, precompute_rho):
    nn_state = state_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()

    space = nn_state.generate_hilbert_space()

    rho = nn_state.rho(space, expand=True) if precompute_rho else None
    rho_r = rotate_rho(nn_state, basis, space, unitary_dict, rho=rho)
    rho_r_probs = torch.diagonal(cplx.real(rho_r))

    rho_r_probs_fast = rotate_rho_probs(nn_state,
                                        basis,
                                        space,
                                        unitary_dict,
                                        rho=rho)

    # use different tolerance as this sometimes just barely breaks through the
    #  smaller TOL value from test_grads.py
    assertAlmostEqual(
        rho_r_probs,
        rho_r_probs_fast,
        tol=(TOL * 10),
        msg="Fast rho probs rotation failed!",
    )
コード例 #3
0
def test_density_matrix_tr1():
    nn_state = DensityMatrix(5, gpu=False)

    space = nn_state.generate_hilbert_space(5)
    matrix = nn_state.rho(space, space) / nn_state.normalization(space)

    msg = f"Trace of density matrix is not within {TOL} of 1!"
    assertAlmostEqual(torch.trace(matrix[0]), torch.Tensor([1]), TOL, msg=msg)
コード例 #4
0
def test_density_matrix_diagonal():
    nn_state = DensityMatrix(5, gpu=False)
    v = nn_state.generate_hilbert_space(5)

    rho = nn_state.rho(v, expand=True)
    diag = nn_state.rho(v, expand=False)

    msg = "Diagonal of density matrix is wrong!"
    assertAlmostEqual(torch.einsum("cii...->ci...", rho), diag, TOL, msg=msg)
コード例 #5
0
def test_rotate_psi_inner_prod(num_visible, state_type, precompute_psi):
    nn_state = state_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()

    space = nn_state.generate_hilbert_space()

    psi = nn_state.psi(space) if precompute_psi else None
    psi_r = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi)

    psi_r_ip = rotate_psi_inner_prod(nn_state, basis, space, unitary_dict, psi=psi)

    assertAlmostEqual(psi_r, psi_r_ip, msg="Fast psi inner product rotation failed!")
コード例 #6
0
def test_rotate_psi(num_visible, wvfn_type):
    nn_state = wvfn_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()

    space = nn_state.generate_hilbert_space()
    psi = nn_state.psi(space)

    psi_r_fast = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi)

    U = reduce(cplx.kronecker_prod, [unitary_dict[b] for b in basis])
    psi_r_correct = cplx.matmul(U, psi)

    assertAlmostEqual(psi_r_fast, psi_r_correct, msg="Fast psi rotation failed!")
コード例 #7
0
def test_rotate_rho(num_visible, state_type):
    nn_state = state_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()

    space = nn_state.generate_hilbert_space()
    rho = nn_state.rho(space, space)

    rho_r_fast = rotate_rho(nn_state, basis, space, unitary_dict, rho=rho)

    U = reduce(cplx.kronecker_prod, [unitary_dict[b] for b in basis])
    rho_r_correct = cplx.matmul(U, rho)
    rho_r_correct = cplx.matmul(rho_r_correct, cplx.conjugate(U))

    assertAlmostEqual(rho_r_fast, rho_r_correct, msg="Fast rho rotation failed!")
コード例 #8
0
def test_density_matrix_expansion(prop):
    qucumber.set_random_seed(INIT_SEED, cpu=True, gpu=False, quiet=True)

    nn_state = DensityMatrix(5, gpu=False)
    v = nn_state.generate_hilbert_space(5)
    vp = v[torch.randperm(v.shape[0]), :]

    prop_name = prop[0]
    is_complex = prop[1]
    args = prop[2:]
    fn = attrgetter(prop_name)(nn_state)

    matrix = fn(v, vp, *args, expand=True)
    diag = fn(v, vp, *args, expand=False)

    msg = f"Diagonal of matrix {prop_name} is wrong!"

    equation = "cii...->ci..." if is_complex else "ii->i"
    assertAlmostEqual(torch.einsum(equation, matrix), diag, TOL, msg=msg)