Ejemplo n.º 1
0
    def test_funcional_shapes(self, device, dtype, shapes):
        input_shape = shapes + (3, )
        t = torch.rand(*input_shape, device=device, dtype=dtype)

        # Feed batches
        cross_product_matrices = []
        for i in range(t.shape[1]):
            cross_product_matrices.append(
                epi.cross_product_matrix(t[:, i, ...]))
        cross_product_matrix_parts = torch.stack(cross_product_matrices, dim=1)

        # Feed one-shot
        cross_product_matrix_whole = epi.cross_product_matrix(t)

        assert_close(cross_product_matrix_parts, cross_product_matrix_whole)
Ejemplo n.º 2
0
 def test_mean_std(self, device, dtype):
     vec = torch.tensor([[1., 2., 3.]], device=device, dtype=dtype)
     cross_product_matrix = epi.cross_product_matrix(vec)
     assert_allclose(cross_product_matrix[..., 0, 1],
                     -cross_product_matrix[..., 1, 0])
     assert_allclose(cross_product_matrix[..., 0, 2],
                     -cross_product_matrix[..., 2, 0])
     assert_allclose(cross_product_matrix[..., 1, 2],
                     -cross_product_matrix[..., 2, 1])
Ejemplo n.º 3
0
 def test_shape(self, batch_size, device, dtype):
     B = batch_size
     vec = torch.rand(B, 3, device=device, dtype=dtype)
     cross_product_matrix = epi.cross_product_matrix(vec)
     assert cross_product_matrix.shape == (B, 3, 3)
Ejemplo n.º 4
0
 def test_smoke(self, device, dtype):
     vec = torch.rand(1, 3, device=device, dtype=dtype)
     cross_product_matrix = epi.cross_product_matrix(vec)
     assert cross_product_matrix.shape == (1, 3, 3)
Ejemplo n.º 5
0
 def test_shapes(self, device, dtype, shapes):
     input_shape = shapes + (3, )
     output_shape = shapes + (3, 3)
     t = torch.rand(*input_shape, device=device, dtype=dtype)
     cross_product_matrix = epi.cross_product_matrix(t)
     assert cross_product_matrix.shape == output_shape