def _map_coordinates(input, coordinates, order, mode, cval): input = jnp.asarray(input) coordinates = [jnp.asarray(c) for c in coordinates] cval = jnp.asarray(cval, input.dtype) if len(coordinates) != input.ndim: raise ValueError( 'coordinates must be a sequence of length input.ndim, but ' '{} != {}'.format(len(coordinates), input.ndim)) index_fixer = _INDEX_FIXERS.get(mode) if index_fixer is None: raise NotImplementedError( 'jax.scipy.ndimage.map_coordinates does not yet support mode {}. ' 'Currently supported modes are {}.'.format(mode, set(_INDEX_FIXERS))) if mode == 'constant': is_valid = lambda index, size: (0 <= index) & (index < size) else: is_valid = lambda index, size: True if order == 0: interp_fun = _nearest_indices_and_weights elif order == 1: interp_fun = _linear_indices_and_weights else: raise NotImplementedError( 'jax.scipy.ndimage.map_coordinates currently requires order<=1') valid_1d_interpolations = [] for coordinate, size in zip(coordinates, input.shape): interp_nodes = interp_fun(coordinate) valid_interp = [] for index, weight in interp_nodes: fixed_index = index_fixer(index, size) valid = is_valid(index, size) valid_interp.append((fixed_index, valid, weight)) valid_1d_interpolations.append(valid_interp) outputs = [] for items in itertools.product(*valid_1d_interpolations): indices, validities, weights = zip(*items) if all(valid is True for valid in validities): # fast path contribution = input[indices] else: all_valid = functools.reduce(operator.and_, validities) contribution = jnp.where(all_valid, input[indices], cval) outputs.append(_nonempty_prod(weights) * contribution) result = _nonempty_sum(outputs) if jnp.issubdtype(input.dtype, jnp.integer): result = _round_half_away_from_zero(result) return result.astype(input.dtype)
def _parse_input_dimensions( args: Tuple[NDArray, ...], input_core_dims: List[CoreDims], error_context: str = "", ) -> Tuple[Tuple[int, ...], Dict[str, int]]: """Parse broadcast and core dimensions for vectorize with a signature. Args: args: tuple of input arguments to examine. input_core_dims: list of core dimensions corresponding to each input. error_context: string context for error messages. Returns: broadcast_shape: common shape to broadcast all non-core dimensions to. dim_sizes: common sizes for named core dimensions. """ if len(args) != len(input_core_dims): raise TypeError( 'wrong number of positional arguments: expected %r, got %r %s' % (len(input_core_dims), len(args), error_context)) shapes = [] dim_sizes: Dict[str, int] = {} for arg, core_dims in zip(args, input_core_dims): _update_dim_sizes(dim_sizes, arg.shape, core_dims, error_context, is_input=True) ndim = arg.ndim - len(core_dims) shapes.append(arg.shape[:ndim]) broadcast_shape = lax.broadcast_shapes(*shapes) # TODO(mattjj): this code needs updating for dynamic shapes (hence ignore) return broadcast_shape, dim_sizes # type: ignore
def _update_dim_sizes(dim_sizes: Dict[str, int], shape: Tuple[int, ...], core_dims: CoreDims, error_context: str = "", *, is_input: bool): """Incrementally check and update core dimension sizes for a single argument. Args: dim_sizes: sizes of existing core dimensions. Will be updated in-place. shape: shape of this argument. core_dims: core dimensions for this argument. error_context: string context for error messages. is_input: are we parsing input or output arguments? """ num_core_dims = len(core_dims) if is_input: if len(shape) < num_core_dims: raise ValueError( 'input with shape %r does not have enough dimensions for all core ' 'dimensions %r %s' % (shape, core_dims, error_context)) else: if len(shape) != num_core_dims: raise ValueError( 'output shape %r does not match core dimensions %r %s' % (shape, core_dims, error_context)) core_shape = shape[-num_core_dims:] if core_dims else () for dim, size in zip(core_dims, core_shape): if dim not in dim_sizes: dim_sizes[dim] = size elif size != dim_sizes[dim]: raise ValueError( 'inconsistent size for core dimension %r: %r vs %r %s' % (dim, size, dim_sizes[dim], error_context))
def wrapped(*args): out = func(*args) out_shapes = map(jnp.shape, out if isinstance(out, tuple) else [out]) if expected_output_core_dims is None: output_core_dims = [()] * len(out_shapes) else: output_core_dims = expected_output_core_dims if len(output_core_dims) > 1 and not isinstance(out, tuple): raise TypeError( "output must be a tuple when multiple outputs are expected, " "got: {!r}\n{}".format(out, error_context)) if len(out_shapes) != len(output_core_dims): raise TypeError( 'wrong number of output arguments: expected %r, got %r %s' % (len(output_core_dims), len(out_shapes), error_context)) sizes = dict(dim_sizes) for shape, core_dims in zip(out_shapes, output_core_dims): _update_dim_sizes(sizes, shape, core_dims, error_context, is_input=False) return out
def wrapped(*args): error_context = ("on vectorized function with excluded={!r} and " "signature={!r}".format(excluded, signature)) excluded_func, args = _apply_excluded(pyfunc, excluded, args) args = tuple(map(jnp.asarray, args)) if signature is not None: input_core_dims, output_core_dims = _parse_gufunc_signature( signature) else: input_core_dims = [()] * len(args) output_core_dims = None broadcast_shape, dim_sizes = _parse_input_dimensions( args, input_core_dims, error_context) checked_func = _check_output_dims(excluded_func, dim_sizes, output_core_dims, error_context) # Rather than broadcasting all arguments to full broadcast shapes, prefer # expanding dimensions using vmap when possible. By pushing broadcasting # into vmap, we can make use of more efficient batching rules for # primitives where only some arguments are batched (e.g., for # lax_linalg.triangular_solve). vec_args = [] vmap_counts = [] for arg, core_dims in zip(args, input_core_dims): # Explicitly broadcast the dimensions already found on each argument, # because these dimensiosns might be of size 1, which vmap doesn't # handle. # TODO(shoyer): Consider squeezing out size 1 dimensions instead, and # doing all vectorization with vmap? This *might* be a little more # efficient but would require more careful book-keeping. core_shape = tuple(dim_sizes[dim] for dim in core_dims) full_shape = broadcast_shape + core_shape vec_shape = full_shape[-arg.ndim:] if arg.ndim else () vec_arg = jnp.broadcast_to(arg, vec_shape) vec_args.append(vec_arg) vmap_count = len(vec_shape) - len(core_shape) vmap_counts.append(vmap_count) vectorized_func = checked_func while any(vmap_counts): in_axes = tuple(0 if c > 0 else None for c in vmap_counts) vmap_counts = [max(c - 1, 0) for c in vmap_counts] vectorized_func = api.vmap(vectorized_func, in_axes) return vectorized_func(*vec_args)