def reconstruct(self, smiles):
        junc_tree = MolJuncTree(smiles)
        junc_tree.recover()

        set_batch_nodeID([junc_tree], self.vocab)

        jtenc_holder, _ = JTNNEncoder.tensorize([junc_tree])
        mpn_holder = MessPassNet.tensorize([smiles])

        tree_vec, _, mol_vec = self.encode(jtenc_holder, mpn_holder)

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(self.T_var(tree_vec))  # Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(self.G_var(mol_vec))  # Following Mueller et al.

        # epsilon = create_var(torch.randn(1, self.latent_size // 2), False)
        # tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon
        # epsilon = create_var(torch.randn(1, self.latent_size // 2), False)
        # mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon
        epsilon = create_var(torch.randn(1, self.latent_size), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(torch.randn(1, self.latent_size), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon

        return self.decode(tree_vec, mol_vec)
    def __init__(self, vocab, hidden_size, latent_size, depth):
        super(JTPropVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth

        self.embedding = nn.Embedding(vocab.size(), hidden_size)
        self.jtnn = JTNNEncoder(vocab, hidden_size, self.embedding)
        self.jtmpn = JTMessPassNet(hidden_size, depth)
        self.mpn = MessPassNet(hidden_size, depth)
        self.decoder = JTNNDecoder(vocab, hidden_size, latent_size / 2,
                                   self.embedding)

        self.T_mean = nn.Linear(hidden_size, latent_size / 2)
        self.T_var = nn.Linear(hidden_size, latent_size / 2)
        self.G_mean = nn.Linear(hidden_size, latent_size / 2)
        self.G_var = nn.Linear(hidden_size, latent_size / 2)

        self.propNN = nn.Sequential(
            nn.Linear(self.latent_size, self.hidden_size), nn.Tanh(),
            nn.Linear(self.latent_size, self.hidden_size), nn.Tanh(),
            nn.Linear(self.hidden_size, 1))
        self.prop_loss = nn.MSELoss()
        self.assm_loss = nn.CrossEntropyLoss(size_average=False)
        self.stereo_loss = nn.CrossEntropyLoss(size_average=False)
Beispiel #3
0
def tensorize(junc_tree_batch, vocab, use_graph_conv, assm=True):
    set_batch_nodeID(junc_tree_batch, vocab)
    smiles_batch = [junc_tree.smiles for junc_tree in junc_tree_batch]
    jtenc_holder, mess_dict = JTNNEncoder.tensorize(junc_tree_batch)

    prop_batch = []
    for smiles in smiles_batch:
        prop_batch.append(Descriptors.MolLogP(MolFromSmiles(smiles)))

    if use_graph_conv:
        molenc_holder = MolGraphEncoder.tensorize(smiles_batch)

        if assm is False:
            return junc_tree_batch, jtenc_holder, molenc_holder

        candidate_smiles = []
        cand_batch_idx = []
        for idx, junc_tree in enumerate(junc_tree_batch):
            for node in junc_tree.nodes:
                # leaf node's attachment is determined by neighboring node's attachment
                if node.is_leaf or len(node.candidates) == 1:
                    continue
                candidate_smiles.extend(
                    [candidate for candidate in node.candidates])
                cand_batch_idx.extend([idx] * len(node.candidates))

        cand_molenc_holder = MolGraphEncoder.tensorize(candidate_smiles)
        cand_batch_idx = torch.LongTensor(cand_batch_idx)

        return junc_tree_batch, jtenc_holder, molenc_holder, (
            cand_molenc_holder, cand_batch_idx), prop_batch

    else:
        mpn_holder = MessPassNet.tensorize(smiles_batch)

        if assm is False:
            return junc_tree_batch, jtenc_holder, mpn_holder

        candidates = []
        cand_batch_idx = []
        for idx, junc_tree in enumerate(junc_tree_batch):
            for node in junc_tree.nodes:
                # leaf node's attachment is determined by neighboring node's attachment
                if node.is_leaf or len(node.candidates) == 1:
                    continue
                candidates.extend([(candidate, junc_tree.nodes, node)
                                   for candidate in node.candidates])
                cand_batch_idx.extend([idx] * len(node.candidates))

        jtmpn_holder = JTMessPassNet.tensorize(candidates, mess_dict)
        cand_batch_idx = torch.LongTensor(cand_batch_idx)

        return junc_tree_batch, jtenc_holder, mpn_holder, (
            jtmpn_holder, cand_batch_idx), prop_batch
Beispiel #4
0
    def decode(self, x_tree_vecs, x_mol_vecs):
        # currently do not support batch decoding
        assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1

        pred_root, pred_nodes = self.decoder.decode(x_tree_vecs)
        if len(pred_nodes) == 0:
            return None
        elif len(pred_nodes) == 1:
            return pred_root.smiles

        # mark nid & is_leaf & atommap
        for idx, node in enumerate(pred_nodes):
            node.nid = idx + 1
            node.is_leaf = (len(node.neighbors) == 1)
            if len(node.neighbors) > 1:
                set_atom_map(node.mol, node.nid)

        scope = [(0, len(pred_nodes))]
        jtenc_holder, mess_dict = JTNNEncoder.tensorize_nodes(
            pred_nodes, scope)
        _, tree_mess = self.jtnn(*jtenc_holder)

        # important: tree_mess is a matrix, mess_dict is a python dict
        tree_mess = (tree_mess, mess_dict)

        # bilinear
        x_mol_vecs = self.A_assm(x_mol_vecs).squeeze()

        cur_mol = deep_copy_mol(pred_root.mol)
        global_amap = [{}] + [{} for node in pred_nodes]
        global_amap[1] = {
            atom.GetIdx(): atom.GetIdx()
            for atom in cur_mol.GetAtoms()
        }

        cur_mol = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol,
                                    global_amap, [], pred_root, None)
        if cur_mol is None:
            return None

        cur_mol = cur_mol.GetMol()
        set_atom_map(cur_mol)
        cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
        return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None
    def reconstruct_graph_conv(self, smiles):
        junc_tree = MolJuncTree(smiles)
        junc_tree.recover()

        set_batch_nodeID([junc_tree], self.vocab)

        jtenc_holder, _ = JTNNEncoder.tensorize([junc_tree])
        molenc_holder = MolGraphEncoder.tensorize([smiles])

        tree_vec, mol_vec = self.encode_graph_conv(jtenc_holder, molenc_holder)

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(self.T_var(tree_vec))  # Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(self.G_var(mol_vec))  # Following Mueller et al.

        # epsilon = create_var(torch.randn(1, self.latent_size // 2), False)
        # tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon
        # epsilon = create_var(torch.randn(1, self.latent_size // 2), False)
        # mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon
        epsilon = create_var(torch.randn(1, self.latent_size), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(torch.randn(1, self.latent_size), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
        return self.decode_graph_conv(tree_vec, mol_vec)

    # def recon_eval(self, smiles):
    #     junc_tree = MolJuncTree(smiles)
    #     junc_tree.recover()
    #     _, tree_vec, mol_vec = self.encode([junc_tree])
    #
    #     tree_mean = self.T_mean(tree_vec)
    #     tree_log_var = -torch.abs(self.T_var(tree_vec))  # Following Mueller et al.
    #     mol_mean = self.G_mean(mol_vec)
    #     mol_log_var = -torch.abs(self.G_var(mol_vec))  # Following Mueller et al.
    #
    #     all_smiles = []
    #     for i in range(10):
    #         epsilon = create_var(torch.randn(1, self.latent_size // 2), False)
    #         tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon
    #         epsilon = create_var(torch.randn(1, self.latent_size // 2), False)
    #         mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon
    #         for j in range(10):
    #             new_smiles = self.decode(tree_vec, mol_vec, prob_decode=True)
    #             all_smiles.append(new_smiles)
    #     return all_smiles

    # def sample_prior(self):
    #     z_tree = torch.randn(1, self.latent_size).cuda()
    #     z_mol = torch.randn(1, self.latent_size).cuda()
    #     return self.decode(z_tree, z_mol)

    # def sample_eval(self):
    #     tree_vec = create_var(torch.randn(1, self.latent_size // 2), False)
    #     mol_vec = create_var(torch.randn(1, self.latent_size // 2), False)
    #     all_smiles = []
    #     for i in range(100):
    #         s = self.decode(tree_vec, mol_vec, prob_decode=True)
    #         all_smiles.append(s)
    #     return all_smiles

    # def decode(self, tree_vec, mol_vec, prob_decode):
    #
    #     pred_root, pred_nodes = self.decoder.decode(tree_vec, prob_decode)
    #
    #     # Mark nid & is_leaf & atommap
    #     for idx, node in enumerate(pred_nodes):
    #         node.nid = idx + 1
    #         node.is_leaf = (len(node.neighbors) == 1)
    #         if len(node.neighbors) > 1:
    #                 set_atom_map(node.mol, node.nid)
    #
    #     tree_mess = self.jtnn([pred_root])[0]
    #
    #     cur_mol = deep_copy_mol(pred_root.mol)
    #     global_atom_map = [{}] + [{} for node in pred_nodes]
    #     global_atom_map[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()}
    #
    #     cur_mol = self.dfs_assemble(tree_mess, mol_vec, pred_nodes, cur_mol, global_atom_map, [], pred_root, None,
    #                                 prob_decode)
    #     if cur_mol is None:
    #         return None
    #
    #     cur_mol = cur_mol.GetMol()
    #     set_atom_map(cur_mol)
    #     cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
    #     if cur_mol is None:
    #         return None
    #
    #     smiles2D = Chem.MolToSmiles(cur_mol)
    #     stereo_candidates = decode_stereo(smiles2D)
    #     if len(stereo_candidates) == 1:
    #         return stereo_candidates[0]
    #     stereo_vecs = self.mpn(stereo_candidates)
    #     stereo_vecs = self.G_mean(stereo_vecs)
    #     scores = nn.CosineSimilarity()(stereo_vecs, mol_vec)
    #     _, max_id = scores.max(dim=0)
    #     return stereo_candidates[max_id.item()]

    # def dfs_assemble(self, tree_mess, mol_vec, all_nodes, cur_mol, global_atom_map, parent_atom_map, cur_node, parent_node, prob_decode):
    #     parent_nid = parent_node.nid if parent_node is not None else -1
    #     prev_nodes = [parent_node] if parent_node is not None else []
    #
    #     children = [neighbor_node for neighbor_node in cur_node.neighbors if neighbor_node.nid != parent_nid]
    #
    #     # exclude neighbor nodes corresponding to "singleton clusters"
    #     neighbors = [neighbor_node for neighbor_node in children if neighbor_node.mol.GetNumAtoms() > 1]
    #
    #     # sort neighbor nodes in descending order by number of atoms
    #     neighbors = sorted(neighbors, key=lambda x: x.mol.GetNumAtoms(), reverse=True)
    #
    #     # obtain neighbor nodes corresponding to "singleton-clusters"
    #     singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
    #     neighbors = singletons + neighbors
    #
    #     # neighbor_id, ctr_atom_idx, neighbor_atom_idx
    #     cur_atom_map = [(parent_nid, a2, a1) for nid, a1, a2 in parent_atom_map if nid == cur_node.nid]
    #     candidates = enum_assemble(cur_node, neighbors, prev_nodes, cur_atom_map)
    #     if len(candidates) == 0:
    #         return None
    #
    #     candidate_smiles, candidate_mols, candidate_atom_maps = zip(*candidates)
    #
    #     candidates = [(candidate_mol, all_nodes, cur_node) for candidate_mol in candidate_mols]
    #
    #     candidate_vecs = self.jtmpn(candidates, tree_mess)
    #     candidate_vecs = self.G_mean(candidate_vecs)
    #     mol_vec = mol_vec.squeeze()
    #     scores = torch.mv(candidate_vecs, mol_vec) * 20
    #
    #     if prob_decode:
    #         probs = nn.Softmax()(scores.view(1, -1)).squeeze() + 1e-5  # prevent prob = 0
    #         # quick fix
    #         if len(probs.shape) == 0:
    #             probs = torch.tensor([probs])
    #         cand_idx = torch.multinomial(probs, probs.numel())
    #     else:
    #         _, cand_idx = torch.sort(scores, descending=True)
    #
    #     backup_mol = Chem.RWMol(cur_mol)
    #     for idx in range(cand_idx.numel()):
    #         cur_mol = Chem.RWMol(backup_mol)
    #         pred_atom_map = candidate_atom_maps[cand_idx[idx].item()]
    #         new_global_atom_map = copy.deepcopy(global_atom_map)
    #
    #         for neighbor_id, ctr_atom_idx, neighbor_atom_idx in pred_atom_map:
    #             if neighbor_id == parent_nid:
    #                 continue
    #             new_global_atom_map[neighbor_id][neighbor_atom_idx] = new_global_atom_map[cur_node.nid][ctr_atom_idx]
    #
    #         cur_mol = attach_mols(cur_mol, children, [], new_global_atom_map)  # parent is already attached
    #         new_mol = cur_mol.GetMol()
    #         new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
    #
    #         if new_mol is None:
    #             continue
    #
    #         result = True
    #         for child_node in children:
    #             if child_node.is_leaf: continue
    #             cur_mol = self.dfs_assemble(tree_mess, mol_vec, all_nodes, cur_mol, new_global_atom_map, pred_atom_map,
    #                                         child_node, cur_node, prob_decode)
    #             if cur_mol is None:
    #                 result = False
    #                 break
    #         if result:
    #             return cur_mol
    #
    #     return None
    def __init__(self, vocab, hidden_size, latent_size, depthT, depthG, num_layers, use_graph_conv, share_embedding=False):
        """
        Description: This is the constructor for the class.

        Args:
            vocab: List[MolJuncTreeNode]
                The cluster vocabulary over the dataset.

            hidden_size: int
                The dimension of the embedding space.

            latent_size: int
                The dimension of the latent space.

            depthT: int
                The number of timesteps for implementing message passing for encoding the junction trees.

            depthG: int
                The number of timesteps for implementing message passing for encoding the molecular graphs.

            num_layers: int
                The number of layers for the graph convolutional encoder.

            use_graph_conv: Boolean
                Whether to use the Graph ConvNet or Message Passing for encoding molecular graphs.

            share_embedding: Boolean
                Whether to share the embedding space between
        """

        # invoke superclass constructor
        super(JTNNVAE, self).__init__()

        # whether to use message passing or graph convnet
        self.use_graph_conv = use_graph_conv

        # cluster vocabulary over the entire dataset.
        self.vocab = vocab

        # size of hidden layer 
        self.hidden_size = hidden_size

        # size of latent space (latent representation involves both tree and graph encoding vectors)
        self.latent_size = latent_size = latent_size // 2

        if self.use_graph_conv:
            # for encoding molecular graphs, to hidden vector representation
            self.graph_enc = MolGraphEncoder(hidden_size, num_layers)

        else:
            # encoder for producing the molecule graph encoding given batch of molecules
            self.mpn = MessPassNet(hidden_size, depthG)

            # for encoding candidate subgraphs, in the graph decoding phase (section 2.5)
            self.jtmpn = JTMessPassNet(hidden_size, depthG)

        if share_embedding:
            self.embedding = nn.Embedding(vocab.size(), hidden_size)
            self.jtnn = JTNNEncoder(hidden_size, depthT, self.embedding)
            self.decoder = JTNNDecoder(vocab, hidden_size, latent_size, self.embedding)
        else:
            self.jtnn = JTNNEncoder(hidden_size, depthT, nn.Embedding(vocab.size(), hidden_size))
            self.decoder = JTNNDecoder(vocab, hidden_size, latent_size, nn.Embedding(vocab.size(), hidden_size))

        # weight matrices for calculating mean and log_var vectors, for implementing the VAE
        self.T_mean = nn.Linear(hidden_size, latent_size)

        self.T_var = nn.Linear(hidden_size, latent_size)

        self.G_mean = nn.Linear(hidden_size, latent_size)

        self.G_var = nn.Linear(hidden_size, latent_size)

        self.A_assm = nn.Linear(latent_size, hidden_size, bias=False)

        # reconstruction loss
        self.assm_loss = nn.CrossEntropyLoss(size_average=False)
Beispiel #7
0
def tensorize(junc_tree_batch, vocab, use_graph_conv, assm=True):
    set_batch_nodeID(junc_tree_batch, vocab)
    smiles_batch = [junc_tree.smiles for junc_tree in junc_tree_batch]
    jtenc_holder, mess_dict = JTNNEncoder.tensorize(junc_tree_batch)

    if use_graph_conv:
        molenc_holder = MolGraphEncoder.tensorize(smiles_batch)

        if assm is False:
            return junc_tree_batch, jtenc_holder, molenc_holder

        candidate_smiles = []
        cand_batch_idx = []
        for idx, junc_tree in enumerate(junc_tree_batch):
            for node in junc_tree.nodes:
                # leaf node's attachment is determined by neighboring node's attachment
                if node.is_leaf or len(node.candidates) == 1:
                    continue
                candidate_smiles.extend(
                    [candidate for candidate in node.candidates])
                cand_batch_idx.extend([idx] * len(node.candidates))

        cand_molenc_holder = MolGraphEncoder.tensorize(candidate_smiles)
        cand_batch_idx = torch.LongTensor(cand_batch_idx)

        # stereo_candidates = []
        # stereo_batch_idx = []
        # stereo_labels = []
        # for idx, junc_tree in enumerate(junc_tree_batch):
        #     candidates = junc_tree.stereo_candidates
        #     if len(candidates) == 1:
        #         continue
        #     if junc_tree.smiles3D not in candidates:
        #         candidates.append(junc_tree.smiles3D)
        #
        #     stereo_candidates.extend(candidates)
        #     stereo_batch_idx.extend([idx] * len(candidates))
        #     stereo_labels.append( (candidates.index(junc_tree.smiles3D), len(candidates)) )
        #
        # stereo_molenc_holder = None
        # if len(stereo_labels) > 0:
        #     stereo_molenc_holder = MolGraphEncoder.tensorize(stereo_candidates)
        # stereo_batch_idx = torch.LongTensor(stereo_batch_idx)

        # return junc_tree_batch, jtenc_holder, molenc_holder, (cand_molenc_holder, cand_batch_idx), (stereo_molenc_holder, stereo_batch_idx, stereo_labels)
        return junc_tree_batch, jtenc_holder, molenc_holder, (
            cand_molenc_holder, cand_batch_idx)

    else:
        mpn_holder = MessPassNet.tensorize(smiles_batch)
        # jtenc_holder, mess_dict = JTNNEncoder.tensorize(junc_tree_batch)

        if assm is False:
            return junc_tree_batch, jtenc_holder, mpn_holder

        candidates = []
        cand_batch_idx = []
        for idx, junc_tree in enumerate(junc_tree_batch):
            for node in junc_tree.nodes:
                # leaf node's attachment is determined by neighboring node's attachment
                if node.is_leaf or len(node.candidates) == 1:
                    continue
                candidates.extend([(candidate, junc_tree.nodes, node)
                                   for candidate in node.candidates])
                cand_batch_idx.extend([idx] * len(node.candidates))

        jtmpn_holder = JTMessPassNet.tensorize(candidates, mess_dict)
        cand_batch_idx = torch.LongTensor(cand_batch_idx)

        # stereo_candidates = []
        # stereo_batch_idx = []
        # stereo_labels = []
        # for idx, junc_tree in enumerate(junc_tree_batch):
        #     candidates = junc_tree.stereo_candidates
        #     if len(candidates) == 1:
        #         continue
        #     if junc_tree.smiles3D not in candidates:
        #         candidates.append(junc_tree.smiles3D)
        #
        #     stereo_candidates.extend(candidates)
        #     stereo_batch_idx.extend([idx] * len(candidates))
        #     stereo_labels.append((candidates.index(junc_tree.smiles3D), len(candidates)))
        #
        # stereo_molenc_holder = None
        # if len(stereo_labels) > 0:
        #     stereo_molenc_holder = MessPassNet.tensorize(stereo_candidates)
        # stereo_batch_idx = torch.LongTensor(stereo_batch_idx)
        #
        # stereo_batch_idx.to(stereo_batch_idx)

        # return junc_tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder, cand_batch_idx), (stereo_molenc_holder, stereo_batch_idx, stereo_labels)
        return junc_tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,
                                                           cand_batch_idx)