Exemplo n.º 1
0
 def btrifact(self, pivot=True):
     r"""See :func:`torch.lu`"""
     warnings.warn(
         "torch.btrifact is deprecated in favour of torch.lu and will be removed in "
         "the next release. Please use torch.lu instead.",
         stacklevel=2)
     return torch._lu_with_info(self, pivot=pivot, check_errors=True)
Exemplo n.º 2
0
 def forward(ctx, self, pivot=True, get_infos=False):
     LU, pivots, infos = torch._lu_with_info(self,
                                             pivot=pivot,
                                             check_errors=(not get_infos))
     ctx.save_for_backward(LU, pivots)
     ctx.mark_non_differentiable(pivots, infos)
     return LU, pivots, infos
Exemplo n.º 3
0
    def lu(self, pivot=True, get_infos=False):
        r"""See :func:`torch.lu`"""
        # If get_infos is True, then we don't need to check for errors and vice versa
        if has_torch_function_unary(self):
            return handle_torch_function(Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos)

        if not torch._jit_internal.is_scripting():
            if self.requires_grad:
                if not (self.size(-2) == self.size(-1) and (self.dtype.is_floating_point) or self.is_complex):
                    raise ValueError(
                        'lu.backward works only with batches of squared full-rank matrices'
                        ' of floating or complex types.'
                    )

                from torch._autograd_functions import _LU
                LU, pivots, infos = _LU.apply(self, pivot, get_infos)
                if get_infos:
                    return LU, pivots, infos
                else:
                    return LU, pivots
        else:
            if self.requires_grad:
                raise RuntimeError(
                    'Script and require gradients is not supported at the moment.'
                    'If you just want to do the forward, use .detach()'
                    'on the input before calling the function.'
                )

        LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
        if get_infos:
            return LU, pivots, infos
        else:
            return LU, pivots
Exemplo n.º 4
0
 def btrifact_with_info(self, pivot=True):
     r"""See :func:`torch.lu`"""
     warnings.warn(
         "torch.btrifact_with_info is deprecated in favour of torch.lu with the "
         "get_infos argument and will be removed in the next release. Please use "
         "torch.lu with the get_infos argument set to True instead.",
         stacklevel=2)
     return torch._lu_with_info(self, pivot=pivot, check_errors=False)
Exemplo n.º 5
0
 def lu(self, pivot=True, get_infos=False):
     r"""See :func:`torch.lu`"""
     # If get_infos is True, then we don't need to check for errors and vice versa
     LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
     if get_infos:
         return LU, pivots, infos
     else:
         return LU, pivots
Exemplo n.º 6
0
    def lu(self, pivot=True, get_infos=False):
        r"""See :func:`torch.lu`"""
        # If get_infos is True, then we don't need to check for errors and vice versa
        if has_torch_function_unary(self):
            return handle_torch_function(Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos)

        LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
        if get_infos:
            return LU, pivots, infos
        else:
            return LU, pivots
Exemplo n.º 7
0
 def lu(self, pivot=True, get_infos=False):
     r"""See :func:`torch.lu`"""
     # If get_infos is True, then we don't need to check for errors and vice versa
     relevant_args = (self,)
     from torch.overrides import has_torch_function, handle_torch_function
     if type(self) is not Tensor and has_torch_function(relevant_args):
         return handle_torch_function(Tensor.lu, relevant_args, self, pivot=pivot, get_infos=get_infos)
     LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
     if get_infos:
         return LU, pivots, infos
     else:
         return LU, pivots
Exemplo n.º 8
0
def _lu_impl(A, pivot=True, get_infos=False, out=None):
    # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]
    r"""Computes the LU factorization of a matrix or batches of matrices
    :attr:`A`. Returns a tuple containing the LU factorization and
    pivots of :attr:`A`.  Pivoting is done if :attr:`pivot` is set to
    ``True``.

    .. note::
        The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``,
        then the returned pivots is a tensor filled with zeros of the appropriate size.

    .. note::
        LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting
        to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is
        available for CUDA.

    .. note::
        This function does not check if the factorization was successful or not if
        :attr:`get_infos` is ``True`` since the status of the factorization is present in the
        third element of the return tuple.

    .. note::
        In the case of batches of square matrices with size less or
        equal to 32 on a CUDA device, the LU factorization is repeated
        for singular matrices due to the bug in the MAGMA library (see
        magma issue 13).

    .. note::
       ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.

    Arguments:
        A (Tensor): the tensor to factor of size :math:`(*, m, n)`
        pivot (bool, optional): controls whether pivoting is done. Default: ``True``
        get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
                                    Default: ``False``
        out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
                               then the elements in the tuple are Tensor, IntTensor,
                               and IntTensor. If :attr:`get_infos` is ``False``, then the
                               elements in the tuple are Tensor, IntTensor. Default: ``None``

    Returns:
        (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing

            - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`

            - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)`

            - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
              size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
              each minibatch has succeeded or failed

    Example::

        >>> A = torch.randn(2, 3, 3)
        >>> A_LU, pivots = torch.lu(A)
        >>> A_LU
        tensor([[[ 1.3506,  2.5558, -0.0816],
                 [ 0.1684,  1.1551,  0.1940],
                 [ 0.1193,  0.6189, -0.5497]],

                [[ 0.4526,  1.2526, -0.3285],
                 [-0.7988,  0.7175, -0.9701],
                 [ 0.2634, -0.9255, -0.3459]]])
        >>> pivots
        tensor([[ 3,  3,  3],
                [ 3,  3,  3]], dtype=torch.int32)
        >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
        >>> if info.nonzero().size(0) == 0:
        ...   print('LU factorization succeeded for all samples!')
        LU factorization succeeded for all samples!
    """
    # If get_infos is True, then we don't need to check for errors and vice versa
    return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
Exemplo n.º 9
0
def lu(A, pivot=True, get_infos=False, out=None):
    r"""Computes the LU factorization of a square matrix or batches of square matrices
    :attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`.
    Pivoting is done if :attr:`pivot` is set to ``True``.

    .. note::
        The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``,
        then the returned pivots is a tensor filled with zeros of the appropriate size.

    .. note::
        LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting
        to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is
        available for CUDA.

    .. note::
        This function does not check if the factorization was successful or not if
        :attr:`get_infos` is ``True`` since the status of the factorization is present in the
        third element of the return tuple.

    Arguments:
        A (Tensor): the tensor to factor of size :math:`(*, m, m)`
        pivot (bool, optional): controls whether pivoting is done. Default: ``True``
        get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
                                    Default: ``False``
        out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
                               then the elements in the tuple are Tensor, IntTensor,
                               and IntTensor. If :attr:`get_infos` is ``False``, then the
                               elements in the tuple are Tensor, IntTensor. Default: ``None``

    Returns:
        (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing

            - **factorization** (*Tensor*): the factorization of size :math:`(*, m, m)`

            - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)`

            - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
              size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
              each minibatch has succeeded or failed

    Example::

        >>> A = torch.randn(2, 3, 3)
        >>> A_LU, pivots = torch.lu(A)
        >>> A_LU
        tensor([[[ 1.3506,  2.5558, -0.0816],
                 [ 0.1684,  1.1551,  0.1940],
                 [ 0.1193,  0.6189, -0.5497]],

                [[ 0.4526,  1.2526, -0.3285],
                 [-0.7988,  0.7175, -0.9701],
                 [ 0.2634, -0.9255, -0.3459]]])
        >>> pivots
        tensor([[ 3,  3,  3],
                [ 3,  3,  3]], dtype=torch.int32)
        >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
        >>> if info.nonzero().size(0) == 0:
        ...   print('LU factorization succeeded for all samples!')
        LU factorization succeeded for all samples!
    """
    # If get_infos is True, then we don't need to check for errors and vice versa
    result = torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
    if out is not None:
        if not isinstance(out, (tuple, list)):
            raise TypeError("argument 'out' must be tuple of Tensors, not {}"
                            .format(type(out).__name__))
        if len(out) - int(get_infos) != 2:
            raise TypeError("expected tuple of {} elements but got {}"
                            .format(2 + int(get_infos), len(out)))
        return (out[i].resize_as_(result[i]).copy_(result[i]) for i in range(len(out)))
    if get_infos:
        return result  # A_LU, pivots, infos
    else:
        return result[0], result[1]  # A_LU, pivots