def test_funcional_shapes(self, device, dtype, shapes): input_shape = shapes + (3, ) t = torch.rand(*input_shape, device=device, dtype=dtype) # Feed batches cross_product_matrices = [] for i in range(t.shape[1]): cross_product_matrices.append( epi.cross_product_matrix(t[:, i, ...])) cross_product_matrix_parts = torch.stack(cross_product_matrices, dim=1) # Feed one-shot cross_product_matrix_whole = epi.cross_product_matrix(t) assert_close(cross_product_matrix_parts, cross_product_matrix_whole)
def test_mean_std(self, device, dtype): vec = torch.tensor([[1., 2., 3.]], device=device, dtype=dtype) cross_product_matrix = epi.cross_product_matrix(vec) assert_allclose(cross_product_matrix[..., 0, 1], -cross_product_matrix[..., 1, 0]) assert_allclose(cross_product_matrix[..., 0, 2], -cross_product_matrix[..., 2, 0]) assert_allclose(cross_product_matrix[..., 1, 2], -cross_product_matrix[..., 2, 1])
def test_shape(self, batch_size, device, dtype): B = batch_size vec = torch.rand(B, 3, device=device, dtype=dtype) cross_product_matrix = epi.cross_product_matrix(vec) assert cross_product_matrix.shape == (B, 3, 3)
def test_smoke(self, device, dtype): vec = torch.rand(1, 3, device=device, dtype=dtype) cross_product_matrix = epi.cross_product_matrix(vec) assert cross_product_matrix.shape == (1, 3, 3)
def test_shapes(self, device, dtype, shapes): input_shape = shapes + (3, ) output_shape = shapes + (3, 3) t = torch.rand(*input_shape, device=device, dtype=dtype) cross_product_matrix = epi.cross_product_matrix(t) assert cross_product_matrix.shape == output_shape