def batch_align_scalar(tensor, innate_spatial_dims, target): if rank(tensor) == 0: assert innate_spatial_dims == 0 return math.expand_dims(tensor, 0, len(math.staticshape(target))) if math.staticshape(tensor)[-1] != 1 or math.ndims(tensor) <= 1: tensor = math.expand_dims(tensor, -1) result = batch_align(tensor, innate_spatial_dims + 1, target) return result
def fftfreq(resolution, mode='vector', dtype=None): """ Returns the discrete Fourier transform sample frequencies. These are the frequencies corresponding to the components of the result of `math.fft` on a tensor of shape `resolution`. :param resolution: grid resolution measured in cells :param mode: one of (None, 'vector', 'absolute', 'square') :param dtype: data type of the returned tensor :return: tensor holding the frequencies of the corresponding values computed by math.fft """ assert mode in ('vector', 'absolute', 'square') k = np.meshgrid(*[np.fft.fftfreq(int(n)) for n in resolution], indexing='ij') k = math.expand_dims(math.stack(k, -1), 0) if dtype is not None: k = k.astype(dtype) else: k = math.to_float(k) if mode == 'vector': return k k = math.sum(k**2, axis=-1, keepdims=True) if mode == 'square': return k else: return math.sqrt(k)
def fftfreq(resolution, mode='vector', dtype=np.float32): assert mode in ('vector', 'absolute', 'square') k = np.meshgrid(*[np.fft.fftfreq(int(n)) for n in resolution], indexing='ij') k = math.expand_dims(math.stack(k, -1), 0) k = k.astype(dtype) if mode == 'vector': return k k = math.sum(k**2, axis=-1, keepdims=True) if mode == 'square': return k else: return math.sqrt(k)
def batch_align(tensor, innate_dims, target, convert_to_same_backend=True): if isinstance(tensor, (tuple, list)): return [batch_align(t, innate_dims, target) for t in tensor] # --- Convert type --- if convert_to_same_backend: backend = math.choose_backend([tensor, target]) tensor = backend.as_tensor(tensor) target = backend.as_tensor(target) # --- Batch align --- ndims = len(math.staticshape(tensor)) if ndims <= innate_dims: return tensor # There is no batch dimension target_ndims = len(math.staticshape(target)) assert target_ndims >= ndims if target_ndims == ndims: return tensor return math.expand_dims(tensor, axis=(-innate_dims - 1), number=(target_ndims - ndims))
def pre_validated(self, struct, item, value): tensor = math.as_tensor(value) min_rank = item.trait_kwargs['min_rank'] if callable(min_rank): min_rank = min_rank(struct) shape = math.staticshape(value) if len(shape) < min_rank: tensor = math.expand_dims(tensor, axis=0, number=min_rank - len(shape)) shape = math.staticshape(value) batch_shape = shape[:-min_rank if min_rank != 0 else None] if struct.batch_shape is None: struct.batch_shape = batch_shape else: struct.batch_shape = _combined_shape(batch_shape, struct.batch_shape, item, struct) struct.batch_rank = len(struct.batch_shape) return tensor
def spatial_sum(tensor): summed = math.sum(tensor, axis=math.dimrange(tensor)) for i in math.dimrange(tensor): summed = math.expand_dims(summed, i) return summed
def batch_align_scalar(tensor, innate_spatial_dims, target): if math.staticshape(tensor)[-1] != 1: tensor = math.expand_dims(tensor, -1) result = batch_align(tensor, innate_spatial_dims + 1, target) return result