Example #1
0
def floor_ste(x: Tensor) -> Tensor:
    """
    Function that implements :func:`torch.floor` with a straight-through gradient estimator.

    Notes:
        Wrapper for either :func:`~brevitas.function.autograd_ste_ops.floor_ste_impl` (with env
        ``BREVITAS_JIT=0``) or its native just-in-time compiled variant (with ``BREVITAS_JIT=1``).

    Examples:
        >>> x = torch.tensor([1.7, -1.7], requires_grad=True)
        >>> y = floor_ste(x)
        >>> y
        tensor([ 1., -2.], grad_fn=<FloorSteFnBackward>)
        >>> grad = torch.tensor([0.1, -0.1])
        >>> y.backward(grad)
        >>> (x.grad == grad).all().item()
        True
    """
    return fn_prefix.floor_ste_impl(x)
Example #2
0
def floor_ste(x: Tensor) -> Tensor:
    """ Perform floor operation with Straight Trough Estimation (STE) of the Gradient

    This operation behaves like an identity on the backward pass.
    For Pytorch version >= 1.3.0, the STE operator is implemented in C++ using the
    torch::autograd::Function class and compiled. At execution time, the Just-In-Time (JIT) compiler of Pytorch
    is used to speed-up the computation.
    For Pytorch version < 1.3.0, the STE operator is implemented using the
    torch.autograd.Function class in python, and the JIT cannot be used.

    Parameters
    ----------
    x : Tensor
        Tensor on which to apply the floor operation

    Returns
    -------
    Tensor
        Tensor after applying floor operation.
        When backpropagating through this value, a straight through estimator is applied.

    """
    return fn_prefix.floor_ste_impl(x)
Example #3
0
def floor_ste(x: Tensor) -> Tensor:
    """
    Wrapper for either :func:`~brevitas.function.autograd_ste_ops.floor_ste_impl` (with env
    BREVITAS_JIT=0) or its native just-in-time compiled variant (with BREVITAS_JIT=1).
    """
    return fn_prefix.floor_ste_impl(x)