def __init__(self, vocab, hidden_size, latent_size, depth):
        super(JTPropVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth

        self.embedding = nn.Embedding(vocab.size(), hidden_size)
        self.jtnn = JTNNEncoder(vocab, hidden_size, self.embedding)
        self.jtmpn = JTMPN(hidden_size, depth)
        self.mpn = MPN(hidden_size, depth)
        self.decoder = JTNNDecoder(vocab, hidden_size, latent_size / 2,
                                   self.embedding)

        self.T_mean = nn.Linear(hidden_size, latent_size / 2)
        self.T_var = nn.Linear(hidden_size, latent_size / 2)
        self.G_mean = nn.Linear(hidden_size, latent_size / 2)
        self.G_var = nn.Linear(hidden_size, latent_size / 2)

        self.propNN = nn.Sequential(
            nn.Linear(self.latent_size, self.hidden_size), nn.Tanh(),
            nn.Linear(self.hidden_size, 1))
        self.prop_loss = nn.MSELoss()
        self.assm_loss = nn.CrossEntropyLoss(size_average=False)
        self.stereo_loss = nn.CrossEntropyLoss(size_average=False)
예제 #2
0
    def __init__(self, vocab, hidden_size, latent_size, depthT, depthG):
        super(JTNNVAEMLP, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size = latent_size / 2 #Tree and Mol has two vectors

        self.jtnn = JTNNEncoder(hidden_size, depthT, nn.Embedding(vocab.size(), hidden_size))
        self.decoder = JTNNDecoder(vocab, hidden_size, latent_size, nn.Embedding(vocab.size(), hidden_size))

        self.jtmpn = JTMPN(hidden_size, depthG)
        self.mpn = MPN(hidden_size, depthG)

        self.A_assm = nn.Linear(latent_size, hidden_size, bias=False)
        self.assm_loss = nn.CrossEntropyLoss(size_average=False)

        self.T_mean = nn.Linear(hidden_size, latent_size)
        self.T_var = nn.Linear(hidden_size, latent_size)
        self.G_mean = nn.Linear(hidden_size, latent_size)
        self.G_var = nn.Linear(hidden_size, latent_size)

        self.T_hat_mean = nn.Linear(latent_size, latent_size)
        self.T_hat_var = nn.Linear(latent_size, latent_size)

        #New MLP
        self.gene_exp_size = 978

        self.gene_mlp = nn.Linear(self.gene_exp_size, latent_size)

        self.tree_mlp = nn.Linear(latent_size+latent_size, latent_size)
예제 #3
0
    def __init__(self, vocab, args):
        super(DiffVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size = args.hidden_size
        self.rand_size = rand_size = args.rand_size

        self.jtmpn = JTMPN(hidden_size, args.depthG)
        self.mpn = MPN(hidden_size, args.depthG)

        if args.share_embedding:
            self.embedding = nn.Embedding(vocab.size(), hidden_size)
            self.jtnn = JTNNEncoder(hidden_size, args.depthT, self.embedding)
            self.decoder = JTNNDecoder(vocab, hidden_size, self.embedding, args.use_molatt)
        else:
            self.jtnn = JTNNEncoder(hidden_size, args.depthT, nn.Embedding(vocab.size(), hidden_size))
            self.decoder = JTNNDecoder(vocab, hidden_size, nn.Embedding(vocab.size(), hidden_size), args.use_molatt)

        self.A_assm = nn.Linear(hidden_size, hidden_size, bias=False)
        self.assm_loss = nn.CrossEntropyLoss(size_average=False)

        self.T_mean = nn.Linear(hidden_size, rand_size / 2)
        self.T_var = nn.Linear(hidden_size, rand_size / 2)
        self.G_mean = nn.Linear(hidden_size, rand_size / 2)
        self.G_var = nn.Linear(hidden_size, rand_size / 2)
        self.B_t = nn.Sequential(nn.Linear(hidden_size + rand_size / 2, hidden_size), nn.ReLU())
        self.B_g = nn.Sequential(nn.Linear(hidden_size + rand_size / 2, hidden_size), nn.ReLU())
예제 #4
0
def tensorize(tree_batch, vocab, assm=True, if_need_origin_word = False):
    set_batch_nodeID(tree_batch, vocab)
    smiles_batch = [tree.smiles for tree in tree_batch]
    (jtenc_holder,mess_dict),origin_word = JTNNEncoder.tensorize(tree_batch)
    mpn_holder = MPN.tensorize(smiles_batch)

    if assm is False:
        if if_need_origin_word:
            return tree_batch, jtenc_holder, mpn_holder, origin_word
        else:
            return tree_batch, jtenc_holder, mpn_holder

    cands = []
    batch_idx = []
    for i,mol_tree in enumerate(tree_batch):
        for node in mol_tree.nodes:
            #Leaf node's attachment is determined by neighboring node's attachment
            if node.is_leaf or len(node.cands) == 1: continue
            cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] )
            batch_idx.extend([i] * len(node.cands))

    jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
    batch_idx = torch.LongTensor(batch_idx)

    return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx)
예제 #5
0
    def dfs_assemble(self, y_tree_mess, x_mol_vec_pooled, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node):
        fa_nid = fa_node.nid if fa_node is not None else -1
        prev_nodes = [fa_node] if fa_node is not None else []

        children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
        neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
        neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
        singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
        neighbors = singletons + neighbors

        cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
        cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
        if len(cands) == 0:
            return None

        cand_smiles,cand_amap = zip(*cands)
        cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles]

        jtmpn_holder = JTMPN.tensorize(cands, y_tree_mess[1])
        fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
        cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess[0])

        scores = torch.mv(cand_vecs, x_mol_vec_pooled)
        _,cand_idx = torch.sort(scores, descending=True)

        backup_mol = Chem.RWMol(cur_mol)
        #for i in xrange(cand_idx.numel()):
        for i in xrange( min(cand_idx.numel(), 5) ):
            cur_mol = Chem.RWMol(backup_mol)
            pred_amap = cand_amap[cand_idx[i].item()]
            new_global_amap = copy.deepcopy(global_amap)

            for nei_id,ctr_atom,nei_atom in pred_amap:
                if nei_id == fa_nid:
                    continue
                new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom]

            cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached
            new_mol = cur_mol.GetMol()
            new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))

            if new_mol is None: continue
            
            result = True
            for nei_node in children:
                if nei_node.is_leaf: continue
                cur_mol = self.dfs_assemble(y_tree_mess, x_mol_vec_pooled, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node)
                if cur_mol is None: 
                    result = False
                    break
            if result: return cur_mol

        return None
예제 #6
0
    def __init__(self,
                 vocab,
                 hidden_size,
                 latent_size,
                 depthT,
                 depthG,
                 loss_type='cos'):
        super(JTNNMJ, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size = latent_size / 2  #Tree and Mol has two vectors

        self.jtnn = JTNNEncoder(hidden_size, depthT,
                                nn.Embedding(vocab.size(), hidden_size))
        self.decoder = JTNNDecoder(vocab, hidden_size, latent_size,
                                   nn.Embedding(vocab.size(), hidden_size))

        self.jtmpn = JTMPN(hidden_size, depthG)
        self.mpn = MPN(hidden_size, depthG)

        self.A_assm = nn.Linear(latent_size, hidden_size, bias=False)
        self.assm_loss = nn.CrossEntropyLoss(size_average=False)

        self.T_mean = nn.Linear(hidden_size, latent_size)
        self.T_var = nn.Linear(hidden_size, latent_size)
        self.G_mean = nn.Linear(hidden_size, latent_size)
        self.G_var = nn.Linear(hidden_size, latent_size)

        #For MJ
        self.gene_exp_size = 978

        self.gene_mlp = nn.Linear(self.gene_exp_size, 2 * hidden_size)
        self.gene_mlp2 = nn.Linear(2 * hidden_size, 2 * hidden_size)

        self.cos = nn.CosineSimilarity()
        self.loss_type = loss_type
        if loss_type == 'L1':
            self.cos_loss = torch.nn.L1Loss(reduction='elementwise_mean')
        elif loss_type == 'L2':
            self.cos_loss = torch.nn.MSELoss(reduction='elementwise_mean')
        elif loss_type == 'cos':
            self.cos_loss = torch.nn.CosineEmbeddingLoss()
예제 #7
0
    def __init__(self, vocab, hidden_size, latent_size, depth):
        super(JTNNVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = int(hidden_size)
        self.latent_size = int(latent_size)
        self.depth = depth

        self.embedding = nn.Embedding(vocab.size(), hidden_size)
        self.jtnn = JTNNEncoder(vocab, hidden_size, self.embedding)
        self.jtmpn = JTMPN(hidden_size, depth)
        self.mpn = MPN(hidden_size, depth)
        self.decoder = JTNNDecoder(vocab, hidden_size, latent_size / 2,
                                   self.embedding)

        self.T_mean = nn.Linear(hidden_size, int(latent_size / 2))
        self.T_var = nn.Linear(hidden_size, int(latent_size / 2))
        self.G_mean = nn.Linear(hidden_size, int(latent_size / 2))
        self.G_var = nn.Linear(hidden_size, int(latent_size / 2))

        self.assm_loss = nn.CrossEntropyLoss(size_average=False)
        self.stereo_loss = nn.CrossEntropyLoss(size_average=False)
예제 #8
0
파일: jtnn_vae.py 프로젝트: yuxwind/DIG
    def dfs_assemble(self, y_tree_mess, x_mol_vecs, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node, prob_decode, check_aroma):
        fa_nid = fa_node.nid if fa_node is not None else -1
        prev_nodes = [fa_node] if fa_node is not None else []

        children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
        neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
        neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
        singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
        neighbors = singletons + neighbors

        cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
        cands,aroma_score = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
        if len(cands) == 0 or (sum(aroma_score) < 0 and check_aroma):
            return None, cur_mol

        cand_smiles,cand_amap = zip(*cands)
        aroma_score = torch.Tensor(aroma_score).cuda()
        cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles]

        if len(cands) > 1:
            jtmpn_holder = JTMPN.tensorize(cands, y_tree_mess[1])
            fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
            cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess[0])
            scores = torch.mv(cand_vecs, x_mol_vecs) + aroma_score
        else:
            scores = torch.Tensor([1.0])

        if prob_decode:
            probs = F.softmax(scores.view(1,-1), dim=1).squeeze() + 1e-7 #prevent prob = 0
            cand_idx = torch.multinomial(probs, probs.numel())
        else:
            _,cand_idx = torch.sort(scores, descending=True)

        backup_mol = Chem.RWMol(cur_mol)
        pre_mol = cur_mol
        for i in range(cand_idx.numel()):
            cur_mol = Chem.RWMol(backup_mol)
            pred_amap = cand_amap[cand_idx[i].item()]
            new_global_amap = copy.deepcopy(global_amap)

            for nei_id,ctr_atom,nei_atom in pred_amap:
                if nei_id == fa_nid:
                    continue
                new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom]

            cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached
            new_mol = cur_mol.GetMol()
            new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))

            if new_mol is None: continue
            
            has_error = False
            for nei_node in children:
                if nei_node.is_leaf: continue
                tmp_mol, tmp_mol2 = self.dfs_assemble(y_tree_mess, x_mol_vecs, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node, prob_decode, check_aroma)
                if tmp_mol is None: 
                    has_error = True
                    if i == 0: pre_mol = tmp_mol2
                    break
                cur_mol = tmp_mol

            if not has_error: return cur_mol, cur_mol

        return None, pre_mol