Beispiel #1
0
def ceil_ste(x: Tensor) -> Tensor:
    """
    Function that implements :func:`torch.ceil` with a straight-through gradient estimator.

    Notes:
        Wrapper for either :func:`~brevitas.function.autograd_ste_ops.ceil_ste_impl` (with env
        ``BREVITAS_JIT=0``) or its native just-in-time compiled variant (with ``BREVITAS_JIT=1``).

    Examples:
        >>> x = torch.tensor([1.7, -1.7], requires_grad=True)
        >>> y = ceil_ste(x)
        >>> y
        tensor([ 2., -1.], grad_fn=<CeilSteFnBackward>)
        >>> grad = torch.tensor([0.1, -0.1])
        >>> y.backward(grad)
        >>> (x.grad == grad).all().item()
        True
    """
    return fn_prefix.ceil_ste_impl(x)
Beispiel #2
0
def ceil_ste(x: Tensor) -> Tensor:
    """ Perform ceil operation with Straight Trough Estimation (STE) of the Gradient

    This operation behaves like an identity on the backward pass.
    For Pytorch version >= 1.3.0, the STE operator is implemented in C++ using the
    torch::autograd::Function class and compiled. At execution time, the Just-In-Time (JIT) compiler of Pytorch
    is used to speed-up the computation.
    For Pytorch version < 1.3.0, the STE operator is implemented using the
    torch.autograd.Function class in python, and the JIT cannot be used.

    Parameters
    ----------
    x : Tensor
        Tensor on which to apply the ceil operation

    Returns
    -------
    Tensor
        Tensor after applying ceil operation.
        When backpropagating through this value, a straight through estimator is applied.

    """
    return fn_prefix.ceil_ste_impl(x)
Beispiel #3
0
def ceil_ste(x: Tensor) -> Tensor:
    """
    Wrapper for either :func:`~brevitas.function.autograd_ste_ops.ceil_ste_impl` (with env
    BREVITAS_JIT=0) or its native just-in-time compiled variant (with BREVITAS_JIT=1).
    """
    return fn_prefix.ceil_ste_impl(x)