def test_inverse_angles(float_tolerance): a = o3.rand_angles() b = o3.inverse_angles(*a) c = o3.compose_angles(*a, *b) e = o3.identity_angles(requires_grad=True) rc = o3.angles_to_matrix(*c) re = o3.angles_to_matrix(*e) assert (rc - re).abs().max() < float_tolerance # test `requires_grad` re.sum().backward() assert e[0].grad is not None
def test_compose(float_tolerance): q1 = o3.rand_quaternion(10) q2 = o3.rand_quaternion(10) q = o3.compose_quaternion(q1, q2) R1 = o3.quaternion_to_matrix(q1) R2 = o3.quaternion_to_matrix(q2) R = R1 @ R2 abc1 = o3.quaternion_to_angles(q1) abc2 = o3.quaternion_to_angles(q2) abc = o3.compose_angles(*abc1, *abc2) ax1, a1 = o3.quaternion_to_axis_angle(q1) ax2, a2 = o3.quaternion_to_axis_angle(q2) ax, a = o3.compose_axis_angle(ax1, a1, ax2, a2) R1 = o3.quaternion_to_matrix(q) R2 = R R3 = o3.angles_to_matrix(*abc) R4 = o3.axis_angle_to_matrix(ax, a) assert (R1 - R2).abs().max().median() < float_tolerance assert (R1 - R3).abs().max().median() < float_tolerance assert (R1 - R4).abs().max().median() < float_tolerance
def test_equivariance(float_tolerance): lmax = 5 irreps = o3.Irreps.spherical_harmonics(lmax) x = torch.randn(2, 3) abc = o3.rand_angles() y1 = o3.spherical_harmonics(irreps, x @ o3.angles_to_matrix(*abc).T, False) y2 = o3.spherical_harmonics(irreps, x, False) @ irreps.D_from_angles(*abc).T assert (y1 - y2).abs().max() < 10 * float_tolerance
def find_peaks(self, signal, res=100): r"""Locate peaks on the sphere Examples -------- >>> s = SphericalTensor(4, 1, -1) >>> pos = torch.tensor([ ... [4.0, 0.0, 4.0], ... [0.0, 5.0, 0.0], ... ]) >>> x = s.with_peaks_at(pos) >>> pos, val = s.find_peaks(x) >>> pos[val > 4.0].mul(10).round().abs() tensor([[ 7., 0., 7.], [ 0., 10., 0.]]) >>> val[val > 4.0].mul(10).round().abs() tensor([57., 50.]) """ x1, f1 = self.signal_on_grid(signal, res) abc = torch.tensor([pi / 2, pi / 2, pi / 2]) R = o3.angles_to_matrix(*abc) D = self.D_from_matrix(R) r_signal = D @ signal rx2, f2 = self.signal_on_grid(r_signal, res) x2 = torch.einsum('ij,baj->bai', R.T, rx2) ij = _find_peaks_2d(f1) x1p = torch.stack([x1[i, j] for i, j in ij]) f1p = torch.stack([f1[i, j] for i, j in ij]) ij = _find_peaks_2d(f2) x2p = torch.stack([x2[i, j] for i, j in ij]) f2p = torch.stack([f2[i, j] for i, j in ij]) # Union of the results mask = torch.cdist(x1p, x2p) < 2 * pi / res x = torch.cat([x1p[mask.sum(1) == 0], x2p]) f = torch.cat([f1p[mask.sum(1) == 0], f2p]) return x, f
def test_cartesian(float_tolerance): abc = o3.rand_angles(10) R = o3.angles_to_matrix(*abc) D = o3.wigner_D(1, *abc) assert (R - D).abs().max() < float_tolerance