def scoring_function(smiles_list): scores = [] for smiles in smiles_list: adjacency, nodes, edges = smile_to_graph(smiles) adjacency, nodes, edges = molgraph_collate_fn( ((adjacency, nodes, edges), )) output = model.forward(nodes, edges, adjacency) scores.append(float(m(output))) return scores
def property_prediction(smiles,Property): valid = {"t_half","logD","hml_clearance"} if Property not in valid: raise ValueError("property must be one of %r." % valid) adjacency, nodes, edges = smile_to_graph(smiles) adjacency, nodes, edges=molgraph_collate_fn(((adjacency, nodes, edges),)) if Property=="t_half": model = EMNImplementation(node_features=40, edge_features=4,edge_embedding_size=50, message_passes=6, out_features=1, edge_emb_depth=3, edge_emb_hidden_dim=120, att_depth=3, att_hidden_dim=80, msg_depth=3, msg_hidden_dim=80, gather_width=100, gather_att_depth=3, gather_att_hidden_dim=80, gather_emb_depth=3, gather_emb_hidden_dim=80, out_depth=2, out_hidden_dim=60) checkpoint = torch.load(r"checkpoints/t_half.ckpt") model.load_state_dict(checkpoint['state_dict']) model.eval() output = model.forward(nodes, edges,adjacency) return output if Property=="logD": model = EMNImplementation(node_features=40, edge_features=4,edge_embedding_size=50, message_passes=6, out_features=1, edge_emb_depth=3, edge_emb_hidden_dim=120, att_depth=3, att_hidden_dim=80, msg_depth=3, msg_hidden_dim=80, gather_width=100, gather_att_depth=3, gather_att_hidden_dim=80, gather_emb_depth=3, gather_emb_hidden_dim=80, out_depth=2, out_hidden_dim=60) checkpoint = torch.load(r"checkpoints/logD.ckpt") model.load_state_dict(checkpoint['state_dict']) model.eval() output = model.forward(nodes, edges,adjacency) return output if Property=="hml_clearance": model = EMNImplementation(node_features=40, edge_features=4,edge_embedding_size=50, message_passes=6, out_features=1, edge_emb_depth=3, edge_emb_hidden_dim=120, att_depth=3, att_hidden_dim=60, msg_depth=3, msg_hidden_dim=60, gather_width=80, gather_att_depth=3, gather_att_hidden_dim=80, gather_emb_depth=3, gather_emb_hidden_dim=80, out_depth=2, out_hidden_dim=60) checkpoint = torch.load(r"checkpoints/hml_clearance.ckpt") model.load_state_dict(checkpoint['state_dict']) model.eval() output = model.forward(nodes, edges,adjacency) return torch.exp(output)
def visualizations(model, smile, feature_list, color_map=plt.cm.bwr): model.eval() adjacency, nodes, edges = smile_to_graph(smile) mols = Chem.MolFromSmiles(smile) ig = IntegratedGradients(model) adjacency, nodes, edges = molgraph_collate_fn( ((adjacency, nodes, edges), )) attr = ig.attribute(nodes, additional_forward_args=(edges, adjacency), target=0) attr1 = torch.squeeze(attr, dim=0) attr2 = attr1.sum(dim=1) vmax = max(attr2.abs().max(), 1e-16) vmin = -vmax node_colors = get_colors(attr1, color_map) node_colors = node_colors[:, :3] fig, ax = plt.subplots(figsize=(6, 1)) fig.subplots_adjust(bottom=0.5) norm = plt.Normalize(vmin, vmax) fig.colorbar(cm.ScalarMappable(norm=norm, cmap=color_map), cax=ax, orientation='horizontal', label='color_bar') b = BytesIO() b.write( moltopng(mols, node_colors=node_colors, edge_colors={}, molSize=(600, 600))) b.seek(0) display(Image.open(b)) b.close() symbols = { i: f'{mols.GetAtomWithIdx(i).GetSymbol()}{i}' for i in range(mols.GetNumAtoms()) } x_pos = (np.arange(len(feature_list))) y_pos = (np.arange(len(list(symbols.values())))) plt.matshow(attr1, cmap=color_map) plt.xticks(x_pos, feature_list, rotation='vertical') plt.yticks(y_pos, list(symbols.values())) plt.show() visualize_importances(list(symbols.values()), attr2)
def prepare_data(self): data = pd.read_csv('t_half.csv') data_list = [] for index in range(len(data)): adjacency, nodes, edges = smile_to_graph(data['Smiles'][index]) #targets = data.iloc[index,1:] targets = np.expand_dims(data['Standard Value'][index], axis=0) data_list.append(((adjacency, nodes, edges), targets)) l = len(data) shuff = list(range(l)) random.shuffle(shuff) data_list1 = [data_list[i] for i in shuff] self.ggnn_train = tuple(data_list1[:int(l * 0.8)]) self.ggnn_val = tuple(data_list1[int(l * 0.8):int(l * 0.9)]) self.ggnn_test = tuple(data_list1[int(l * 0.9):])