def test_ste_round_grads(self): x = torch.rand(24, requires_grad=True) y = ste_round(x) y.backward(x) assert x.grad is not None assert (x.grad == x).all()
def test_ste_round_ok(self): x = torch.rand(16) assert (ste_round(x) == torch.round(x)).all()