def forward(self, data): self.device = data.pos.device self.num_atoms = len(data.batch) self.batch_size = len(data.natoms) atomic_numbers = data.atomic_numbers.long() pos = data.pos if self.regress_forces: pos = pos.requires_grad_(True) if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 100) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors if self.use_pbc: assert (atomic_numbers.dim() == 1 and atomic_numbers.dtype == torch.long) out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, return_distance_vec=True, ) edge_index = out["edge_index"] edge_distance = out["distances"] edge_distance_vec = out["distance_vec"] else: edge_index = radius_graph(pos, r=self.cutoff, batch=data.batch) j, i = edge_index edge_distance_vec = pos[j] - pos[i] edge_distance = edge_distance_vec.norm(dim=-1) edge_index, edge_distance, edge_distance_vec = self._filter_edges( edge_index, edge_distance, edge_distance_vec, self.max_num_neighbors, ) outputs = self._forward_helper(data, edge_index, edge_distance, edge_distance_vec) if self.show_timing_info is True: torch.cuda.synchronize() print("Memory: {}\t{}\t{}".format( len(edge_index[0]), torch.cuda.memory_allocated() / (1000 * len(edge_index[0])), torch.cuda.max_memory_allocated() / 1000000, )) return outputs
def update_graph(self, atoms): edge_index, cell_offsets, num_neighbors = radius_graph_pbc( atoms, 6, 50) atoms.edge_index = edge_index atoms.cell_offsets = cell_offsets atoms.neighbors = num_neighbors if self.transform is not None: atoms = self.transform(atoms) return atoms
def forward(self, data): # Get node features if self.embedding.device != data.atomic_numbers.device: self.embedding = self.embedding.to(data.atomic_numbers.device) data.x = self.embedding[data.atomic_numbers.long() - 1] pos = data.pos if self.regress_forces: pos = pos.requires_grad_(True) if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 50, data.pos.device ) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors if self.use_pbc: out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, ) data.edge_index = out["edge_index"] distances = out["distances"] else: data.edge_index = radius_graph( data.pos, r=self.cutoff, batch=data.batch ) row, col = data.edge_index distances = (pos[row] - pos[col]).norm(dim=-1) data.edge_attr = self.distance_expansion(distances) # Forward pass through the network mol_feats = self._convolve(data) mol_feats = self.conv_to_fc(mol_feats) if hasattr(self, "fcs"): mol_feats = self.fcs(mol_feats) energy = self.fc_out(mol_feats) if self.regress_forces: forces = -1 * ( torch.autograd.grad( energy, pos, grad_outputs=torch.ones_like(energy), create_graph=True, )[0] ) return energy, forces else: return energy
def forward(self, data): z = data.atomic_numbers.long() pos = data.pos if self.regress_forces: pos = pos.requires_grad_(True) batch = data.batch if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 50, data.pos.device) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors # TODO return distance computation in radius_graph_pbc to remove need # for get_pbc_distances call if self.use_pbc: assert z.dim() == 1 and z.dtype == torch.long out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, ) edge_index = out["edge_index"] edge_weight = out["distances"] edge_attr = self.distance_expansion(edge_weight) h = self.embedding(z) for interaction in self.interactions: h = h + interaction(h, edge_index, edge_weight, edge_attr) h = self.lin1(h) h = self.act(h) h = self.lin2(h) batch = torch.zeros_like(z) if batch is None else batch energy = scatter(h, batch, dim=0, reduce=self.readout) else: energy = super(SchNetWrap, self).forward(z, pos, batch) if self.regress_forces: forces = -1 * (torch.autograd.grad( energy, pos, grad_outputs=torch.ones_like(energy), create_graph=True, )[0]) return energy, forces else: return energy
def calculate(self, atoms, properties, system_changes): Calculator.calculate(self, atoms, properties, system_changes) data_object = self.a2g.convert(atoms) batch = data_list_collater([data_object]) if self.pbc_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( batch, 6, 50, batch.pos.device) batch.edge_index = edge_index batch.cell_offsets = cell_offsets batch.neighbors = neighbors predictions = self.trainer.predict(batch) self.results["energy"] = predictions["energy"][0] self.results["forces"] = predictions["forces"][0]
def calculate(self, atoms, properties, system_changes): Calculator.calculate(self, atoms, properties, system_changes) data_object = self.a2g.convert(atoms) batch = data_list_collater([data_object]) if self.pbc_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( batch, 6, 50, batch.pos.device) batch.edge_index = edge_index batch.cell_offsets = cell_offsets batch.neighbors = neighbors predictions = self.trainer.predict(batch, per_image=False) if self.trainer.name == "s2ef": self.results["energy"] = predictions["energy"].item() self.results["forces"] = predictions["forces"].cpu().numpy() elif self.trainer.name == "is2re": self.results["energy"] = predictions["energy"].item()
def _forward(self, data): pos = data.pos batch = data.batch if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 50, data.pos.device) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors if self.use_pbc: out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, return_offsets=True, ) edge_index = out["edge_index"] cell_offsets = out["offsets"] else: edge_index = radius_graph(pos, r=self.cutoff, batch=batch) raise NotImplementedError h = self.atom_map[data.atomic_numbers.long()] """ Propagate messages along edges and average over energies""" h = self.embedding_mlp(h) for i in range(self.hidden_layer): h, pos = self.egnn[i](h, pos, edge_index, cell_offsets, batch) out = self.head_pre_pool(h) out = global_mean_pool(out, batch) energy = self.head_post_pool(out) return energy
def _forward(self, data): pos = data.pos batch = data.batch if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 50, data.pos.device ) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors if self.use_pbc: out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, return_offsets=True, ) edge_index = out["edge_index"] cell_offsets = out["offsets"] else: edge_index = radius_graph(pos, r=self.cutoff, batch=batch) raise NotImplementedError h = self.atom_map[data.atomic_numbers.long()] h = self.embedding(h) for layer in self.layers: h, pos = layer(h, pos, edge_index, cell_offsets) # Output heads h = self.head_pre_pool(h) h = global_mean_pool(h, batch) energy = self.head_post_pool(h) return energy
def generate_interaction_graph(self, data): num_atoms = data.atomic_numbers.size(0) if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, self.max_neighbors) else: edge_index = data.edge_index cell_offsets = data.cell_offsets neighbors = data.neighbors # Switch the indices, so the second one becomes the target index, # over which we can efficiently aggregate. out = get_pbc_distances( data.pos, edge_index, data.cell, cell_offsets, neighbors, return_offsets=True, return_distance_vec=True, ) edge_index = out["edge_index"] D_st = out["distances"] # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. V_st = -out["distance_vec"] / D_st[:, None] # offsets_ca = -out["offsets"] # a - c + offset # Mask interaction edges if required if self.otf_graph or np.isclose(self.cutoff, 6): select_cutoff = None else: select_cutoff = self.cutoff ( edge_index, cell_offsets, neighbors, D_st, V_st, ) = self.select_edges( data=data, edge_index=edge_index, cell_offsets=cell_offsets, neighbors=neighbors, edge_dist=D_st, edge_vector=V_st, cutoff=select_cutoff, ) ( edge_index, cell_offsets, neighbors, D_st, V_st, ) = self.reorder_symmetric_edges(edge_index, cell_offsets, neighbors, D_st, V_st) # Indices for swapping c->a and a->c (for symmetric MP) block_sizes = neighbors // 2 id_swap = repeat_blocks( block_sizes, repeats=2, continuous_indexing=False, start_idx=block_sizes[0], block_inc=block_sizes[:-1] + block_sizes[1:], repeat_inc=-block_sizes, ) id3_ba, id3_ca, id3_ragged_idx = self.get_triplets(edge_index, num_atoms=num_atoms) return ( edge_index, neighbors, D_st, V_st, id_swap, id3_ba, id3_ca, id3_ragged_idx, )
def _forward(self, data): pos = data.pos batch = data.batch if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 50, data.pos.device ) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors if self.use_pbc: out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, return_offsets=True, ) edge_index = out["edge_index"] dist = out["distances"] offsets = out["offsets"] j, i = edge_index else: edge_index = radius_graph(pos, r=self.cutoff, batch=batch) j, i = edge_index dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( edge_index, num_nodes=data.atomic_numbers.size(0) ) # Calculate angles. pos_i = pos[idx_i].detach() pos_j = pos[idx_j].detach() if self.use_pbc: pos_ji, pos_kj = ( pos[idx_j].detach() - pos_i + offsets[idx_ji], pos[idx_k].detach() - pos_j + offsets[idx_kj], ) else: pos_ji, pos_kj = ( pos[idx_j].detach() - pos_i, pos[idx_k].detach() - pos_j, ) a = (pos_ji * pos_kj).sum(dim=-1) b = torch.cross(pos_ji, pos_kj).norm(dim=-1) angle = torch.atan2(b, a) rbf = self.rbf(dist) sbf = self.sbf(dist, angle, idx_kj) # Embedding block. x = self.emb(data.atomic_numbers.long(), rbf, i, j) P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) # Interaction blocks. for interaction_block, output_block in zip( self.interaction_blocks, self.output_blocks[1:] ): x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) P += output_block(x, rbf, i, num_nodes=pos.size(0)) energy = ( P.sum(dim=0) if data.batch is None else scatter(P, data.batch, dim=0) ) return P, energy
def _forward(self, data): pos = data.pos batch = data.batch if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 50, data.pos.device ) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors if self.use_pbc: out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, return_offsets=True, ) edge_index = out["edge_index"] cell_offsets = out["offsets"] else: edge_index = radius_graph(pos, r=self.cutoff, batch=batch) raise NotImplementedError # construct the node and edge attributes rel_pos = (pos[edge_index[0]] - pos[edge_index[1]]) + cell_offsets edge_dist = rel_pos.pow(2).sum(-1, keepdims=True) edge_dist_radii_1 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[0]]][:, None] edge_dist_radii_2 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[1]]][:, None] edge_dist_radii_12 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[0]]][:, None] - self.atom_radii[ data.atomic_numbers.long()[edge_index[1]]][:, None] edge_attr = spherical_harmonics(self.attr_irreps, rel_pos, normalize=True, normalization='component') node_attr = self.node_attribute_net(edge_index, edge_attr) if (data.contains_isolated_nodes() and edge_index.max().item() + 1 != data.num_nodes): nr_add_attr = data.num_nodes - (edge_index.max().item() + 1) add_attr = node_attr.new_tensor(np.tile(np.eye(node_attr.shape[-1])[0,:], (nr_add_attr,1))) #add_attr = node_attr.new_tensor(np.zeros((nr_add_attr, node_attr.shape[-1]))) node_attr = torch.cat((node_attr, add_attr), -2) # node_attr, edge_attr = self.attribute_net(pos, edge_index) x = self.atom_map[data.atomic_numbers.long()] x = self.embedding_layer_1(x, node_attr) x = self.embedding_layer_2(x, node_attr) x = self.embedding_layer_3(x, node_attr) # The main layers for layer in self.layers: x, pos = layer(x, pos, edge_index, edge_dist, edge_dist_radii_1, edge_dist_radii_2, edge_dist_radii_12, edge_attr, node_attr) # Output head x = self.head_pre_pool_layer_1(x, node_attr) x = self.head_pre_pool_layer_2(x) x = global_mean_pool(x, batch) x = self.head_post_pool_layer_1(x) x = self.head_post_pool_layer_2(x) # Return the result return x
def forward(self, data): z = data.atomic_numbers.long() pos = data.pos batch = data.batch if self.feat == "simple": h = self.embedding(z) elif self.feat == "full": h = self.embedding(self.atom_map[z]) else: raise RuntimeError("Undefined feature type for atom") if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 50 ) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, return_distance_vec=True, ) edge_index = out["edge_index"] edge_dist = out["distances"] edge_vec = out["distance_vec"] if self.pbc_apply_sph_harm: edge_vec_normalized = edge_vec / edge_dist.view(-1, 1) edge_attr_sph = self.pbc_sph(edge_vec_normalized) # calculate the edge weight according to the dist edge_weight = torch.cos(0.5 * edge_dist * PI / self.cutoff) # normalized edge vectors edge_vec_normalized = edge_vec / edge_dist.view(-1, 1) # edge distance, taking the atom_radii into account # each element lies in [0,1] edge_dist_list = ( torch.stack( [ edge_dist, edge_dist - self.atom_radii[z[edge_index[0]]], edge_dist - self.atom_radii[z[edge_index[1]]], edge_dist - self.atom_radii[z[edge_index[0]]] - self.atom_radii[z[edge_index[1]]], ] ).transpose(0, 1) / self.cutoff ) if self.ablation == "nodistlist": edge_dist_list = edge_dist_list[:, 0].view(-1, 1) # make sure distance is positive edge_dist_list[edge_dist_list < 1e-3] = 1e-3 # squash to [0,1] for gaussian basis if self.basis_type == "gauss": edge_vec_normalized = (edge_vec_normalized + 1) / 2.0 # process raw_edge_attributes to generate edge_attributes if self.ablation == "onlydist": raw_edge_attr = edge_dist_list else: raw_edge_attr = torch.cat( [edge_vec_normalized, edge_dist_list], dim=1 ) if "sph" in self.basis_type: edge_attr = self.basis_fun(raw_edge_attr, edge_attr_sph) else: edge_attr = self.basis_fun(raw_edge_attr) # pass edge_attributes through interaction blocks for i, interaction in enumerate(self.interactions): h = h + interaction(h, edge_index, edge_attr, edge_weight) h = self.lin(h) h = self.activation(h) out = scatter(h, batch, dim=0, reduce="add") force = self.decoder(h) energy = self.energy_mlp(out) return energy, force
def forward(self, data): pos = data.pos if self.regress_forces: pos = pos.requires_grad_(True) batch = data.batch if self.otf_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( data, self.cutoff, 50, data.pos.device) data.edge_index = edge_index data.cell_offsets = cell_offsets data.neighbors = neighbors if self.use_pbc: out = get_pbc_distances( pos, data.edge_index, data.cell, data.cell_offsets, data.neighbors, return_offsets=True, ) edge_index = out["edge_index"] dist = out["distances"] offsets = out["offsets"] j, i = edge_index else: edge_index = radius_graph(pos, r=self.cutoff, batch=batch) j, i = edge_index dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( edge_index, num_nodes=data.atomic_numbers.size(0)) # Cap no. of triplets during training. if self.training: sub_ix = torch.randperm(idx_i.size(0))[:self.max_angles_per_image * data.natoms.size(0)] idx_i, idx_j, idx_k = idx_i[sub_ix], idx_j[sub_ix], idx_k[sub_ix] idx_kj, idx_ji = idx_kj[sub_ix], idx_ji[sub_ix] # Calculate angles. pos_i = pos[idx_i].detach() pos_j = pos[idx_j].detach() if self.use_pbc: pos_ji, pos_kj = ( pos[idx_j].detach() - pos_i + offsets[idx_ji], pos[idx_k].detach() - pos_j + offsets[idx_kj], ) else: pos_ji, pos_kj = ( pos[idx_j].detach() - pos_i, pos[idx_k].detach() - pos_j, ) a = (pos_ji * pos_kj).sum(dim=-1) b = torch.cross(pos_ji, pos_kj).norm(dim=-1) angle = torch.atan2(b, a) rbf = self.rbf(dist) sbf = self.sbf(dist, angle, idx_kj) # Embedding block. x = self.emb(data.atomic_numbers.long(), rbf, i, j) P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) # Interaction blocks. for interaction_block, output_block in zip(self.interaction_blocks, self.output_blocks[1:]): x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) P += output_block(x, rbf, i, num_nodes=pos.size(0)) energy = P.sum(dim=0) if batch is None else scatter(P, batch, dim=0) if self.regress_forces: forces = -1 * (torch.autograd.grad( energy, pos, grad_outputs=torch.ones_like(energy), create_graph=True, )[0]) return energy, forces else: return energy