def test_sh_parity(): """ (-1)^l Y(x) = Y(-x) """ with o3.torch_default_dtype(torch.float64): for l in range(7 + 1): x = torch.randn(3) Y1 = (-1) ** l * rsh.spherical_harmonics_xyz([l], x) Y2 = rsh.spherical_harmonics_xyz([l], -x) assert (Y1 - Y2).abs().max() < 1e-10 * Y1.abs().max()
def test_sh_cuda_single(): if torch.cuda.is_available(): with o3.torch_default_dtype(torch.float64): for l in range(10 + 1): x = torch.randn(10, 3) x_cuda = x.cuda() Y1 = rsh.spherical_harmonics_xyz([l], x) Y2 = rsh.spherical_harmonics_xyz([l], x_cuda).cpu() assert (Y1 - Y2).abs().max() < 1e-7 else: print("Cuda is not available! test_sh_cuda_single skipped!")
def test_sh_cuda_ordered_partial(): if torch.cuda.is_available(): with o3.torch_default_dtype(torch.float64): l = [0, 2, 5, 7, 10] x = torch.randn(10, 3) x_cuda = x.cuda() Y1 = rsh.spherical_harmonics_xyz(l, x) Y2 = rsh.spherical_harmonics_xyz(l, x_cuda).cpu() assert (Y1 - Y2).abs().max() < 1e-7 else: print("Cuda is not available! test_sh_cuda_ordered_partial skipped!")
def test_sh_cuda_ordered_full(self): if torch.cuda.is_available(): with o3.torch_default_dtype(torch.float64): l = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] x = torch.randn(10, 3) x_cuda = x.cuda() Y1 = rsh.spherical_harmonics_xyz(l, x) Y2 = rsh.spherical_harmonics_xyz(l, x_cuda).cpu() self.assertLess((Y1 - Y2).abs().max(), 1e-7) else: print("Cuda is not available! test_sh_cuda_ordered_full skipped!")
def test_sh_norm(): with o3.torch_default_dtype(torch.float64): l_filter = list(range(15)) Ys = [rsh.spherical_harmonics_xyz([l], torch.randn(10, 3)) for l in l_filter] s = torch.stack([Y.pow(2).mean(-1) for Y in Ys]) d = s - 1 / (4 * math.pi) assert d.pow(2).mean().sqrt() < 1e-10 n = rsh.spherical_harmonics_xyz(3, torch.randn(3), 'norm').norm() assert abs(n - 1) < 1e-10 n = rsh.spherical_harmonics_xyz(3, torch.randn(3), 'component').norm() assert abs(n - 7**0.5) < 1e-10
def test_rsh_backwardable(): lmax = 10 Rs = [(1, l) for l in range(lmax + 1)] xyz = torch.tensor([0., 0., 1.], requires_grad=True) sph = rsh.spherical_harmonics_xyz(Rs, xyz, eps=0) sph.norm(2, -1).mean().backward() assert torch.allclose( torch.isnan(xyz.grad).nonzero(), torch.LongTensor([[0], [1], [2]])) xyz = torch.tensor([0., 0., 1.], requires_grad=True) sph = rsh.spherical_harmonics_xyz(Rs, xyz, eps=1e-10) sph.norm(2, -1).mean().backward() assert torch.allclose( torch.isnan(xyz.grad).nonzero(), torch.LongTensor([[]]))
def spherical_harmonics_dirac(vectors, lmax): """ approximation of a signal that is 0 everywhere except on the angle (alpha, beta) where it is one. the higher is lmax the better is the approximation """ return 4 * math.pi / (lmax + 1)**2 * rsh.spherical_harmonics_xyz( list(range(lmax + 1)), vectors)
def forward(self, r, r_eps=0, custom_backward=False): """ :param r: tensor [..., 3] :param custom_backward: call KernelFn rather than using automatic differentiation :return: tensor [..., l_out * mul_out * m_out, l_in * mul_in * m_in] """ *size, xyz = r.size() assert xyz == 3 r = r.reshape(-1, 3) radii = r.norm(2, dim=1) # [batch] # (1) Case r > 0 # precompute all needed spherical harmonics Y = rsh.spherical_harmonics_xyz(self.set_of_l_filters, r[radii > r_eps]) # [batch, l_filter * m_filter] # use the radial model to fix all the degrees of freedom # note: for the normalization we assume that the variance of R[i] is one R = self.R(radii[radii > r_eps]) # [batch, l_out * l_in * mul_out * mul_in * l_filter] if custom_backward: kernel1 = KernelFn.apply(Y, R, self.norm_coef, self.Rs_in, self.Rs_out, self.selection_rule, self.set_of_l_filters) else: kernel1 = kernel_fn_forward(Y, R, self.norm_coef, self.Rs_in, self.Rs_out, self.selection_rule, self.set_of_l_filters) # (2) Case r = 0 kernel2 = self.linear() kernel = r.new_zeros(len(r), *kernel2.shape) kernel[radii > r_eps] = kernel1 kernel[radii <= r_eps] = kernel2 return kernel.reshape(*size, *kernel2.shape)
def __init__(self, Rs_in, Rs_out, RadialModel, r, r_eps=0, selection_rule=o3.selection_rule_in_out_sh, normalization='component'): """ :param Rs_in: list of triplet (multiplicity, representation order, parity) :param Rs_out: list of triplet (multiplicity, representation order, parity) :param RadialModel: Class(d), trainable model: R -> R^d :param tensor r: [..., 3] :param float r_eps: distance considered as zero :param selection_rule: function of signature (l_in, p_in, l_out, p_out) -> [l_filter] :param sh: spherical harmonics function of signature ([l_filter], xyz[..., 3]) -> Y[m, ...] :param normalization: either 'norm' or 'component' representation order = nonnegative integer parity = 0 (no parity), 1 (even), -1 (odd) """ super().__init__() self.Rs_in = rs.convention(Rs_in) self.Rs_out = rs.convention(Rs_out) self.check_input_output(selection_rule) *self.size, xyz = r.size() assert xyz == 3 r = r.reshape(-1, 3) # [batch, space] self.register_buffer('radii', r.norm(2, dim=1)) # [batch] self.r_eps = r_eps self.tp = rs.TensorProduct(self.Rs_in, selection_rule, self.Rs_out, normalization, sorted=True) self.Rs_f = self.tp.Rs_in2 Y = rsh.spherical_harmonics_xyz( [(1, l, p) for _, l, p in self.Rs_f], r[self.radii > self.r_eps]) # [batch, l_filter * m_filter] # Normalize the spherical harmonics if normalization == 'component': Y.mul_(math.sqrt(4 * math.pi)) if normalization == 'norm': diag = math.sqrt(4 * math.pi) * torch.cat([ torch.ones(2 * l + 1) / math.sqrt(2 * l + 1) for _, l, _ in self.Rs_f ]) Y.mul_(diag) self.register_buffer('Y', Y) self.R = RadialModel(rs.mul_dim(self.Rs_f)) if (self.radii <= self.r_eps).any(): self.linear = KernelLinear(self.Rs_in, self.Rs_out) else: self.linear = None
def forward(self, features, edge_index, edge_r, sh=None, size=None, n_norm=1): """ :param features: Tensor of shape [n_target, dim(Rs_in)] :param edge_index: LongTensor of shape [2, num_messages] edge_index[0] = sources (convolution centers) edge_index[1] = targets (neighbors) :param edge_r: Tensor of shape [num_messages, 3] edge_r = position_target - position_source :param sh: Tensor of shape [num_messages, dim(Rs_sh)] :param size: (n_target, n_source) or None :param n_norm: typical number of targets per source :return: Tensor of shape [n_source, dim(Rs_out)] """ if sh is None: sh = rsh.spherical_harmonics_xyz( self.Rs_sh, edge_r, self.normalization) # [num_messages, dim(Rs_sh)] sh = sh / n_norm**0.5 w = self.rm(edge_r.norm(dim=1)) # [num_messages, nweight] return self.propagate(edge_index, size=size, x=features, sh=sh, w=w)
def test1(self): """test gradients of the Kernel""" torch.set_default_dtype(torch.float64) Rs_in = [(1, 0), (1, 1), (1, 0), (1, 2)] Rs_out = [(1, 0), (1, 1), (1, 2), (1, 0)] kernel = Kernel(Rs_in, Rs_out, ConstantRadialModel, partial(o3.selection_rule_in_out_sh, lmax=1)) n_path = 0 for mul_out, l_out, p_out in kernel.Rs_out: for mul_in, l_in, p_in in kernel.Rs_in: l_filters = kernel.selection_rule(l_in, p_in, l_out, p_out) n_path += mul_out * mul_in * len(l_filters) r = torch.randn(2, 3) Y = rsh.spherical_harmonics_xyz(kernel.set_of_l_filters, r) # [l_filter * m_filter, batch] Y = Y.clone().detach().requires_grad_(True) R = torch.randn( 2, n_path, requires_grad=True ) # [batch, l_out * l_in * mul_out * mul_in * l_filter] inputs = (Y, R, kernel.norm_coef, kernel.Rs_in, kernel.Rs_out, kernel.selection_rule, kernel.set_of_l_filters) self.assertTrue(torch.autograd.gradcheck(KernelFn.apply, inputs))
def signal_xyz(self, r): """ Evaluate the signal on the sphere """ sh = rsh.spherical_harmonics_xyz(list(range(self.lmax + 1)), r) dim = (self.lmax + 1)**2 output = torch.einsum('ai,zi->za', sh.reshape(-1, dim), self.signal.reshape(-1, dim)) return output.reshape((*self.signal.shape[:-1], *r.shape[:-1]))
def test_wigner_3j_sh_norm(): with o3.torch_default_dtype(torch.float64): for l_out in range(3 + 1): for l_in in range(l_out, 4 + 1): for l_f in range(abs(l_out - l_in), l_out + l_in + 1): Q = o3.wigner_3j(l_out, l_in, l_f) Y = rsh.spherical_harmonics_xyz([l_f], torch.randn(3)) QY = math.sqrt(4 * math.pi) * Q @ Y assert abs(QY.norm() - 1) < 1e-10
def test_sh_norm(self): with o3.torch_default_dtype(torch.float64): l_filter = list(range(15)) Ys = [ rsh.spherical_harmonics_xyz([l], torch.randn(10, 3)) for l in l_filter ] s = torch.stack([Y.pow(2).mean(-1) for Y in Ys]) d = s - 1 / (4 * math.pi) self.assertLess(d.pow(2).mean().sqrt(), 1e-10)
def forward(f, shapes, labels, lmax, device): r_max = 1.1 x = torch.ones(4, 1) batch = Batch.from_data_list([DataNeighbors(x, shape, r_max, y=label, self_interaction=False) for shape, label in zip(shapes, labels)]) batch = batch.to(device) sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), batch.edge_attr, 'component') out = f(batch.x, batch.edge_index, batch.edge_attr, sh=sh, n_norm=3) out = scatter_add(out, batch.batch, dim=0) out = torch.tanh(out) return out
def __init__(self, Rs_in, Rs_out, size, steps=(1, 1, 1), lmax=None, fuzzy_pixels=False, allow_unused_inputs=False, allow_zero_outputs=False, **kwargs): super().__init__() r = torch.linspace(-1, 1, size) x = r * steps[0] / min(steps) x = x[x.abs() <= 1] y = r * steps[1] / min(steps) y = y[y.abs() <= 1] z = r * steps[2] / min(steps) z = z[z.abs() <= 1] r = torch.stack(torch.meshgrid(x, y, z), dim=-1) # [x, y, z, R^3] R = partial(CosineBasisModel, max_radius=1.0, number_of_basis=(size + 1) // 2, h=50, L=3, act=swish) self.kernel = FrozenKernel( Rs_in, Rs_out, R, r, selection_rule=partial(o3.selection_rule_in_out_sh, lmax=lmax), normalization='component', allow_unused_inputs=allow_unused_inputs, allow_zero_outputs=allow_zero_outputs, ) self.kwargs = kwargs if fuzzy_pixels: # re-evaluate spherical harmonics by adding randomness r = r.reshape(-1, 3) r = r[self.kernel.radii > 0] rand = torch.rand(20**3, *r.shape).mul(2).sub(1) # [-1, 1] rand.mul_(1 / (size - 1)) rand[:, :, 0].mul_(steps[0] / min(steps)) rand[:, :, 1].mul_(steps[1] / min(steps)) rand[:, :, 2].mul_(steps[2] / min(steps)) r = rand + r.unsqueeze(0) # [rand, batch, R^3] Y = rsh.spherical_harmonics_xyz([(1, l, p) for _, l, p in self.kernel.Rs_f], r) # Y # [rand, batch, l_filter * m_filter] Y.mul_(math.sqrt(4 * math.pi)) # normalization='component' self.kernel.Y.copy_(Y.mean(0))
def forward(self, features, difference_geometry, mask, y=None, radii=None, custom_backward=True): """ :param features: tensor [batch, b, l_in * mul_in * m_in] :param difference_geometry: tensor [batch, a, b, xyz] :param mask: tensor [batch, a] (In order to zero contributions from padded atoms.) :param y: Optional precomputed spherical harmonics. :param radii: Optional precomputed normed geometry. :param custom_backward: call KernelConvFn rather than using automatic differentiation, (default True) :return: tensor [batch, a, l_out * mul_out * m_out] """ _batch, _a, _b, xyz = difference_geometry.size() assert xyz == 3 if radii is None: radii = difference_geometry.norm(2, dim=-1) # [batch, a, b] # precompute all needed spherical harmonics if y is None: y = rsh.spherical_harmonics_xyz( self.set_of_l_filters, difference_geometry) # [batch, a, b, l_filter * m_filter] y[radii == 0] = 0 # use the radial model to fix all the degrees of freedom # note: for the normalization we assume that the variance of R[i] is one r = self.R(radii.flatten()).reshape( *radii.shape, -1) # [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter] r = r.clone() r[radii == 0] = 0 if custom_backward: output = KernelConvFn.apply(features, y, r, self.norm_coef, self.Rs_in, self.Rs_out, self.selection_rule, self.set_of_l_filters) else: output = kernel_conv_fn_forward(features, y, r, self.norm_coef, self.Rs_in, self.Rs_out, self.selection_rule, self.set_of_l_filters) # Case r > 0 if radii.shape[1] == radii.shape[2]: output += torch.einsum('ij,zaj->zai', self.linear(), features) return output * mask.unsqueeze(-1)
def forward(f, shapes, Rs_sh, device): r_max = 1.1 x = torch.ones(4, 1) batch = Batch.from_data_list([DataNeighbors(x, shape, r_max, self_interaction=False) for shape in shapes]) batch = batch.to(device) # Pre-compute the spherical harmonics and re-use them in each convolution sh = rsh.spherical_harmonics_xyz(Rs_sh, batch.edge_attr, 'component') out = f(batch.x, batch.edge_index, batch.edge_attr, sh=sh, n_norm=3) out = scatter_add(out, batch.batch, dim=0) out = torch.tanh(out) return out
def test_irr_repr_wigner_3j(self): """Test irr_repr and wigner_3j equivariance.""" with torch_default_dtype(torch.float64): l_in = 3 l_out = 2 for l_f in range(abs(l_in - l_out), l_in + l_out + 1): r = torch.randn(100, 3) Q = o3.wigner_3j(l_out, l_in, l_f) abc = torch.randn(3) D_in = o3.irr_repr(l_in, *abc) D_out = o3.irr_repr(l_out, *abc) Y = rsh.spherical_harmonics_xyz([l_f], r @ o3.rot(*abc).t()) W = torch.einsum("ijk,zk->zij", (Q, Y)) W1 = torch.einsum("zij,jk->zik", (W, D_in)) Y = rsh.spherical_harmonics_xyz([l_f], r) W = torch.einsum("ijk,zk->zij", (Q, Y)) W2 = torch.einsum("ij,zjk->zik", (D_out, W)) self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
def adjusted_projection(vectors, lmax): """ :param vectors: tensor of shape [..., xyz] :return: tensor of shape [l * m] """ vectors = vectors.reshape(-1, 3) radii = vectors.norm(2, -1) # [batch] vectors = vectors[radii > 0] # [batch, 3] coeff = projection(vectors, lmax) # [batch, l * m] A = torch.einsum( "ai,bi->ab", rsh.spherical_harmonics_xyz(list(range(lmax + 1)), vectors), coeff) coeff *= torch.lstsq(radii, A).solution.reshape(-1).unsqueeze(-1) return coeff.sum(0)
def forward(self, r, r_eps=0, **_kwargs): """ :param r: tensor [..., 3] :return: tensor [..., l_out * mul_out * m_out, l_in * mul_in * m_in] """ *size, xyz = r.size() assert xyz == 3 r = r.reshape(-1, 3) radii = r.norm(2, dim=1) # [batch] # (1) Case r > 0 # precompute all needed spherical harmonics Y = rsh.spherical_harmonics_xyz( self.Ls, r[radii > r_eps]) # [batch, l_filter * m_filter] # Normalize the spherical harmonics if self.normalization == 'component': Y.mul_(math.sqrt(4 * math.pi)) if self.normalization == 'norm': diag = math.sqrt(4 * math.pi) * torch.cat([ torch.ones(2 * l + 1) / math.sqrt(2 * l + 1) for _, l, _ in self.Rs_f ]) Y.mul_(diag) # use the radial model to fix all the degrees of freedom # note: for the normalization we assume that the variance of R[i] is one R = self.R(radii[radii > r_eps] ) # [batch, l_out * l_in * mul_out * mul_in * l_filter] RY = rsh.mul_radial_angular(self.Rs_f, R, Y) if Y.shape[0] == 0: kernel1 = torch.zeros(0, rs.dim(self.Rs_out), rs.dim(self.Rs_in)) else: kernel1 = self.tp.right(RY) # (2) Case r = 0 kernel2 = self.linear() kernel = r.new_zeros(len(r), *kernel2.shape) kernel[radii > r_eps] = kernel1 kernel[radii <= r_eps] = kernel2 return kernel.reshape(*size, *kernel2.shape)
def forward(self, features, edge_index, edge_r, sh=None, size=None, n_norm=1): # features = [num_atoms, dim(Rs_in)] if sh is None: sh = rsh.spherical_harmonics_xyz(self.Rs_sh, edge_r, "component") # [num_messages, dim(Rs_sh)] sh = sh / n_norm**0.5 w = self.rm(edge_r.norm(dim=1)) # [num_messages, nweight] self_interation = self.lin1(features) features = self.propagate(edge_index, size=size, x=features, sh=sh, w=w) features = self.lin2(features) has_self_interaction = torch.cat([ torch.ones(mul * (2 * l + 1)) if any(l_in == l and p_in == p for _, l_in, p_in in self.Rs_in) else torch.zeros(mul * (2 * l + 1)) for mul, l, p in self.Rs_out ]) return 0.5**0.5 * self_interation + (1 + (0.5**0.5 - 1) * has_self_interaction) * features
def test_sh_closure(): """ integral of Ylm * Yjn = delta_lj delta_mn integral of 1 over the unit sphere = 4 pi """ with o3.torch_default_dtype(torch.float64): x = torch.randn(300000, 3) Ys = [rsh.spherical_harmonics_xyz([l], x) for l in range(0, 3 + 1)] for l1, Y1 in enumerate(Ys): for l2, Y2 in enumerate(Ys): m = (Y1.reshape(-1, 2 * l1 + 1, 1) * Y2.reshape(-1, 1, 2 * l2 + 1)).mean(0) * 4 * math.pi if l1 == l2: i = torch.eye(2 * l1 + 1) assert (m - i).abs().max() < 0.01 else: assert m.abs().max() < 0.01
def adjusted_projection(vectors, lmax): """ :param vectors: tensor of shape [..., xyz] :return: tensor of shape [l * m] """ vectors = vectors.reshape(-1, 3) radii = vectors.norm(2, -1) # [batch] vectors = vectors[radii > 0] # [batch, 3] coeff = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), vectors) # [batch, l * m] A = torch.einsum("ai,bi->ab", coeff, coeff) # Y(v_a) . Y(v_b) solution_b = radii_a solution = torch.lstsq(radii, A).solution.reshape(-1) # [b] assert (radii - A @ solution).abs().max() < 1e-5 * radii.abs().max() return solution @ coeff
def from_geometry(cls, vectors, lmax, p=0, adjusted=True): """ :param vectors: tensor of vectors (p=-1) or pseudovectors (p=1) of shape [..., 3=xyz] """ if adjusted: signal = adjusted_projection(vectors, lmax) else: vectors = vectors.reshape(-1, 3) r = vectors.norm(dim=1) sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), vectors) # 0.5 * sum_a ( Y(v_a) . sum_b r_b Y(v_b) s - r_a )^2 A = torch.einsum('ai,b,bi->a', sh, r, sh) # 0.5 * sum_a ( A_a s - r_a )^2 # sum_a A_a^2 s = sum_a A_a r_a s = torch.dot(A, r) / A.norm().pow(2) signal = s * torch.einsum('a,ai->i', r, sh) return cls(signal, p_val=1, p_arg=p)
def forward(self, features, edge_index, edge_r, size=None, n_norm=1, custom_backward=False): """ :param features: Tensor of shape [n_target, dim(Rs_in)] :param edge_index: LongTensor of shape [2, num_edges] ~ [a, b] edge_index[0] = sources (convolution centers) edge_index[1] = targets (neighbors) :param edge_r: Tensor of shape [num_edges, 3] edge_r = position_target - position_source :param size: n_points or None :param n_norm: typical number of targets per source :return: Tensor of shape [n_points, dim(Rs_out)] """ assert edge_r.shape[1] == 3 radii = edge_r.norm(2, dim=-1) # precompute all needed spherical harmonics y = rsh.spherical_harmonics_xyz(self.set_of_l_filters, edge_r) # [batch, a, b, l_filter * m_filter] y[radii == 0] = 0 # use the radial model to fix all the degrees of freedom # note: for the normalization we assume that the variance of R[i] is one r = self.R(radii.flatten()).reshape(*radii.shape, -1) # [*_, n_edges, l_out * l_in * mul_out * mul_in * l_filter] r = r.clone() r[radii == 0] = 0 if custom_backward: assert False, "Custom backward for sparse kernel: not coded yet!" #output = KernelConvFn.apply( # features, edge_index, y, r, self.norm_coef, self.Rs_in, self.Rs_out, self.selection_rule, self.set_of_l_filters #) else: output = kernel_conv_fn_forward( features, edge_index, y, r, self.norm_coef, self.Rs_in, self.Rs_out, self.selection_rule, self.set_of_l_filters ) output.div_(n_norm ** 0.5) # Case r > 0 #if radii.shape[1] == radii.shape[2]: output += torch.einsum('ij,aj->ai', self.linear(), features) return output
def forward(self, features, edge_index, edge_r, sh=None, size=None, n_norm=1): if sh is None: sh = rsh.spherical_harmonics_xyz( self.Rs_sh, edge_r, "component") # [num_messages, dim(Rs_sh)] sh = sh / n_norm**0.5 w = self.rm(edge_r.norm(dim=1)) # [num_messages, nweight] features = self.propagate(edge_index, size=size, x=features, sh=sh, w=w) features = self.lin(features) return features
def test1(self): if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') torch.set_default_dtype(torch.float64) Rs_in = [(1, 0), (1, 1), (1, 2), (1, 0)] Rs_out = [(1, 0), (1, 1), (2, 0), (1, 2)] KC = KernelConv(Rs_in, Rs_out, ConstantRadialModel, partial(o3.selection_rule_in_out_sh, lmax=1)).to(device) n_path = 0 for mul_out, l_out, p_out in KC.Rs_out: for mul_in, l_in, p_in in KC.Rs_in: l_filters = KC.selection_rule(l_in, p_in, l_out, p_out) n_path += mul_out * mul_in * len(l_filters) batch = 1 atoms = 3 F = torch.randn(batch, atoms, dim(Rs_in), requires_grad=True).to(device) geo = torch.randn(batch, atoms, 3) r = (geo.unsqueeze(1) - geo.unsqueeze(2)).to(device) Y = rsh.spherical_harmonics_xyz( KC.set_of_l_filters, r) # [batch, a, b, l_filter * m_filter] Y[r.norm(2, dim=-1) == 0] = 0 Y = Y.clone().detach().requires_grad_(True).to(device) R = torch.randn(batch, atoms, atoms, n_path, requires_grad=True).to( device ) # [batch, a, b, l_out * l_in * mul_out * mul_in * l_filter] inputs = (F, Y, R, KC.norm_coef, KC.Rs_in, KC.Rs_out, KC.selection_rule, KC.set_of_l_filters) self.assertTrue(torch.autograd.gradcheck(KernelConvFn.apply, inputs))
def main(): torch.set_default_dtype(torch.float64) device = 'cuda' if torch.cuda.is_available() else 'cpu' print(device) x = torch.ones(4, 1) Rs_in = [(1, 0, 1)] r_max = 1.1 tetris, labels = get_dataset() tetris_dataset = [ dh.DataNeighbors(x, shape, r_max, y=label) for shape, label in zip(tetris, labels) ] Rs_out = [(1, 0, -1), (6, 0, 1)] lmax = 3 f = MLNetwork(Rs_in, Rs_out, Convolution, partial(make_gated_block, mul=16, lmax=lmax), 2) f = f.to(device) batch = Batch.from_data_list(tetris_dataset) batch = batch.to(device) sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), batch.edge_attr, 'component') optimizer = torch.optim.Adam(f.parameters(), lr=3e-3) wall = time.perf_counter() for step in range(100): out = f(batch.x, batch.edge_index, batch.edge_attr, sh=sh, n_norm=3) out = scatter_add(out, batch.batch, dim=0) out = torch.tanh(out) acc = out.cpu().round().eq(labels).double().mean().item() r_tetris_dataset = [ dh.DataNeighbors(x, shape, r_max, y=label) for shape, label in zip(*get_dataset()) ] r_batch = Batch.from_data_list(r_tetris_dataset) r_batch = r_batch.to(device) r_sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), r_batch.edge_attr, 'component') with torch.no_grad(): r_out = f(r_batch.x, r_batch.edge_index, r_batch.edge_attr, sh=r_sh, n_norm=3) r_out = scatter_add(r_out, r_batch.batch, dim=0) r_out = torch.tanh(r_out) loss = (out - labels).pow(2).mean() optimizer.zero_grad() loss.backward() optimizer.step() print( "wall={:.1f} step={} loss={:.2e} accuracy={:.2f} equivariance error={:.1e}" .format(time.perf_counter() - wall, step, loss.item(), acc, (out - r_out).pow(2).mean().sqrt().item())) print(labels.numpy().round(1)) print(out.detach().numpy().round(1))