Esempio n. 1
0
def df_to_graph(struct_df, label):
    """
    struct_df: Dataframe
    """
    lig_df = struct_df[struct_df.chain == 'L']
    lig_graph = gr.prot_df_to_graph(lig_df)
    prot_df = struct_df[struct_df.chain != 'L']
    prot_graph = gr.prot_df_to_graph(prot_df)
    node_feats, edge_index, edge_feats, pos = gr.combine_graphs(
        prot_graph, lig_graph, edges_between=True)
    data = Data(node_feats, edge_index, edge_feats, y=label, pos=pos)

    return data
Esempio n. 2
0
def df_to_graph(struct_df, chain_res, label):
    """
    struct_df: Dataframe
    """

    chain, resnum = chain_res
    res_df = struct_df[(struct_df.chain == chain)
                       & (struct_df.residue == resnum)]
    if 'CA' not in res_df.name.tolist():
        return None
    CA_pos = res_df[res_df['name'] == 'CA'][['x', 'y', 'z']].astype(
        np.float32).to_numpy()[0]

    kd_tree = scipy.spatial.KDTree(struct_df[['x', 'y', 'z']].to_numpy())
    graph_pt_idx = kd_tree.query_ball_point(CA_pos, r=30.0, p=2.0)
    graph_df = struct_df.iloc[graph_pt_idx].reset_index(drop=True)
    ca_idx = np.where((graph_df.chain == chain) & (graph_df.residue == resnum)
                      & (graph_df.name == 'CA'))[0]

    node_feats, edge_index, edge_feats, pos = gr.prot_df_to_graph(graph_df)
    data = Data(node_feats, edge_index, edge_feats, y=label, pos=pos)
    data.ca_idx = torch.LongTensor(ca_idx)
    data.n_nodes = data.num_nodes

    return data
Esempio n. 3
0
def df_to_graph(struct_df, label):
    """
    struct_df: Dataframe
    """
    label = torch.FloatTensor(label)
    node_feats, edge_index, edge_feats, pos = gr.prot_df_to_graph(struct_df)
    data = Data(node_feats, edge_index, edge_feats, y=label, pos=pos)

    return data
Esempio n. 4
0
def df_to_graph(struct_df, label):
    """
    label: residue label (int)
    chain_res: (chain ID, residue ID) to index df
    struct_df: Dataframe with entire structure
    """
    label = torch.FloatTensor(label)
    node_feats, edge_index, edge_feats, pos = gr.prot_df_to_graph(struct_df)
    data = Data(node_feats, edge_index, edge_feats, y=label, pos=pos)

    return data
Esempio n. 5
0
def df_to_graph(struct_df, chain_res, label):
    """
    label: residue label (int)
    chain_res: chain ID_residue ID_residue name defining center residue
    struct_df: Dataframe with entire environment
    grid_config: defined config
    """
    chain, resnum, _ = chain_res.split('_')
    ca_idx = np.where((struct_df.chain == chain) & (struct_df.residue == int(resnum)) & (struct_df.name == 'CA'))[0]

    node_feats, edge_index, edge_feats, pos = gr.prot_df_to_graph(struct_df)
    data = Data(node_feats, edge_index, edge_feats, y=label, pos=pos)
    data.ca_idx = torch.LongTensor(ca_idx)
    data.n_nodes = data.num_nodes

    return data
Esempio n. 6
0
 def process(self):
     label_file = os.path.join(self.root, 'pdbbind_refined_set_labels.csv')
     label_df = pd.read_csv(label_file)
     i = 0
     for raw_path in self.raw_paths:
         pdb_code = fi.get_pdb_code(raw_path)
         y = torch.FloatTensor([get_label(pdb_code, label_df)])
         if '_ligand' in raw_path:
             mol_graph = graph.mol_to_graph(
                 dt.read_sdf_to_mol(raw_path, add_hs=True)[0])
         elif '_pocket' in raw_path:
             prot_graph = graph.prot_df_to_graph(
                 dt.bp_to_df(dt.read_any(raw_path, name=pdb_code)))
             node_feats, edge_index, edge_feats, pos = graph.combine_graphs(
                 prot_graph, mol_graph, edges_between=True)
             data = Data(node_feats, edge_index, edge_feats, y=y, pos=pos)
             data.pdb = pdb_code
             torch.save(
                 data,
                 os.path.join(self.processed_dir, 'data_{}.pt'.format(i)))
             i += 1
         else:
             continue