Exemplo n.º 1
0
def set_default_dtype(d):
    r"""Sets the default floating point dtype to :attr:`d`.
    This dtype is:

    1. The inferred dtype for python floats in :func:`torch.tensor`.
    2. Used to infer dtype for python complex numbers. The default complex dtype is set to
       ``torch.complex128`` if default floating point dtype is ``torch.float64``,
       otherwise it's set to ``torch.complex64``

    The default floating point dtype is initially ``torch.float32``.

    Args:
        d (:class:`torch.dtype`): the floating point dtype to make the default

    Example:
        >>> # initial default for floating point is torch.float32
        >>> torch.tensor([1.2, 3]).dtype
        torch.float32
        >>> # initial default for floating point is torch.complex64
        >>> torch.tensor([1.2, 3j]).dtype
        torch.complex64
        >>> torch.set_default_dtype(torch.float64)
        >>> torch.tensor([1.2, 3]).dtype    # a new floating point tensor
        torch.float64
        >>> torch.tensor([1.2, 3j]).dtype   # a new complex tensor
        torch.complex128

    """
    _C._set_default_dtype(d)
Exemplo n.º 2
0
def set_default_dtype(d):
    r"""

    Sets the default floating point dtype to :attr:`d`. Supports torch.float32
    and torch.float64 as inputs. Other dtypes may be accepted without complaint
    but are not supported and are unlikely to work as expected.

    When PyTorch is initialized its default floating point dtype is torch.float32,
    and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like
    type inference. The default floating point dtype is used to:

    1. Implicitly determine the default complex dtype. When the default floating point
       type is float32 the default complex dtype is complex64, and when the default
       floating point type is float64 the default complex type is complex128.
    2. Infer the dtype for tensors constructed using Python floats or complex Python
       numbers. See examples below.
    3. Determine the result of type promotion between bool and integer tensors and
       Python floats and complex Python numbers.

    Args:
        d (:class:`torch.dtype`): the floating point dtype to make the default.
                                  Either torch.float32 or torch.float64.

    Example:
        >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
        >>> # initial default for floating point is torch.float32
        >>> # Python floats are interpreted as float32
        >>> torch.tensor([1.2, 3]).dtype
        torch.float32
        >>> # initial default for floating point is torch.complex64
        >>> # Complex Python numbers are interpreted as complex64
        >>> torch.tensor([1.2, 3j]).dtype
        torch.complex64

        >>> torch.set_default_dtype(torch.float64)

        >>> # Python floats are now interpreted as float64
        >>> torch.tensor([1.2, 3]).dtype    # a new floating point tensor
        torch.float64
        >>> # Complex Python numbers are now interpreted as complex128
        >>> torch.tensor([1.2, 3j]).dtype   # a new complex tensor
        torch.complex128

    """
    _C._set_default_dtype(d)
Exemplo n.º 3
0
def set_default_dtype(d):
    r"""Sets the default floating point dtype to :attr:`d`. This type will be
    used as default floating point type for type inference in
    :func:`torch.tensor`.

    The default floating point dtype is initially ``torch.float32``.

    Args:
        d (:class:`torch.dtype`): the floating point dtype to make the default

    Example::

        >>> torch.tensor([1.2, 3]).dtype           # initial default for floating point is torch.float32
        torch.float32
        >>> torch.set_default_dtype(torch.float64)
        >>> torch.tensor([1.2, 3]).dtype           # a new floating point tensor
        torch.float64

    """
    _C._set_default_dtype(d)