Ejemplo n.º 1
0
 def test_cat(self):
     tensor1 = MaskedTensor(torch.tensor([1, 2]))
     tensor2 = MaskedTensor(torch.tensor([3, 4]))
     stack = MaskedTorch.cat([tensor1, tensor2], dim=0)
     res = MaskedTensor(torch.tensor([[1, 2, 3, 4]]))
     self.assertTrue(torch.all(stack == res),
                     msg="Cat is not equal to expected")
Ejemplo n.º 2
0
 def test_not_implemented_method(self):
     tensor = MaskedTensor(tensor=torch.tensor([1, 2, 3]))
     torch_sum = MaskedTorch.sum(tensor)
     self.assertEqual(torch_sum, torch.tensor(6))
Ejemplo n.º 3
0
 def test_zeros_mask_value(self):
     zeros = MaskedTorch.zeros(1, 2, 3)
     self.assertTrue(torch.all(zeros.mask == 0),
                     msg="Zeros mask are not all zeros")
Ejemplo n.º 4
0
 def test_zeros_tensor_type_bool(self):
     dtype = torch.bool
     zeros = MaskedTorch.zeros(1, 2, 3, dtype=dtype)
     self.assertEqual(zeros.tensor.dtype, dtype)
Ejemplo n.º 5
0
 def test_zeros_tensor_value(self):
     zeros = MaskedTorch.zeros(1, 2, 3)
     self.assertTrue(torch.all(zeros == 0), msg="Zeros are not all zeros")
Ejemplo n.º 6
0
 def test_zeros_tensor_shape(self):
     zeros = MaskedTorch.zeros(1, 2, 3)
     self.assertEqual(zeros.shape, (1, 2, 3))
Ejemplo n.º 7
0
 def distance(self, p1s: MaskedTensor, p2s: MaskedTensor) -> MaskedTensor:
     diff = p1s - p2s  # (..., Len, Dims)
     square = diff.pow_(2)
     sum_squares = square.sum(dim=-1)
     return MaskedTorch.sqrt(sum_squares)