def test_custom_weighted_tensor_product(): torch.set_default_dtype(torch.float64) Rs_in1 = [(20, 1), (4, 2)] Rs_in2 = [(10, 0), (10, 1), (4, 2)] Rs_out = [(3, 0), (4, 1)] instr = [ (0, 1, 0, 'uvw'), (1, 2, 1, 'uuu'), (0, 1, 1, 'uvw'), ] tp = CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr) x1 = rs.randn(20, Rs_in1) x2 = rs.randn(20, Rs_in2) angles = o3.rand_angles() z1 = tp(x1, x2) @ rs.rep(Rs_out, *angles).T z2 = tp(x1 @ rs.rep(Rs_in1, *angles).T, x2 @ rs.rep(Rs_in2, *angles).T) z1.sum().backward() assert torch.allclose(z1, z2)
def test_equivariance(Rs_in, Rs_out, n_source, n_target, n_edge): torch.set_default_dtype(torch.float64) mp = Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel)) groups = 4 mp_group = Convolution( GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups)) features = rs.randn(n_target, Rs_in) features2 = rs.randn(n_target, Rs_in * groups) r_source = torch.randn(n_source, 3) r_target = torch.randn(n_target, 3) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge, )), torch.randint(n_target, size=(n_edge, )), ]) size = (n_target, n_source) if n_edge == 0: edge_r = torch.zeros(0, 3) else: edge_r = torch.stack( [r_target[j] - r_source[i] for i, j in edge_index.T]) print(features.shape, edge_index.shape, edge_r.shape, size) out1 = mp(features, edge_index, edge_r, size=size) out1_groups = mp(features2, edge_index, edge_r, size=size, groups=groups) out1_kernel_groups = mp_group(features2, edge_index, edge_r, size=size, groups=groups) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) D_in_groups = rs.rep(Rs_in * groups, *angles) D_out_groups = rs.rep(Rs_out * groups, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out out2_groups = mp(features2 @ D_in_groups.T, edge_index, edge_r @ R.T, size=size, groups=groups) @ D_out_groups out2_kernel_groups = mp_group(features2 @ D_in_groups.T, edge_index, edge_r @ R.T, size=size, groups=groups) @ D_out_groups assert (out1 - out2).abs().max() < 1e-10 assert (out1_groups - out2_groups).abs().max() < 1e-10 assert (out1_kernel_groups - out2_kernel_groups).abs().max() < 1e-10
def test_learnable_tensor_product_normalization(): Rs_in1 = [2, 0, 4] Rs_in2 = [2, 3] Rs_out = [0, 2, 4, 5] m = LearnableTensorProduct(Rs_in1, Rs_in2, Rs_out) x1 = rs.randn(1000, Rs_in1) x2 = rs.randn(1000, Rs_in2) y = m(x1, x2) assert y.var().log10().abs() < 1.5, y.var().item()
def test_tensor_product_in_in_normalization_norm(Rs_in1, Rs_in2): with o3.torch_default_dtype(torch.float64): tp = rs.TensorProduct(Rs_in1, Rs_in2, o3.selection_rule, normalization='norm') x1 = rs.randn(10, Rs_in1, normalization='norm') x2 = rs.randn(10, Rs_in2, normalization='norm') n = Norm(tp.Rs_out, normalization='norm') x = n(tp(x1, x2)).mean(0) assert (x.log10().abs() < 1).all()
def test_equivariance_wtp(Rs_in, Rs_out, n_source, n_target, n_edge): torch.set_default_dtype(torch.float64) mp = WTPConv(Rs_in, Rs_out, 3, ConstantRadialModel) features = rs.randn(n_target, Rs_in) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge, )), torch.randint(n_target, size=(n_edge, )), ]) size = (n_target, n_source) edge_r = torch.randn(n_edge, 3) if n_edge > 1: edge_r[0] = 0 out1 = mp(features, edge_index, edge_r, size=size) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out assert (out1 - out2).abs().max() < 1e-10
def parity_gated_block_parity(self, K): """Test parity equivariance on GatedBlockParity and dependencies.""" with torch_default_dtype(torch.float64): mul = 2 Rs_in = [(mul, l, p) for l in range(3 + 1) for p in [-1, 1]] K = partial(K, RadialModel=ConstantRadialModel) scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, absolute)] rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1), (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)] n = 3 * mul gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)] act = GatedBlockParity(*scalars, *gates, rs_nonscalars) conv = Convolution(K(Rs_in, act.Rs_in)) D_in = rs.rep(Rs_in, 0, 0, 0, 1) D_out = rs.rep(act.Rs_out, 0, 0, 0, 1) fea = rs.randn(1, 3, Rs_in) geo = torch.randn(1, 3, 3) x1 = torch.einsum("ij,zaj->zai", (D_out, act(conv(fea, geo)))) x2 = act(conv(torch.einsum("ij,zaj->zai", (D_in, fea)), -geo)) self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
def test(Rs, act): x = rs.randn(2, Rs) ac = S2Activation(Rs, act, 200, lmax_out=lmax + 1, random_rot=True) a, b, c, p = *torch.rand(3), 1 y1 = ac(x) @ rs.rep(ac.Rs_out, a, b, c, p).T y2 = ac(x @ rs.rep(Rs, a, b, c, p).T) self.assertLess((y1 - y2).abs().max(), 3e-4 * y1.abs().max())
def test_learnable_tensor_square_normalization(): Rs_in = [1, 2, 3, 4] Rs_out = [0, 2, 4, 5] m = LearnableTensorSquare(Rs_in, Rs_out) y = m(rs.randn(1000, Rs_in)) assert y.var().log10().abs() < 1.5, y.var().item()
def test_that_it_runs(Rs, p, dtype): x = rs.randn(10, Rs, dtype=dtype) m = Dropout(Rs, p=p) m.train(True) y = m(x) assert ((y == x / (1 - p)) | (y == 0)).all() m.train(False) assert (m(x) == x).all()
def test_tensor_product_equal_TensorProduct(): with o3.torch_default_dtype(torch.float64): Rs_1 = [(3, 0), (2, 1), (5, 2)] Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)] Rs_out, m = rs.tensor_product(Rs_1, Rs_2, o3.selection_rule, sorted=True) mul = rs.TensorProduct(Rs_1, Rs_2, o3.selection_rule) x1 = rs.randn(1, Rs_1) x2 = rs.randn(1, Rs_2) y1 = mul(x1, x2) y2 = torch.einsum('zi,zj->ijz', x1, x2) y2 = (m @ y2.reshape(rs.dim(Rs_1) * rs.dim(Rs_2), -1)).T assert rs.dim(Rs_out) == y1.shape[1] assert (y1 - y2).abs().max() < 1e-10 * y1.abs().max()
def test_weighted_tensor_product(): torch.set_default_dtype(torch.float64) Rs_in1 = rs.simplify([1] * 20 + [2] * 4) Rs_in2 = rs.simplify([0] * 10 + [1] * 10 + [2] * 5) Rs_out = rs.simplify([0] * 3 + [1] * 4) tp = WeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, groups=2) x1 = rs.randn(20, Rs_in1) x2 = rs.randn(20, Rs_in2) angles = o3.rand_angles() z1 = tp(x1, x2) @ rs.rep(Rs_out, *angles).T z2 = tp(x1 @ rs.rep(Rs_in1, *angles).T, x2 @ rs.rep(Rs_in2, *angles).T) z1.sum().backward() assert torch.allclose(z1, z2)
def test_normalization(self): with o3.torch_default_dtype(torch.float64): lmax = 5 res = (20, 30) for normalization in ['component', 'norm']: to = s2grid.ToS2Grid(lmax, res, normalization=normalization) x = rs.randn(50, [(1, l) for l in range(lmax + 1)], normalization=normalization) y = to(x) self.assertAlmostEqual(y.var().item(), 1, delta=0.2)
def test_tensor_product_left_right(): with o3.torch_default_dtype(torch.float64): Rs_1 = [(3, 0), (2, 1), (5, 2)] Rs_2 = [(1, 0), (2, 1), (2, 2), (2, 0), (2, 1), (1, 2)] mul = rs.TensorProduct(Rs_1, Rs_2, o3.selection_rule) x1 = rs.randn(2, Rs_1) x2 = rs.randn(2, Rs_2) y0 = mul(x1, x2) y1 = mul(torch.einsum('zi,zj->zij', x1, x2)) assert (y0 - y1).abs().max() < 1e-10 * y0.abs().max() mul._complete = 'in1' y1 = mul(x1, x2) assert (y0 - y1).abs().max() < 1e-10 * y0.abs().max() mul._complete = 'in2' y1 = mul(x1, x2) assert (y0 - y1).abs().max() < 1e-10 * y0.abs().max()
def test(act, normalization): x = rs.randn(2, Rs, normalization=normalization) ac = S2Activation(Rs, act, 120, normalization=normalization, lmax_out=6, random_rot=True) a, b, c = o3.rand_angles() y1 = ac(x) @ rs.rep(ac.Rs_out, a, b, c, 1).T y2 = ac(x @ rs.rep(Rs, a, b, c, 1).T) self.assertLess((y1 - y2).abs().max(), 1e-10 * y1.abs().max())
def test_equivariance(): torch.set_default_dtype(torch.float64) n_edge = 3 n_source = 4 n_target = 2 Rs_in = [(3, 0), (0, 1)] Rs_mid1 = [(5, 0), (1, 1)] Rs_mid2 = [(5, 0), (1, 1), (1, 2)] Rs_out = [(5, 1), (3, 2)] convolution = lambda Rs_in, Rs_out: Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel)) convolution_groups = lambda Rs_in, Rs_out: Convolution( GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups)) groups = 4 mp = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution) mp_groups = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution_groups) features = rs.randn(n_target, Rs_in) r_source = torch.randn(n_source, 3) r_target = torch.randn(n_target, 3) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge,)), torch.randint(n_target, size=(n_edge,)), ]) size = (n_target, n_source) if n_edge == 0: edge_r = torch.zeros(0, 3) else: edge_r = torch.stack([ r_target[j] - r_source[i] for i, j in edge_index.T ]) print(features.shape, edge_index.shape, edge_r.shape, size) out1 = mp(features, edge_index, edge_r, size=size) out1_groups = mp_groups(features, edge_index, edge_r, size=size) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out out2_groups = mp_groups(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out assert (out1 - out2).abs().max() < 1e-10 assert (out1_groups - out2_groups).abs().max() < 1e-10
def test_image_network(): torch.set_default_dtype(torch.float64) Rs = [0, 0, 3] model = ImageS2Network(Rs_in=Rs, mul=4, lmax=6, Rs_out=Rs, size=5, layers=3) image = rs.randn(1, 16, 16, 16, Rs) model(image)
def test_inverse(self): with o3.torch_default_dtype(torch.float64): lmax = 5 res = (50, 75) for normalization in ['component', 'norm']: to = s2grid.ToS2Grid(lmax, res, normalization=normalization) fr = s2grid.FromS2Grid(res, lmax, normalization=normalization) sig = rs.randn(10, [(1, l) for l in range(lmax + 1)]) self.assertLess((fr(to(sig)) - sig).abs().max(), 1e-5) s = to(sig) self.assertLess((to(fr(s)) - s).abs().max(), 1e-5)
def test_image_gated_conv_parity_network(): torch.set_default_dtype(torch.float64) Rs = [(2, 0, 1), (1, 1, -1)] model = ImageGatedConvParityNetwork(Rs_in=Rs, mul=3, Rs_out=Rs, size=5, lmax=6, layers=3) image = rs.randn(1, 16, 16, 16, Rs) model(image)
def test_image_gated_conv_network(): torch.set_default_dtype(torch.float64) Rs = [0, 0, 3] model = ImageGatedConvNetwork(Rs_in=Rs, Rs_hidden=[0, 1, 2, 3], Rs_out=Rs, size=5, lmax=6, layers=3) image = rs.randn(1, 16, 16, 16, Rs) model(image)
def test_normalization(fuzzy_pixels): batch = 3 size = 5 input_size = 15 Rs_in = [(20, 0), (20, 1), (10, 2)] Rs_out = [0, 1, 2] conv = Convolution(Rs_in, Rs_out, size, lmax=2, fuzzy_pixels=fuzzy_pixels) x = rs.randn(batch, input_size, input_size, input_size, Rs_in) y = conv(x) assert y.shape[-1] == rs.dim(Rs_out) assert y.var().log10().abs() < 1.5
def test_custom_weighted_tensor_product2(): torch.set_default_dtype(torch.float64) Rs_in1 = [(2, l) for l in [0, 1, 2]] Rs_in2 = [(3, l) for l in [0, 1, 2]] Rs_out = [(2, l) for l in [0, 1, 2]] instr = [(i1, i2, i3, 'uvw') for i1, (_, l1) in enumerate(Rs_in1) for i2, (_, l2) in enumerate(Rs_in2) for i3, (_, l3) in enumerate(Rs_out) if abs(l1 - l2) <= l3 <= l1 + l2] tp1 = CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr) tp2 = CustomWeightedTensorProduct(Rs_in1, Rs_in2, Rs_out, instr, own_weight=False, _specialized_code=False) x1 = rs.randn(2, Rs_in1) x2 = rs.randn(2, Rs_in2) assert torch.allclose(tp1(x1, x2), tp2(x1, x2, tp1.weight))
def test_tensor_square_equivariance(self): with o3.torch_default_dtype(torch.float64): Rs_in = [(3, 0), (2, 1), (5, 2)] sq = TensorSquare(Rs_in, o3.selection_rule) x = rs.randn(Rs_in) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc) D_out = rs.rep(sq.Rs_out, *abc) y1 = sq(D_in @ x) y2 = D_out @ sq(x) self.assertLess((y1 - y2).abs().max(), 1e-7 * y1.abs().max())
def test_equivariance_s2parity_network(): torch.set_default_dtype(torch.float64) mul = 3 Rs_in = [(mul, l, -1) for l in range(3 + 1)] Rs_out = [(mul, l, 1) for l in range(3 + 1)] net = S2ParityNetwork(Rs_in, mul, lmax=3, Rs_out=Rs_out) abc = o3.rand_angles() D_in = rs.rep(Rs_in, *abc, 1) D_out = rs.rep(Rs_out, *abc, 1) fea = rs.randn(10, Rs_in) x1 = torch.einsum("ij,zj->zi", D_out, net(fea)) x2 = net(torch.einsum("ij,zj->zi", D_in, fea)) assert (x1 - x2).norm() < 1e-3 * x1.norm()
def test_inverse_different_ls(self): with o3.torch_default_dtype(torch.float64): lin = 5 lout = 7 res = (50, 60) for normalization in ['component', 'norm', 'none']: to = s2grid.ToS2Grid(lin, res, normalization=normalization) fr = s2grid.FromS2Grid(res, lout, lmax_in=lin, normalization=normalization) si = rs.randn(10, [(1, l) for l in range(lin + 1)]) so = fr(to(si)) so = so[:, :si.shape[1]] self.assertLess((so - si).abs().max(), 1e-5)
def test_norm_activation(Rs, normalization, dtype): with o3.torch_default_dtype(dtype): m = NormActivation(Rs, swish, normalization=normalization) D = rs.rep(Rs, *o3.rand_angles()) x = rs.randn(2, Rs, normalization=normalization) y1 = m(x) y1 = torch.einsum('ij,zj->zi', D, y1) x2 = torch.einsum('ij,zj->zi', D, x) y2 = m(x2) assert (y1 - y2).abs().max() < { torch.float32: 1e-5, torch.float64: 1e-10 }[dtype]
def _test_normalization(self, f): batch = 3 size = 5 input_size = 15 Rs_in = [(20, 0), (20, 1), (10, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] conv = f(Rs_in, Rs_out, size) x = rs.randn(batch, Rs_in, input_size, input_size, input_size) y = conv(x) self.assertEqual(y.size(1), rs.dim(Rs_out)) y_mean, y_std = y.mean().item(), y.std().item() self.assertAlmostEqual(y_mean, 0, delta=0.3) self.assertAlmostEqual(y_std, 1, delta=0.5)
def test_s2conv_network(): torch.set_default_dtype(torch.float64) lmax = 3 Rs = [(1, l, 1) for l in range(lmax + 1)] model = S2ConvNetwork(Rs, 4, Rs, lmax) features = rs.randn(1, 4, Rs) geometry = torch.randn(1, 4, 3) output = model(features, geometry) angles = o3.rand_angles() D = rs.rep(Rs, *angles, 1) R = -o3.rot(*angles) ein = torch.einsum output2 = ein('ij,zaj->zai', D.T, model(ein('ij,zaj->zai', D, features), ein('ij,zaj->zai', R, geometry))) assert (output - output2).abs().max() < 1e-10 * output.abs().max()
def test_norm(Rs, normalization): m = Norm(Rs, normalization=normalization) x = rs.randn(2, Rs, normalization=normalization) x = m(x) assert x.shape == (2, rs.mul_dim(Rs))
# Define the input and output representations Rs_in = [(1, 0), (2, 1)] # Input = One scalar plus two vectors Rs_out = [(1, 1)] # Output = One single vector # Radial model: R+ -> R^d RadialModel = partial(GaussianRadialModel, max_radius=3.0, number_of_basis=3, h=100, L=1, act=swish) # kernel: composed on a radial part that contains the learned parameters # and an angular part given by the spherical hamonics and the Clebsch-Gordan coefficients K = partial(Kernel, RadialModel=RadialModel) # Create the convolution module conv = Convolution(K(Rs_in, Rs_out)) # Module to compute the norm of each irreducible component norm = Norm(Rs_out) n = 5 # number of input points features = rs.randn(1, n, Rs_in, requires_grad=True) in_geometry = torch.randn(1, n, 3) out_geometry = torch.zeros(1, 1, 3) # One point at the origin out = norm(conv(features, in_geometry, out_geometry)) out.backward() print(out) print(features.grad)
# Radial model: R+ -> R^d RadialModel = partial(GaussianRadialModel, max_radius=3.0, number_of_basis=3, h=100, L=1, act=swish) # kernel: composed on a radial part that contains the learned parameters # and an angular part given by the spherical hamonics and the Clebsch-Gordan coefficients K = partial(Kernel, RadialModel=RadialModel, normalization='norm') # Use the kernel to define a convolution operation C = partial(Convolution, K) # Create the convolution module conv = C(Rs_in, Rs_out) # Module to compute the norm of each irreducible component norm = Norm(Rs_out, normalization='norm') n = 5 # number of input points features = rs.randn(1, n, Rs_in, normalization='norm', requires_grad=True) in_geometry = torch.randn(1, n, 3) out_geometry = torch.zeros(1, 1, 3) # One point at the origin norm(conv(features, in_geometry, out_geometry)).backward() print(features) print(features.grad)