Ejemplo n.º 1
0
def _map_coordinates(input, coordinates, order, mode, cval):
    input = jnp.asarray(input)
    coordinates = [jnp.asarray(c) for c in coordinates]
    cval = jnp.asarray(cval, input.dtype)

    if len(coordinates) != input.ndim:
        raise ValueError(
            'coordinates must be a sequence of length input.ndim, but '
            '{} != {}'.format(len(coordinates), input.ndim))

    index_fixer = _INDEX_FIXERS.get(mode)
    if index_fixer is None:
        raise NotImplementedError(
            'jax.scipy.ndimage.map_coordinates does not yet support mode {}. '
            'Currently supported modes are {}.'.format(mode,
                                                       set(_INDEX_FIXERS)))

    if mode == 'constant':
        is_valid = lambda index, size: (0 <= index) & (index < size)
    else:
        is_valid = lambda index, size: True

    if order == 0:
        interp_fun = _nearest_indices_and_weights
    elif order == 1:
        interp_fun = _linear_indices_and_weights
    else:
        raise NotImplementedError(
            'jax.scipy.ndimage.map_coordinates currently requires order<=1')

    valid_1d_interpolations = []
    for coordinate, size in zip(coordinates, input.shape):
        interp_nodes = interp_fun(coordinate)
        valid_interp = []
        for index, weight in interp_nodes:
            fixed_index = index_fixer(index, size)
            valid = is_valid(index, size)
            valid_interp.append((fixed_index, valid, weight))
        valid_1d_interpolations.append(valid_interp)

    outputs = []
    for items in itertools.product(*valid_1d_interpolations):
        indices, validities, weights = zip(*items)
        if all(valid is True for valid in validities):
            # fast path
            contribution = input[indices]
        else:
            all_valid = functools.reduce(operator.and_, validities)
            contribution = jnp.where(all_valid, input[indices], cval)
        outputs.append(_nonempty_prod(weights) * contribution)
    result = _nonempty_sum(outputs)
    if jnp.issubdtype(input.dtype, jnp.integer):
        result = _round_half_away_from_zero(result)
    return result.astype(input.dtype)
Ejemplo n.º 2
0
def _parse_input_dimensions(
    args: Tuple[NDArray, ...],
    input_core_dims: List[CoreDims],
    error_context: str = "",
) -> Tuple[Tuple[int, ...], Dict[str, int]]:
  """Parse broadcast and core dimensions for vectorize with a signature.

  Args:
    args: tuple of input arguments to examine.
    input_core_dims: list of core dimensions corresponding to each input.
    error_context: string context for error messages.

  Returns:
    broadcast_shape: common shape to broadcast all non-core dimensions to.
    dim_sizes: common sizes for named core dimensions.
  """
  if len(args) != len(input_core_dims):
    raise TypeError(
        'wrong number of positional arguments: expected %r, got %r %s'
        % (len(input_core_dims), len(args), error_context))
  shapes = []
  dim_sizes: Dict[str, int] = {}
  for arg, core_dims in zip(args, input_core_dims):
    _update_dim_sizes(dim_sizes, arg.shape, core_dims, error_context,
                      is_input=True)
    ndim = arg.ndim - len(core_dims)
    shapes.append(arg.shape[:ndim])
  broadcast_shape = lax.broadcast_shapes(*shapes)
  # TODO(mattjj): this code needs updating for dynamic shapes (hence ignore)
  return broadcast_shape, dim_sizes  # type: ignore
Ejemplo n.º 3
0
def _update_dim_sizes(dim_sizes: Dict[str, int],
                      shape: Tuple[int, ...],
                      core_dims: CoreDims,
                      error_context: str = "",
                      *,
                      is_input: bool):
    """Incrementally check and update core dimension sizes for a single argument.

  Args:
    dim_sizes: sizes of existing core dimensions. Will be updated in-place.
    shape: shape of this argument.
    core_dims: core dimensions for this argument.
    error_context: string context for error messages.
    is_input: are we parsing input or output arguments?
  """
    num_core_dims = len(core_dims)
    if is_input:
        if len(shape) < num_core_dims:
            raise ValueError(
                'input with shape %r does not have enough dimensions for all core '
                'dimensions %r %s' % (shape, core_dims, error_context))
    else:
        if len(shape) != num_core_dims:
            raise ValueError(
                'output shape %r does not match core dimensions %r %s' %
                (shape, core_dims, error_context))

    core_shape = shape[-num_core_dims:] if core_dims else ()
    for dim, size in zip(core_dims, core_shape):
        if dim not in dim_sizes:
            dim_sizes[dim] = size
        elif size != dim_sizes[dim]:
            raise ValueError(
                'inconsistent size for core dimension %r: %r vs %r %s' %
                (dim, size, dim_sizes[dim], error_context))
Ejemplo n.º 4
0
    def wrapped(*args):
        out = func(*args)
        out_shapes = map(jnp.shape, out if isinstance(out, tuple) else [out])

        if expected_output_core_dims is None:
            output_core_dims = [()] * len(out_shapes)
        else:
            output_core_dims = expected_output_core_dims
            if len(output_core_dims) > 1 and not isinstance(out, tuple):
                raise TypeError(
                    "output must be a tuple when multiple outputs are expected, "
                    "got: {!r}\n{}".format(out, error_context))
            if len(out_shapes) != len(output_core_dims):
                raise TypeError(
                    'wrong number of output arguments: expected %r, got %r %s'
                    % (len(output_core_dims), len(out_shapes), error_context))

        sizes = dict(dim_sizes)
        for shape, core_dims in zip(out_shapes, output_core_dims):
            _update_dim_sizes(sizes,
                              shape,
                              core_dims,
                              error_context,
                              is_input=False)

        return out
Ejemplo n.º 5
0
    def wrapped(*args):
        error_context = ("on vectorized function with excluded={!r} and "
                         "signature={!r}".format(excluded, signature))
        excluded_func, args = _apply_excluded(pyfunc, excluded, args)
        args = tuple(map(jnp.asarray, args))

        if signature is not None:
            input_core_dims, output_core_dims = _parse_gufunc_signature(
                signature)
        else:
            input_core_dims = [()] * len(args)
            output_core_dims = None

        broadcast_shape, dim_sizes = _parse_input_dimensions(
            args, input_core_dims, error_context)

        checked_func = _check_output_dims(excluded_func, dim_sizes,
                                          output_core_dims, error_context)

        # Rather than broadcasting all arguments to full broadcast shapes, prefer
        # expanding dimensions using vmap when possible. By pushing broadcasting
        # into vmap, we can make use of more efficient batching rules for
        # primitives where only some arguments are batched (e.g., for
        # lax_linalg.triangular_solve).

        vec_args = []
        vmap_counts = []

        for arg, core_dims in zip(args, input_core_dims):
            # Explicitly broadcast the dimensions already found on each argument,
            # because these dimensiosns might be of size 1, which vmap doesn't
            # handle.
            # TODO(shoyer): Consider squeezing out size 1 dimensions instead, and
            # doing all vectorization with vmap? This *might* be a little more
            # efficient but would require more careful book-keeping.
            core_shape = tuple(dim_sizes[dim] for dim in core_dims)
            full_shape = broadcast_shape + core_shape
            vec_shape = full_shape[-arg.ndim:] if arg.ndim else ()

            vec_arg = jnp.broadcast_to(arg, vec_shape)
            vec_args.append(vec_arg)

            vmap_count = len(vec_shape) - len(core_shape)
            vmap_counts.append(vmap_count)

        vectorized_func = checked_func
        while any(vmap_counts):
            in_axes = tuple(0 if c > 0 else None for c in vmap_counts)
            vmap_counts = [max(c - 1, 0) for c in vmap_counts]
            vectorized_func = api.vmap(vectorized_func, in_axes)
        return vectorized_func(*vec_args)