Esempio n. 1
0
def _get_ith_data(data_index, E, N, R, D, Q, Z):
    num_atoms = N[data_index].item()
    _tmp_data = Data()
    _tmp_data.E = E[data_index].view(-1)
    _tmp_data.N = N[data_index].view(-1)
    _tmp_data.R = R[data_index, :num_atoms, :].view(-1, 3)
    _tmp_data.D = D[data_index, :].view(-1, 3)
    _tmp_data.Q = Q[data_index].view(-1)
    _tmp_data.Z = Z[data_index, :num_atoms].view(-1)
    return _tmp_data
def physnet_to_datalist(self,
                        N,
                        R,
                        E,
                        D,
                        Q,
                        Z,
                        num_mol,
                        mols,
                        efgs_batch,
                        EFG_R,
                        EFG_Z,
                        num_efg,
                        sol_data=None):
    """
    load data from PhysNet structure to InMemoryDataset structure (more compact)
    :return:
    """
    from rdkit.Chem.inchi import MolToInchi

    data_array = np.empty(num_mol, dtype=Data)
    t0 = time.time()
    Z_0 = Z[0, :]
    n_heavy = len(Z_0) - (Z_0 == 0).sum() - (Z_0 == 1).sum()

    jianing_to_dongdong_map = []

    for i in tqdm(range(num_mol)):
        if self.bond_atom_sep:
            mol = mols[i]
        else:
            mol = None
        # atomic infos
        _tmp_Data = Data()

        num_atoms = N[i]
        _tmp_Data.N = num_atoms.view(-1)
        _tmp_Data.R = R[i, :N[i]].view(-1, 3)
        _tmp_Data.E = E[i].view(-1)
        _tmp_Data.D = D[i].view(-1, 3)
        _tmp_Data.Q = Q[i].view(-1)
        _tmp_Data.Z = Z[i, :N[i]].view(-1)

        if self.cal_efg:
            _tmp_Data.atom_to_EFG_batch = efgs_batch[i, :N[i]].view(-1)
            _tmp_Data.EFG_R = EFG_R[i, :num_efg[i]].view(-1, 3)
            _tmp_Data.EFG_Z = EFG_Z[i, :num_efg[i]].view(-1)
            _tmp_Data.EFG_N = num_efg[i].view(-1)

        if sol_data is not None:
            # find molecule from solvation csv file based on InChI, if found, add it
            this_sol_data = sol_data.loc[sol_data["InChI"] == MolToInchi(mol)]
            if this_sol_data.shape[0] == 1:
                for key in sol_keys:
                    _tmp_Data.__setattr__(
                        key,
                        torch.as_tensor(this_sol_data.iloc[0][key]).view(-1))
                jianing_to_dongdong_map.append(1)
            else:
                jianing_to_dongdong_map.append(0)
                continue

        _tmp_Data = self.pre_transform(
            data=_tmp_Data,
            edge_version=self.edge_version,
            do_sort_edge=self.sort_edge,
            cal_efg=self.cal_efg,
            cutoff=self.cutoff,
            extended_bond=self.extended_bond,
            boundary_factor=self.boundary_factor,
            type_3_body=self.type_3_body,
            use_center=self.use_center,
            mol=mol,
            cal_3body_term=self.cal_3body_term,
            bond_atom_sep=self.bond_atom_sep,
            record_long_range=self.record_long_range)

        data_array[i] = _tmp_Data

    if sol_data is not None:
        torch.save(torch.as_tensor(jianing_to_dongdong_map),
                   "jianing_to_dongdong_map_{}.pt".format(n_heavy))

    data_list = [
        data_array[i] for i in range(num_mol) if data_array[i] is not None
    ]

    return data_list