def get_transformation_matrix(angles, translation):
    """
    Create a rotation matrix defined by the rotation angles around the x, y and z axes
    :param angles:
    :param translation:
    :return:
    """
    rtvec = torch.cat((tgm.deg2rad(angles), translation), 0)
    rtvec = torch.unsqueeze(rtvec, 0)

    transformation_matrix = tgm.rtvec_to_pose(rtvec)

    return transformation_matrix
Ejemplo n.º 2
0
    def test_deg2rad(self):
        # generate input data
        x_deg = 180. * torch.rand(2, 3, 4)

        # convert radians/degrees
        x_rad = tgm.deg2rad(x_deg)
        x_rad_to_deg = tgm.rad2deg(x_rad)

        # compute error
        error = utils.compute_mse(x_deg, x_rad_to_deg)
        self.assertAlmostEqual(error.item(), 0.0, places=4)

        # functional
        self.assertTrue(torch.allclose(x_rad, tgm.DegToRad()(x_deg)))
Ejemplo n.º 3
0
def test_deg2rad(batch_shape, device_type):
    # generate input data
    x_deg = 180. * torch.rand(batch_shape)
    x_deg = x_deg.to(torch.device(device_type))

    # convert radians/degrees
    x_rad = tgm.deg2rad(x_deg)
    x_rad_to_deg = tgm.rad2deg(x_rad)

    # compute error
    error = utils.compute_mse(x_deg, x_rad_to_deg)
    assert pytest.approx(error.item(), 0.0)

    # functional
    assert torch.allclose(x_rad, tgm.DegToRad()(x_deg))

    assert gradcheck(tgm.deg2rad, (utils.tensor_to_gradcheck_var(x_deg), ),
                     raise_exception=True)
Ejemplo n.º 4
0
def test_rad2deg(batch_shape, device_type):
    # generate input data
    x_rad = tgm.pi * torch.rand(batch_shape)
    x_rad = x_rad.to(torch.device(device_type))

    # convert radians/degrees
    x_deg = tgm.rad2deg(x_rad)
    x_deg_to_rad = tgm.deg2rad(x_deg)

    # compute error
    error = utils.compute_mse(x_rad, x_deg_to_rad)
    assert pytest.approx(error.item(), 0.0)

    # functional
    assert torch.allclose(x_deg, tgm.RadToDeg()(x_rad))

    # evaluate function gradient
    assert gradcheck(tgm.rad2deg, (utils.tensor_to_gradcheck_var(x_rad), ),
                     raise_exception=True)