Ejemplo n.º 1
0
 def test_ste_round_grads(self):
     x = torch.rand(24, requires_grad=True)
     y = ste_round(x)
     y.backward(x)
     assert x.grad is not None
     assert (x.grad == x).all()
Ejemplo n.º 2
0
 def test_ste_round_ok(self):
     x = torch.rand(16)
     assert (ste_round(x) == torch.round(x)).all()