def df_to_graph(struct_df, label): """ struct_df: Dataframe """ lig_df = struct_df[struct_df.chain == 'L'] lig_graph = gr.prot_df_to_graph(lig_df) prot_df = struct_df[struct_df.chain != 'L'] prot_graph = gr.prot_df_to_graph(prot_df) node_feats, edge_index, edge_feats, pos = gr.combine_graphs( prot_graph, lig_graph, edges_between=True) data = Data(node_feats, edge_index, edge_feats, y=label, pos=pos) return data
def df_to_graph(struct_df, chain_res, label): """ struct_df: Dataframe """ chain, resnum = chain_res res_df = struct_df[(struct_df.chain == chain) & (struct_df.residue == resnum)] if 'CA' not in res_df.name.tolist(): return None CA_pos = res_df[res_df['name'] == 'CA'][['x', 'y', 'z']].astype( np.float32).to_numpy()[0] kd_tree = scipy.spatial.KDTree(struct_df[['x', 'y', 'z']].to_numpy()) graph_pt_idx = kd_tree.query_ball_point(CA_pos, r=30.0, p=2.0) graph_df = struct_df.iloc[graph_pt_idx].reset_index(drop=True) ca_idx = np.where((graph_df.chain == chain) & (graph_df.residue == resnum) & (graph_df.name == 'CA'))[0] node_feats, edge_index, edge_feats, pos = gr.prot_df_to_graph(graph_df) data = Data(node_feats, edge_index, edge_feats, y=label, pos=pos) data.ca_idx = torch.LongTensor(ca_idx) data.n_nodes = data.num_nodes return data
def df_to_graph(struct_df, label): """ struct_df: Dataframe """ label = torch.FloatTensor(label) node_feats, edge_index, edge_feats, pos = gr.prot_df_to_graph(struct_df) data = Data(node_feats, edge_index, edge_feats, y=label, pos=pos) return data
def df_to_graph(struct_df, label): """ label: residue label (int) chain_res: (chain ID, residue ID) to index df struct_df: Dataframe with entire structure """ label = torch.FloatTensor(label) node_feats, edge_index, edge_feats, pos = gr.prot_df_to_graph(struct_df) data = Data(node_feats, edge_index, edge_feats, y=label, pos=pos) return data
def df_to_graph(struct_df, chain_res, label): """ label: residue label (int) chain_res: chain ID_residue ID_residue name defining center residue struct_df: Dataframe with entire environment grid_config: defined config """ chain, resnum, _ = chain_res.split('_') ca_idx = np.where((struct_df.chain == chain) & (struct_df.residue == int(resnum)) & (struct_df.name == 'CA'))[0] node_feats, edge_index, edge_feats, pos = gr.prot_df_to_graph(struct_df) data = Data(node_feats, edge_index, edge_feats, y=label, pos=pos) data.ca_idx = torch.LongTensor(ca_idx) data.n_nodes = data.num_nodes return data
def process(self): label_file = os.path.join(self.root, 'pdbbind_refined_set_labels.csv') label_df = pd.read_csv(label_file) i = 0 for raw_path in self.raw_paths: pdb_code = fi.get_pdb_code(raw_path) y = torch.FloatTensor([get_label(pdb_code, label_df)]) if '_ligand' in raw_path: mol_graph = graph.mol_to_graph( dt.read_sdf_to_mol(raw_path, add_hs=True)[0]) elif '_pocket' in raw_path: prot_graph = graph.prot_df_to_graph( dt.bp_to_df(dt.read_any(raw_path, name=pdb_code))) node_feats, edge_index, edge_feats, pos = graph.combine_graphs( prot_graph, mol_graph, edges_between=True) data = Data(node_feats, edge_index, edge_feats, y=y, pos=pos) data.pdb = pdb_code torch.save( data, os.path.join(self.processed_dir, 'data_{}.pt'.format(i))) i += 1 else: continue