Ejemplo n.º 1
0
 def __call__(self, item):
     # transform protein and/or pocket to PTG graphs
     if self.pocket_only:
         item = prot_graph_transform(item,
                                     atom_keys=['atoms_pocket'],
                                     label_key='scores')
     else:
         item = prot_graph_transform(
             item,
             atom_keys=['atoms_protein', 'atoms_pocket'],
             label_key='scores')
     # transform ligand into PTG graph
     item = mol_graph_transform(item,
                                'atoms_ligand',
                                'scores',
                                use_bonds=True,
                                onehot_edges=False)
     node_feats, edges, edge_feats, node_pos = gr.combine_graphs(
         item['atoms_pocket'], item['atoms_ligand'])
     combined_graph = Data(node_feats,
                           edges,
                           edge_feats,
                           y=item['scores']['neglog_aff'],
                           pos=node_pos)
     return combined_graph
Ejemplo n.º 2
0
 def __call__(self, item):
     item = prot_graph_transform(item, ['atoms'], 'scores')
     graph = item['atoms']
     graph.y = torch.FloatTensor([graph.y['gdt_ts']])
     graph.target = item['id'][0]
     graph.decoy = item['id'][1]
     return graph
Ejemplo n.º 3
0
 def __call__(self, item):
     item = prot_graph_transform(item, ['atoms'], 'scores')
     graph = item['atoms']
     graph.y = torch.FloatTensor([graph.y['rms']])
     split = item['id'].split("'")
     graph.target = split[1]
     graph.decoy = split[3]
     return graph
Ejemplo n.º 4
0
    def __call__(self, item):
        # transform each atoms df to PTG graphs
        mutation = item['id'].split('_')[-1]
        orig_df = item['original_atoms'].reset_index(drop=True)
        mut_df = item['mutated_atoms'].reset_index(drop=True)
        orig_idx = self._extract_mut_idx(orig_df, mutation)
        mut_idx = self._extract_mut_idx(mut_df, mutation)

        item = prot_graph_transform(
            item,
            atom_keys=['original_atoms', 'mutated_atoms'],
            label_key='label')
        orig_graph = self._augment_graph(item['original_atoms'], orig_idx)
        mut_graph = self._augment_graph(item['mutated_atoms'], mut_idx)
        return orig_graph, mut_graph
Ejemplo n.º 5
0
 def __call__(self, item):
     # transform protein and/or pocket to PTG graphs
     item = prot_graph_transform(item, atom_keys=self.atom_keys, label_key=self.label_key)
     
     return item