Ejemplo n.º 1
0
def cast(pda: Union[pdarray, Strings],
         dt: Union[np.dtype, str]) -> Union[pdarray, Strings]:
    """
    Cast an array to another dtype.

    Parameters
    ----------
    pda : pdarray or Strings
        The array of values to cast
    dtype : np.dtype or str
        The target dtype to cast values to

    Returns
    -------
    pdarray or Strings
        Array of values cast to desired dtype

    Notes
    -----
    The cast is performed according to Chapel's casting rules and is NOT safe 
    from overflows or underflows. The user must ensure that the target dtype 
    has the precision and capacity to hold the desired result.
    
    Examples
    --------
    >>> ak.cast(ak.linspace(1.0,5.0,5), dt=ak.int64)
    array([1, 2, 3, 4, 5])    
    
    >>> ak.cast(ak.arange(0,5), dt=ak.float64).dtype
    dtype('float64')
    
    >>> ak.cast(ak.arange(0,5), dt=ak.bool)
    array([False, True, True, True, True])
    
    >>> ak.cast(ak.linspace(0,4,5), dt=ak.bool)
    array([False, True, True, True, True])
    """

    if isinstance(pda, pdarray):
        name = pda.name
        objtype = "pdarray"
    elif isinstance(pda, Strings):
        name = '+'.join((pda.offsets.name, pda.bytes.name))
        objtype = "str"
    # typechecked decorator guarantees no other case

    dt = _as_dtype(dt)
    opt = ""
    cmd = "cast"
    args = "{} {} {} {}".format(name, objtype, dt.name, opt)
    repMsg = generic_msg(cmd=cmd, args=args)
    if dt.name.startswith("str"):
        return Strings(*(type_cast(str, repMsg).split("+")))
    else:
        return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 2
0
def hash(pda: pdarray,
         full: bool = True) -> Union[Tuple[pdarray, pdarray], pdarray]:
    """
    Return an element-wise hash of the array.

    Parameters
    ----------
    pda : pdarray

    full : bool
        By default, a 128-bit hash is computed and returned as
        two int64 arrays. If full=False, then a 64-bit hash
        is computed and returned as a single int64 array.
    
    Returns
    -------
    hashes
        If full=True, a 2-tuple of pdarrays containing the high 
        and low 64 bits of each hash, respectively.
        If full=False, a single pdarray containing a 64-bit hash
    
    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray

    Notes
    -----
    This function uses the SIPhash algorithm, which can output
    either a 64-bit or 128-bit hash. However, the 64-bit hash
    runs a significant risk of collisions when applied to more 
    than a few million unique values. Unless the number of unique
    values is known to be small, the 128-bit hash is strongly
    recommended.

    Note that this hash should not be used for security, or for
    any cryptographic application. Not only is SIPhash not
    intended for such uses, but this implementation employs a
    fixed key for the hash, which makes it possible for an
    adversary with control over input to engineer collisions.
    """
    if full:
        subcmd = "hash128"
    else:
        subcmd = "hash64"
    repMsg = type_cast(
        str, generic_msg(cmd="efunc", args="{} {}".format(subcmd, pda.name)))
    if full:
        a, b = type_cast(str, repMsg).split('+')
        return create_pdarray(type_cast(str,
                                        a)), create_pdarray(type_cast(str, b))
    else:
        return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 3
0
def abs(pda: pdarray) -> pdarray:
    """
    Return the element-wise absolute value of the array.

    Parameters
    ----------
    pda : pdarray
    
    Returns
    -------
    pdarray
        A pdarray containing absolute values of the input array elements
   
    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray
        
    Examples
    --------
    >>> ak.abs(ak.arange(-5,-1))
    array([5, 4, 3, 2])
    
    >>> ak.abs(ak.linspace(-5,-1,5))
    array([5, 4, 3, 2, 1])    
    """
    repMsg = generic_msg(cmd="efunc", args="{} {}".format("abs", pda.name))
    return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 4
0
def exp(pda: pdarray) -> pdarray:
    """
    Return the element-wise exponential of the array.
    
    Parameters
    ----------
    pda : pdarray

    Returns
    -------
    pdarray
        A pdarray containing exponential values of the input 
        array elements

    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray
        
    Examples
    --------
    >>> ak.exp(ak.arange(1,5))
    array([2.7182818284590451, 7.3890560989306504, 20.085536923187668, 54.598150033144236])
    
    >>> ak.exp(ak.uniform(5,1.0,5.0))
    array([11.84010843172504, 46.454368507659211, 5.5571769623557188, 
           33.494295836924771, 13.478894913238722])
    """
    repMsg = generic_msg(cmd="efunc", args="{} {}".format("exp", pda.name))
    return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 5
0
def cumprod(pda: pdarray) -> pdarray:
    """
    Return the cumulative product over the array. 

    The product is inclusive, such that the ``i`` th element of the 
    result is the product of elements up to and including ``i``.

    Parameters
    ----------
    pda : pdarray
    
    Returns
    -------
    pdarray
        A pdarray containing cumulative products for each element
        of the original pdarray

    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray
        
    Examples
    --------
    >>> ak.cumprod(ak.arange(1,5))
    array([1, 2, 6, 24]))

    >>> ak.cumprod(ak.uniform(5,1.0,5.0))
    array([1.5728783400481925, 7.0472855509390593, 33.78523998586553, 
           134.05309592737584, 450.21589865655358])
    """
    repMsg = generic_msg(cmd="efunc", args="{} {}".format("cumprod", pda.name))
    return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 6
0
    def get_alt_field(self, var_info: Union[Variable, FieldInfo],
                      alt_prefix: str) -> Variable:
        """Get the alternate input/output field for a given element of `ScanArgs`.

        For example, if `var_info` is in ``ScanArgs.outer_out_sit_sot``, then
        ``get_alt_field(var_info, "inner_out")`` returns the element corresponding
        `var_info` in ``ScanArgs.inner_out_sit_sot``.

        Parameters
        ----------
        var_info:
            The element for which we want the alternate
        alt_prefix:
            The string prefix for the alternate field type.  It can be one of
            the following: ``"inner_out"``, ``"inner_in"``, ``"outer_in"``, and
            ``"outer_out"``.

        Outputs
        -------
        The alternate variable.
        """
        _var_info: FieldInfo
        if not isinstance(var_info, FieldInfo):
            find_var_info = self.find_among_fields(var_info)
            if find_var_info is None:
                raise ValueError(f"Couldn't find {var_info} among fields")
            _var_info = find_var_info
        else:
            _var_info = var_info

        alt_type = _var_info.name[(_var_info.name.index("_", 6) + 1):]
        alt_var = getattr(self, f"{alt_prefix}_{alt_type}")[_var_info.index]
        return type_cast(Variable, alt_var)
Ejemplo n.º 7
0
def _cast_tile_layer(raw_layer: RawLayer) -> TileLayer:
    """Cast the raw_layer to a TileLayer.

    Args:
        raw_layer: RawLayer to be casted to a TileLayer

    Returns:
        TileLayer: The TileLayer created from raw_layer
    """
    tile_layer = TileLayer(**_get_common_attributes(raw_layer).__dict__)

    if raw_layer.get("chunks") is not None:
        tile_layer.chunks = []
        for chunk in raw_layer["chunks"]:
            if raw_layer.get("encoding") is not None:
                tile_layer.chunks.append(
                    _cast_chunk(chunk, raw_layer["encoding"],
                                raw_layer["compression"]))
            else:
                tile_layer.chunks.append(_cast_chunk(chunk))

    if raw_layer.get("data") is not None:
        if raw_layer.get("encoding") is not None:
            tile_layer.data = _decode_tile_layer_data(
                data=type_cast(str, raw_layer["data"]),
                compression=raw_layer["compression"],
                layer_width=raw_layer["width"],
            )
        else:
            tile_layer.data = _convert_raw_tile_layer_data(
                raw_layer["data"],
                raw_layer["width"]  # type: ignore
            )

    return tile_layer
Ejemplo n.º 8
0
def histogram(pda: pdarray, bins: int_scalars = 10) -> pdarray:
    """
    Compute a histogram of evenly spaced bins over the range of an array.
    
    Parameters
    ----------
    pda : pdarray
        The values to histogram

    bins : int_scalars
        The number of equal-size bins to use (default: 10)

    Returns
    -------
    pdarray, int64 or float64
        The number of values present in each bin
        
    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray or if bins is
        not an int.
    ValueError
        Raised if bins < 1
    NotImplementedError
        Raised if pdarray dtype is bool or uint8

    See Also
    --------
    value_counts

    Notes
    -----
    The bins are evenly spaced in the interval [pda.min(), pda.max()].
    Currently, the user must re-compute the bin edges, e.g. with np.linspace 
    (see below) in order to plot the histogram.

    Examples
    --------
    >>> import matplotlib.pyplot as plt
    >>> A = ak.arange(0, 10, 1)
    >>> nbins = 3
    >>> h = ak.histogram(A, bins=nbins)
    >>> h
    array([3, 3, 4])
    # Recreate the bin edges in NumPy
    >>> binEdges = np.linspace(A.min(), A.max(), nbins+1)
    >>> binEdges
    array([0., 3., 6., 9.])
    # To plot, use only the left edges, and export the histogram to NumPy
    >>> plt.plot(binEdges[:-1], h.to_ndarray())
    """
    if bins < 1:
        raise ValueError('bins must be 1 or greater')
    repMsg = generic_msg(cmd="histogram", args="{} {}".format(pda.name, bins))
    return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 9
0
def cos(pda: pdarray) -> pdarray:
    """
    Return the element-wise cosine of the array.

    Parameters
    ----------
    pda : pdarray
    
    Returns
    -------
    pdarray
        A pdarray containing cosine for each element
        of the original pdarray
    
    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray
    """
    repMsg = type_cast(str, generic_msg("efunc {} {}".format("cos", pda.name)))
    return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 10
0
def cast(pda: Union[pdarray, Strings], dt) -> Union[pdarray, Strings]:
    """
    Cast an array to another dtype.

    Parameters
    ----------
    pda : pdarray or Strings
        The array of values to cast
    dtype : np.dtype or str
        The target dtype to cast values to

    Returns
    -------
    pdarray or Strings
        Array of values cast to desired dtype

    Notes
    -----
    The cast is performed according to Chapel's casting rules and is NOT safe 
    from overflows or underflows. The user must ensure that the target dtype 
    has the precision and capacity to hold the desired result.
    """

    if isinstance(pda, pdarray):
        name = pda.name
        objtype = "pdarray"
    elif isinstance(pda, Strings):
        name = '+'.join((pda.offsets.name, pda.bytes.name))
        objtype = "str"
    # typechecked decorator guarantees no other case

    dt = _as_dtype(dt)
    opt = ""
    msg = "cast {} {} {} {}".format(name, objtype, dt.name, opt)
    repMsg = generic_msg(msg)
    if dt.name.startswith("str"):
        return Strings(*(type_cast(str, repMsg).split("+")))
    else:
        return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 11
0
def cast(raw_properties: List[RawProperty]) -> Properties:
    """Cast a list of `RawProperty`s into `Properties`

    Args:
        raw_properties: The list of `RawProperty`s to cast.

    Returns:
        Properties: The casted `Properties`.
    """

    final: Properties = {}
    value: Property

    for property_ in raw_properties:
        if property_["type"] == "file":
            value = Path(type_cast(str, property_["value"]))
        elif property_["type"] == "color":
            value = parse_color(type_cast(str, property_["value"]))
        else:
            value = property_["value"]
        final[property_["name"]] = value

    return final
Ejemplo n.º 12
0
def abs(pda: pdarray) -> pdarray:
    """
    Return the element-wise absolute value of the array.

    Parameters
    ----------
    pda : pdarray
    
    Returns
    -------
    pdarray
        A pdarray containing absolute values of the input array elements
   
    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray
    """
    repMsg = generic_msg("efunc {} {}".format("abs", pda.name))
    return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 13
0
def sin(pda: pdarray) -> pdarray:
    """
    Return the element-wise sine of the array.

    Parameters
    ----------
    pda : pdarray
    
    Returns
    -------
    pdarray
        A pdarray containing sin for each element
        of the original pdarray
    
    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray
    """
    repMsg = generic_msg(cmd="efunc", args="{} {}".format("sin", pda.name))
    return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 14
0
def isnan(pda: pdarray) -> pdarray:
    """
    Test a pdarray for Not a number / NaN values
    Currently only supports float-value-based arrays

    Parameters
    ----------
    pda : pdarray to test

    Returns
    -------
    pdarray consisting of True / False values; True where NaN, False otherwise

    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray
    RuntimeError
        if the underlying pdarray is not float-based
    """
    rep_msg = generic_msg(cmd="efunc", args=f"isnan {pda.name}")
    return create_pdarray(type_cast(str, rep_msg))
Ejemplo n.º 15
0
def log(pda: pdarray) -> pdarray:
    """
    Return the element-wise natural log of the array. 

    Parameters
    ----------
    pda : pdarray

    Returns
    -------
    pdarray
        A pdarray containing natural log values of the input 
        array elements

    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray

    Notes
    -----
    Logarithms with other bases can be computed as follows:

    Examples
    --------
    >>> A = ak.array([1, 10, 100])
    # Natural log
    >>> ak.log(A)
    array([0, 2.3025850929940459, 4.6051701859880918])
    # Log base 10
    >>> ak.log(A) / np.log(10)
    array([0, 1, 2])
    # Log base 2
    >>> ak.log(A) / np.log(2)
    array([0, 3.3219280948873626, 6.6438561897747253])
    """
    repMsg = generic_msg(cmd="efunc", args="{} {}".format("log", pda.name))
    return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 16
0
def cumsum(pda: pdarray) -> pdarray:
    """
    Return the cumulative sum over the array. 

    The sum is inclusive, such that the ``i`` th element of the 
    result is the sum of elements up to and including ``i``.
    
    Parameters
    ----------
    pda : pdarray
    
    Returns
    -------
    pdarray
        A pdarray containing cumulative sums for each element
        of the original pdarray
    
    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray
        
    Examples
    --------
    >>> ak.cumsum(ak.arange([1,5]))
    array([1, 3, 6])

    >>> ak.cumsum(ak.uniform(5,1.0,5.0))
    array([3.1598310770203937, 5.4110385860243131, 9.1622479306453748, 
           12.710615785506533, 13.945880905466208])
    
    >>> ak.cumsum(ak.randint(0, 1, 5, dtype=ak.bool))
    array([0, 1, 1, 2, 3])
    """
    repMsg = generic_msg(cmd="efunc", args="{} {}".format("cumsum", pda.name))
    return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 17
0
def cumprod(pda: pdarray) -> pdarray:
    """
    Return the cumulative product over the array. 

    The product is inclusive, such that the ``i`` th element of the 
    result is the product of elements up to and including ``i``.

    Parameters
    ----------
    pda : pdarray
    
    Returns
    -------
    pdarray
        A pdarray containing cumulative products for each element
        of the original pdarray

    Raises
    ------
    TypeError
        Raised if the parameter is not a pdarray
    """
    repMsg = generic_msg("efunc {} {}".format("cumprod", pda.name))
    return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 18
0
def where(condition: pdarray, A: Union[numeric_scalars, pdarray],
          B: Union[numeric_scalars, pdarray]) -> pdarray:
    """
    Returns an array with elements chosen from A and B based upon a 
    conditioning array. As is the case with numpy.where, the return array
    consists of values from the first array (A) where the conditioning array 
    elements are True and from the second array (B) where the conditioning
    array elements are False.
    
    Parameters
    ----------
    condition : pdarray
        Used to choose values from A or B
    A : Union[numeric_scalars, pdarray]
        Value(s) used when condition is True
    B : Union[numeric_scalars, pdarray]
        Value(s) used when condition is False

    Returns
    -------
    pdarray
        Values chosen from A where the condition is True and B where
        the condition is False
        
    Raises 
    ------
    TypeError
        Raised if the condition object is not a pdarray, if A or B is not
        an int, np.int64, float, np.float64, or pdarray, if pdarray dtypes 
        are not supported or do not match, or multiple condition clauses (see 
        Notes section) are applied
    ValueError
        Raised if the shapes of the condition, A, and B pdarrays are unequal
        
    Examples
    --------
    >>> a1 = ak.arange(1,10)
    >>> a2 = ak.ones(9, dtype=np.int64)
    >>> cond = a1 < 5
    >>> ak.where(cond,a1,a2)
    array([1, 2, 3, 4, 1, 1, 1, 1, 1])
    
    >>> a1 = ak.arange(1,10)
    >>> a2 = ak.ones(9, dtype=np.int64)
    >>> cond = a1 == 5
    >>> ak.where(cond,a1,a2)
    array([1, 1, 1, 1, 5, 1, 1, 1, 1])

    >>> a1 = ak.arange(1,10)
    >>> a2 = 10
    >>> cond = a1 < 5
    >>> ak.where(cond,a1,a2)
    array([1, 2, 3, 4, 10, 10, 10, 10, 10])

    Notes
    -----
    A and B must have the same dtype and only one conditional clause 
    is supported e.g., n < 5, n > 1, which is supported in numpy
    is not currently supported in Arkouda
    """
    if (not isSupportedNumber(A) and not isinstance(A,pdarray)) or \
                                      (not isSupportedNumber(B) and not isinstance(B,pdarray)):
        raise TypeError(
            'both A and B must be an int, np.int64, float, np.float64, or pdarray'
        )
    if isinstance(A, pdarray) and isinstance(B, pdarray):
        repMsg = generic_msg(cmd="efunc3vv", args="{} {} {} {}".\
                             format("where",
                                    condition.name,
                                    A.name,
                                    B.name))
    # For scalars, try to convert it to the array's dtype
    elif isinstance(A, pdarray) and np.isscalar(B):
        repMsg = generic_msg(cmd="efunc3vs", args="{} {} {} {} {}".\
                             format("where",
                                    condition.name,
                                    A.name,
                                    A.dtype.name,
                                    A.format_other(B)))
    elif isinstance(B, pdarray) and np.isscalar(A):
        repMsg = generic_msg(cmd="efunc3sv", args="{} {} {} {} {}".\
                             format("where",
                                    condition.name,
                                    B.dtype.name,
                                    B.format_other(A),
                                    B.name))
    elif np.isscalar(A) and np.isscalar(B):
        # Scalars must share a common dtype (or be cast)
        dtA = resolve_scalar_dtype(A)
        dtB = resolve_scalar_dtype(B)
        # Make sure at least one of the dtypes is supported
        if not (dtA in DTypes or dtB in DTypes):
            raise TypeError(
                ("Not implemented for scalar types {} " + "and {}").format(
                    dtA, dtB))
        # If the dtypes are the same, do not cast
        if dtA == dtB:  # type: ignore
            dt = dtA
        # If the dtypes are different, try casting one direction then the other
        elif dtB in DTypes and np.can_cast(A, dtB):
            A = np.dtype(dtB).type(A)
            dt = dtB
        elif dtA in DTypes and np.can_cast(B, dtA):
            B = np.dtype(dtA).type(B)
            dt = dtA
        # Cannot safely cast
        else:
            raise TypeError(("Cannot cast between scalars {} and {} to " +
                             "supported dtype").format(A, B))
        repMsg = generic_msg(cmd="efunc3ss", args="{} {} {} {} {} {}".\
                             format("where",
                                    condition.name,
                                    dt,
                                    A,
                                    dt,
                                    B))
    return create_pdarray(type_cast(str, repMsg))
Ejemplo n.º 19
0
def conv_general_dilated_local(
        lhs: jnp.ndarray,
        rhs: jnp.ndarray,
        window_strides: Sequence[int],
        padding: Union[str, Sequence[Tuple[int, int]]],
        filter_shape: Sequence[int],
        lhs_dilation: Optional[Sequence[int]] = None,
        rhs_dilation: Optional[Sequence[int]] = None,
        dimension_numbers: Optional[
            convolution.ConvGeneralDilatedDimensionNumbers] = None,
        precision: lax.PrecisionLike = None) -> jnp.ndarray:
    """General n-dimensional unshared convolution operator with optional dilation.

  Also known as locally connected layer, the operation is equivalent to
  convolution with a separate (unshared) `rhs` kernel used at each output
  spatial location. Docstring below adapted from `jax.lax.conv_general_dilated`.

  See Also:
    https://www.tensorflow.org/xla/operation_semantics#conv_convolution

  Args:
    lhs: a rank `n+2` dimensional input array.
    rhs: a rank `n+2` dimensional array of kernel weights. Unlike in regular
      CNNs, its spatial coordinates (`H`, `W`, ...) correspond to output spatial
      locations, while input spatial locations are fused with the input channel
      locations in the single `I` dimension, in the order of
      `"C" + ''.join(c for c in rhs_spec if c not in 'OI')`, where
      `rhs_spec = dimension_numbers[1]`. For example, if `rhs_spec == "WHIO",
      the unfolded kernel shape is
      `"[output W][output H]{I[receptive window W][receptive window H]}O"`.
    window_strides: a sequence of `n` integers, representing the inter-window
      strides.
    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
      `n` `(low, high)` integer pairs that give the padding to apply before and
      after each spatial dimension.
    filter_shape: a sequence of `n` integers, representing the receptive window
      spatial shape in the order as specified in
      `rhs_spec = dimension_numbers[1]`.
    lhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
      is also known as transposed convolution.
    rhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each input spatial dimension of `rhs`.
      RHS dilation is also known as atrous convolution.
    dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
      a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
      of length `n+2`.
    precision: Optional. Either ``None``, which means the default precision for
      the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
      ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
      ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.

  Returns:
    An array containing the unshared convolution result.

  In the string case of `dimension_numbers`, each character identifies by
  position:

  - the batch dimensions in `lhs`, `rhs`, and the output with the character
    'N',
  - the feature dimensions in `lhs` and the output with the character 'C',
  - the input and output feature dimensions in rhs with the characters 'I'
    and 'O' respectively, and
  - spatial dimension correspondences between `lhs`, `rhs`, and the output using
    any distinct characters.

  For example, to indicate dimension numbers consistent with the `conv` function
  with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
  another example, to indicate dimension numbers consistent with the TensorFlow
  Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
  latter form of convolution dimension specification, window strides are
  associated with spatial dimension character labels according to the order in
  which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
  is matched with the dimension corresponding to the first character
  appearing in rhs_spec that is not `'I'` or `'O'`.

  If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
  (for a 2D convolution).
  """
    c_precision = lax.canonicalize_precision(precision)
    lhs_precision = type_cast(Optional[lax.PrecisionType],
                              (c_precision[0] if
                               (isinstance(c_precision, tuple)
                                and len(c_precision) == 2) else c_precision))

    patches = conv_general_dilated_patches(lhs=lhs,
                                           filter_shape=filter_shape,
                                           window_strides=window_strides,
                                           padding=padding,
                                           lhs_dilation=lhs_dilation,
                                           rhs_dilation=rhs_dilation,
                                           dimension_numbers=dimension_numbers,
                                           precision=lhs_precision)

    lhs_spec, rhs_spec, out_spec = convolution.conv_dimension_numbers(
        lhs.shape, (1, 1) + tuple(filter_shape), dimension_numbers)

    lhs_c_dims, rhs_c_dims = [out_spec[1]], [rhs_spec[1]]

    lhs_b_dims = out_spec[2:]
    rhs_b_dims = rhs_spec[2:]

    rhs_b_dims = [
        rhs_b_dims[i]
        for i in sorted(range(len(rhs_b_dims)), key=lambda k: lhs_b_dims[k])
    ]
    lhs_b_dims = sorted(lhs_b_dims)

    dn = ((lhs_c_dims, rhs_c_dims), (lhs_b_dims, rhs_b_dims))
    out = lax.dot_general(patches,
                          rhs,
                          dimension_numbers=dn,
                          precision=precision)
    out = jnp.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1]))
    return out