示例#1
0
def for_back_np_2d():
    """Compute forward 2D FFT and then backward 2D FFT.

    Compute this for all possible grid sizes with lengths powers of 2.

    """
    length = [2**s for s in range(7, 11)]
    all_shapes = list(itertools.product(length, length))
    dtype = np.complex128
    eps = np.finfo(dtype).eps
    eps *= 10
    func = np.ones
    delta_r = (1, 1)
    for shape in all_shapes:
        grid = [func(shape, dtype=dtype)] * 2
        grid_k = ttools.fft_2d(grid, delta_r)
        grid_r = ttools.ifft_2d(grid_k, delta_r)

        max_diff0 = np.max(abs(grid_r[0] - grid[0]))
        max_diff1 = np.max(abs(grid_r[1] - grid[1]))
        atoms = [ttools.calc_atoms(grid), ttools.calc_atoms(grid_r)]

        assert (max_diff0 <= eps and max_diff1 <= eps), f"Max errors: \
            {max_diff0}, {max_diff1}"

        assert len(grid_r) == 2
        assert abs(atoms[0] - atoms[1]) <= eps * math.prod(shape), \
            f"\nAtom num. before/after: {atoms[0]}, {atoms[1]}; \
            \nDifference: {abs(atoms[0] - atoms[1])}; \
            \n2*eps*N: {eps * math.prod(shape)}."

    print("Test `for_back_np_2d` passed.")
示例#2
0
def compare3_torch_1d_2d():
    """Compute a single forward 2D FFT and two forward 1D FFTs.

    Compute this for all possible grid sizes with lengths powers of 2.

    """
    length = [2**s for s in range(7, 11)]
    all_shapes = list(itertools.product(length, length))
    dtype = torch.complex128
    eps = torch.finfo(dtype).eps
    eps *= 10
    func = torch.ones
    delta_r = (1, 1)
    for shape in all_shapes:
        grid = [func(shape, dtype=dtype, device='cuda')] * 2
        grid_r_half = ttools.fft_1d(grid, delta_r, axis=1)
        grid_r = ttools.fft_1d(grid_r_half, delta_r, axis=0)
        grid = ttools.fft_2d(grid, delta_r)

        max_diff0 = torch.max(abs(grid_r[0] - grid[0]))
        max_diff1 = torch.max(abs(grid_r[1] - grid[1]))
        atoms = [ttools.calc_atoms(grid), ttools.calc_atoms(grid_r)]

        assert (max_diff0 <= eps and max_diff1 <= eps), f"Max errors: \
            {max_diff0}, {max_diff1}."

        assert len(grid_r) == 2
        assert abs(atoms[0] - atoms[1]) <= eps * math.prod(shape), \
            f"\nAtom num. before/after: {atoms[0]}, {atoms[1]}; \
            \nDifference: {abs(atoms[0] - atoms[1])}; \
            \n2*eps*N: {eps * math.prod(shape)}."

    print("Test `compare3_torch_1d_2d` passed.")
示例#3
0
def compare_norm_torch():
    """Compare the total norm of a grid and its FFT.

    Compute this for all possible grid sizes with lengths powers of 2.

    """
    length = [2**s for s in range(7, 11)]
    all_shapes = list(itertools.product(length, length))
    dtype = torch.complex128
    eps = torch.finfo(dtype).eps
    eps *= 10
    func = torch.ones
    delta_r = (1, 1)
    for shape in all_shapes:
        grid = [func(shape, dtype=dtype)] * 2
        grid_k = ttools.fft_2d(grid, delta_r)

        vol_elem = 4 * np.pi**2 / math.prod(shape)
        atoms = [ttools.calc_atoms(grid), ttools.calc_atoms(grid_k, vol_elem)]

        assert abs(atoms[0] - atoms[1]) <= eps * math.prod(shape), \
            f"\nAtom num. before/after: {atoms[0]}, {atoms[1]}; \
            \nDifference: {abs(atoms[0] - atoms[1])}; \
            \n2*eps*N: {eps * math.prod(shape)}."

    print("Test `compare_norm_torch` passed.")
示例#4
0
g_sc = {'uu': 1.0, 'dd': 1.0, 'ud': 1.04}
pop_frac = (0.5, 0.5)
ps = spin.PSpinor(DATA_PATH, overwrite=True, atom_num=ATOM_NUM, omeg=omeg,
                  g_sc=g_sc, phase_factor=-1, is_coupling=False,
                  pop_frac=pop_frac, r_sizes=(8, 8), mesh_points=(256, 512))

plt.figure()
# plt.imshow(ttools.density(ttools.fft_2d(ps.psi, ps.delta_r))[0])
plt.imshow(ttools.density(ps.psik)[0])
plt.show()

ps.coupling_setup(wavel=790.1e-9)
ps.coupling_grad()

psi = ps.psi
psik = ttools.fft_2d(psi, ps.delta_r)
psi_prime = ttools.ifft_2d(psik, ps.delta_r)
print((np.abs(psi[0])**2 - np.abs(psi_prime[0])**2).max())

# --------- 2. RUN (Imaginary) ----
ps.N_STEPS = 1000
ps.dt = 1/50
ps.is_sampling = True
ps.device = 'cuda:0'

res0 = ps.imaginary()
# `res0` is an object containing the final wavefunctions, the energy exp.
# values, populations, average positions, and a directory path to sampled
# wavefunctions. It also has class methods for plotting and analysis.

psik_shifted = ps.shift_momentum(ps.psik, frac=(0.0, 1.0))