def test_weird_irreps(): # string input o3.spherical_harmonics("0e + 1o", torch.randn(1, 3), False) # Weird multipliciteis irreps = o3.Irreps("1x0e + 4x1o + 3x2e") out = o3.spherical_harmonics(irreps, torch.randn(7, 3), True) assert out.shape[-1] == irreps.dim # Bad parity with pytest.raises(ValueError): # L = 1 shouldn't be even for a vector input o3.SphericalHarmonics( irreps_out="1x0e + 4x1e + 3x2e", normalize=True, normalization='integral', irreps_in="1o", ) # Good parity but psuedovector input _ = o3.SphericalHarmonics(irreps_in="1e", irreps_out="1x0e + 4x1e + 3x2e", normalize=True) # Invalid input with pytest.raises(ValueError): _ = o3.SphericalHarmonics( irreps_in="1e + 3o", # invalid irreps_out="1x0e + 4x1e + 3x2e", normalize=True)
def test_parity(float_tolerance, l): r""" (-1)^l Y(x) = Y(-x) """ x = torch.randn(3) Y1 = (-1)**l * o3.spherical_harmonics(l, x, False) Y2 = o3.spherical_harmonics(l, -x, False) assert (Y1 - Y2).abs().max() < float_tolerance
def test_equivariance(float_tolerance): lmax = 5 irreps = o3.Irreps.spherical_harmonics(lmax) x = torch.randn(2, 3) abc = o3.rand_angles() y1 = o3.spherical_harmonics(irreps, x @ o3.angles_to_matrix(*abc).T, False) y2 = o3.spherical_harmonics(irreps, x, False) @ irreps.D_from_angles(*abc).T assert (y1 - y2).abs().max() < 10 * float_tolerance
def forward(self, data) -> torch.Tensor: num_neighbors = 2 # typical number of neighbors num_nodes = 4 # typical number of nodes edge_src, edge_dst = radius_graph( x=data.pos, r=1.1, batch=data.batch) # tensors of indices representing the graph edge_vec = data.pos[edge_src] - data.pos[edge_dst] edge_sh = o3.spherical_harmonics( l=self.irreps_sh, x=edge_vec, normalize= False, # here we don't normalize otherwise it would not be a polynomial normalization='component') # For each node, the initial features are the sum of the spherical harmonics of the neighbors node_features = scatter(edge_sh, edge_dst, dim=0).div(num_neighbors**0.5) # For each edge, tensor product the features on the source node with the spherical harmonics edge_features = self.tp1(node_features[edge_src], edge_sh) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) edge_features = self.tp2(node_features[edge_src], edge_sh) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) # For each graph, all the node's features are summed return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5)
def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: batch, node_inputs, edge_src, edge_dst, edge_vec = self.preprocess(data) del data edge_attr = o3.spherical_harmonics(range(self.lmax + 1), edge_vec, True, normalization='component') # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.number_of_basis, basis='smooth_finite', # the smooth_finite basis with cutoff = True goes to zero at max_radius cutoff=True, # no need for an additional smooth cutoff ).mul(self.number_of_basis**0.5) # Node attributes are not used here node_attr = node_inputs.new_ones(node_inputs.shape[0], 1) node_outputs = self.mp(node_inputs, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedding) if self.pool_nodes: return scatter(node_outputs, batch, int(batch.max()) + 1).div(self.num_nodes**0.5) else: return node_outputs
def forward(self, data) -> torch.Tensor: num_neighbors = 3 # typical number of neighbors num_nodes = 4 # typical number of nodes num_z = self.num_z # number of atom types # graph edge_src, edge_dst = radius_graph(data.pos, 10.0, data.batch) # spherical harmonics edge_vec = data.pos[edge_src] - data.pos[edge_dst] edge_sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, normalize=False, normalization='component') # edge types edge_zz = num_z * data.z[edge_src] + data.z[edge_dst] # from 0 to num_z^2 - 1 edge_zz = torch.nn.functional.one_hot(edge_zz, num_z**2).mul(num_z) edge_zz = edge_zz.to(edge_sh.dtype) # edge attributes edge_attr = self.mul(edge_zz, edge_sh) # For each node, the initial features are the sum of the spherical harmonics of the neighbors node_features = scatter(edge_sh, edge_dst, dim=0).div(num_neighbors**0.5) # For each edge, tensor product the features on the source node with the spherical harmonics edge_features = self.tp1(node_features[edge_src], edge_attr) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) edge_features = self.tp2(node_features[edge_src], edge_attr) node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5) # For each graph, all the node's features are summed return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5)
def signal_xyz(self, signal, r): r"""Evaluate the signal on given points on the sphere .. math:: f(\vec x / \|\vec x\|) Parameters ---------- signal : `torch.Tensor` tensor of shape ``(*A, self.dim)`` r : `torch.Tensor` tensor of shape ``(*B, 3)`` Returns ------- `torch.Tensor` tensor of shape ``(*A, *B)`` Examples -------- >>> s = SphericalTensor(3, 1, -1) >>> s.signal_xyz(s.randn(2, 1, 3, -1), torch.randn(2, 4, 3)).shape torch.Size([2, 1, 3, 2, 4]) """ sh = o3.spherical_harmonics(self, r, normalize=True) dim = (self.lmax + 1)**2 output = torch.einsum('bi,ai->ab', sh.reshape(-1, dim), signal.reshape(-1, dim)) return output.reshape(signal.shape[:-1] + r.shape[:-1])
def test_zeros(): assert torch.allclose( o3.spherical_harmonics([0, 1], torch.zeros(1, 3), False, normalization='norm'), torch.tensor([[1, 0, 0, 0.0]]))
def test(): from torch_cluster import radius from e3nn.math import soft_one_hot_linspace conv = Convolution( irreps_node_input='0e + 1e', irreps_node_output='0e + 1e', irreps_node_attr_input='2x0e', irreps_node_attr_output='3x0e', irreps_edge_attr='0e + 1e', num_edge_scalar_attr=4, radial_layers=1, radial_neurons=50, num_neighbors=3.0, ) pos_in = torch.randn(5, 3) pos_out = torch.randn(2, 3) node_input = torch.randn(5, 4) node_attr_input = torch.randn(5, 2) node_attr_output = torch.randn(2, 3) edge_src, edge_dst = radius(pos_out, pos_in, r=2.0) edge_vec = pos_in[edge_src] - pos_out[edge_dst] edge_attr = o3.spherical_harmonics([0, 1], edge_vec, True) edge_scalar_attr = soft_one_hot_linspace(x=edge_vec.norm(dim=1), start=0.0, end=2.0, number=4, basis='smooth_finite', cutoff=True) conv(node_input, node_attr_input, node_attr_output, edge_src, edge_dst, edge_attr, edge_scalar_attr)
def test_sh_same(float_tolerance): for l in range(4 + 1): x = torch.randn(10, 3) a, b = o3.xyz_to_angles(x) y1 = o3.spherical_harmonics(l, x, True) y2 = o3.spherical_harmonics_alpha_beta(l, a, b) assert (y1 - y2).abs().max() < float_tolerance
def test_module(normalization, normalize): l = o3.Irreps("0e + 1o + 3o") sp = o3.SphericalHarmonics(l, normalize, normalization) sp_jit = assert_auto_jitable(sp) xyz = torch.randn(11, 3) assert torch.allclose( sp_jit(xyz), o3.spherical_harmonics(l, xyz, normalize, normalization)) assert_equivariant(sp)
def forward(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor: """evaluate the network Parameters ---------- data : `torch_geometric.data.Data` or dict data object containing - ``pos`` the position of the nodes (atoms) - ``x`` the input features of the nodes, optional - ``z`` the attributes of the nodes, for instance the atom type, optional - ``batch`` the graph to which the node belong, optional """ if 'batch' in data: batch = data['batch'] else: batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long) edge_index = radius_graph(data['pos'], self.max_radius, batch) edge_src = edge_index[0] edge_dst = edge_index[1] edge_vec = data['pos'][edge_src] - data['pos'][edge_dst] edge_sh = o3.spherical_harmonics(self.irreps_edge_attr, edge_vec, True, normalization='component') edge_length = edge_vec.norm(dim=1) edge_length_embedded = soft_one_hot_linspace( x=edge_length, start=0.0, end=self.max_radius, number=self.number_of_basis, basis='gaussian', cutoff=False).mul(self.number_of_basis**0.5) edge_attr = smooth_cutoff( edge_length / self.max_radius)[:, None] * edge_sh if self.input_has_node_in and 'x' in data: assert self.irreps_in is not None x = data['x'] else: assert self.irreps_in is None x = data['pos'].new_ones((data['pos'].shape[0], 1)) if self.input_has_node_attr and 'z' in data: z = data['z'] else: assert self.irreps_node_attr == o3.Irreps("0e") z = data['pos'].new_ones((data['pos'].shape[0], 1)) for lay in self.layers: x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded) if self.reduce_output: return scatter(x, batch, dim=0).div(self.num_nodes**0.5) else: return x
def test_normalization(float_tolerance, l): n = o3.spherical_harmonics(l, torch.randn(3), normalize=True, normalization='integral').pow(2).mean() assert abs(n - 1 / (4 * math.pi)) < float_tolerance n = o3.spherical_harmonics(l, torch.randn(3), normalize=True, normalization='norm').norm() assert abs(n - 1) < float_tolerance n = o3.spherical_harmonics(l, torch.randn(3), normalize=True, normalization='component').pow(2).mean() assert abs(n - 1) < float_tolerance
def sum_of_diracs(self, positions: torch.Tensor, values: torch.Tensor) -> torch.Tensor: r"""Sum (almost-) dirac deltas .. math:: f(x) = \sum_i v_i \delta^L(\vec r_i) where :math:`\delta^L` is the apporximation of a dirac delta. Parameters ---------- positions : `torch.Tensor` :math:`\vec r_i` tensor of shape ``(..., N, 3)`` values : `torch.Tensor` :math:`v_i` tensor of shape ``(..., N)`` Returns ------- `torch.Tensor` tensor of shape ``(..., self.dim)`` Examples -------- >>> s = SphericalTensor(7, 1, -1) >>> pos = torch.tensor([ ... [1.0, 0.0, 0.0], ... [0.0, 1.0, 0.0], ... ]) >>> val = torch.tensor([ ... -1.0, ... 1.0, ... ]) >>> x = s.sum_of_diracs(pos, val) >>> s.signal_xyz(x, torch.eye(3)).mul(10.0).round() tensor([-10., 10., -0.]) >>> s.sum_of_diracs(torch.empty(1, 0, 2, 3), torch.empty(2, 0, 1)).shape torch.Size([2, 0, 64]) >>> s.sum_of_diracs(torch.randn(1, 3, 2, 3), torch.randn(2, 1, 1)).shape torch.Size([2, 3, 64]) """ positions, values = torch.broadcast_tensors(positions, values[..., None]) values = values[..., 0] if positions.numel() == 0: return torch.zeros(values.shape[:-1] + (self.dim, )) y = o3.spherical_harmonics(self, positions, True) # [..., N, dim] v = values[..., None] return 4 * pi / (self.lmax + 1)**2 * (y * v).sum(-2)
def forward(self, data) -> torch.Tensor: node_atom = data['z'] node_pos = data['pos'] batch = data['batch'] # The graph edge_src, edge_dst = radius_graph(node_pos, r=self.max_radius, batch=batch, max_num_neighbors=1000) # Edge attributes edge_vec = node_pos[edge_src] - node_pos[edge_dst] edge_sh = o3.spherical_harmonics(l=range(self.sh_lmax + 1), x=edge_vec, normalize=True, normalization='component') # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.num_basis, basis='smooth_finite', cutoff=True, ).mul(self.num_basis**0.5) node_input = node_pos.new_ones(node_pos.shape[0], 1) node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[node_atom] node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5) node_outputs = self.mp(node_features=node_input, node_attr=node_attr, edge_src=edge_src, edge_dst=edge_dst, edge_attr=edge_sh, edge_scalars=edge_length_embedding) node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5) node_outputs = node_outputs.view(-1, 1) node_outputs = node_outputs.div(self.num_nodes**0.5) if self.atomref is not None: node_outputs = node_outputs + self.atomref[node_atom] # for target=7, MAE of 75eV outputs = scatter(node_outputs, batch, dim=0) return outputs
def main(): parser = argparse.ArgumentParser() parser.add_argument("--l_in", type=int, required=True) parser.add_argument("--l_out", type=int, required=True) parser.add_argument("--n", type=int, default=30, help="size of the SOFT grid") parser.add_argument("--dpi", type=float, default=100) parser.add_argument("--sep", type=float, default=0.5, help="space between matrices") args = parser.parse_args() torch.set_default_dtype(torch.float64) x, y, z, alpha, beta = spherical_surface(args.n) out = [] for l in range(abs(args.l_out - args.l_in), args.l_out + args.l_in + 1): C = o3.clebsch_gordan(args.l_out, args.l_in, l) Y = o3.spherical_harmonics(l, alpha, beta) out.append(torch.einsum("ijk,k...->ij...", (C, Y))) f = torch.stack(out) nf, dim_out, dim_in, *_ = f.size() f = 0.5 + 0.5 * f / f.abs().max() fig = plt.figure(figsize=(nf * dim_in + (nf - 1) * args.sep, dim_out), dpi=args.dpi) for index in range(nf): for i in range(dim_out): for j in range(dim_in): width = 1 / (nf * dim_in + (nf - 1) * args.sep) height = 1 / dim_out rect = [ (index * (dim_in + args.sep) + j) * width, (dim_out - i - 1) * height, width, height ] ax = fig.add_axes(rect, projection='3d') fc = plt.get_cmap("bwr")(f[index, i, j].detach().cpu().numpy()) ax.plot_surface(x.numpy(), y.numpy(), z.numpy(), rstride=1, cstride=1, facecolors=fc) ax.set_axis_off() a = 0.6 ax.set_xlim3d(-a, a) ax.set_ylim3d(-a, a) ax.set_zlim3d(-a, a) ax.view_init(90, 0) plt.savefig("kernels{}to{}.png".format(args.l_in, args.l_out), transparent=True)
def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs): super().__init__() self.irreps_in = o3.Irreps(irreps_in) self.irreps_out = o3.Irreps(irreps_out) self.irreps_sh = o3.Irreps(irreps_sh) self.num_radial_basis = num_radial_basis # self-connection self.sc = Linear(self.irreps_in, self.irreps_out) # connection with neighbors r = diameter / 2 s = math.floor(r / steps[0]) x = torch.arange(-s, s + 1.0) * steps[0] s = math.floor(r / steps[1]) y = torch.arange(-s, s + 1.0) * steps[1] s = math.floor(r / steps[2]) z = torch.arange(-s, s + 1.0) * steps[2] lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1) # [x, y, z, R^3] self.register_buffer('lattice', lattice) if 'padding' not in kwargs: kwargs['padding'] = tuple(s // 2 for s in lattice.shape[:3]) self.kwargs = kwargs emb = soft_one_hot_linspace( x=lattice.norm(dim=-1), start=0.0, end=r, number=self.num_radial_basis, basis='smooth_finite', cutoff=True, ) self.register_buffer('emb', emb) sh = o3.spherical_harmonics( l=self.irreps_sh, x=lattice, normalize=True, normalization='component' ) # [x, y, z, irreps_sh.dim] self.register_buffer('sh', sh) self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False) self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel))
def test_spherical_harmonics(self): """ This test tests that - irr_repr - compose - spherical_harmonics are compatible Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x) with x = Z(a) Y(b) eta """ for order in range(7): with o3.torch_default_dtype(torch.float64): a, b = torch.rand(2) alpha, beta, gamma = torch.rand(3) ra, rb, _ = o3.compose(alpha, beta, gamma, a, b, 0) Yrx = o3.spherical_harmonics(order, ra, rb) Y = o3.spherical_harmonics(order, a, b) DrY = o3.irr_repr(order, alpha, beta, gamma) @ Y self.assertLess((Yrx - DrY).abs().max(), 1e-10 * Y.abs().max())
def message(self, x_i, x_j, pos_i, pos_j, cell_offsets): """ Message according to eqs 3-4 in the paper """ rel_pos = (pos_i - pos_j) + cell_offsets dist = rel_pos.pow(2).sum(-1, keepdims=True) rel_pos = spherical_harmonics(self.irreps_rel_pos, rel_pos, normalize=True, normalization='component') # message = self.message_net(torch.cat((x_i, x_j, dist, rel_pos), dim=-1)) message = self.message_layer_1(torch.cat((x_i, x_j, dist, rel_pos), dim=-1)) message = self.message_layer_2(torch.cat((message, rel_pos), dim=-1)) message = torch.cat((message, rel_pos), dim=-1) # <---- pass the relative position along if self.update_pos: # TODO: currently no updated... pos_message = (pos_i - pos_j) * self.pos_net(message) # torch geometric does not support tuple outputs. message = torch.cat((pos_message, message), dim=-1) return message
def test_recurrence_relation(float_tolerance, l): if torch.get_default_dtype() != torch.float64 and l > 6: pytest.xfail('we expect this to fail for high l and single precision') x = torch.randn(3, requires_grad=True) a = o3.spherical_harmonics(l + 1, x, False) b = torch.einsum('ijk,j,k->i', o3.wigner_3j(l + 1, l, 1), o3.spherical_harmonics(l, x, False), x) alpha = b.norm() / a.norm() assert (a / a.norm() - b / b.norm()).abs().max() < 10 * float_tolerance def f(x): return o3.spherical_harmonics(l + 1, x, False) a = torch.autograd.functional.jacobian(f, x) b = (l + 1) / alpha * torch.einsum('ijk,j->ik', o3.wigner_3j(l + 1, l, 1), o3.spherical_harmonics(l, x, False)) assert (a - b).abs().max() < 100 * float_tolerance
def test_closure(): r""" integral of Ylm * Yjn = delta_lj delta_mn integral of 1 over the unit sphere = 4 pi """ x = torch.randn(1_000_000, 3) Ys = [o3.spherical_harmonics(l, x, True) for l in range(0, 3 + 1)] for l1, Y1 in enumerate(Ys): for l2, Y2 in enumerate(Ys): m = Y1[:, :, None] * Y2[:, None, :] m = m.mean(0) * 4 * math.pi if l1 == l2: i = torch.eye(2 * l1 + 1) assert (m - i).abs().max() < 0.01 else: assert m.abs().max() < 0.01
def forward(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor: if 'batch' in data: batch = data['batch'] else: batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long) # The graph edge_src = data['edge_index'][0] edge_dst = data['edge_index'][1] # Edge attributes edge_vec = data['pos'][edge_src] - data['pos'][edge_dst] edge_sh = o3.spherical_harmonics(range(self.lmax + 1), edge_vec, True, normalization='component') edge_attr = torch.cat([data['edge_attr'], edge_sh], dim=1) # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.number_of_basis, basis= 'cosine', # the cosine basis with cutoff = True goes to zero at max_radius cutoff=True, # no need for an additional smooth cutoff ).mul(self.number_of_basis**0.5) node_outputs = self.mp(data['node_input'], data['node_attr'], edge_src, edge_dst, edge_attr, edge_length_embedding) if self.pool_nodes: return scatter(node_outputs, batch, dim=0).div(self.num_nodes**0.5) else: return node_outputs
def forward(self, data) -> torch.Tensor: num_nodes = 4 # typical number of nodes edge_src, edge_dst = radius_graph(x=data.pos, r=2.5, batch=data.batch) edge_vec = data.pos[edge_src] - data.pos[edge_dst] edge_attr = o3.spherical_harmonics(l=self.irreps_sh, x=edge_vec, normalize=True, normalization='component') edge_length_embedded = soft_one_hot_linspace(x=edge_vec.norm(dim=1), start=0.5, end=2.5, number=3, basis='smooth_finite', cutoff=True) * 3**0.5 x = scatter(edge_attr, edge_dst, dim=0).div(self.num_neighbors**0.5) x = self.conv(x, edge_src, edge_dst, edge_attr, edge_length_embedded) x = self.gate(x) x = self.final(x, edge_src, edge_dst, edge_attr, edge_length_embedded) return scatter(x, data.batch, dim=0).div(num_nodes**0.5)
def forward(self, node_atom, node_pos, batch) -> torch.Tensor: # The graph edge_src, edge_dst = radius_graph( node_pos, r=self.max_radius, batch=batch, max_num_neighbors=1000 ) # Edge attributes edge_vec = node_pos[edge_src] - node_pos[edge_dst] edge_sh = o3.spherical_harmonics( l=range(self.lmax + 1), x=edge_vec, normalize=True, normalization='component' ) # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.number_of_basis, basis='cosine', # the cosine basis with cutoff = True goes to zero at max_radius cutoff=True, # no need for an additional smooth cutoff ).mul(self.number_of_basis**0.5) node_input = node_pos.new_ones(node_pos.shape[0], 1) node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[node_atom] node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5) node_outputs = self.mp( node_features=node_input, node_attr=node_attr, edge_src=edge_src, edge_dst=edge_dst, edge_attr=edge_sh, edge_scalars=edge_length_embedding ) node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5) node_outputs = node_outputs.view(-1, 1) node_outputs = node_outputs.div(self.num_nodes**0.5) if self.mean is not None and self.std is not None: node_outputs = node_outputs * self.std + self.mean if self.atomref is not None: node_outputs = node_outputs + self.atomref[node_atom] # for target=7, MAE of 75eV outputs = scatter(node_outputs, batch, dim=0) if self.scale is not None: outputs = self.scale * outputs return outputs
def test_weird_call(): o3.spherical_harmonics([4, 1, 2, 3, 3, 1, 0], torch.randn(2, 1, 2, 3), False)
def func(pos): return o3.spherical_harmonics(ls, pos, False)
def forward(self, z, pos, batch=None): assert z.dim() == 1 and z.dtype == torch.long assert pos.dim() == 2 and pos.shape[1] == 3 batch = torch.zeros_like(z) if batch is None else batch edge_src, edge_dst = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=1000) edge_vec = pos[edge_src] - pos[edge_dst] edge_sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, True, 'component') edge_len = edge_vec.norm(dim=1) edge_len_emb = self.radial( soft_one_hot_linspace(edge_len, 0.0, self.cutoff, self.rad_gaussians)) edge_c = (pi * edge_len / self.cutoff).cos().add(1).div(2) edge_sh = edge_c[:, None] * edge_sh / self.num_neighbors**0.5 # z : [1, 6, 7, 8, 9] -> [0, 1, 2, 3, 4] node_z = z.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[z] # edge_zz = 5 * node_z[edge_src] + node_z[edge_dst] node_z = torch.nn.functional.one_hot(node_z, 5).mul(5**0.5) # edge_zz = torch.nn.functional.one_hot(edge_zz, 25).mul(5.0) # edge_attr = self.mul(edge_zz, edge_sh) edge_attr = edge_sh h = scatter(edge_sh, edge_src, dim=0, dim_size=len(pos)) h[:, 0] = 1 h = self.mul_node(node_z, h) print_std('h', h) for conv, act in self.layers[:-1]: with torch.autograd.profiler.record_function("Layer"): h = conv(h, node_z, edge_src, edge_dst, edge_len_emb, edge_attr) # convolution print_std('post conv', h) h = act(h) # gate non linearity print_std('post gate', h) with torch.autograd.profiler.record_function("Layer"): h = self.layers[-1](h, node_z, edge_src, edge_dst, edge_len_emb, edge_attr) print_std('h out', h) s = 0 for i, (mul, (l, p)) in enumerate(self.irreps_out): assert mul == 1 and l == 0 if p == 1: s += h[:, i] if p == -1: s += h[:, i].pow(2).mul(0.5) # odd^2 = even h = s.view(-1, 1) print_std('h out+', h) # for the scatter we normalize h = h / self.num_atoms**0.5 if self.mean is not None and self.std is not None: h = h * self.std + self.mean if self.atomref is not None: h = h + self.atomref[z] # for target=7, MAE of 75eV out = scatter(h, batch, dim=0) if self.scale is not None: out = self.scale * out return out
def _forward(self, data): pos = data.pos batch = data.batch if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 50, data.pos.device ) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors if self.use_pbc: out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, return_offsets=True, ) edge_index = out["edge_index"] cell_offsets = out["offsets"] else: edge_index = radius_graph(pos, r=self.cutoff, batch=batch) raise NotImplementedError # construct the node and edge attributes rel_pos = (pos[edge_index[0]] - pos[edge_index[1]]) + cell_offsets edge_dist = rel_pos.pow(2).sum(-1, keepdims=True) edge_dist_radii_1 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[0]]][:, None] edge_dist_radii_2 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[1]]][:, None] edge_dist_radii_12 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[0]]][:, None] - self.atom_radii[ data.atomic_numbers.long()[edge_index[1]]][:, None] edge_attr = spherical_harmonics(self.attr_irreps, rel_pos, normalize=True, normalization='component') node_attr = self.node_attribute_net(edge_index, edge_attr) if (data.contains_isolated_nodes() and edge_index.max().item() + 1 != data.num_nodes): nr_add_attr = data.num_nodes - (edge_index.max().item() + 1) add_attr = node_attr.new_tensor(np.tile(np.eye(node_attr.shape[-1])[0,:], (nr_add_attr,1))) #add_attr = node_attr.new_tensor(np.zeros((nr_add_attr, node_attr.shape[-1]))) node_attr = torch.cat((node_attr, add_attr), -2) # node_attr, edge_attr = self.attribute_net(pos, edge_index) x = self.atom_map[data.atomic_numbers.long()] x = self.embedding_layer_1(x, node_attr) x = self.embedding_layer_2(x, node_attr) x = self.embedding_layer_3(x, node_attr) # The main layers for layer in self.layers: x, pos = layer(x, pos, edge_index, edge_dist, edge_dist_radii_1, edge_dist_radii_2, edge_dist_radii_12, edge_attr, node_attr) # Output head x = self.head_pre_pool_layer_1(x, node_attr) x = self.head_pre_pool_layer_2(x) x = global_mean_pool(x, batch) x = self.head_post_pool_layer_1(x) x = self.head_post_pool_layer_2(x) # Return the result return x
def f(x): return o3.spherical_harmonics(l + 1, x, False)
def __init__(self, alpha, beta, lmax): super().__init__() sh = torch.cat( [o3.spherical_harmonics(l, alpha, beta) for l in range(lmax + 1)]) self.register_buffer("sh", sh)