def test_tensor_square_norm(self): for Rs_in in [[(1, 0), (2, 1), (4, 3)]]: with o3.torch_default_dtype(torch.float64): Rs_out, Q = rs.tensor_square(Rs_in, o3.selection_rule, normalization='component', sorted=True) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) Q1 = torch.einsum("ijk,il->ljk", (Q, D_out)) Q2 = torch.einsum("li,mj,kij->klm", (D_in, D_in, Q)) d = (Q1 - Q2).pow(2).mean().sqrt() / Q1.pow(2).mean().sqrt() self.assertLess(d, 1e-10) n = Q.size(0) M = Q.reshape(n, -1) I = torch.eye(n) d = ((M @ M.t()) - I).pow(2).mean().sqrt() self.assertLess(d, 1e-10)
def test(Rs, act): x = torch.randn(55, sum(2 * l + 1 for _, l, _ in Rs)) ac = S2Activation(Rs, act, 1000) y1 = ac(x, dim=-1) @ rs.rep(ac.Rs_out, 0, 0, 0, -1).T y2 = ac(x @ rs.rep(Rs, 0, 0, 0, -1).T, dim=-1) self.assertLess((y1 - y2).abs().max(), 1e-10)
def test_tensor_product_norm(self): for Rs_in1, Rs_in2 in [([(1, 0)], [(2, 0)]), ([(3, 1), (2, 2)], [(2, 0), (1, 1), (1, 3)])]: with o3.torch_default_dtype(torch.float64): Rs_out, Q = rs.tensor_product(Rs_in1, Rs_in2, o3.selection_rule) abc = torch.rand(3, dtype=torch.float64) D_in1 = rs.rep(Rs_in1, *abc) D_in2 = rs.rep(Rs_in2, *abc) D_out = rs.rep(Rs_out, *abc) Q1 = torch.einsum("ijk,il->ljk", (Q, D_out)) Q2 = torch.einsum("li,mj,kij->klm", (D_in1, D_in2, Q)) d = (Q1 - Q2).pow(2).mean().sqrt() / Q1.pow(2).mean().sqrt() self.assertLess(d, 1e-10) n = Q.size(0) M = Q.reshape(n, n) I = torch.eye(n, dtype=M.dtype) d = ((M @ M.t()) - I).pow(2).mean().sqrt() self.assertLess(d, 1e-10) d = ((M.t() @ M) - I).pow(2).mean().sqrt() self.assertLess(d, 1e-10)
def rotation_gated_block(self, K): """Test rotation equivariance on GatedBlock and dependencies.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] K = partial(K, RadialModel=ConstantRadialModel) act = GatedBlock(Rs_out, scalar_activation=sigmoid, gate_activation=sigmoid) conv = Convolution(K(Rs_in, act.Rs_in)) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) fea = torch.randn(1, 4, rs.dim(Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test_equivariance_wtp(Rs_in, Rs_out, n_source, n_target, n_edge): torch.set_default_dtype(torch.float64) mp = WTPConv(Rs_in, Rs_out, 3, ConstantRadialModel) features = rs.randn(n_target, Rs_in) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge, )), torch.randint(n_target, size=(n_edge, )), ]) size = (n_target, n_source) edge_r = torch.randn(n_edge, 3) if n_edge > 1: edge_r[0] = 0 out1 = mp(features, edge_index, edge_r, size=size) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out assert (out1 - out2).abs().max() < 1e-10
def test_custom_weighted_tensor_product(): torch.set_default_dtype(torch.float64) Rs_in1 = [(20, 1), (4, 2)] Rs_in2 = [(10, 0), (10, 1), (4, 2)] Rs_out = [(3, 0), (4, 1)] instr = [ (0, 1, 0, 'uvw'), (1, 2, 1, 'uuu'), (0, 1, 1, 'uvw'), ] tp = CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr) x1 = rs.randn(20, Rs_in1) x2 = rs.randn(20, Rs_in2) angles = o3.rand_angles() z1 = tp(x1, x2) @ rs.rep(Rs_out, *angles).T z2 = tp(x1 @ rs.rep(Rs_in1, *angles).T, x2 @ rs.rep(Rs_in2, *angles).T) z1.sum().backward() assert torch.allclose(z1, z2)
def parity_rotation_gated_block_parity(self, K): """Test parity and rotation equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]] K = partial(K, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K(Rs_in, act.Rs_in)) abc = torch.randn(3) rot_geo = -o3.rot(*abc) D_in = rs.rep(Rs_in, *abc, 1) D_out = rs.rep(act.Rs_out, *abc, 1) fea = torch.randn(1, 4, rs.dim(Rs_in)) geo = torch.randn(1, 4, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act( conv(torch.einsum("ij,zaj->zai", (D_in, fea)), torch.einsum("ij,zaj->zai", rot_geo, geo))) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def check_rotation_parity(batch: int = 10, n_atoms: int = 25): # Setup the network. K = partial(Kernel, RadialModel=ConstantRadialModel) Rs_in = [(1, 0, +1)] act = GatedBlockParity( Rs_scalars=[(4, 0, +1)], act_scalars=[(-1, relu)], Rs_gates=[(8, 0, +1)], act_gates=[(-1, tanh)], Rs_nonscalars=[(4, 1, -1), (4, 2, +1)] ) conv = Convolution(K, Rs_in, act.Rs_in) Rs_out = act.Rs_out # Setup the data. The geometry, input features, and output features must all rotate and observe parity. abc = torch.randn(3) # Rotation seed of euler angles. rot_geo = -o3.rot(*abc) # Negative because geometry has odd parity. i.e. improper rotation. D_in = rs.rep(Rs_in, *abc, parity=1) D_out = rs.rep(Rs_out, *abc, parity=1) feat = torch.randn(batch, n_atoms, rs.dim(Rs_in)) # Transforms with wigner D matrix and parity. geo = torch.randn(batch, n_atoms, 3) # Transforms with rotation matrix and parity. # Test equivariance. F = act(conv(feat, geo)) RF = torch.einsum("ij,zkj->zki", D_out, F) FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t())) return (RF - FR).norm() < 10e-5 * RF.norm()
def check_rotation(batch: int = 10, n_atoms: int = 25): # Setup the network. K = partial(Kernel, RadialModel=ConstantRadialModel) Rs_in = [(1, 0), (1, 1)] Rs_out = [(1, 0), (1, 1), (1, 2)] act = GatedBlock( Rs_out, scalar_activation=sigmoid, gate_activation=absolute, ) conv = Convolution(K, Rs_in, act.Rs_in) # Setup the data. The geometry, input features, and output features must all rotate. abc = torch.randn(3) # Rotation seed of euler angles. rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) feat = torch.randn(batch, n_atoms, rs.dim(Rs_in)) # Transforms with wigner D matrix geo = torch.randn(batch, n_atoms, 3) # Transforms with rotation matrix. # Test equivariance. F = act(conv(feat, geo)) RF = torch.einsum("ij,zkj->zki", D_out, F) FR = act(conv(feat @ D_in.t(), geo @ rot_geo.t())) return (RF - FR).norm() < 10e-5 * RF.norm()
def test(Rs, act): x = rs.randn(2, Rs) ac = S2Activation(Rs, act, 200, lmax_out=lmax + 1, random_rot=True) a, b, c, p = *torch.rand(3), 1 y1 = ac(x) @ rs.rep(ac.Rs_out, a, b, c, p).T y2 = ac(x @ rs.rep(Rs, a, b, c, p).T) self.assertLess((y1 - y2).abs().max(), 3e-4 * y1.abs().max())
def test(Rs, ac): x = torch.randn(99, rs.dim(Rs)) a, b = torch.rand(2) c = 1 y1 = ac(x, dim=-1) @ rs.rep(ac.Rs_out, a, b, c).T y2 = ac(x @ rs.rep(Rs, a, b, c).T, dim=-1) y3 = ac(x @ rs.rep(Rs, -c, -b, -a).T, dim=-1) self.assertLess((y1 - y2).norm(), (y1 - y3).norm())
def test_equivariance(Rs_in, Rs_out, n_source, n_target, n_edge): torch.set_default_dtype(torch.float64) mp = Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel)) groups = 4 mp_group = Convolution( GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups)) features = rs.randn(n_target, Rs_in) features2 = rs.randn(n_target, Rs_in * groups) r_source = torch.randn(n_source, 3) r_target = torch.randn(n_target, 3) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge, )), torch.randint(n_target, size=(n_edge, )), ]) size = (n_target, n_source) if n_edge == 0: edge_r = torch.zeros(0, 3) else: edge_r = torch.stack( [r_target[j] - r_source[i] for i, j in edge_index.T]) print(features.shape, edge_index.shape, edge_r.shape, size) out1 = mp(features, edge_index, edge_r, size=size) out1_groups = mp(features2, edge_index, edge_r, size=size, groups=groups) out1_kernel_groups = mp_group(features2, edge_index, edge_r, size=size, groups=groups) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) D_in_groups = rs.rep(Rs_in * groups, *angles) D_out_groups = rs.rep(Rs_out * groups, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out out2_groups = mp(features2 @ D_in_groups.T, edge_index, edge_r @ R.T, size=size, groups=groups) @ D_out_groups out2_kernel_groups = mp_group(features2 @ D_in_groups.T, edge_index, edge_r @ R.T, size=size, groups=groups) @ D_out_groups assert (out1 - out2).abs().max() < 1e-10 assert (out1_groups - out2_groups).abs().max() < 1e-10 assert (out1_kernel_groups - out2_kernel_groups).abs().max() < 1e-10
def test_equiv(self): torch.set_default_dtype(torch.float64) Rs_in = [(5, 0), (15, 1), (5, 0), (10, 2)] Rs_out = [(2, 0), (1, 1), (1, 2), (3, 0)] lin = Linear(Rs_in, Rs_out) f_in = torch.randn(100, rs.dim(Rs_in)) angles = torch.randn(3) y1 = lin(torch.einsum('ij,zj->zi', rs.rep(Rs_in, *angles), f_in)) y2 = torch.einsum('ij,zj->zi', rs.rep(Rs_out, *angles), lin(f_in)) self.assertLess((y1 - y2).abs().max(), 1e-10)
def test(act, normalization): x = rs.randn(2, Rs, normalization=normalization) ac = S2Activation(Rs, act, 120, normalization=normalization, lmax_out=6, random_rot=True) a, b, c = o3.rand_angles() y1 = ac(x) @ rs.rep(ac.Rs_out, a, b, c, 1).T y2 = ac(x @ rs.rep(Rs, a, b, c, 1).T) self.assertLess((y1 - y2).abs().max(), 1e-10 * y1.abs().max())
def test_equivariance(): torch.set_default_dtype(torch.float64) n_edge = 3 n_source = 4 n_target = 2 Rs_in = [(3, 0), (0, 1)] Rs_mid1 = [(5, 0), (1, 1)] Rs_mid2 = [(5, 0), (1, 1), (1, 2)] Rs_out = [(5, 1), (3, 2)] convolution = lambda Rs_in, Rs_out: Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel)) convolution_groups = lambda Rs_in, Rs_out: Convolution( GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups)) groups = 4 mp = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution) mp_groups = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution_groups) features = rs.randn(n_target, Rs_in) r_source = torch.randn(n_source, 3) r_target = torch.randn(n_target, 3) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge,)), torch.randint(n_target, size=(n_edge,)), ]) size = (n_target, n_source) if n_edge == 0: edge_r = torch.zeros(0, 3) else: edge_r = torch.stack([ r_target[j] - r_source[i] for i, j in edge_index.T ]) print(features.shape, edge_index.shape, edge_r.shape, size) out1 = mp(features, edge_index, edge_r, size=size) out1_groups = mp_groups(features, edge_index, edge_r, size=size) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out out2_groups = mp_groups(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out assert (out1 - out2).abs().max() < 1e-10 assert (out1_groups - out2_groups).abs().max() < 1e-10
def parity_kernel(self, K): """Test parity equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0, 1), (2, 1, 1), (2, 2, -1)] Rs_out = [(2, 0, -1), (2, 1, 1), (2, 2, 1)] k = K(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) D_in = rs.rep(Rs_in, 0, 0, 0, 1) D_out = rs.rep(Rs_out, 0, 0, 0, 1) W1 = D_out @ k(r) # [i, j] W2 = k(-r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def test_tensor_square_equivariance(self): with o3.torch_default_dtype(torch.float64): Rs_in = [(3, 0), (2, 1), (5, 2)] sq = TensorSquare(Rs_in, o3.selection_rule) x = rs.randn(Rs_in) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(sq.Rs_out, *abc) y1 = sq(D_in @ x) y2 = D_out @ sq(x) self.assertLess((y1 - y2).abs().max(), 1e-7 * y1.abs().max())
def rotation_kernel(self, K): """Test rotation equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] k = K(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) abc = torch.randn(3) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) W1 = D_out @ k(r) # [i, j] W2 = k(o3.rot(*abc) @ r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def find_peaks(self, res=100): x1, f1 = self.signal_on_grid(res) abc = pi / 2, pi / 2, pi / 2 R = o3.rot(*abc) D = rs.rep(self.Rs, *abc) rtensor = SphericalTensor(D @ self.signal) rx2, f2 = rtensor.signal_on_grid(res) x2 = torch.einsum('ij,baj->bai', R.T, rx2) ij = _find_peaks_2d(f1) x1p = torch.stack([x1[i, j] for i, j in ij]) f1p = torch.stack([f1[i, j] for i, j in ij]) ij = _find_peaks_2d(f2) x2p = torch.stack([x2[i, j] for i, j in ij]) f2p = torch.stack([f2[i, j] for i, j in ij]) # Union of the results mask = torch.cdist(x1p, x2p) < 2 * pi / res x = torch.cat([x1p[mask.sum(1) == 0], x2p]) f = torch.cat([f1p[mask.sum(1) == 0], f2p]) return x, f
def test_equivariance_s2network(self): with torch_default_dtype(torch.float64): mul = 3 Rs_in = [(mul, l) for l in range(3 + 1)] Rs_out = [(mul, l) for l in range(3 + 1)] net = S2Network(Rs_in, mul, lmax=4, Rs_out=Rs_out) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) fea = torch.randn(10, rs.dim(Rs_in)) x1 = torch.einsum("ij,zj->zi", D_out, net(fea)) x2 = net(torch.einsum("ij,zj->zi", D_in, fea)) self.assertLess((x1 - x2).norm(), 1e-3 * x1.norm())
def test_equivariance_s2parity_network(): torch.set_default_dtype(torch.float64) mul = 3 Rs_in = [(mul, l, -1) for l in range(3 + 1)] Rs_out = [(mul, l, 1) for l in range(3 + 1)] net = S2ParityNetwork(Rs_in, mul, lmax=3, Rs_out=Rs_out) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc, 1) D_out = rs.rep(Rs_out, *abc, 1) fea = rs.randn(10, Rs_in) x1 = torch.einsum("ij,zj->zi", D_out, net(fea)) x2 = net(torch.einsum("ij,zj->zi", D_in, fea)) assert (x1 - x2).norm() < 1e-3 * x1.norm()
def forward(self, features): ''' :param features: [..., l * m] ''' if self.random_rot: abc = o3.rand_angles() features = torch.einsum('ij,...j->...i', rs.rep(self.Rs_in, *abc), features) features = self.to_s2(features) # [..., beta, alpha] features = self.act(features) features = self.from_s2(features) if self.random_rot: features = torch.einsum('ij,...j->...i', rs.rep(self.Rs_out, *abc).T, features) return features
def parity_rotation_linear(self, L): """Test parity and rotation equivariance on Linear.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]] Rs_out = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]] lin = L(Rs_in, Rs_out) abc = torch.randn(3) D_in = rs.rep(lin.Rs_in, *abc, 1) D_out = rs.rep(lin.Rs_out, *abc, 1) fea = torch.randn(rs.dim(Rs_in)) x1 = torch.einsum("ij,j->i", D_out, lin(fea)) x2 = lin(torch.einsum("ij,j->i", D_in, fea)) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test_weighted_tensor_product(): torch.set_default_dtype(torch.float64) Rs_in1 = rs.simplify([1] * 20 + [2] * 4) Rs_in2 = rs.simplify([0] * 10 + [1] * 10 + [2] * 5) Rs_out = rs.simplify([0] * 3 + [1] * 4) tp = WeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, groups=2) x1 = rs.randn(20, Rs_in1) x2 = rs.randn(20, Rs_in2) angles = o3.rand_angles() z1 = tp(x1, x2) @ rs.rep(Rs_out, *angles).T z2 = tp(x1 @ rs.rep(Rs_in1, *angles).T, x2 @ rs.rep(Rs_in2, *angles).T) z1.sum().backward() assert torch.allclose(z1, z2)
def test_equivariance_gatedconvnetwork(self): with torch_default_dtype(torch.float64): mul = 3 Rs_in = [(mul, l) for l in range(3 + 1)] Rs_out = [(mul, l) for l in range(3 + 1)] net = GatedConvNetwork(Rs_in, [(10, 0), (1, 1), (1, 2), (1, 3)], Rs_out) abc = torch.randn(3) rot_geo = o3.rot(*abc) D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(Rs_out, *abc) fea = torch.randn(1, 10, rs.dim(Rs_in)) geo = torch.randn(1, 10, 3) x1 = torch.einsum("ij,zaj->zai", D_out, net(fea, geo)) x2 = net(torch.einsum("ij,zaj->zai", D_in, fea), torch.einsum("ij,zaj->zai", rot_geo, geo)) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test_norm_activation(Rs, normalization, dtype): with o3.torch_default_dtype(dtype): m = NormActivation(Rs, swish, normalization=normalization) D = rs.rep(Rs, *o3.rand_angles()) x = rs.randn(2, Rs, normalization=normalization) y1 = m(x) y1 = torch.einsum('ij,zj->zi', D, y1) x2 = torch.einsum('ij,zj->zi', D, x) y2 = m(x2) assert (y1 - y2).abs().max() < { torch.float32: 1e-5, torch.float64: 1e-10 }[dtype]
def test_s2conv_network(): torch.set_default_dtype(torch.float64) lmax = 3 Rs = [(1, l, 1) for l in range(lmax + 1)] model = S2ConvNetwork(Rs, 4, Rs, lmax) features = rs.randn(1, 4, Rs) geometry = torch.randn(1, 4, 3) output = model(features, geometry) angles = o3.rand_angles() D = rs.rep(Rs, *angles, 1) R = -o3.rot(*angles) ein = torch.einsum output2 = ein('ij,zaj->zai', D.T, model(ein('ij,zaj->zai', D, features), ein('ij,zaj->zai', R, geometry))) assert (output - output2).abs().max() < 1e-10 * output.abs().max()
def _rep_zx(Rs, dtype, device): o = torch.zeros((), dtype=dtype, device=device) return rs.rep(Rs, o, -math.pi / 2, o)