def test(): from torch_cluster import radius from e3nn.math import soft_one_hot_linspace conv = Convolution( irreps_node_input='0e + 1e', irreps_node_output='0e + 1e', irreps_node_attr_input='2x0e', irreps_node_attr_output='3x0e', irreps_edge_attr='0e + 1e', num_edge_scalar_attr=4, radial_layers=1, radial_neurons=50, num_neighbors=3.0, ) pos_in = torch.randn(5, 3) pos_out = torch.randn(2, 3) node_input = torch.randn(5, 4) node_attr_input = torch.randn(5, 2) node_attr_output = torch.randn(2, 3) edge_src, edge_dst = radius(pos_out, pos_in, r=2.0) edge_vec = pos_in[edge_src] - pos_out[edge_dst] edge_attr = o3.spherical_harmonics([0, 1], edge_vec, True) edge_scalar_attr = soft_one_hot_linspace(x=edge_vec.norm(dim=1), start=0.0, end=2.0, number=4, basis='smooth_finite', cutoff=True) conv(node_input, node_attr_input, node_attr_output, edge_src, edge_dst, edge_attr, edge_scalar_attr)
def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: batch, node_inputs, edge_src, edge_dst, edge_vec = self.preprocess(data) del data edge_attr = o3.spherical_harmonics(range(self.lmax + 1), edge_vec, True, normalization='component') # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.number_of_basis, basis='smooth_finite', # the smooth_finite basis with cutoff = True goes to zero at max_radius cutoff=True, # no need for an additional smooth cutoff ).mul(self.number_of_basis**0.5) # Node attributes are not used here node_attr = node_inputs.new_ones(node_inputs.shape[0], 1) node_outputs = self.mp(node_inputs, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedding) if self.pool_nodes: return scatter(node_outputs, batch, int(batch.max()) + 1).div(self.num_nodes**0.5) else: return node_outputs
def forward(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor: """evaluate the network Parameters ---------- data : `torch_geometric.data.Data` or dict data object containing - ``pos`` the position of the nodes (atoms) - ``x`` the input features of the nodes, optional - ``z`` the attributes of the nodes, for instance the atom type, optional - ``batch`` the graph to which the node belong, optional """ if 'batch' in data: batch = data['batch'] else: batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long) edge_index = radius_graph(data['pos'], self.max_radius, batch) edge_src = edge_index[0] edge_dst = edge_index[1] edge_vec = data['pos'][edge_src] - data['pos'][edge_dst] edge_sh = o3.spherical_harmonics(self.irreps_edge_attr, edge_vec, True, normalization='component') edge_length = edge_vec.norm(dim=1) edge_length_embedded = soft_one_hot_linspace( x=edge_length, start=0.0, end=self.max_radius, number=self.number_of_basis, basis='gaussian', cutoff=False).mul(self.number_of_basis**0.5) edge_attr = smooth_cutoff( edge_length / self.max_radius)[:, None] * edge_sh if self.input_has_node_in and 'x' in data: assert self.irreps_in is not None x = data['x'] else: assert self.irreps_in is None x = data['pos'].new_ones((data['pos'].shape[0], 1)) if self.input_has_node_attr and 'z' in data: z = data['z'] else: assert self.irreps_node_attr == o3.Irreps("0e") z = data['pos'].new_ones((data['pos'].shape[0], 1)) for lay in self.layers: x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded) if self.reduce_output: return scatter(x, batch, dim=0).div(self.num_nodes**0.5) else: return x
def test_zero_out(basis): x1 = torch.linspace(-2.0, -1.1, 20) x2 = torch.linspace(2.1, 3.0, 20) x = torch.cat([x1, x2]) y = soft_one_hot_linspace(x, -1.0, 2.0, 5, basis, cutoff=True) if basis == 'gaussian': assert y.abs().max() < 0.22 else: assert y.abs().max() == 0.0
def forward(self, data) -> torch.Tensor: node_atom = data['z'] node_pos = data['pos'] batch = data['batch'] # The graph edge_src, edge_dst = radius_graph(node_pos, r=self.max_radius, batch=batch, max_num_neighbors=1000) # Edge attributes edge_vec = node_pos[edge_src] - node_pos[edge_dst] edge_sh = o3.spherical_harmonics(l=range(self.sh_lmax + 1), x=edge_vec, normalize=True, normalization='component') # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.num_basis, basis='smooth_finite', cutoff=True, ).mul(self.num_basis**0.5) node_input = node_pos.new_ones(node_pos.shape[0], 1) node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[node_atom] node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5) node_outputs = self.mp(node_features=node_input, node_attr=node_attr, edge_src=edge_src, edge_dst=edge_dst, edge_attr=edge_sh, edge_scalars=edge_length_embedding) node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5) node_outputs = node_outputs.view(-1, 1) node_outputs = node_outputs.div(self.num_nodes**0.5) if self.atomref is not None: node_outputs = node_outputs + self.atomref[node_atom] # for target=7, MAE of 75eV outputs = scatter(node_outputs, batch, dim=0) return outputs
def __init__(self, irreps_in, irreps_out, irreps_sh, diameter, num_radial_basis, steps=(1.0, 1.0, 1.0), **kwargs): super().__init__() self.irreps_in = o3.Irreps(irreps_in) self.irreps_out = o3.Irreps(irreps_out) self.irreps_sh = o3.Irreps(irreps_sh) self.num_radial_basis = num_radial_basis # self-connection self.sc = Linear(self.irreps_in, self.irreps_out) # connection with neighbors r = diameter / 2 s = math.floor(r / steps[0]) x = torch.arange(-s, s + 1.0) * steps[0] s = math.floor(r / steps[1]) y = torch.arange(-s, s + 1.0) * steps[1] s = math.floor(r / steps[2]) z = torch.arange(-s, s + 1.0) * steps[2] lattice = torch.stack(torch.meshgrid(x, y, z), dim=-1) # [x, y, z, R^3] self.register_buffer('lattice', lattice) if 'padding' not in kwargs: kwargs['padding'] = tuple(s // 2 for s in lattice.shape[:3]) self.kwargs = kwargs emb = soft_one_hot_linspace( x=lattice.norm(dim=-1), start=0.0, end=r, number=self.num_radial_basis, basis='smooth_finite', cutoff=True, ) self.register_buffer('emb', emb) sh = o3.spherical_harmonics( l=self.irreps_sh, x=lattice, normalize=True, normalization='component' ) # [x, y, z, irreps_sh.dim] self.register_buffer('sh', sh) self.tp = FullyConnectedTensorProduct(self.irreps_in, self.irreps_sh, self.irreps_out, shared_weights=False) self.weight = torch.nn.Parameter(torch.randn(self.num_radial_basis, self.tp.weight_numel))
def forward(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor: if 'batch' in data: batch = data['batch'] else: batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long) # The graph edge_src = data['edge_index'][0] edge_dst = data['edge_index'][1] # Edge attributes edge_vec = data['pos'][edge_src] - data['pos'][edge_dst] edge_sh = o3.spherical_harmonics(range(self.lmax + 1), edge_vec, True, normalization='component') edge_attr = torch.cat([data['edge_attr'], edge_sh], dim=1) # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.number_of_basis, basis= 'cosine', # the cosine basis with cutoff = True goes to zero at max_radius cutoff=True, # no need for an additional smooth cutoff ).mul(self.number_of_basis**0.5) node_outputs = self.mp(data['node_input'], data['node_attr'], edge_src, edge_dst, edge_attr, edge_length_embedding) if self.pool_nodes: return scatter(node_outputs, batch, dim=0).div(self.num_nodes**0.5) else: return node_outputs
def forward(self, data) -> torch.Tensor: num_nodes = 4 # typical number of nodes edge_src, edge_dst = radius_graph(x=data.pos, r=2.5, batch=data.batch) edge_vec = data.pos[edge_src] - data.pos[edge_dst] edge_attr = o3.spherical_harmonics(l=self.irreps_sh, x=edge_vec, normalize=True, normalization='component') edge_length_embedded = soft_one_hot_linspace(x=edge_vec.norm(dim=1), start=0.5, end=2.5, number=3, basis='smooth_finite', cutoff=True) * 3**0.5 x = scatter(edge_attr, edge_dst, dim=0).div(self.num_neighbors**0.5) x = self.conv(x, edge_src, edge_dst, edge_attr, edge_length_embedded) x = self.gate(x) x = self.final(x, edge_src, edge_dst, edge_attr, edge_length_embedded) return scatter(x, data.batch, dim=0).div(num_nodes**0.5)
def test_normalized(basis, cutoff): x = torch.linspace(-14.0, 105.0, 50) y = soft_one_hot_linspace(x, -20.0, 120.0, 12, basis, cutoff) assert 0.4 < y.pow(2).sum(1).min() assert y.pow(2).sum(1).max() < 2.0
def forward(self, node_atom, node_pos, batch) -> torch.Tensor: # The graph edge_src, edge_dst = radius_graph( node_pos, r=self.max_radius, batch=batch, max_num_neighbors=1000 ) # Edge attributes edge_vec = node_pos[edge_src] - node_pos[edge_dst] edge_sh = o3.spherical_harmonics( l=range(self.lmax + 1), x=edge_vec, normalize=True, normalization='component' ) # Edge length embedding edge_length = edge_vec.norm(dim=1) edge_length_embedding = soft_one_hot_linspace( edge_length, 0.0, self.max_radius, self.number_of_basis, basis='cosine', # the cosine basis with cutoff = True goes to zero at max_radius cutoff=True, # no need for an additional smooth cutoff ).mul(self.number_of_basis**0.5) node_input = node_pos.new_ones(node_pos.shape[0], 1) node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[node_atom] node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5) node_outputs = self.mp( node_features=node_input, node_attr=node_attr, edge_src=edge_src, edge_dst=edge_dst, edge_attr=edge_sh, edge_scalars=edge_length_embedding ) node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5) node_outputs = node_outputs.view(-1, 1) node_outputs = node_outputs.div(self.num_nodes**0.5) if self.mean is not None and self.std is not None: node_outputs = node_outputs * self.std + self.mean if self.atomref is not None: node_outputs = node_outputs + self.atomref[node_atom] # for target=7, MAE of 75eV outputs = scatter(node_outputs, batch, dim=0) if self.scale is not None: outputs = self.scale * outputs return outputs
def forward(self, z, pos, batch=None): assert z.dim() == 1 and z.dtype == torch.long assert pos.dim() == 2 and pos.shape[1] == 3 batch = torch.zeros_like(z) if batch is None else batch edge_src, edge_dst = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=1000) edge_vec = pos[edge_src] - pos[edge_dst] edge_sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, True, 'component') edge_len = edge_vec.norm(dim=1) edge_len_emb = self.radial( soft_one_hot_linspace(edge_len, 0.0, self.cutoff, self.rad_gaussians)) edge_c = (pi * edge_len / self.cutoff).cos().add(1).div(2) edge_sh = edge_c[:, None] * edge_sh / self.num_neighbors**0.5 # z : [1, 6, 7, 8, 9] -> [0, 1, 2, 3, 4] node_z = z.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[z] # edge_zz = 5 * node_z[edge_src] + node_z[edge_dst] node_z = torch.nn.functional.one_hot(node_z, 5).mul(5**0.5) # edge_zz = torch.nn.functional.one_hot(edge_zz, 25).mul(5.0) # edge_attr = self.mul(edge_zz, edge_sh) edge_attr = edge_sh h = scatter(edge_sh, edge_src, dim=0, dim_size=len(pos)) h[:, 0] = 1 h = self.mul_node(node_z, h) print_std('h', h) for conv, act in self.layers[:-1]: with torch.autograd.profiler.record_function("Layer"): h = conv(h, node_z, edge_src, edge_dst, edge_len_emb, edge_attr) # convolution print_std('post conv', h) h = act(h) # gate non linearity print_std('post gate', h) with torch.autograd.profiler.record_function("Layer"): h = self.layers[-1](h, node_z, edge_src, edge_dst, edge_len_emb, edge_attr) print_std('h out', h) s = 0 for i, (mul, (l, p)) in enumerate(self.irreps_out): assert mul == 1 and l == 0 if p == 1: s += h[:, i] if p == -1: s += h[:, i].pow(2).mul(0.5) # odd^2 = even h = s.view(-1, 1) print_std('h out+', h) # for the scatter we normalize h = h / self.num_atoms**0.5 if self.mean is not None and self.std is not None: h = h * self.std + self.mean if self.atomref is not None: h = h + self.atomref[z] # for target=7, MAE of 75eV out = scatter(h, batch, dim=0) if self.scale is not None: out = self.scale * out return out