Ejemplo n.º 1
0
    def __init__(self, vocab, args):
        super(DiffVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size = args.hidden_size
        self.rand_size = rand_size = args.rand_size

        self.jtmpn = JTMPN(hidden_size, args.depthG)
        self.mpn = MPN(hidden_size, args.depthG)

        if args.share_embedding:
            self.embedding = nn.Embedding(vocab.size(), hidden_size)
            self.jtnn = JTNNEncoder(hidden_size, args.depthT, self.embedding)
            self.decoder = JTNNDecoder(vocab, hidden_size, self.embedding, args.use_molatt)
        else:
            self.jtnn = JTNNEncoder(hidden_size, args.depthT, nn.Embedding(vocab.size(), hidden_size))
            self.decoder = JTNNDecoder(vocab, hidden_size, nn.Embedding(vocab.size(), hidden_size), args.use_molatt)

        self.A_assm = nn.Linear(hidden_size, hidden_size, bias=False)
        self.assm_loss = nn.CrossEntropyLoss(size_average=False)

        self.T_mean = nn.Linear(hidden_size, rand_size / 2)
        self.T_var = nn.Linear(hidden_size, rand_size / 2)
        self.G_mean = nn.Linear(hidden_size, rand_size / 2)
        self.G_var = nn.Linear(hidden_size, rand_size / 2)
        self.B_t = nn.Sequential(nn.Linear(hidden_size + rand_size / 2, hidden_size), nn.ReLU())
        self.B_g = nn.Sequential(nn.Linear(hidden_size + rand_size / 2, hidden_size), nn.ReLU())
    def __init__(self, vocab, hidden_size, latent_size, depth):
        super(JTPropVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth

        self.embedding = nn.Embedding(vocab.size(), hidden_size)
        self.jtnn = JTNNEncoder(vocab, hidden_size, self.embedding)
        self.jtmpn = JTMPN(hidden_size, depth)
        self.mpn = MPN(hidden_size, depth)
        self.decoder = JTNNDecoder(vocab, hidden_size, latent_size / 2,
                                   self.embedding)

        self.T_mean = nn.Linear(hidden_size, latent_size / 2)
        self.T_var = nn.Linear(hidden_size, latent_size / 2)
        self.G_mean = nn.Linear(hidden_size, latent_size / 2)
        self.G_var = nn.Linear(hidden_size, latent_size / 2)

        self.propNN = nn.Sequential(
            nn.Linear(self.latent_size, self.hidden_size), nn.Tanh(),
            nn.Linear(self.hidden_size, 1))
        self.prop_loss = nn.MSELoss()
        self.assm_loss = nn.CrossEntropyLoss(size_average=False)
        self.stereo_loss = nn.CrossEntropyLoss(size_average=False)
Ejemplo n.º 3
0
    def __init__(self, vocab, hidden_size, latent_size, depthT, depthG):
        super(JTNNVAEMLP, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size = latent_size / 2 #Tree and Mol has two vectors

        self.jtnn = JTNNEncoder(hidden_size, depthT, nn.Embedding(vocab.size(), hidden_size))
        self.decoder = JTNNDecoder(vocab, hidden_size, latent_size, nn.Embedding(vocab.size(), hidden_size))

        self.jtmpn = JTMPN(hidden_size, depthG)
        self.mpn = MPN(hidden_size, depthG)

        self.A_assm = nn.Linear(latent_size, hidden_size, bias=False)
        self.assm_loss = nn.CrossEntropyLoss(size_average=False)

        self.T_mean = nn.Linear(hidden_size, latent_size)
        self.T_var = nn.Linear(hidden_size, latent_size)
        self.G_mean = nn.Linear(hidden_size, latent_size)
        self.G_var = nn.Linear(hidden_size, latent_size)

        self.T_hat_mean = nn.Linear(latent_size, latent_size)
        self.T_hat_var = nn.Linear(latent_size, latent_size)

        #New MLP
        self.gene_exp_size = 978

        self.gene_mlp = nn.Linear(self.gene_exp_size, latent_size)

        self.tree_mlp = nn.Linear(latent_size+latent_size, latent_size)
Ejemplo n.º 4
0
    def decode(self, x_tree_vecs, x_mol_vecs):
        #currently do not support batch decoding
        assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1

        pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, x_mol_vecs)
        if len(pred_nodes) == 0: return None
        elif len(pred_nodes) == 1: return pred_root.smiles

        #Mark nid & is_leaf & atommap
        for i,node in enumerate(pred_nodes):
            node.nid = i + 1
            node.is_leaf = (len(node.neighbors) == 1)
            if len(node.neighbors) > 1:
                set_atommap(node.mol, node.nid)

        scope = [(0, len(pred_nodes))]
        jtenc_holder,mess_dict = JTNNEncoder.tensorize_nodes(pred_nodes, scope)
        _,tree_mess = self.jtnn(*jtenc_holder)
        tree_mess = (tree_mess, mess_dict) #Important: tree_mess is a matrix, mess_dict is a python dict

        x_mol_vec_pooled = x_mol_vecs.sum(dim=1) #average pooling?
        x_mol_vec_pooled = self.A_assm(x_mol_vec_pooled).squeeze() #bilinear

        cur_mol = copy_edit_mol(pred_root.mol)
        global_amap = [{}] + [{} for node in pred_nodes]
        global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}

        cur_mol = self.dfs_assemble(tree_mess, x_mol_vec_pooled, pred_nodes, cur_mol, global_amap, [], pred_root, None)
        if cur_mol is None: 
            return None

        cur_mol = cur_mol.GetMol()
        set_atommap(cur_mol)
        cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
        return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None
Ejemplo n.º 5
0
def tensorize(tree_batch, vocab, assm=True, if_need_origin_word = False):
    set_batch_nodeID(tree_batch, vocab)
    smiles_batch = [tree.smiles for tree in tree_batch]
    (jtenc_holder,mess_dict),origin_word = JTNNEncoder.tensorize(tree_batch)
    mpn_holder = MPN.tensorize(smiles_batch)

    if assm is False:
        if if_need_origin_word:
            return tree_batch, jtenc_holder, mpn_holder, origin_word
        else:
            return tree_batch, jtenc_holder, mpn_holder

    cands = []
    batch_idx = []
    for i,mol_tree in enumerate(tree_batch):
        for node in mol_tree.nodes:
            #Leaf node's attachment is determined by neighboring node's attachment
            if node.is_leaf or len(node.cands) == 1: continue
            cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] )
            batch_idx.extend([i] * len(node.cands))

    jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
    batch_idx = torch.LongTensor(batch_idx)

    return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx)
Ejemplo n.º 6
0
    def __init__(self,
                 vocab,
                 hidden_size,
                 latent_size,
                 depthT,
                 depthG,
                 loss_type='cos'):
        super(JTNNMJ, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size = latent_size / 2  #Tree and Mol has two vectors

        self.jtnn = JTNNEncoder(hidden_size, depthT,
                                nn.Embedding(vocab.size(), hidden_size))
        self.decoder = JTNNDecoder(vocab, hidden_size, latent_size,
                                   nn.Embedding(vocab.size(), hidden_size))

        self.jtmpn = JTMPN(hidden_size, depthG)
        self.mpn = MPN(hidden_size, depthG)

        self.A_assm = nn.Linear(latent_size, hidden_size, bias=False)
        self.assm_loss = nn.CrossEntropyLoss(size_average=False)

        self.T_mean = nn.Linear(hidden_size, latent_size)
        self.T_var = nn.Linear(hidden_size, latent_size)
        self.G_mean = nn.Linear(hidden_size, latent_size)
        self.G_var = nn.Linear(hidden_size, latent_size)

        #For MJ
        self.gene_exp_size = 978

        self.gene_mlp = nn.Linear(self.gene_exp_size, 2 * hidden_size)
        self.gene_mlp2 = nn.Linear(2 * hidden_size, 2 * hidden_size)

        self.cos = nn.CosineSimilarity()
        self.loss_type = loss_type
        if loss_type == 'L1':
            self.cos_loss = torch.nn.L1Loss(reduction='elementwise_mean')
        elif loss_type == 'L2':
            self.cos_loss = torch.nn.MSELoss(reduction='elementwise_mean')
        elif loss_type == 'cos':
            self.cos_loss = torch.nn.CosineEmbeddingLoss()
Ejemplo n.º 7
0
    def __init__(self, vocab, hidden_size, latent_size, depth):
        super(JTNNVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = int(hidden_size)
        self.latent_size = int(latent_size)
        self.depth = depth

        self.embedding = nn.Embedding(vocab.size(), hidden_size)
        self.jtnn = JTNNEncoder(vocab, hidden_size, self.embedding)
        self.jtmpn = JTMPN(hidden_size, depth)
        self.mpn = MPN(hidden_size, depth)
        self.decoder = JTNNDecoder(vocab, hidden_size, latent_size / 2,
                                   self.embedding)

        self.T_mean = nn.Linear(hidden_size, int(latent_size / 2))
        self.T_var = nn.Linear(hidden_size, int(latent_size / 2))
        self.G_mean = nn.Linear(hidden_size, int(latent_size / 2))
        self.G_var = nn.Linear(hidden_size, int(latent_size / 2))

        self.assm_loss = nn.CrossEntropyLoss(size_average=False)
        self.stereo_loss = nn.CrossEntropyLoss(size_average=False)