def forward(self, inputs): """ Args: inputs (dict of torch.Tensor): SchNetPack format dictionary of input tensors. Returns: torch.Tensor: Nbatch x Natoms x Nsymmetry_functions Tensor containing ACSFs or wACSFs. """ positions = inputs[Structure.R] Z = inputs[Structure.Z] neighbors = inputs[Structure.neighbors] neighbor_mask = inputs[Structure.neighbor_mask] # Compute radial functions if self.RDF is not None: # Get atom type embeddings Z_rad = self.radial_Z(Z) # Get atom types of neighbors Z_ij = snn.neighbor_elements(Z_rad, neighbors) # Compute distances distances = snn.atom_distances(positions, neighbors, neighbor_mask=neighbor_mask) radial_sf = self.RDF(distances, elemental_weights=Z_ij, neighbor_mask=neighbor_mask) else: radial_sf = None if self.ADF is not None: # Get pair indices idx_j = inputs[Structure.neighbor_pairs_j] idx_k = inputs[Structure.neighbor_pairs_k] neighbor_pairs_mask = inputs[Structure.neighbor_pairs_mask] # Get element contributions of the pairs Z_angular = self.angular_Z(Z) Z_ij = snn.neighbor_elements(Z_angular, idx_j) Z_ik = snn.neighbor_elements(Z_angular, idx_k) # Compute triple distances r_ij, r_ik, r_jk = snn.triple_distances(positions, idx_j, idx_k) angular_sf = self.ADF(r_ij, r_ik, r_jk, elemental_weights=(Z_ij, Z_ik), triple_masks=neighbor_pairs_mask) else: angular_sf = None # Concatenate and return symmetry functions if self.RDF is None: symmetry_functions = angular_sf elif self.ADF is None: symmetry_functions = radial_sf else: symmetry_functions = torch.cat((radial_sf, angular_sf), 2) return symmetry_functions
def forward(self, inputs): positions = inputs[Properties.R] neighbors = inputs[Properties.neighbors] nbh_mask = inputs[Properties.neighbor_mask] atom_mask = inputs[Properties.atom_mask] # Get environment distances and positions distances, dist_vecs = L.atom_distances(positions, neighbors, return_vecs=True) # Get atomic contributions contributions = self.out_net(inputs) # Redistribute contributions to neighbors # B x A x N x 1 # neighbor_contributions = L.neighbor_elements(c1, neighbors) neighbor_contributions = L.neighbor_elements(contributions, neighbors) if self.cutoff_network is not None: f_cut = self.cutoff_network(distances)[..., None] neighbor_contributions = neighbor_contributions * f_cut neighbor_contributions1 = neighbor_contributions[:, :, :, 0] neighbor_contributions2 = neighbor_contributions[:, :, :, 1] # B x A x N x 3 atomic_dipoles = self.nbh_agg( dist_vecs * neighbor_contributions1[..., None], nbh_mask) # B x A x N x 3 masked_dist = (distances**3 * nbh_mask) + (1 - nbh_mask) nbh_fields = (dist_vecs * neighbor_contributions2[..., None] / masked_dist[..., None]) atomic_fields = self.nbh_agg(nbh_fields, nbh_mask) field_norm = torch.norm(atomic_fields, dim=-1, keepdim=True) field_norm = field_norm + (field_norm < 1e-10).float() atomic_fields = atomic_fields / field_norm atomic_polar = atomic_dipoles[..., None] * atomic_fields[:, :, None, :] # Symmetrize atomic_polar = symmetric_product(atomic_polar) global_polar = self.atom_agg(atomic_polar, atom_mask[..., None]) result = { # "y_i": atomic_polar, self.property: global_polar } if self.isotropic: result[self.isotropic] = torch.mean(torch.diagonal(global_polar, dim1=-2, dim2=-1), dim=-1, keepdim=True) return result
def forward(self, inputs): """ Args: inputs (dict of torch.Tensor): SchNetPack format dictionary of input tensors. Returns: torch.Tensor: Nbatch x Natoms x Nsymmetry_functions Tensor containing ACSFs or wACSFs. """ positions = inputs[Properties.R] Z = inputs[Properties.Z] neighbors = inputs[Properties.neighbors] neighbor_mask = inputs[Properties.neighbor_mask] cell = inputs[Properties.cell] cell_offset = inputs[Properties.cell_offset] # Compute radial functions if self.RDF is not None: # Get atom type embeddings Z_rad = self.radial_Z(Z) # Get atom types of neighbors Z_ij = snn.neighbor_elements(Z_rad, neighbors) # Compute distances distances = snn.atom_distances( positions, neighbors, neighbor_mask=neighbor_mask, cell=cell, cell_offsets=cell_offset, ) radial_sf = self.RDF(distances, elemental_weights=Z_ij, neighbor_mask=neighbor_mask) else: radial_sf = None if self.ADF is not None: # Get pair indices try: idx_j = inputs[Properties.neighbor_pairs_j] idx_k = inputs[Properties.neighbor_pairs_k] except KeyError as e: raise HDNNException("Angular symmetry functions require " + "`collect_triples=True` in AtomsData.") neighbor_pairs_mask = inputs[Properties.neighbor_pairs_mask] # Get element contributions of the pairs Z_angular = self.angular_Z(Z) Z_ij = snn.neighbor_elements(Z_angular, idx_j) Z_ik = snn.neighbor_elements(Z_angular, idx_k) # Offset indices offset_idx_j = inputs[Properties.neighbor_offsets_j] offset_idx_k = inputs[Properties.neighbor_offsets_k] # Compute triple distances r_ij, r_ik, r_jk = snn.triple_distances( positions, idx_j, idx_k, offset_idx_j=offset_idx_j, offset_idx_k=offset_idx_k, cell=cell, cell_offsets=cell_offset, ) angular_sf = self.ADF( r_ij, r_ik, r_jk, elemental_weights=(Z_ij, Z_ik), triple_masks=neighbor_pairs_mask, ) else: angular_sf = None # Concatenate and return symmetry functions if self.RDF is None: symmetry_functions = angular_sf elif self.ADF is None: symmetry_functions = radial_sf else: symmetry_functions = torch.cat((radial_sf, angular_sf), 2) return symmetry_functions
def forward(self, inputs): """ Args: inputs (dict of torch.Tensor): SchNetPack format dictionary of input tensors. Returns: torch.Tensor: Nbatch x Natoms x Nsymmetry_functions Tensor containing ACSFs or wACSFs. """ positions = inputs[Properties.R] Z = inputs[Properties.Z] neighbors = inputs[Properties.neighbors] neighbor_mask = inputs[Properties.neighbor_mask] cell = inputs[Properties.cell] cell_offset = inputs[Properties.cell_offset] # Compute radial functions if self.RDF is not None: # Get atom type embeddings Z_rad = self.radial_Z(Z) # Get atom types of neighbors Z_ij = snn.neighbor_elements(Z_rad, neighbors) # Compute distances distances = snn.atom_distances( positions, neighbors, neighbor_mask=neighbor_mask, cell=cell, cell_offsets=cell_offset, ) radial_sf = self.RDF(distances, elemental_weights=Z_ij, neighbor_mask=neighbor_mask) else: radial_sf = None if self.APF is not None: try: idx_j = inputs[Properties.neighbor_pairs_j] idx_k = inputs[Properties.neighbor_pairs_k] except KeyError as e: raise HDNNException("Angular symmetry functions require " + "`collect_triples=True` in AtomsData.") neighbor_pairs_mask = inputs[Properties.neighbor_pairs_mask] pair_dist = distances inv_dist = 1 / pair_dist[:, :, :] pair_dist = torch.reshape( pair_dist, (pair_dist.shape[0], pair_dist.shape[1] * pair_dist.shape[2])) inv_dist = torch.reshape( inv_dist, (inv_dist.shape[0], inv_dist.shape[1] * inv_dist.shape[2])) dists = torch.stack((pair_dist, inv_dist), -1) # Get element contributions of the pairs Z_ap = self.ap_Z(Z) Z_ik = snn.neighbor_elements(Z_ap, idx_k) # Offset indices offset_idx_j = inputs[Properties.neighbor_offsets_j] offset_idx_k = inputs[Properties.neighbor_offsets_k] # Compute triple distances r_ij, r_ik, r_jk = snn.triple_distances( positions, idx_j, idx_k, offset_idx_j=offset_idx_j, offset_idx_k=offset_idx_k, cell=cell, cell_offsets=cell_offset, ) ap_sf = self.APF( r_ij, r_ik, r_jk, elemental_weights=Z_ik, triple_masks=neighbor_pairs_mask, ) if self.fd_react is not None: atom_t_ind = self.atom_t * (positions.shape[1] - 1) sel_atom_react = atom_t_ind + self.atom_react react_dists = pair_dist[:, sel_atom_react].reshape( pair_dist.shape[0], sel_atom_react.shape[0]) sel_atom_prod = atom_t_ind + self.atom_prod prod_dists = pair_dist[:, sel_atom_prod].reshape( pair_dist.shape[0], sel_atom_prod.shape[0]) react_dists = torch.min(react_dists, dim=1, keepdim=True)[0] prod_dists = torch.min(prod_dists, dim=1, keepdim=True)[0] #trans_dist = torch.cat((react_dists, prod_dists), dim=-1) #cut_dist = torch.max(trans_dist, dim=1, keepdim=True)[0] #cutoffs = self.cutoff2(cut_dist) fd_react = self.fd_react(react_dists) fd_prod = self.fd_prod(prod_dists) fd = torch.mul(fd_react, fd_prod) #fd_react = torch.repeat_interleave(fd_react, repeats=pair_dist.shape[1], dim=1) #fd_react = fd_react.unsqueeze(-1) #fd_prod = torch.repeat_interleave(fd_prod, repeats=pair_dist.shape[1], dim=1) #fd_prod = fd_prod.unsqueeze(-1) #dists = torch.mul(dists, fd_react) #dists = torch.mul(dists, fd_prod) else: ap_sf = None return radial_sf, ap_sf, dists, fd
def forward(self, inputs): """ Args: inputs (dict of torch.Tensor): SchNetPack format dictionary of input tensors. Returns: torch.Tensor: Nbatch x Natoms x Nsymmetry_functions Tensor containing ACSFs or wACSFs. """ positions = inputs[Properties.R] Z = inputs[Properties.Z] neighbors = inputs[Properties.neighbors] neighbor_mask = inputs[Properties.neighbor_mask] cell = inputs[Properties.cell] cell_offset = inputs[Properties.cell_offset] # Compute radial functions if self.RDF is not None: ZA = inputs['ZA'] ZB = inputs['ZB'] # Get atom type embeddings Z_rad = self.radial_Z(Z) # Get atom types of neighbors Z_ij = snn.neighbor_elements(Z_rad, neighbors) # Compute distances distances = snn.atom_distances( positions, neighbors, neighbor_mask=neighbor_mask, cell=cell, cell_offsets=cell_offset, ) radial_sf = self.RDF(distances, elemental_weights=Z_ij, neighbor_mask=neighbor_mask) mon_A = torch.arange(0, ZA.shape[1], 1, device=cell_offset.device) mon_B = torch.arange(ZA.shape[1], ZA.shape[1] + ZB.shape[1], 1, device=cell_offset.device) radial_sf_A = radial_sf[:, mon_A, :] radial_sf_B = radial_sf[:, mon_B, :] else: radial_sf = None if self.APF is not None: try: idx_j = inputs[Properties.neighbor_pairs_j] idx_k = inputs[Properties.neighbor_pairs_k] except KeyError as e: raise HDNNException("Angular symmetry functions require " + "`collect_triples=True` in AtomsData.") ZA = inputs['ZA'] ZB = inputs['ZB'] neighbor_pairs_mask = inputs[Properties.neighbor_pairs_mask] neighbor_inter = inputs[Properties.neighbor_inter] offset_inter = inputs[Properties.neighbor_offset_inter] offset_intra = inputs[Properties.cell_offset_intra] neighbor_inter_mask = inputs[Properties.neighbor_inter_mask] distances_inter = snn.atom_distances( positions, neighbor_inter, neighbor_mask=neighbor_inter_mask, cell=cell, cell_offsets=offset_inter, ) pair_dist = torch.ones_like(distances_inter, device=cell_offset.device) pair_dist[neighbor_inter_mask != 0] = distances_inter[ neighbor_inter_mask != 0] inv_dist = 1 / pair_dist[:, :, :] mon_A = torch.arange(0, ZA.shape[1], 1, device=cell_offset.device) mon_B = torch.arange(0, ZB.shape[1], 1, device=cell_offset.device) pair_dist = pair_dist[:, mon_A, :][:, :, mon_B] inv_dist = inv_dist[:, mon_A, :][:, :, mon_B] pair_dist = torch.reshape( pair_dist, (pair_dist.shape[0], pair_dist.shape[1] * pair_dist.shape[2])) inv_dist = torch.reshape( inv_dist, (inv_dist.shape[0], inv_dist.shape[1] * inv_dist.shape[2])) dists = torch.stack((pair_dist, inv_dist), -1) # Get element contributions of the pairs Z_ap = self.ap_Z(Z) Z_ik = snn.neighbor_elements(Z_ap, idx_k) # Offset indices offset_idx_j = inputs[Properties.neighbor_offsets_j] offset_idx_k = inputs[Properties.neighbor_offsets_k] # Compute triple distances r_ij, r_ik, r_jk = snn.triple_distances_apnet( positions, idx_j, idx_k, offset_idx_j=offset_idx_j, offset_idx_k=offset_idx_k, cell=cell, cell_offsets_intra=offset_intra, cell_offsets_inter=offset_inter, ) ap_sf = self.APF( r_ij, r_ik, r_jk, elemental_weights=Z_ik, triple_masks=neighbor_pairs_mask, ) mon_A = torch.arange(0, ZA.shape[1], 1, device=cell_offset.device) mon_B = torch.arange(ZA.shape[1], ZA.shape[1] + ZB.shape[1], 1, device=cell_offset.device) smon_B = torch.arange(0, ZB.shape[1], 1, device=cell_offset.device) ap_sf_A = ap_sf[:, mon_A, :, :][:, :, smon_B, :] ap_sf_B = ap_sf[:, mon_B, :] if self.cutoff2 is not None: cutoffs = self.cutoff2(pair_dist) cutoffs = cutoffs.unsqueeze(-1) #test_zero_A = test_zero_A.unsqueeze(-1) #test_zero_B = test_zero_B.unsqueeze(-1) dists = torch.mul(dists, cutoffs) #dists = torch.mul(dists, test_zero_A) #dists = torch.mul(dists, test_zero_B) if hasattr(self, 'fermi_dirac'): morse_bond = distances[:, self.morse1, self.morse2].reshape( distances.shape[0], self.morse2.shape[0]) morse_bond = torch.min(morse_bond, dim=1, keepdim=True)[0] fd = self.fermi_dirac(morse_bond) fd = torch.repeat_interleave(fd, repeats=pair_dist.shape[1], dim=1) fd = fd.unsqueeze(-1) dists = torch.mul(dists, fd) else: ap_sf = None return radial_sf_A, radial_sf_B, ap_sf_A, ap_sf_B, dists