def test1(self): """test gradients of the Kernel""" torch.set_default_dtype(torch.float64) Rs_in = [(1, 0), (1, 1), (1, 0), (1, 2)] Rs_out = [(1, 0), (1, 1), (1, 2), (1, 0)] kernel = Kernel(Rs_in, Rs_out, ConstantRadialModel, partial(o3.selection_rule_in_out_sh, lmax=1)) n_path = 0 for mul_out, l_out, p_out in kernel.Rs_out: for mul_in, l_in, p_in in kernel.Rs_in: l_filters = kernel.selection_rule(l_in, p_in, l_out, p_out) n_path += mul_out * mul_in * len(l_filters) r = torch.randn(2, 3) Y = rsh.spherical_harmonics_xyz(kernel.set_of_l_filters, r) # [l_filter * m_filter, batch] Y = Y.clone().detach().requires_grad_(True) R = torch.randn( 2, n_path, requires_grad=True ) # [batch, l_out * l_in * mul_out * mul_in * l_filter] inputs = (Y, R, kernel.norm_coef, kernel.Rs_in, kernel.Rs_out, kernel.selection_rule, kernel.set_of_l_filters) self.assertTrue(torch.autograd.gradcheck(KernelFn.apply, inputs))
def test1(self): torch.set_default_dtype(torch.float64) Rs_in = [(1, 0), (1, 1), (2, 0), (1, 2)] Rs_out = [(2, 0), (1, 1), (1, 2), (3, 0)] kernel = Kernel(Rs_in, Rs_out, ConstantRadialModel) n_path = 0 for mul_out, l_out, p_out in kernel.Rs_out: for mul_in, l_in, p_in in kernel.Rs_in: l_filters = kernel.get_l_filters(l_in, p_in, l_out, p_out) n_path += mul_out * mul_in * len(l_filters) for rg_Y, rg_R in [(True, True), (True, False), (False, True)]: r = torch.randn(2, 3) radii = r.norm(2, dim=1) # [batch] Y = kernel.sh(kernel.set_of_l_filters, r) # [l_filter * m_filter, batch] Y = Y.clone().detach().requires_grad_(rg_Y) R = torch.randn( 2, n_path, requires_grad=rg_R ) # [batch, l_out * l_in * mul_out * mul_in * l_filter] norm_coef = kernel.norm_coef norm_coef = norm_coef[:, :, (radii == 0).type( torch.long)] # [l_out, l_in, batch] inputs = (Y, R, norm_coef, kernel.Rs_in, kernel.Rs_out, kernel.get_l_filters, kernel.set_of_l_filters) self.assertTrue(torch.autograd.gradcheck(KernelFn.apply, inputs))
def test_compare_forward(self): for normalization in ["norm", "component"]: torch.manual_seed(0) K = Kernel(self.Rs_in, self.Rs_out, RadialModel=ConstantRadialModel, normalization=normalization) new_features = K(self.geometry) torch.manual_seed(0) K2 = Kernel(self.Rs_in, self.Rs_out, RadialModel=ConstantRadialModel, normalization=normalization) check_new_features = K2(self.geometry, custom_backward=True) assert all(torch.all(a == b) for a, b in zip(K.parameters(), K.parameters())), self.msg self.assertTrue(torch.allclose(new_features, check_new_features))
def test_flow(): """ This test checks that information is flowing as expected from target to source. edge_index[0] is source (convolution center) edge_index[1] is target (neighbors) """ edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4], ]) features = torch.tensor([-1., 1., 1., 1., 1.]) features = features.unsqueeze(-1) edge_r = torch.ones(edge_index.shape[-1], 3) Rs = [0] conv = Convolution(Kernel(Rs, Rs, ConstantRadialModel)) conv.kernel.R.weight.data.fill_(1.) # Fix weight to 1. output = conv(features, edge_index, edge_r) torch.allclose(output, torch.tensor([4., 0., 0., 0., 0.]).unsqueeze(-1)) edge_index = torch.LongTensor([[1, 2, 3, 4], [0, 0, 0, 0]]) output = conv(features, edge_index, edge_r) torch.allclose(output, torch.tensor([0., -1., -1., -1., -1.]).unsqueeze(-1))
def test1(self): with torch_default_dtype(torch.float64): Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)] Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)] f = GatedBlock(Rs_out, rescaled_act.Softplus(beta=5), rescaled_act.sigmoid) c = Convolution(Kernel(Rs_in, f.Rs_in, ConstantRadialModel)) abc = torch.randn(3) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in)) geo = torch.randn(1, 5, 3) rx = torch.einsum("ij,zaj->zai", (D_in, x)) rgeo = geo @ o3.rot(*abc).t() y = f(c(x, geo), dim=2) ry = torch.einsum("ij,zaj->zai", (D_out, y)) self.assertLess((f(c(rx, rgeo)) - ry).norm(), 1e-10 * ry.norm())
def __init__(self, Rs_in, Rs_out): super().__init__() # hack: use Kernel to construct the weight matrix def get_l_filters(l_in, l_out): return [0] if l_in == l_out else [] self.kernel = Kernel(Rs_in, Rs_out, ConstantRadialModel, get_l_filters)
def __init__(self, Rs_in, Rs_out, size, **kwargs): super().__init__() R = partial(CosineBasisModel, max_radius=1.0, number_of_basis=(size + 1) // 2, h=50, L=3, act=rescaled_act.relu) self.kernel = Kernel(Rs_in, Rs_out, R, normalization='component') x = torch.linspace(-1, 1, size) self.r = torch.stack(torch.meshgrid(x, x, x), dim=-1) self.kwargs = kwargs
def test_equivariance(Rs_in, Rs_out, n_source, n_target, n_edge): torch.set_default_dtype(torch.float64) mp = Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel)) groups = 4 mp_group = Convolution( GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups)) features = rs.randn(n_target, Rs_in) features2 = rs.randn(n_target, Rs_in * groups) r_source = torch.randn(n_source, 3) r_target = torch.randn(n_target, 3) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge, )), torch.randint(n_target, size=(n_edge, )), ]) size = (n_target, n_source) if n_edge == 0: edge_r = torch.zeros(0, 3) else: edge_r = torch.stack( [r_target[j] - r_source[i] for i, j in edge_index.T]) print(features.shape, edge_index.shape, edge_r.shape, size) out1 = mp(features, edge_index, edge_r, size=size) out1_groups = mp(features2, edge_index, edge_r, size=size, groups=groups) out1_kernel_groups = mp_group(features2, edge_index, edge_r, size=size, groups=groups) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) D_in_groups = rs.rep(Rs_in * groups, *angles) D_out_groups = rs.rep(Rs_out * groups, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out out2_groups = mp(features2 @ D_in_groups.T, edge_index, edge_r @ R.T, size=size, groups=groups) @ D_out_groups out2_kernel_groups = mp_group(features2 @ D_in_groups.T, edge_index, edge_r @ R.T, size=size, groups=groups) @ D_out_groups assert (out1 - out2).abs().max() < 1e-10 assert (out1_groups - out2_groups).abs().max() < 1e-10 assert (out1_kernel_groups - out2_kernel_groups).abs().max() < 1e-10
def test2(self): with torch_default_dtype(torch.float64): mul = 100000 for l_in in range(4): Rs_in = [(mul, l_in)] for l_out in range(4): Rs_out = [(1, l_out)] k = Kernel(Rs_in, Rs_out, ConstantRadialModel, normalization='norm') k = k(torch.randn(1, 3)) self.assertLess(k.mean().item(), 1e-3) self.assertAlmostEqual(k.var().item() * mul, 1 / (2 * l_out + 1), places=1)
def __init__(self, num_radial=30, max_radius=2): super().__init__() sp = rescaled_act.Softplus(beta=5) RadialModel = partial(CosineBasisModel, max_radius=max_radius, number_of_basis=num_radial, h=100, L=2, act=sp) self.conv = Convolution(Kernel([(1, 0)], [(1, 1)], RadialModel))
def test_equivariance(): torch.set_default_dtype(torch.float64) n_edge = 3 n_source = 4 n_target = 2 Rs_in = [(3, 0), (0, 1)] Rs_mid1 = [(5, 0), (1, 1)] Rs_mid2 = [(5, 0), (1, 1), (1, 2)] Rs_out = [(5, 1), (3, 2)] convolution = lambda Rs_in, Rs_out: Convolution(Kernel(Rs_in, Rs_out, ConstantRadialModel)) convolution_groups = lambda Rs_in, Rs_out: Convolution( GroupKernel(Rs_in, Rs_out, partial(Kernel, RadialModel=ConstantRadialModel), groups)) groups = 4 mp = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution) mp_groups = DepthwiseConvolution(Rs_in, Rs_out, Rs_mid1, Rs_mid2, groups, convolution_groups) features = rs.randn(n_target, Rs_in) r_source = torch.randn(n_source, 3) r_target = torch.randn(n_target, 3) edge_index = torch.stack([ torch.randint(n_source, size=(n_edge,)), torch.randint(n_target, size=(n_edge,)), ]) size = (n_target, n_source) if n_edge == 0: edge_r = torch.zeros(0, 3) else: edge_r = torch.stack([ r_target[j] - r_source[i] for i, j in edge_index.T ]) print(features.shape, edge_index.shape, edge_r.shape, size) out1 = mp(features, edge_index, edge_r, size=size) out1_groups = mp_groups(features, edge_index, edge_r, size=size) angles = o3.rand_angles() D_in = rs.rep(Rs_in, *angles) D_out = rs.rep(Rs_out, *angles) R = o3.rot(*angles) out2 = mp(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out out2_groups = mp_groups(features @ D_in.T, edge_index, edge_r @ R.T, size=size) @ D_out assert (out1 - out2).abs().max() < 1e-10 assert (out1_groups - out2_groups).abs().max() < 1e-10
def get_kernel_conv_kernelconv(self, seed, normalization): torch.manual_seed(seed) C = Convolution( Kernel(self.Rs_in, self.Rs_out, ConstantRadialModel, normalization=normalization)) torch.manual_seed(seed) KC = KernelConv(self.Rs_in, self.Rs_out, RadialModel=ConstantRadialModel, normalization=normalization) return C, KC
def layer(Rs1, Rs2): R = partial(GaussianRadialModel, max_radius=max_radius, number_of_basis=radial_basis, h=radial_neurons, L=radial_layers, act=radial_act, min_radius=min_radius) k = Kernel(Rs1, Rs2, R, partial(o3.selection_rule_in_out_sh, lmax=lmax), allow_unused_inputs=True) return Convolution(k)
def test2(self): """Test rotation equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0), (0, 1), (2, 2)] Rs_out = [(2, 0), (2, 1), (2, 2)] k = Kernel(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) abc = torch.randn(3) D_in = o3.direct_sum( * [o3.irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)]) D_out = o3.direct_sum(*[ o3.irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul) ]) W1 = D_out @ k(r) # [i, j] W2 = k(o3.rot(*abc) @ r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
def test4(self): """Test parity equivariance on Kernel.""" with torch_default_dtype(torch.float64): Rs_in = [(2, 0, 1), (2, 1, 1), (2, 2, -1)] Rs_out = [(2, 0, -1), (2, 1, 1), (2, 2, 1)] k = Kernel(Rs_in, Rs_out, ConstantRadialModel) r = torch.randn(3) D_in = o3.direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in Rs_in for _ in range(mul) ]) D_out = o3.direct_sum(*[ p * torch.eye(2 * l + 1) for mul, l, p in Rs_out for _ in range(mul) ]) W1 = D_out @ k(r) # [i, j] W2 = k(-r) @ D_in # [i, j] self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())