Esempio n. 1
0
def _cofactor_solve(a, b):
    """Equivalent to det(a)*solve(a, b) for nonsingular mat.

  Intermediate function used for jvp and vjp of det.
  This function borrows heavily from jax.numpy.linalg.solve and
  jax.numpy.linalg.slogdet to compute the gradient of the determinant
  in a way that is well defined even for low rank matrices.

  This function handles two different cases:
  * rank(a) == n or n-1
  * rank(a) < n-1

  For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
  Rather than computing det(a)*solve(a, b), which would return NaN, we work
  directly with the LU decomposition. If a = p @ l @ u, then
  det(a)*solve(a, b) =
  prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
  prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
  If a is rank n-1, then the lower right corner of u will be zero and the
  triangular_solve will fail.
  Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
  Then y_{n}
  x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
  x_{n} * prod_{i=1...n-1}(u_{ii})
  So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
  we can avoid the triangular_solve failing.
  To correctly compute the rest of y_{i} for i != n, we simply multiply
  x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.

  For the second case, a check is done on the matrix to see if `solve`
  returns NaN or Inf, and gives a matrix of zeros as a result, as the
  gradient of the determinant of a matrix with rank less than n-1 is 0.
  This will still return the correct value for rank n-1 matrices, as the check
  is applied *after* the lower right corner of u has been updated.

  Args:
    a: A square matrix or batch of matrices, possibly singular.
    b: A matrix, or batch of matrices of the same dimension as a.

  Returns:
    det(a) and cofactor(a)^T*b, aka adjugate(a)*b
  """
    a = _promote_arg_dtypes(jnp.asarray(a))
    b = _promote_arg_dtypes(jnp.asarray(b))
    a_shape = jnp.shape(a)
    b_shape = jnp.shape(b)
    a_ndims = len(a_shape)
    if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
            and b_shape[-2:] == a_shape[-2:]):
        msg = ("The arguments to _cofactor_solve must have shapes "
               "a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
        raise ValueError(msg.format(a_shape, b_shape))
    if a_shape[-1] == 1:
        return a[..., 0, 0], b
    # lu contains u in the upper triangular matrix and l in the strict lower
    # triangular matrix.
    # The diagonal of l is set to ones without loss of generality.
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
    x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
    lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
    # Compute (partial) determinant, ignoring last diagonal of LU
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
    sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype)
    # partial_det[:, -1] contains the full determinant and
    # partial_det[:, -2] contains det(u) / u_{nn}.
    partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
    lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
    permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], ))
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    # filter out any matrices that are not full rank
    d = jnp.ones(x.shape[:-1], x.dtype)
    d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
    d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
    d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:])
    x = jnp.where(d, jnp.zeros_like(x), x)  # first filter
    x = x[iotas[:-1] + (permutation, slice(None))]
    x = lax_linalg.triangular_solve(lu,
                                    x,
                                    left_side=True,
                                    lower=True,
                                    unit_diagonal=True)
    x = jnp.concatenate(
        (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]),
        axis=-2)
    x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
    x = jnp.where(d, jnp.zeros_like(x), x)  # second filter

    return partial_det[..., -1], x
Esempio n. 2
0
def conv_general_dilated_patches(
    lhs: lax.Array,
    filter_shape: Sequence[int],
    window_strides: Sequence[int],
    padding: Union[str, Sequence[Tuple[int, int]]],
    lhs_dilation: Optional[Sequence[int]] = None,
    rhs_dilation: Optional[Sequence[int]] = None,
    dimension_numbers: Optional[lax.ConvGeneralDilatedDimensionNumbers] = None,
    precision: Optional[lax.PrecisionType] = None,
    preferred_element_type: Optional[DType] = None,
) -> lax.Array:
    """Extract patches subject to the receptive field of `conv_general_dilated`.

  Runs the input through a convolution with given parameters. The kernel of the
  convolution is constructed such that the output channel dimension `"C"`
  contains flattened image patches, so instead a single `"C"` dimension
  represents, for example, three dimensions `"chw"` collapsed. The order of
  these dimensions is `"c" + ''.join(c for c in rhs_spec if c not in 'OI')`,
  where `rhs_spec == dimension_numbers[1]`, and the size of this `"C"`
  dimension is therefore the size of each patch, i.e.
  `np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`, where
  `lhs_spec == dimension_numbers[0]`.

  Docstring below adapted from `jax.lax.conv_general_dilated`.

  See Also:
    https://www.tensorflow.org/xla/operation_semantics#conv_convolution

  Args:
    lhs: a rank `n+2` dimensional input array.
    filter_shape: a sequence of `n` integers, representing the receptive window
      spatial shape in the order as specified in
      `rhs_spec = dimension_numbers[1]`.
    window_strides: a sequence of `n` integers, representing the inter-window
      strides.
    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
      `n` `(low, high)` integer pairs that give the padding to apply before and
      after each spatial dimension.
    lhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
      is also known as transposed convolution.
    rhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
      is also known as atrous convolution.
    dimension_numbers: either `None`, or a 3-tuple
      `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
      of length `n+2`. `None` defaults to `("NCHWD..., OIHWD..., NCHWD...")`.
    precision: Optional. Either ``None``, which means the default precision for
      the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
      ``Precision.HIGH`` or ``Precision.HIGHEST``).
    preferred_element_type: Optional. Either ``None``, which means the default
      accumulation type for the input types, or a datatype, indicating to
      accumulate results to and return a result with that datatype.

  Returns:
    A rank `n+2` array containing the flattened image patches in the output
    channel (`"C"`) dimension. For example if
    `dimension_numbers = ("NcHW", "OIwh", "CNHW")`, the output has dimension
    numbers `"CNHW" = "{cwh}NHW"`, with the size of dimension `"C"` equal to
    the size of each patch
    (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`).

  """
    filter_shape = tuple(filter_shape)
    dimension_numbers = lax.conv_dimension_numbers(lhs.shape,
                                                   (1, 1) + filter_shape,
                                                   dimension_numbers)

    lhs_spec, rhs_spec, out_spec = dimension_numbers

    spatial_size = prod(filter_shape)
    n_channels = lhs.shape[lhs_spec[1]]

    # Move separate `lhs` spatial locations into separate `rhs` channels.
    rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)

    rhs = rhs.reshape((spatial_size, 1) + filter_shape)
    rhs = jnp.tile(rhs, (n_channels, ) + (1, ) * (rhs.ndim - 1))
    rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1]))

    out = lax.conv_general_dilated(
        lhs=lhs,
        rhs=rhs,
        window_strides=window_strides,
        padding=padding,
        lhs_dilation=lhs_dilation,
        rhs_dilation=rhs_dilation,
        dimension_numbers=dimension_numbers,
        precision=None if precision is None else
        (precision, lax.Precision.DEFAULT),
        feature_group_count=n_channels,
        preferred_element_type=preferred_element_type)
    return out