def test_smoke(self, device, dtype):
     input = torch.zeros(1, 1, 2, 3, device=device, dtype=dtype)
     m = kornia.SpatialSoftArgmax2d()
     assert m(input).shape == (1, 1, 2)
Example #2
0
 def test_smoke_batch(self, device):
     input = torch.zeros(2, 1, 2, 3).to(device)
     m = kornia.SpatialSoftArgmax2d()
     assert m(input).shape == (2, 1, 2)
 def test_smoke(self):
     input = torch.zeros(1, 1, 2, 3)
     m = kornia.SpatialSoftArgmax2d()
     assert m(input).shape == (1, 1, 2)