def test_basex_basis_sets_cache(): n = 121 file_name = os.path.join(DATA_DIR, "basex_basis_{}_{}.npy".format(n, n//2)) if os.path.exists(file_name): os.remove(file_name) # 1st call generate and save get_bs_basex_cached(n, basis_dir=DATA_DIR, verbose=False) # 2nd call load from file get_bs_basex_cached(n, basis_dir=DATA_DIR, verbose=False) if os.path.exists(file_name): os.remove(file_name)
def test_basex_zeros(): n = 21 x = np.zeros((n, n), dtype='float32') bs = get_bs_basex_cached(n, basis_dir=None, verbose=False) recon = basex_transform(x, *bs) assert_allclose(recon, 0)
def test_basex_shape(): n = 21 x = np.ones((n, n), dtype='float32') bs = get_bs_basex_cached(n, basis_dir=None, verbose=False) recon = basex_transform(x, *bs) assert recon.shape == (n, n)
def test_basex_step_ratio(): """Check a gaussian solution for BASEX""" n = 51 r_max = 25 ref = GaussianAnalytical(n, r_max, symmetric=True, sigma=10) tr = np.tile(ref.abel[None, :], (n, 1)) # make a 2D array from 1D bs = get_bs_basex_cached(n, basis_dir=None, verbose=False) recon = basex_transform(tr, *bs) recon1d = recon[n//2 + n%2] ratio = absolute_ratio_benchmark(ref, recon1d) assert_allclose( ratio , 1.0, rtol=3e-2, atol=0)