示例#1
0
def interpolate_linear(tensor, upper_weight, dimensions):
    """

    :param tensor:
    :param upper_weight: tensor of floats (leading dimensions must be 1) or nan to ignore interpolation along this axis
    :param dimensions: list or tuple of dimensions (first spatial axis=1) to be interpolated. Other axes are ignored.
    :return:
    """
    lower_weight = 1 - upper_weight
    for dimension in spatial_dimensions(tensor):
        if dimension in dimensions:
            upper_slices = tuple([
                (slice(1, None) if i == dimension else slice(None))
                for i in all_dimensions(tensor)
            ])
            lower_slices = tuple([
                (slice(-1) if i == dimension else slice(None))
                for i in all_dimensions(tensor)
            ])
            tensor = math.mul(tensor[upper_slices],
                              upper_weight[..., dimension - 1]) + math.mul(
                                  tensor[lower_slices],
                                  lower_weight[..., dimension - 1])
    return tensor
示例#2
0
def l1_loss(tensor, batch_norm=True, reduce_batches=True):
    if struct.isstruct(tensor):
        all_tensors = struct.flatten(tensor)
        return sum(
            l1_loss(tensor, batch_norm, reduce_batches)
            for tensor in all_tensors)
    if reduce_batches:
        total_loss = math.sum(math.abs(tensor))
    else:
        total_loss = math.sum(math.abs(tensor),
                              axis=list(range(1, len(tensor.shape))))
    if batch_norm and reduce_batches:
        batch_size = math.shape(tensor)[0]
        return math.div(total_loss, math.to_float(batch_size))
    else:
        return total_loss
示例#3
0
def _sliced_laplace_nd(tensor, axes=None):
    """
    Laplace Stencil for N-Dimensions
    aggregated from (c)enter, (u)pper, and (l)ower parts
    """
    rank = spatial_rank(tensor)
    dims = range(rank)
    components = []
    for ax in dims:
        if _contains_axis(axes, ax, rank):
            lower, center, upper = _dim_shifted(
                tensor,
                ax, (-1, 0, 1),
                diminish_others=(1, 1),
                diminish_other_condition=lambda other_ax: _contains_axis(
                    axes, other_ax, rank))
            components.append(upper + lower - 2 * center)
    return math.sum(components, 0)
示例#4
0
def downsample2x(tensor, interpolation='linear'):
    if struct.isstruct(tensor):
        return struct.map(lambda s: downsample2x(s, interpolation),
                          tensor, recursive=False)

    if interpolation.lower() != 'linear':
        raise ValueError('Only linear interpolation supported')
    dims = range(spatial_rank(tensor))
    tensor = math.pad(tensor,
                      [[0, 0]]
                      + [([0, 1] if (dim % 2) != 0 else [0, 0]) for dim in tensor.shape[1:-1]]
                      + [[0, 0]], 'replicate')
    for dimension in dims:
        upper_slices = tuple([(slice(1, None, 2) if i == dimension else slice(None)) for i in dims])
        lower_slices = tuple([(slice(0, None, 2) if i == dimension else slice(None)) for i in dims])
        tensor_sum = tensor[(slice(None),) + upper_slices + (slice(None),)] + tensor[(slice(None),) + lower_slices + (slice(None),)]
        tensor = tensor_sum / 2
    return tensor
示例#5
0
def indices_tensor(tensor, dtype=None):
    """
    Returns an index tensor of the same spatial shape as the given tensor.
    Each index denotes the location within the tensor starting from zero.
    Indices are encoded as vectors in the index tensor.

    :param tensor: a tensor of shape (batch size, spatial dimensions..., component size)
    :param dtype: NumPy data type or `None` for default
    :return: an index tensor of shape (1, spatial dimensions..., spatial rank)
    """
    spatial_dimensions = list(tensor.shape[1:-1])
    idx_zyx = np.meshgrid(*[range(dim) for dim in spatial_dimensions],
                          indexing='ij')
    idx = np.stack(idx_zyx, axis=-1).reshape([
        1,
    ] + spatial_dimensions + [len(spatial_dimensions)])
    if dtype is not None:
        return idx.astype(dtype)
    else:
        return math.to_float(idx)
示例#6
0
def conjugate_gradient(k,
                       apply_A,
                       initial_x=None,
                       accuracy=1e-5,
                       max_iterations=1024,
                       back_prop=False):
    """
    Solve the linear system of equations Ax=k using the conjugate gradient (CG) algorithm.
    The implementation is based on https://nvlpubs.nist.gov/nistpubs/jres/049/jresv49n6p409_A1b.pdf

    :param k: Right-hand-side vector
    :param apply_A: function that takes x and calculates Ax
    :param initial_x: initial guess for the value of x
    :param accuracy: the algorithm terminates once |Ax-k| ≤ accuracy for every element. If None, the algorithm runs until max_iterations is reached.
    :param max_iterations: maximum number of CG iterations to perform
    :return: Pair containing the result for x and the number of iterations performed
    """
    # Get momentum = k - Ax
    if initial_x is None:
        x = math.zeros_like(k)
        momentum = k
    else:
        x = initial_x
        momentum = k - apply_A(x)
    # Further Variables
    residual = momentum  # residual is previous momentum
    laplace_momentum = apply_A(momentum)  # = A*momentum
    loop_index = 0  # initial
    # Pack Variables for loop
    variables = [x, momentum, laplace_momentum, residual, loop_index]
    # Ensure to run until desired accuracy is achieved
    if accuracy is not None:

        def loop_condition(_1, _2, _3, residual, _i):
            '''continue if the maximum deviation from zero is bigger than desired accuracy'''
            return math.max(math.abs(residual)) >= accuracy
    else:

        def loop_condition(*_args):
            return True

    non_batch_dims = tuple(range(1, len(k.shape)))

    def loop_body(pressure, momentum, A_times_momentum, residual, loop_index):
        """
        iteratively solve for:
        x : pressure
        momentum : momentum
        laplace_momentum : A_times_momentum
        residual : residual
        """
        tmp = math.sum(momentum * A_times_momentum,
                       axis=non_batch_dims,
                       keepdims=True)  # t = sum(mAm)
        a = math.divide_no_nan(
            math.sum(momentum * residual, axis=non_batch_dims, keepdims=True),
            tmp)  # a = sum(mr)/sum(mAm)
        pressure += a * momentum  # p += am
        residual -= a * A_times_momentum  # r -= aAm
        momentum = residual - math.divide_no_nan(
            math.sum(residual * A_times_momentum,
                     axis=non_batch_dims,
                     keepdims=True) * momentum,
            tmp)  # m = r-sum(rAm)*m/t = r-sum(rAm)*m/sum(mAm)
        A_times_momentum = apply_A(momentum)  # Am = A*m
        return [pressure, momentum, A_times_momentum, residual, loop_index + 1]

    x, momentum, laplace_momentum, residual, loop_index = math.while_loop(
        loop_condition,
        loop_body,
        variables,
        parallel_iterations=2,
        back_prop=back_prop,
        swap_memory=False,
        name="pressure_solve_loop",
        maximum_iterations=max_iterations)

    return x, loop_index
示例#7
0
    axis_gradient,
    laplace,
    fourier_laplace,
    fourier_poisson,
    fftfreq,
    abs_square,
    downsample2x,
    upsample2x,
    interpolate_linear,
    spatial_sum,
)
from .batched import BATCHED, ShapeMismatch
from . import optim

# Setup Backend
DYNAMIC_BACKEND.add_backend(SciPyBackend())
DYNAMIC_BACKEND.add_backend(StructBroadcastBackend(DYNAMIC_BACKEND))


def set_precision(floating_point_bits):
    """
    Sets the floating point precision of DYNAMIC_BACKEND which affects all registered backends.

    If `floating_point_bits` is an integer, all floating point tensors created henceforth will be of the corresponding data type, float16, float32 or float64.
    Operations may also convert floating point values to this precision, even if the input had a different precision.

    If `floating_point_bits` is None, new tensors will default to float32 unless specified otherwise.
    The output of math operations has the same precision as its inputs.

    :param floating_point_bits: one of (16, 32, 64, None)
    """
示例#8
0
def spatial_rank(tensor):
    """ The spatial rank of a tensor is ndims - 2. """
    return math.ndims(tensor) - 2
示例#9
0
def is_scalar(tensor):
    return math.ndims(tensor) == 0
示例#10
0
def all_dimensions(tensor):
    return range(len(math.staticshape(tensor)))
示例#11
0
def fourier_laplace(tensor):
    frequencies = math.fft(math.to_complex(tensor))
    k = fftfreq(math.staticshape(tensor)[1:-1], mode='square')
    fft_laplace = -(2 * np.pi)**2 * k
    return math.real(math.ifft(frequencies * fft_laplace))
示例#12
0
def spatial_sum(tensor):
    summed = math.sum(tensor, axis=math.dimrange(tensor))
    for i in math.dimrange(tensor):
        summed = math.expand_dims(summed, i)
    return summed
示例#13
0
def randn(shape, dtype=None):
    array = np.random.randn(*_none_to_one(shape))
    if dtype is not None:
        return array.astype(dtype)
    else:
        return math.to_float(array)
示例#14
0
def ones(shape, dtype=None):
    if dtype is not None:
        return np.ones(_none_to_one(shape), dtype)
    else:
        return math.to_float(np.ones(_none_to_one(shape), np.int8))
示例#15
0
def is_scalar(obj):
    return len(math.staticshape(obj)) == 0
示例#16
0
def batch_align_scalar(tensor, innate_spatial_dims, target):
    if math.staticshape(tensor)[-1] != 1:
        tensor = math.expand_dims(tensor, -1)
    result = batch_align(tensor, innate_spatial_dims + 1, target)
    return result
示例#17
0
def abs_square(complex):
    return math.imag(complex)**2 + math.real(complex)**2
示例#18
0
 def loop_condition(_1, _2, _3, residual, _i):
     '''continue if the maximum deviation from zero is bigger than desired accuracy'''
     return math.max(math.abs(residual)) >= accuracy
示例#19
0
def spatial_dimensions(obj):
    return tuple(range(1, len(math.staticshape(obj)) - 1))
示例#20
0
def _max_residual_condition(residual_index, accuracy):
    """continue if the maximum deviation from zero is bigger than desired accuracy"""
    if accuracy is None:
        return lambda *args: True
    else:
        return lambda *args: math.max(math.abs(args[residual_index])) > accuracy
示例#21
0
def axes(obj):
    return tuple(range(len(math.staticshape(obj)) - 2))
示例#22
0
def rank(tensor):
    return len(math.staticshape(tensor))