def decode(self, tree_vec, mol_vec, prob_decode): pred_root, pred_nodes = self.decoder.decode(tree_vec, prob_decode) #Mark nid & is_leaf & atommap for i, node in enumerate(pred_nodes): node.nid = i + 1 node.is_leaf = (len(node.neighbors) == 1) if len(node.neighbors) > 1: set_atommap(node.mol, node.nid) tree_mess = self.jtnn([pred_root])[0] cur_mol = copy_edit_mol(pred_root.mol) global_amap = [{}] + [{} for node in pred_nodes] global_amap[1] = { atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms() } cur_mol = self.dfs_assemble(tree_mess, mol_vec, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode) if cur_mol is None: return None cur_mol = cur_mol.GetMol() set_atommap(cur_mol) cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol)) if cur_mol is None: return None smiles2D = Chem.MolToSmiles(cur_mol) stereo_cands = decode_stereo(smiles2D) if len(stereo_cands) == 1: return stereo_cands[0] stereo_vecs = self.mpn(mol2graph(stereo_cands)) stereo_vecs = self.G_mean(stereo_vecs) scores = nn.CosineSimilarity()(stereo_vecs, mol_vec) _, max_id = scores.max(dim=0) return stereo_cands[max_id.data[0]]
def stereo(self, mol_batch, mol_vec): stereo_cands, batch_idx = [], [] labels = [] for i, mol_tree in enumerate(mol_batch): cands = mol_tree.stereo_cands if len(cands) == 1: continue if mol_tree.smiles3D not in cands: cands.append(mol_tree.smiles3D) stereo_cands.extend(cands) batch_idx.extend([i] * len(cands)) labels.append((cands.index(mol_tree.smiles3D), len(cands))) if len(labels) == 0: return torch.Tensor([0.0, 1.0]).to(self.decive) batch_idx = torch.tensor(batch_idx, dtype=torch.long).to(self.decive) stereo_cands = self.mpn(mol2graph(stereo_cands)) stereo_cands = self.G_mean(stereo_cands) stereo_labels = mol_vec.index_select(0, batch_idx) scores = torch.nn.CosineSimilarity()(stereo_cands, stereo_labels) st, acc = 0, 0 all_loss = [] for label, le in labels: cur_scores = scores.narrow(0, st, le) if cur_scores.data[label] >= cur_scores.max().item(): acc += 1 label = torch.tensor([label], dtype=torch.long).to(self.decive) all_loss.append(self.stereo_loss(cur_scores.view(1, -1), label)) st += le a = 0.0 for i in all_loss: a += i.item() al_loss = a / len(labels) all_loss = torch.tensor([al_loss]).to(self.decive) #all_loss = torch.cat(all_loss).sum() / len(labels) return all_loss, acc * 1.0 / len(labels)