Esempio n. 1
0
    def forward(self, inputs):
        """
        Args:
            inputs (dict of torch.Tensor): SchNetPack format dictionary of input tensors.

        Returns:
            torch.Tensor: Nbatch x Natoms x Nsymmetry_functions Tensor containing ACSFs or wACSFs.

        """
        positions = inputs[Structure.R]
        Z = inputs[Structure.Z]
        neighbors = inputs[Structure.neighbors]
        neighbor_mask = inputs[Structure.neighbor_mask]

        # Compute radial functions
        if self.RDF is not None:
            # Get atom type embeddings
            Z_rad = self.radial_Z(Z)
            # Get atom types of neighbors
            Z_ij = snn.neighbor_elements(Z_rad, neighbors)
            # Compute distances
            distances = snn.atom_distances(positions,
                                           neighbors,
                                           neighbor_mask=neighbor_mask)
            radial_sf = self.RDF(distances,
                                 elemental_weights=Z_ij,
                                 neighbor_mask=neighbor_mask)
        else:
            radial_sf = None

        if self.ADF is not None:
            # Get pair indices
            idx_j = inputs[Structure.neighbor_pairs_j]
            idx_k = inputs[Structure.neighbor_pairs_k]
            neighbor_pairs_mask = inputs[Structure.neighbor_pairs_mask]

            # Get element contributions of the pairs
            Z_angular = self.angular_Z(Z)
            Z_ij = snn.neighbor_elements(Z_angular, idx_j)
            Z_ik = snn.neighbor_elements(Z_angular, idx_k)
            # Compute triple distances
            r_ij, r_ik, r_jk = snn.triple_distances(positions, idx_j, idx_k)

            angular_sf = self.ADF(r_ij,
                                  r_ik,
                                  r_jk,
                                  elemental_weights=(Z_ij, Z_ik),
                                  triple_masks=neighbor_pairs_mask)
        else:
            angular_sf = None

        # Concatenate and return symmetry functions
        if self.RDF is None:
            symmetry_functions = angular_sf
        elif self.ADF is None:
            symmetry_functions = radial_sf
        else:
            symmetry_functions = torch.cat((radial_sf, angular_sf), 2)

        return symmetry_functions
Esempio n. 2
0
    def forward(self, inputs):
        positions = inputs[Properties.R]
        neighbors = inputs[Properties.neighbors]
        nbh_mask = inputs[Properties.neighbor_mask]
        atom_mask = inputs[Properties.atom_mask]

        # Get environment distances and positions
        distances, dist_vecs = L.atom_distances(positions,
                                                neighbors,
                                                return_vecs=True)

        # Get atomic contributions
        contributions = self.out_net(inputs)

        # Redistribute contributions to neighbors
        # B x A x N x 1
        # neighbor_contributions = L.neighbor_elements(c1, neighbors)
        neighbor_contributions = L.neighbor_elements(contributions, neighbors)

        if self.cutoff_network is not None:
            f_cut = self.cutoff_network(distances)[..., None]
            neighbor_contributions = neighbor_contributions * f_cut

        neighbor_contributions1 = neighbor_contributions[:, :, :, 0]
        neighbor_contributions2 = neighbor_contributions[:, :, :, 1]

        # B x A x N x 3
        atomic_dipoles = self.nbh_agg(
            dist_vecs * neighbor_contributions1[..., None], nbh_mask)
        # B x A x N x 3

        masked_dist = (distances**3 * nbh_mask) + (1 - nbh_mask)
        nbh_fields = (dist_vecs * neighbor_contributions2[..., None] /
                      masked_dist[..., None])
        atomic_fields = self.nbh_agg(nbh_fields, nbh_mask)
        field_norm = torch.norm(atomic_fields, dim=-1, keepdim=True)
        field_norm = field_norm + (field_norm < 1e-10).float()
        atomic_fields = atomic_fields / field_norm

        atomic_polar = atomic_dipoles[..., None] * atomic_fields[:, :, None, :]

        # Symmetrize
        atomic_polar = symmetric_product(atomic_polar)

        global_polar = self.atom_agg(atomic_polar, atom_mask[..., None])

        result = {
            # "y_i": atomic_polar,
            self.property: global_polar
        }

        if self.isotropic:
            result[self.isotropic] = torch.mean(torch.diagonal(global_polar,
                                                               dim1=-2,
                                                               dim2=-1),
                                                dim=-1,
                                                keepdim=True)
        return result
Esempio n. 3
0
    def forward(self, inputs):
        """
        Args:
            inputs (dict of torch.Tensor): SchNetPack format dictionary of input tensors.

        Returns:
            torch.Tensor: Nbatch x Natoms x Nsymmetry_functions Tensor containing ACSFs or wACSFs.

        """
        positions = inputs[Properties.R]
        Z = inputs[Properties.Z]
        neighbors = inputs[Properties.neighbors]
        neighbor_mask = inputs[Properties.neighbor_mask]

        cell = inputs[Properties.cell]
        cell_offset = inputs[Properties.cell_offset]

        # Compute radial functions
        if self.RDF is not None:
            # Get atom type embeddings
            Z_rad = self.radial_Z(Z)
            # Get atom types of neighbors
            Z_ij = snn.neighbor_elements(Z_rad, neighbors)
            # Compute distances
            distances = snn.atom_distances(
                positions,
                neighbors,
                neighbor_mask=neighbor_mask,
                cell=cell,
                cell_offsets=cell_offset,
            )
            radial_sf = self.RDF(distances,
                                 elemental_weights=Z_ij,
                                 neighbor_mask=neighbor_mask)
        else:
            radial_sf = None

        if self.ADF is not None:
            # Get pair indices
            try:
                idx_j = inputs[Properties.neighbor_pairs_j]
                idx_k = inputs[Properties.neighbor_pairs_k]

            except KeyError as e:
                raise HDNNException("Angular symmetry functions require " +
                                    "`collect_triples=True` in AtomsData.")
            neighbor_pairs_mask = inputs[Properties.neighbor_pairs_mask]

            # Get element contributions of the pairs
            Z_angular = self.angular_Z(Z)
            Z_ij = snn.neighbor_elements(Z_angular, idx_j)
            Z_ik = snn.neighbor_elements(Z_angular, idx_k)

            # Offset indices
            offset_idx_j = inputs[Properties.neighbor_offsets_j]
            offset_idx_k = inputs[Properties.neighbor_offsets_k]

            # Compute triple distances
            r_ij, r_ik, r_jk = snn.triple_distances(
                positions,
                idx_j,
                idx_k,
                offset_idx_j=offset_idx_j,
                offset_idx_k=offset_idx_k,
                cell=cell,
                cell_offsets=cell_offset,
            )

            angular_sf = self.ADF(
                r_ij,
                r_ik,
                r_jk,
                elemental_weights=(Z_ij, Z_ik),
                triple_masks=neighbor_pairs_mask,
            )
        else:
            angular_sf = None

        # Concatenate and return symmetry functions
        if self.RDF is None:
            symmetry_functions = angular_sf
        elif self.ADF is None:
            symmetry_functions = radial_sf
        else:
            symmetry_functions = torch.cat((radial_sf, angular_sf), 2)

        return symmetry_functions
Esempio n. 4
0
    def forward(self, inputs):
        """
        Args:
            inputs (dict of torch.Tensor): SchNetPack format dictionary of input tensors.

        Returns:
            torch.Tensor: Nbatch x Natoms x Nsymmetry_functions Tensor containing ACSFs or wACSFs.

        """
        positions = inputs[Properties.R]
        Z = inputs[Properties.Z]
        neighbors = inputs[Properties.neighbors]
        neighbor_mask = inputs[Properties.neighbor_mask]

        cell = inputs[Properties.cell]
        cell_offset = inputs[Properties.cell_offset]
        # Compute radial functions
        if self.RDF is not None:
            # Get atom type embeddings
            Z_rad = self.radial_Z(Z)
            # Get atom types of neighbors
            Z_ij = snn.neighbor_elements(Z_rad, neighbors)
            # Compute distances
            distances = snn.atom_distances(
                positions,
                neighbors,
                neighbor_mask=neighbor_mask,
                cell=cell,
                cell_offsets=cell_offset,
            )
            radial_sf = self.RDF(distances,
                                 elemental_weights=Z_ij,
                                 neighbor_mask=neighbor_mask)

        else:
            radial_sf = None

        if self.APF is not None:
            try:
                idx_j = inputs[Properties.neighbor_pairs_j]
                idx_k = inputs[Properties.neighbor_pairs_k]

            except KeyError as e:
                raise HDNNException("Angular symmetry functions require " +
                                    "`collect_triples=True` in AtomsData.")

            neighbor_pairs_mask = inputs[Properties.neighbor_pairs_mask]

            pair_dist = distances
            inv_dist = 1 / pair_dist[:, :, :]
            pair_dist = torch.reshape(
                pair_dist,
                (pair_dist.shape[0], pair_dist.shape[1] * pair_dist.shape[2]))
            inv_dist = torch.reshape(
                inv_dist,
                (inv_dist.shape[0], inv_dist.shape[1] * inv_dist.shape[2]))
            dists = torch.stack((pair_dist, inv_dist), -1)

            # Get element contributions of the pairs
            Z_ap = self.ap_Z(Z)
            Z_ik = snn.neighbor_elements(Z_ap, idx_k)

            # Offset indices
            offset_idx_j = inputs[Properties.neighbor_offsets_j]
            offset_idx_k = inputs[Properties.neighbor_offsets_k]

            # Compute triple distances
            r_ij, r_ik, r_jk = snn.triple_distances(
                positions,
                idx_j,
                idx_k,
                offset_idx_j=offset_idx_j,
                offset_idx_k=offset_idx_k,
                cell=cell,
                cell_offsets=cell_offset,
            )

            ap_sf = self.APF(
                r_ij,
                r_ik,
                r_jk,
                elemental_weights=Z_ik,
                triple_masks=neighbor_pairs_mask,
            )

            if self.fd_react is not None:
                atom_t_ind = self.atom_t * (positions.shape[1] - 1)
                sel_atom_react = atom_t_ind + self.atom_react
                react_dists = pair_dist[:, sel_atom_react].reshape(
                    pair_dist.shape[0], sel_atom_react.shape[0])
                sel_atom_prod = atom_t_ind + self.atom_prod
                prod_dists = pair_dist[:, sel_atom_prod].reshape(
                    pair_dist.shape[0], sel_atom_prod.shape[0])

                react_dists = torch.min(react_dists, dim=1, keepdim=True)[0]
                prod_dists = torch.min(prod_dists, dim=1, keepdim=True)[0]
                #trans_dist = torch.cat((react_dists, prod_dists), dim=-1)

                #cut_dist = torch.max(trans_dist, dim=1, keepdim=True)[0]
                #cutoffs = self.cutoff2(cut_dist)
                fd_react = self.fd_react(react_dists)
                fd_prod = self.fd_prod(prod_dists)
                fd = torch.mul(fd_react, fd_prod)
                #fd_react = torch.repeat_interleave(fd_react, repeats=pair_dist.shape[1], dim=1)
                #fd_react = fd_react.unsqueeze(-1)
                #fd_prod = torch.repeat_interleave(fd_prod, repeats=pair_dist.shape[1], dim=1)
                #fd_prod = fd_prod.unsqueeze(-1)
                #dists = torch.mul(dists, fd_react)
                #dists = torch.mul(dists, fd_prod)

        else:
            ap_sf = None

        return radial_sf, ap_sf, dists, fd
Esempio n. 5
0
    def forward(self, inputs):
        """
        Args:
            inputs (dict of torch.Tensor): SchNetPack format dictionary of input tensors.

        Returns:
            torch.Tensor: Nbatch x Natoms x Nsymmetry_functions Tensor containing ACSFs or wACSFs.

        """

        positions = inputs[Properties.R]
        Z = inputs[Properties.Z]
        neighbors = inputs[Properties.neighbors]
        neighbor_mask = inputs[Properties.neighbor_mask]

        cell = inputs[Properties.cell]
        cell_offset = inputs[Properties.cell_offset]
        # Compute radial functions
        if self.RDF is not None:
            ZA = inputs['ZA']
            ZB = inputs['ZB']
            # Get atom type embeddings
            Z_rad = self.radial_Z(Z)
            # Get atom types of neighbors
            Z_ij = snn.neighbor_elements(Z_rad, neighbors)
            # Compute distances

            distances = snn.atom_distances(
                positions,
                neighbors,
                neighbor_mask=neighbor_mask,
                cell=cell,
                cell_offsets=cell_offset,
            )

            radial_sf = self.RDF(distances,
                                 elemental_weights=Z_ij,
                                 neighbor_mask=neighbor_mask)

            mon_A = torch.arange(0, ZA.shape[1], 1, device=cell_offset.device)
            mon_B = torch.arange(ZA.shape[1],
                                 ZA.shape[1] + ZB.shape[1],
                                 1,
                                 device=cell_offset.device)

            radial_sf_A = radial_sf[:, mon_A, :]
            radial_sf_B = radial_sf[:, mon_B, :]

        else:
            radial_sf = None

        if self.APF is not None:
            try:
                idx_j = inputs[Properties.neighbor_pairs_j]
                idx_k = inputs[Properties.neighbor_pairs_k]

            except KeyError as e:
                raise HDNNException("Angular symmetry functions require " +
                                    "`collect_triples=True` in AtomsData.")

            ZA = inputs['ZA']
            ZB = inputs['ZB']

            neighbor_pairs_mask = inputs[Properties.neighbor_pairs_mask]

            neighbor_inter = inputs[Properties.neighbor_inter]
            offset_inter = inputs[Properties.neighbor_offset_inter]
            offset_intra = inputs[Properties.cell_offset_intra]
            neighbor_inter_mask = inputs[Properties.neighbor_inter_mask]

            distances_inter = snn.atom_distances(
                positions,
                neighbor_inter,
                neighbor_mask=neighbor_inter_mask,
                cell=cell,
                cell_offsets=offset_inter,
            )

            pair_dist = torch.ones_like(distances_inter,
                                        device=cell_offset.device)
            pair_dist[neighbor_inter_mask != 0] = distances_inter[
                neighbor_inter_mask != 0]
            inv_dist = 1 / pair_dist[:, :, :]
            mon_A = torch.arange(0, ZA.shape[1], 1, device=cell_offset.device)
            mon_B = torch.arange(0, ZB.shape[1], 1, device=cell_offset.device)

            pair_dist = pair_dist[:, mon_A, :][:, :, mon_B]
            inv_dist = inv_dist[:, mon_A, :][:, :, mon_B]
            pair_dist = torch.reshape(
                pair_dist,
                (pair_dist.shape[0], pair_dist.shape[1] * pair_dist.shape[2]))
            inv_dist = torch.reshape(
                inv_dist,
                (inv_dist.shape[0], inv_dist.shape[1] * inv_dist.shape[2]))
            dists = torch.stack((pair_dist, inv_dist), -1)

            # Get element contributions of the pairs
            Z_ap = self.ap_Z(Z)
            Z_ik = snn.neighbor_elements(Z_ap, idx_k)

            # Offset indices
            offset_idx_j = inputs[Properties.neighbor_offsets_j]
            offset_idx_k = inputs[Properties.neighbor_offsets_k]

            # Compute triple distances
            r_ij, r_ik, r_jk = snn.triple_distances_apnet(
                positions,
                idx_j,
                idx_k,
                offset_idx_j=offset_idx_j,
                offset_idx_k=offset_idx_k,
                cell=cell,
                cell_offsets_intra=offset_intra,
                cell_offsets_inter=offset_inter,
            )

            ap_sf = self.APF(
                r_ij,
                r_ik,
                r_jk,
                elemental_weights=Z_ik,
                triple_masks=neighbor_pairs_mask,
            )

            mon_A = torch.arange(0, ZA.shape[1], 1, device=cell_offset.device)
            mon_B = torch.arange(ZA.shape[1],
                                 ZA.shape[1] + ZB.shape[1],
                                 1,
                                 device=cell_offset.device)
            smon_B = torch.arange(0, ZB.shape[1], 1, device=cell_offset.device)
            ap_sf_A = ap_sf[:, mon_A, :, :][:, :, smon_B, :]
            ap_sf_B = ap_sf[:, mon_B, :]

            if self.cutoff2 is not None:
                cutoffs = self.cutoff2(pair_dist)

                cutoffs = cutoffs.unsqueeze(-1)
                #test_zero_A = test_zero_A.unsqueeze(-1)
                #test_zero_B = test_zero_B.unsqueeze(-1)

                dists = torch.mul(dists, cutoffs)
                #dists = torch.mul(dists, test_zero_A)
                #dists = torch.mul(dists, test_zero_B)
                if hasattr(self, 'fermi_dirac'):
                    morse_bond = distances[:, self.morse1,
                                           self.morse2].reshape(
                                               distances.shape[0],
                                               self.morse2.shape[0])
                    morse_bond = torch.min(morse_bond, dim=1, keepdim=True)[0]
                    fd = self.fermi_dirac(morse_bond)
                    fd = torch.repeat_interleave(fd,
                                                 repeats=pair_dist.shape[1],
                                                 dim=1)
                    fd = fd.unsqueeze(-1)
                    dists = torch.mul(dists, fd)

        else:
            ap_sf = None
        return radial_sf_A, radial_sf_B, ap_sf_A, ap_sf_B, dists