def test_straight_through_gradient():
    inputs = torch.autograd.Variable(torch.tensor([1.1, 0.9]), requires_grad=True)
    outputs = nfpq.StraightThroughRound().apply(inputs)
    outputs.sum().backward()
    assert np.isclose(inputs._grad, [1, 1]).all()

    # when Round is applied without straight through, there is no gradient
    inputs.grad.detach_()
    inputs.grad.zero_()
    output_nost = inputs.round()
    assert np.isclose(inputs._grad, [0, 0]).all()
def test_straight_through_gradient():
    inputs = torch.autograd.Variable(torch.tensor([1.1, 0.9]), requires_grad=True)
    outputs = nfpq.StraightThroughRound().apply(inputs)
    outputs.sum().backward()
    assert np.isclose(inputs._grad, [1, 1]).all()
 
    # when Round is applied without straight through, there is no gradient
    inputs.grad.detach_()
    inputs.grad.zero_()
    output_nost = inputs.round()
    assert np.isclose(inputs._grad, [0, 0]).all()

    # Stochastic rounding
    inputs = torch.autograd.Variable(torch.Tensor(100).fill_(0.5), requires_grad=True)
    outputs = nfpq.StraightThroughStochasticRound().apply(inputs)
    assert outputs.max()>0.9 and outputs.min() < 0.1