def batch_align_scalar(tensor, innate_spatial_dims, target): if rank(tensor) == 0: assert innate_spatial_dims == 0 return math.expand_dims(tensor, 0, len(math.staticshape(target))) if math.staticshape(tensor)[-1] != 1 or math.ndims(tensor) <= 1: tensor = math.expand_dims(tensor, -1) result = batch_align(tensor, innate_spatial_dims + 1, target) return result
def batch_align(tensor, innate_dims, target, convert_to_same_backend=True): if isinstance(tensor, (tuple, list)): return [batch_align(t, innate_dims, target) for t in tensor] # --- Convert type --- if convert_to_same_backend: backend = math.choose_backend([tensor, target]) tensor = backend.as_tensor(tensor) target = backend.as_tensor(target) # --- Batch align --- ndims = len(math.staticshape(tensor)) if ndims <= innate_dims: return tensor # There is no batch dimension target_ndims = len(math.staticshape(target)) assert target_ndims >= ndims if target_ndims == ndims: return tensor return math.expand_dims(tensor, axis=(-innate_dims - 1), number=(target_ndims - ndims))
def fourier_poisson(tensor, times=1): """ Inverse operation to `fourier_laplace`. """ frequencies = math.fft(math.to_complex(tensor)) k = fftfreq(math.staticshape(tensor)[1:-1], mode='square') fft_laplace = -(2 * np.pi)**2 * k fft_laplace[(0, ) * math.ndims(k)] = np.inf inv_fft_laplace = 1 / fft_laplace inv_fft_laplace[(0, ) * math.ndims(k)] = 0 return math.real(math.ifft(frequencies * inv_fft_laplace**times))
def pre_validated(self, struct, item, value): tensor = math.as_tensor(value) min_rank = item.trait_kwargs['min_rank'] if callable(min_rank): min_rank = min_rank(struct) shape = math.staticshape(value) if len(shape) < min_rank: tensor = math.expand_dims(tensor, axis=0, number=min_rank - len(shape)) shape = math.staticshape(value) batch_shape = shape[:-min_rank if min_rank != 0 else None] if struct.batch_shape is None: struct.batch_shape = batch_shape else: struct.batch_shape = _combined_shape(batch_shape, struct.batch_shape, item, struct) struct.batch_rank = len(struct.batch_shape) return tensor
def fourier_laplace(tensor, times=1): """ Applies the spatial laplce operator to the given tensor with periodic boundary conditions. *Note:* The results of `fourier_laplace` and `laplace` are close but not identical. This implementation computes the laplace operator in Fourier space. The result for periodic fields is exact, i.e. no numerical instabilities can occur, even for higher-order derivatives. :param tensor: tensor, assumed to have periodic boundary conditions :param times: number of times the laplace operator is applied. The computational cost is independent of this parameter. :return: tensor of same shape as `tensor` """ frequencies = math.fft(math.to_complex(tensor)) k = fftfreq(math.staticshape(tensor)[1:-1], mode='square') fft_laplace = -(2 * np.pi)**2 * k return math.real(math.ifft(frequencies * fft_laplace**times))
def all_dimensions(tensor): return range(len(math.staticshape(tensor)))
def axes(obj): return tuple(range(len(math.staticshape(obj)) - 2))
def spatial_dimensions(obj): return tuple(range(1, len(math.staticshape(obj)) - 1))
def is_scalar(obj): return len(math.staticshape(obj)) == 0
def batch_align_scalar(tensor, innate_spatial_dims, target): if math.staticshape(tensor)[-1] != 1: tensor = math.expand_dims(tensor, -1) result = batch_align(tensor, innate_spatial_dims + 1, target) return result
def fourier_laplace(tensor): frequencies = math.fft(math.to_complex(tensor)) k = fftfreq(math.staticshape(tensor)[1:-1], mode='square') fft_laplace = -(2 * np.pi)**2 * k return math.real(math.ifft(frequencies * fft_laplace))
def rank(tensor): return len(math.staticshape(tensor))