def downsample2x(tensor, interpolation='linear', axes=None): if struct.isstruct(tensor): return struct.map(lambda s: downsample2x(s, interpolation, axes), tensor, recursive=False) if interpolation.lower() != 'linear': raise ValueError('Only linear interpolation supported') rank = spatial_rank(tensor) if axes is None: axes = range(rank) tensor = math.pad( tensor, [[0, 0]] + [([0, 1] if (dim % 2) != 0 and _contains_axis(axes, ax, rank) else [0, 0]) for ax, dim in enumerate(tensor.shape[1:-1])] + [[0, 0]], 'replicate') for axis in axes: upper_slices = tuple([(slice(1, None, 2) if i == axis else slice(None)) for i in range(rank)]) lower_slices = tuple([(slice(0, None, 2) if i == axis else slice(None)) for i in range(rank)]) tensor_sum = tensor[(slice(None), ) + upper_slices + (slice(None), )] + tensor[(slice(None), ) + lower_slices + (slice(None), )] tensor = tensor_sum / 2 return tensor
def laplace(tensor, padding='replicate', axes=None, use_fft_for_periodic=False): """ Spatial Laplace operator as defined for scalar fields. If a vector field is passed, the laplace is computed component-wise. :param use_fft_for_periodic: If True and padding='circular', uses FFT to compute laplace :param tensor: n-dimensional field of shape (batch, spacial dimensions..., components) :param padding: 'valid', 'constant', 'reflect', 'replicate', 'circular' :param axes: The second derivative along these axes is summed over :type axes: list :return: tensor of same shape """ rank = spatial_rank(tensor) if padding is None or padding == 'valid': pass # do not pad tensor elif padding in ('circular', 'wrap') and use_fft_for_periodic: return fourier_laplace(tensor) else: tensor = math.pad( tensor, _get_pad_width_axes(rank, axes, val_true=[1, 1], val_false=[0, 0]), padding) # --- convolutional laplace --- if axes is not None: return _sliced_laplace_nd(tensor, axes) if rank == 2: return _conv_laplace_2d(tensor) elif rank == 3: return _conv_laplace_3d(tensor) else: return _sliced_laplace_nd(tensor)
def _gradient_nd(tensor, padding, relative_shifts): rank = spatial_rank(tensor) tensor = math.pad(tensor, _get_pad_width(rank, (-relative_shifts[0], relative_shifts[1])), mode=padding) components = [] for dimension in range(rank): lower, upper = _dim_shifted(tensor, dimension, relative_shifts, diminish_others=(-relative_shifts[0], relative_shifts[1])) components.append(upper - lower) return math.concat(components, axis=-1)
def _divergence_nd(tensor, relative_shifts): rank = spatial_rank(tensor) tensor = math.pad(tensor, _get_pad_width(rank, (-relative_shifts[0], relative_shifts[1]))) components = [] for dimension in range(rank): lower, upper = _dim_shifted(tensor, dimension, relative_shifts, diminish_others=(-relative_shifts[0], relative_shifts[1]), components=rank - dimension - 1) components.append(upper - lower) return math.sum(components, 0)
def upsample2x(tensor, interpolation='linear'): if struct.isstruct(tensor): return struct.map(lambda s: upsample2x(s, interpolation), tensor, recursive=False) if interpolation.lower() != 'linear': raise ValueError('Only linear interpolation supported') dims = range(spatial_rank(tensor)) vlen = tensor.shape[-1] spatial_dims = tensor.shape[1:-1] rank = spatial_rank(tensor) tensor = math.pad(tensor, _get_pad_width(rank), 'replicate') for dim in dims: lower, center, upper = _dim_shifted(tensor, dim, (-1, 0, 1)) combined = math.stack([0.25 * lower + 0.75 * center, 0.75 * center + 0.25 * upper], axis=2 + dim) tensor = math.reshape(combined, [-1] + [spatial_dims[dim] * 2 if i == dim else tensor.shape[i + 1] for i in dims] + [vlen]) return tensor
def downsample2x(tensor, interpolation='linear'): if struct.isstruct(tensor): return struct.map(lambda s: downsample2x(s, interpolation), tensor, recursive=False) if interpolation.lower() != 'linear': raise ValueError('Only linear interpolation supported') dims = range(spatial_rank(tensor)) tensor = math.pad(tensor, [[0, 0]] + [([0, 1] if (dim % 2) != 0 else [0, 0]) for dim in tensor.shape[1:-1]] + [[0, 0]], 'replicate') for dimension in dims: upper_slices = tuple([(slice(1, None, 2) if i == dimension else slice(None)) for i in dims]) lower_slices = tuple([(slice(0, None, 2) if i == dimension else slice(None)) for i in dims]) tensor_sum = tensor[(slice(None),) + upper_slices + (slice(None),)] + tensor[(slice(None),) + lower_slices + (slice(None),)] tensor = tensor_sum / 2 return tensor