Ejemplo n.º 1
0
def batch_align_scalar(tensor, innate_spatial_dims, target):
    if rank(tensor) == 0:
        assert innate_spatial_dims == 0
        return math.expand_dims(tensor, 0, len(math.staticshape(target)))
    if math.staticshape(tensor)[-1] != 1 or math.ndims(tensor) <= 1:
        tensor = math.expand_dims(tensor, -1)
    result = batch_align(tensor, innate_spatial_dims + 1, target)
    return result
Ejemplo n.º 2
0
def fftfreq(resolution, mode='vector', dtype=None):
    """
    Returns the discrete Fourier transform sample frequencies.
    These are the frequencies corresponding to the components of the result of `math.fft` on a tensor of shape `resolution`.

    :param resolution: grid resolution measured in cells
    :param mode: one of (None, 'vector', 'absolute', 'square')
    :param dtype: data type of the returned tensor
    :return: tensor holding the frequencies of the corresponding values computed by math.fft
    """
    assert mode in ('vector', 'absolute', 'square')
    k = np.meshgrid(*[np.fft.fftfreq(int(n)) for n in resolution],
                    indexing='ij')
    k = math.expand_dims(math.stack(k, -1), 0)
    if dtype is not None:
        k = k.astype(dtype)
    else:
        k = math.to_float(k)
    if mode == 'vector':
        return k
    k = math.sum(k**2, axis=-1, keepdims=True)
    if mode == 'square':
        return k
    else:
        return math.sqrt(k)
Ejemplo n.º 3
0
def fftfreq(resolution, mode='vector', dtype=np.float32):
    assert mode in ('vector', 'absolute', 'square')
    k = np.meshgrid(*[np.fft.fftfreq(int(n)) for n in resolution], indexing='ij')
    k = math.expand_dims(math.stack(k, -1), 0)
    k = k.astype(dtype)
    if mode == 'vector':
        return k
    k = math.sum(k**2, axis=-1, keepdims=True)
    if mode == 'square':
        return k
    else:
        return math.sqrt(k)
Ejemplo n.º 4
0
def batch_align(tensor, innate_dims, target, convert_to_same_backend=True):
    if isinstance(tensor, (tuple, list)):
        return [batch_align(t, innate_dims, target) for t in tensor]
    # --- Convert type ---
    if convert_to_same_backend:
        backend = math.choose_backend([tensor, target])
        tensor = backend.as_tensor(tensor)
        target = backend.as_tensor(target)
    # --- Batch align ---
    ndims = len(math.staticshape(tensor))
    if ndims <= innate_dims:
        return tensor  # There is no batch dimension
    target_ndims = len(math.staticshape(target))
    assert target_ndims >= ndims
    if target_ndims == ndims:
        return tensor
    return math.expand_dims(tensor, axis=(-innate_dims - 1), number=(target_ndims - ndims))
Ejemplo n.º 5
0
 def pre_validated(self, struct, item, value):
     tensor = math.as_tensor(value)
     min_rank = item.trait_kwargs['min_rank']
     if callable(min_rank):
         min_rank = min_rank(struct)
     shape = math.staticshape(value)
     if len(shape) < min_rank:
         tensor = math.expand_dims(tensor,
                                   axis=0,
                                   number=min_rank - len(shape))
         shape = math.staticshape(value)
     batch_shape = shape[:-min_rank if min_rank != 0 else None]
     if struct.batch_shape is None:
         struct.batch_shape = batch_shape
     else:
         struct.batch_shape = _combined_shape(batch_shape,
                                              struct.batch_shape, item,
                                              struct)
     struct.batch_rank = len(struct.batch_shape)
     return tensor
Ejemplo n.º 6
0
def spatial_sum(tensor):
    summed = math.sum(tensor, axis=math.dimrange(tensor))
    for i in math.dimrange(tensor):
        summed = math.expand_dims(summed, i)
    return summed
Ejemplo n.º 7
0
def batch_align_scalar(tensor, innate_spatial_dims, target):
    if math.staticshape(tensor)[-1] != 1:
        tensor = math.expand_dims(tensor, -1)
    result = batch_align(tensor, innate_spatial_dims + 1, target)
    return result