def test_downsample(): f = AntialiasFilter() sampling = (.1, .1) gpts = (228, 229) mask = f.get_mask(gpts, sampling, np) n = np.sum(mask > 0.) array = np.fft.ifft2(mask) waves = Waves(array, sampling=sampling, energy=80e3, antialias_aperture=(2 / 3.,) * 2) assert np.allclose(waves.downsample('valid', return_fourier_space=True).array.real, 1.)
def test_waves_raises(): array = np.ones((1, 25, 25), dtype=np.complex64) waves = Waves(array) with pytest.raises(RuntimeError): waves.grid.check_is_defined() waves = Waves(array, extent=5) with pytest.raises(RuntimeError): waves.accelerator.check_is_defined() waves = Waves(array, extent=5, energy=60e3) waves.grid.check_is_defined() waves.accelerator.check_is_defined()
def test_multislice(): array = np.ones((1, 25, 25), dtype=np.complex64) waves = Waves(array, energy=60e3) potential = DummyPotential(extent=5) new_waves = waves.multislice(potential, pbar=False) assert potential.gpts is not None assert waves.extent is not None assert new_waves is waves new_waves = copy(new_waves) new_waves = new_waves.multislice(potential, pbar=False) assert potential.gpts is not None assert waves.extent is not None assert not np.all(np.isclose(new_waves.array, waves.array))
def test_export_import_waves(tmp_path): d = tmp_path / 'sub' d.mkdir() path = d / 'waves.hdf5' waves = Probe(semiangle_cutoff=30, sampling=.05, extent=10, energy=80e3).build() waves.write(path) imported_waves = Waves.read(path) assert np.allclose(waves.array, imported_waves.array) assert np.allclose(waves.extent, imported_waves.extent) assert np.allclose(waves.energy, imported_waves.energy)
def test_multislice_raises(): array = np.ones((1, 25, 25), dtype=np.complex64) potential = DummyPotential(extent=5) waves = Waves(array, extent=5) with pytest.raises(RuntimeError) as e: waves.multislice(potential, pbar=False) assert str(e.value) == 'Energy is not defined' waves.energy = 60e3 waves.multislice(potential, pbar=False)
def test_create_waves(): array = np.ones((1, 25, 25), dtype=np.complex64) waves = Waves(array) waves.extent = 10 assert (waves.sampling[0] == 10 / array.shape[1]) & (waves.sampling[1] == 10 / array.shape[2]) with pytest.raises(RuntimeError): waves.gpts = 200 waves = Waves(array, sampling=.2) assert (waves.extent[0] == array.shape[1] * .2) & (waves.extent[1] == array.shape[2] * .2)