def test_fft_dims(self): for backend in BACKENDS: with backend: x = math.random_normal(batch(x=8, y=6, z=4)) k3 = math.fft(x, 'x,y,z') k = x for dim in 'xyz': k = math.fft(k, dim) math.assert_close(k, k3, abs_tolerance=1e-5, msg=backend.name)
def fourier(field: GridType, diffusivity: float or math.Tensor, dt: float or math.Tensor) -> FieldType: """ Exact diffusion of a periodic field in frequency space. For non-periodic fields or non-constant diffusivity, use another diffusion function such as `explicit()`. Args: field: diffusivity: Diffusion per time. `diffusion_amount = diffusivity * dt` dt: Time interval. `diffusion_amount = diffusivity * dt` Returns: Diffused field of same type as `field`. """ if isinstance(field, ConstantField): return field assert isinstance(field, Grid), "Cannot diffuse field of type '%s'" % type(field) assert field.extrapolation == math.extrapolation.PERIODIC, "Fourier diffusion can only be applied to periodic fields." amount = diffusivity * dt k = math.fftfreq(field.resolution) k2 = math.vec_squared(k) fft_laplace = -(2 * math.PI)**2 * k2 diffuse_kernel = math.exp(fft_laplace * amount) result_k = math.fft(field.values) * diffuse_kernel result_values = math.real(math.ifft(result_k)) return field.with_(values=result_values)
def test_ifft(self): dimensions = 'xyz' for backend in BACKENDS: with backend: for d in range(1, len(dimensions) + 1): x = math.random_normal(spatial(**{dim: 6 for dim in dimensions[:d]})) + math.tensor((0, 1), batch('batch')) k = math.fft(x) x_ = math.ifft(k) math.assert_close(x, x_, abs_tolerance=1e-5, msg=backend.name)
def diffuse(field, amount, substeps=1): assert isinstance(field, CenteredGrid) if field.extrapolation == 'periodic': frequencies = math.fft(math.to_complex(field.data)) k = math.fftfreq(field.resolution) / field.dx k = math.sum(k**2, axis=-1, keepdims=True) fft_laplace = -(2 * pi)**2 * k diffuse_kernel = math.to_complex(math.exp(fft_laplace * amount)) data = math.ifft(frequencies * diffuse_kernel) data = math.real(data) else: data = field.data for i in range(substeps): data += amount / substeps * field.laplace() return field.with_data(data)
def test_fft(self): def get_2d_sine(grid_size, L): indices = np.array(np.meshgrid(*list(map(range, grid_size)))) phys_coord = indices.T * L / (grid_size[0]) # between [0, L) x, y = phys_coord.T d = np.sin(2 * np.pi * x + 1) * np.sin(2 * np.pi * y + 1) return d sine_field = get_2d_sine((32, 32), L=2) fft_ref_tensor = math.wrap(np.fft.fft2(sine_field), 'x,y') with math.precision(64): for backend in BACKENDS: with backend: sine_tensor = math.tensor(sine_field, 'x,y') fft_tensor = math.fft(sine_tensor) math.assert_close( fft_ref_tensor, fft_tensor, abs_tolerance=1e-12 ) # Should usually be more precise. GitHub Actions has larger errors than usual.
def test_fft(self): def get_2d_sine(grid_size, L): indices = np.array(np.meshgrid(*list(map(range, grid_size)))) phys_coord = indices.T * L / (grid_size[0]) # between [0, L) x, y = phys_coord.T d = np.sin(2 * np.pi * x + 1) * np.sin(2 * np.pi * y + 1) return d sine_field = get_2d_sine((32, 32), L=2) fft_ref_tensor = math.wrap(np.fft.fft2(sine_field), spatial('x,y')) with math.precision(64): for backend in BACKENDS: if backend.name != 'Jax': # TODO Jax casts to float32 / complex64 on GitHub Actions with backend: sine_tensor = math.tensor(sine_field, spatial('x,y')) fft_tensor = math.fft(sine_tensor) self.assertEqual(fft_tensor.dtype, math.DType(complex, 128), msg=backend.name) math.assert_close(fft_ref_tensor, fft_tensor, abs_tolerance=1e-12, msg=backend.name) # Should usually be more precise. GitHub Actions has larger errors than usual.
def step(self, state, dt=1.0, potentials=(), obstacles=()): if len(potentials) == 0: potential = 0 else: potential = math.zeros_like(math.real( state.amplitude)) # for the moment, allow only real potentials for pot in potentials: potential = effect_applied(pot, potential, dt) potential = potential.data amplitude = state.amplitude.data # Rotate by potential rotation = math.exp(1j * math.to_complex(potential * dt)) amplitude = amplitude * rotation # Move by rotating in Fourier space amplitude_fft = math.fft(amplitude) laplace = math.fftfreq(state.resolution, mode='square') amplitude_fft *= math.exp(-1j * (2 * np.pi)**2 * math.to_complex(dt) * laplace / (2 * state.mass)) amplitude = math.ifft(amplitude_fft) obstacle_mask = union_mask([ obstacle.geometry for obstacle in obstacles ]).at(state.amplitude).data amplitude *= 1 - obstacle_mask normalized = False symmetric = False if not symmetric: boundary_mask = math.zeros( state.domain.centered_shape(1, batch_size=1)).data boundary_mask[[slice(None)] + [ slice(self.margin, -self.margin) for i in math.spatial_dimensions(boundary_mask) ] + [slice(None)]] = 1 amplitude *= boundary_mask if len(obstacles) > 0 or not symmetric: amplitude = normalize_probability(amplitude) normalized = True return state.copied_with(amplitude=amplitude)
def diffuse(field, amount, substeps=1): u""" Simulate a finite-time diffusion process of the form dF/dt = α · ΔF on a given `Field` F with diffusion coefficient α. If `field` is periodic (set via `extrapolation='periodic'`), diffusion may be simulated in Fourier space. Otherwise, finite differencing is used to approximate the :param field: CenteredGrid, StaggeredGrid or ConstantField :param amount: number of Field, typically α · dt :param substeps: number of iterations to use :return: Field of same type as `field` :rtype: Field """ if isinstance(field, ConstantField): return field if isinstance(field, StaggeredGrid): return struct.map( lambda grid: diffuse(grid, amount, substeps=substeps), field, leaf_condition=lambda x: isinstance(x, CenteredGrid)) assert isinstance( field, CenteredGrid), "Cannot diffuse field of type '%s'" % type(field) if field.extrapolation == 'periodic' and not isinstance(amount, Field): frequencies = math.fft(field.data) k = math.fftfreq(field.resolution) / field.dx k = math.sum(k**2, axis=-1, keepdims=True) fft_laplace = -(2 * pi)**2 * k diffuse_kernel = math.to_complex(math.exp(fft_laplace * amount)) data = math.ifft(frequencies * diffuse_kernel) data = math.real(data) else: data = field.data if isinstance(amount, Field): amount = amount.at(field).data else: amount = math.batch_align(amount, 0, data) for i in range(substeps): data += amount / substeps * field.laplace().data return field.with_data(data)
def fft(self): return self.with_data(math.fft(self.data))