Example #1
0
    def forward(self, fnode, fmess, node_graph, mess_graph, scope):
        fnode = create_var(fnode)
        fmess = create_var(fmess)
        node_graph = create_var(node_graph)
        mess_graph = create_var(mess_graph)
        messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size))
        ##################
        # try:
        fnode = self.embedding(fnode)
            #print(fnode.size())
        # except:
        #     fnode = torch.randn((fnode.size(),hidden_size)).cuda()
        # ####################

        fmess = index_select_ND(fnode, 0, fmess)
        messages = self.GRU(messages, fmess, mess_graph)

        mess_nei = index_select_ND(messages, 0, node_graph)
        node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        node_vecs = self.outputNN(node_vecs)

        max_len = max([x for _,x in scope])
        batch_vecs = []
        for st,le in scope:
            cur_vecs = node_vecs[st] #Root is the first node
            batch_vecs.append( cur_vecs )

        tree_vecs = torch.stack(batch_vecs, dim=0)
        return tree_vecs, messages
Example #2
0
    def stereo(self, mol_batch, mol_vec):
        stereo_cands, batch_idx = [], []
        labels = []
        for i, mol_tree in enumerate(mol_batch):
            cands = mol_tree.stereo_cands
            if len(cands) == 1: continue
            if mol_tree.smiles3D not in cands:
                cands.append(mol_tree.smiles3D)
            stereo_cands.extend(cands)
            batch_idx.extend([i] * len(cands))
            labels.append((cands.index(mol_tree.smiles3D), len(cands)))

        if len(labels) == 0:
            return create_var(torch.zeros(1)), 1.0

        batch_idx = create_var(torch.LongTensor(batch_idx))
        stereo_cands = self.mpn(mol2graph(stereo_cands))
        stereo_cands = self.G_mean(stereo_cands)
        stereo_labels = mol_vec.index_select(0, batch_idx)
        scores = torch.nn.CosineSimilarity()(stereo_cands, stereo_labels)

        st, acc = 0, 0
        all_loss = []
        for label, le in labels:
            cur_scores = scores.narrow(0, st, le)
            if cur_scores.data[label] >= cur_scores.max():
                acc += 1
            label = create_var(torch.LongTensor([label]))
            all_loss.append(self.stereo_loss(cur_scores.view(1, -1), label))
            st += le
        #all_loss = torch.cat(all_loss).sum() / len(labels)
        all_loss = sum(all_loss) / len(labels)
        return all_loss, acc * 1.0 / len(labels)
    def assm(self, junc_tree_batch, x_jtmpn_holder, z_mol_vecs, x_tree_mess):
        jtmpn_holder, batch_idx = x_jtmpn_holder
        atom_feature_matrix, bond_feature_matrix, atom_adjacency_graph, bond_adjacency_graph, scope = jtmpn_holder

        batch_idx = create_var(batch_idx)

        candidate_vecs = self.jtmpn(atom_feature_matrix, bond_feature_matrix, atom_adjacency_graph, bond_adjacency_graph, scope, x_tree_mess)

        z_mol_vecs = z_mol_vecs.index_select(0, batch_idx)
        z_mol_vecs = self.A_assm(z_mol_vecs)  # bilinear
        scores = torch.bmm(
            z_mol_vecs.unsqueeze(1),
            candidate_vecs.unsqueeze(-1)
        ).squeeze()

        cnt, tot, acc = 0, 0, 0
        all_loss = []
        for i, mol_tree in enumerate(junc_tree_batch):
            comp_nodes = [node for node in mol_tree.nodes if len(node.candidates) > 1 and not node.is_leaf]
            cnt += len(comp_nodes)
            for node in comp_nodes:
                label = node.candidates.index(node.label)
                num_candidates = len(node.candidates)
                cur_score = scores.narrow(0, tot, num_candidates)
                tot += num_candidates

                if cur_score.data[label] >= cur_score.max().item():
                    acc += 1

                label = create_var(torch.LongTensor([label]))
                all_loss.append(self.assm_loss(cur_score.view(1, -1), label))

        all_loss = sum(all_loss) / len(junc_tree_batch)
        return all_loss, acc * 1.0 / cnt
Example #4
0
    def forward(self, fnode, fmess, node_graph, mess_graph, scope):
        fnode = create_var(fnode)
        fmess = create_var(fmess)
        node_graph = create_var(node_graph)
        mess_graph = create_var(mess_graph)
        messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size))

        fnode = self.embedding(fnode)
        fmess1 = index_select_ND(fnode, 0, fmess[:, 0])
        fmess2 = self.E_pos(fmess[:, 1])
        fmess = self.inputNN( torch.cat([fmess1,fmess2], dim=-1) )
        messages = self.GRU(messages, fmess, mess_graph)

        mess_nei = index_select_ND(messages, 0, node_graph)
        node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        node_vecs = self.outputNN(node_vecs)

        max_len = max([x for _,x in scope])
        batch_vecs = []
        for st,le in scope:
            cur_vecs = node_vecs[st] #Root is the first node
            batch_vecs.append( cur_vecs )

        tree_vecs = torch.stack(batch_vecs, dim=0)
        return tree_vecs, messages
Example #5
0
    def forward(self, mol_batch, beta=0):
        batch_size = len(mol_batch)

        tree_mess, tree_vec, mol_vec = self.encode(mol_batch)

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al.

        z_mean = torch.cat([tree_mean,mol_mean], dim=1)
        z_log_var = torch.cat([tree_log_var,mol_log_var], dim=1)
        kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size

        epsilon = create_var(torch.randn(batch_size, (int)(self.latent_size / 2)), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(torch.randn(batch_size, (int)(self.latent_size / 2)), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
        
        word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_batch, tree_vec)
        assm_loss, assm_acc = self.assm(mol_batch, mol_vec, tree_mess)
        if self.use_stereo:
            stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
        else:
            stereo_loss, stereo_acc = 0, 0

        all_vec = torch.cat([tree_vec, mol_vec], dim=1)
        loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss 

        return loss, kl_loss.item(), word_acc, topo_acc, assm_acc, stereo_acc
Example #6
0
    def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, x_tree_mess):
        jtmpn_holder,batch_idx = jtmpn_holder
        fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
        batch_idx = create_var(batch_idx)

        cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, x_tree_mess)

        x_mol_vecs = x_mol_vecs.index_select(0, batch_idx)
        x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear
        scores = torch.bmm(
                x_mol_vecs.unsqueeze(1),
                cand_vecs.unsqueeze(-1)
        ).squeeze()
        
        cnt,tot,acc = 0,0,0
        all_loss = []
        for i,mol_tree in enumerate(mol_batch):
            comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf]
            cnt += len(comp_nodes)
            for node in comp_nodes:
                label = node.cands.index(node.label)
                ncand = len(node.cands)
                cur_score = scores.narrow(0, tot, ncand)
                tot += ncand

                if cur_score.data[label] >= cur_score.max().item():
                    acc += 1

                label = create_var(torch.LongTensor([label]))
                all_loss.append( self.assm_loss(cur_score.view(1,-1), label) )
        
        all_loss = sum(all_loss) / len(mol_batch)
        return all_loss, acc * 1.0 / cnt
Example #7
0
    def forward(self, fatoms, fbonds, agraph, bgraph, scope, tree_message): #tree_message[0] == vec(0)
        fatoms = create_var(fatoms)
        fbonds = create_var(fbonds)
        agraph = create_var(agraph)
        bgraph = create_var(bgraph)

        binput = self.W_i(fbonds)
        graph_message = F.relu(binput)

        for i in xrange(self.depth - 1):
            message = torch.cat([tree_message,graph_message], dim=0) 
            nei_message = index_select_ND(message, 0, bgraph)
            nei_message = nei_message.sum(dim=1) #assuming tree_message[0] == vec(0)
            nei_message = self.W_h(nei_message)
            graph_message = F.relu(binput + nei_message)

        message = torch.cat([tree_message,graph_message], dim=0)
        nei_message = index_select_ND(message, 0, agraph)
        nei_message = nei_message.sum(dim=1)
        ainput = torch.cat([fatoms, nei_message], dim=1)
        atom_hiddens = F.relu(self.W_o(ainput))
        
        mol_vecs = []
        for st,le in scope:
            mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
            mol_vecs.append(mol_vec)

        mol_vecs = torch.stack(mol_vecs, dim=0)
        return mol_vecs
Example #8
0
    def forward(self, fnode, fmess, node_graph, mess_graph, scope):
        fnode = create_var(fnode)
        fmess = create_var(fmess)
        node_graph = create_var(node_graph)
        mess_graph = create_var(mess_graph)
        messages = create_var(torch.zeros(mess_graph.size(0),
                                          self.hidden_size))

        fnode = self.embedding(fnode)
        fmess = index_select_ND(fnode, 0, fmess)
        messages = self.GRU(messages, fmess, mess_graph)

        mess_nei = index_select_ND(messages, 0, node_graph)
        node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        node_vecs = self.outputNN(node_vecs)

        max_len = max([x for _, x in scope])
        batch_vecs = []
        for st, le in scope:
            cur_vecs = node_vecs[st:st + le]
            cur_vecs = F.pad(cur_vecs, (0, 0, 0, max_len - le))
            batch_vecs.append(cur_vecs)

        tree_vecs = torch.stack(batch_vecs, dim=0)
        return tree_vecs, messages
Example #9
0
    def forward(self, mol_batch, beta=0):
        batch_size = len(mol_batch)
        mol_batch, prop_batch = zip(*mol_batch)
        tree_mess, tree_vec, mol_vec = self.encode(mol_batch)

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al.

        z_mean = torch.cat([tree_mean,mol_mean], dim=1)
        z_log_var = torch.cat([tree_log_var,mol_log_var], dim=1)
        kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size

        epsilon = create_var(torch.randn(batch_size, self.latent_size / 2), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(torch.randn(batch_size, self.latent_size / 2), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
        
        word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_batch, tree_vec)
        assm_loss, assm_acc = self.assm(mol_batch, mol_vec, tree_mess)
        stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)

        all_vec = torch.cat([tree_vec, mol_vec], dim=1)
        prop_label = create_var(torch.Tensor(prop_batch))
        prop_loss = self.prop_loss(self.propNN(all_vec).squeeze(), prop_label)
        
        loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss + prop_loss
        return loss, kl_loss.data[0], word_acc, topo_acc, assm_acc, stereo_acc, prop_loss.data[0]
    def reconstruct(self, smiles):
        junc_tree = MolJuncTree(smiles)
        junc_tree.recover()

        set_batch_nodeID([junc_tree], self.vocab)

        jtenc_holder, _ = JTNNEncoder.tensorize([junc_tree])
        mpn_holder = MessPassNet.tensorize([smiles])

        tree_vec, _, mol_vec = self.encode(jtenc_holder, mpn_holder)

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(self.T_var(tree_vec))  # Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(self.G_var(mol_vec))  # Following Mueller et al.

        # epsilon = create_var(torch.randn(1, self.latent_size // 2), False)
        # tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon
        # epsilon = create_var(torch.randn(1, self.latent_size // 2), False)
        # mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon
        epsilon = create_var(torch.randn(1, self.latent_size), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(torch.randn(1, self.latent_size), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon

        return self.decode(tree_vec, mol_vec)
Example #11
0
    def forward(self, mol_graph):
        fatoms, fbonds, agraph, bgraph, scope = mol_graph
        fatoms = create_var(fatoms)
        fbonds = create_var(fbonds)
        agraph = create_var(agraph)
        bgraph = create_var(bgraph)

        binput = self.W_i(fbonds)
        message = nn.ReLU()(binput)

        for _ in xrange(self.depth - 1):
            nei_message = index_select_ND(message, 0, bgraph)
            nei_message = nei_message.sum(dim=1)
            nei_message = self.W_h(nei_message)
            message = nn.ReLU()(binput + nei_message)

        nei_message = index_select_ND(message, 0, agraph)
        nei_message = nei_message.sum(dim=1)
        ainput = torch.cat([fatoms, nei_message], dim=1)
        atom_hiddens = nn.ReLU()(self.W_o(ainput))

        mol_vecs = []
        for st, le in scope:
            mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
            mol_vecs.append(mol_vec)

        mol_vecs = torch.stack(mol_vecs, dim=0)
        return mol_vecs
Example #12
0
 def sample_eval(self):
     tree_vec = create_var(torch.randn(1, self.latent_size / 2), False)
     mol_vec = create_var(torch.randn(1, self.latent_size / 2), False)
     all_smiles = []
     for i in xrange(100):
         s = self.decode(tree_vec, mol_vec, prob_decode=True)
         all_smiles.append(s)
     return all_smiles
Example #13
0
    def fuse_noise(self, tree_vecs, mol_vecs):
        tree_eps = create_var( torch.randn(tree_vecs.size(0), 1, self.rand_size / 2) )
        tree_eps = tree_eps.expand(-1, tree_vecs.size(1), -1)
        mol_eps = create_var( torch.randn(mol_vecs.size(0), 1, self.rand_size / 2) )
        mol_eps = mol_eps.expand(-1, mol_vecs.size(1), -1)

        tree_vecs = torch.cat([tree_vecs,tree_eps], dim=-1) 
        mol_vecs = torch.cat([mol_vecs,mol_eps], dim=-1) 
        return self.B_t(tree_vecs), self.B_g(mol_vecs)
Example #14
0
 def sample_prior_eval(self, prob_decode=False, ns=1000, nd=500):
     priors = []
     for i in range(ns):  #100
         dec = []
         tree_vec = create_var(torch.randn(1, self.latent_size / 2), False)
         mol_vec = create_var(torch.randn(1, self.latent_size / 2), False)
         for j in range(nd):  #500
             dec.append(self.decode(tree_vec, mol_vec, prob_decode))
         priors.append(dec)
     return priors
Example #15
0
    def optimize(self, smiles, sim_cutoff, lr=2.0, num_iter=20):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        _, tree_vec, mol_vec = self.encode([mol_tree])

        mol = Chem.MolFromSmiles(smiles)
        fp1 = AllChem.GetMorganFingerprint(mol, 2)

        tree_mean = self.T_mean(tree_vec)
        # Following Mueller et al.
        tree_log_var = -torch.abs(self.T_var(tree_vec))
        mol_mean = self.G_mean(mol_vec)
        # Following Mueller et al.
        mol_log_var = -torch.abs(self.G_var(mol_vec))
        mean = torch.cat([tree_mean, mol_mean], dim=1)
        log_var = torch.cat([tree_log_var, mol_log_var], dim=1)
        cur_vec = create_var(mean.data, True)

        visited = []
        for _ in xrange(num_iter):
            prop_val = self.propNN(cur_vec).squeeze()
            grad = torch.autograd.grad(prop_val, cur_vec)[0]
            cur_vec = cur_vec.data + lr * grad.data
            cur_vec = create_var(cur_vec, True)
            visited.append(cur_vec)

        l, r = 0, num_iter - 1
        while l < r - 1:
            mid = (l + r) / 2
            new_vec = visited[mid]
            tree_vec, mol_vec = torch.chunk(new_vec, 2, dim=1)
            new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
            if new_smiles is None:
                r = mid - 1
                continue

            new_mol = Chem.MolFromSmiles(new_smiles)
            fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
            sim = DataStructs.TanimotoSimilarity(fp1, fp2)
            if sim < sim_cutoff:
                r = mid - 1
            else:
                l = mid

        tree_vec, mol_vec = torch.chunk(visited[l], 2, dim=1)
        new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
        if new_smiles is None:
            return smiles, 1.0
        new_mol = Chem.MolFromSmiles(new_smiles)
        fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
        sim = DataStructs.TanimotoSimilarity(fp1, fp2)
        if sim >= sim_cutoff:
            return new_smiles, sim
        else:
            return smiles, 1.0
    def forward(self, root_batch):
        orders = []  # orders: list(list), 每个子列表代表一个根层次遍历的结果
        for root in root_batch:
            # oder: list(list), 一个列表分为两部分,
            # 一个是自底向上的顺序,每个子列表中包含该层的结点及其父节点
            # 一个是自顶向下的顺序,每个子列表中包含该层的结点及其子节点
            order = get_prop_order(root)
            orders.append(order)

        h = {}
        max_depth = max([len(order) for order in orders])
        padding = create_var(torch.zeros(self.hidden_size), False)

        for t in range(max_depth):
            prop_list = []
            for order in orders:
                if len(order) > t:  # 确保这棵树有第t层
                    prop_list.extend(order[t])  # 第t层的层次列表加入到prop_list

            cur_x = []
            cur_h_nei = []
            for node_x, node_y in prop_list:
                x, y = node_x.idx, node_y.idx  # 结点编号
                cur_x.append(node_x.wid)  # 结点类型编号

                h_nei = []
                for node_z in node_x.neighbors:
                    z = node_z.idx
                    if z == y:
                        continue
                    # h_nei:结点x除y以外的邻居,即与其相邻的结点
                    h_nei.append(h[(z, x)])

                # 如果邻居数量达不到最大值,则用padding的变量填充
                pad_len = MAX_NB - len(h_nei)
                h_nei.extend([padding] * pad_len)
                cur_h_nei.extend(h_nei)

            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(
                cur_x
            )  # 从这里开始,标签转化为了向量, cur_x.size = (len(prop_list), hidden_size)
            cur_h_nei = torch.cat(cur_h_nei,
                                  dim=0).view(-1, MAX_NB, self.hidden_size)
            # cur_nei_h.size = (len(prop_list), MAX_NB, hidden_size)

            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r,
                        self.W_h)
            for i, m in enumerate(prop_list):
                x, y = m[0].idx, m[1].idx
                h[(x, y)] = new_h[i]

        # node aggregate
        root_vecs = node_aggregate(root_batch, h, self.embedding, self.W)
        return h, root_vecs
Example #17
0
    def forward(self, root_batch):
        orders = []
        for root in root_batch:
            order = get_prop_order(root)
            orders.append(order)
        
        h = {}
        max_depth = max([len(x) for x in orders])
        padding = create_var(torch.zeros(self.hidden_size), False)

        maxx=0

        for t in range(max_depth):
            prop_list = []
            for order in orders:
                if t < len(order):
                    prop_list.extend(order[t])

            cur_x = []
            cur_h_nei = []
            for node_x,node_y in prop_list:
                x,y = node_x.idx,node_y.idx
                cur_x.append(node_x.wid)

                h_nei = []
                for node_z in node_x.neighbors:
                    z = node_z.idx
                    if z == y: continue
                    h_nei.append(h[(z,x)])
                if len(h_nei)>MAX_NB:
                    print("len(h_nei)")
                    print(len(h_nei))

                if len(h_nei)>maxx:
                    maxx=len(h_nei)
                pad_len = MAX_NB - len(h_nei)
                h_nei.extend([padding] * pad_len)
                cur_h_nei.extend(h_nei)
            #print(maxx)
            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(cur_x)
            #print(torch.cat(cur_h_nei, dim=0).size())
            cur_h_nei = torch.cat(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size)

            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
            for i,m in enumerate(prop_list):
                x,y = m[0].idx,m[1].idx
                h[(x,y)] = new_h[i]

        root_vecs = node_aggregate(root_batch, h, self.embedding, self.W)

        return h, root_vecs
    def forward(self, atom_feature_matrix, bond_feature_matrix, atom_adjacency_graph, atom_bond_adjacency_graph, bond_atom_adjacency_graph, scope):
        """
        Args:
            atom_feature_matrix: torch.tensor (shape: batch_size x atom_feature_dim)
                The matrix containing feature vectors, for all the atoms, across the entire batch.
                * atom_feature_dim = len(ELEM_LIST) + 6 + 5 + 4 + 1

            bond_feature_matrix: torch.tensor (shape: batch_size x bond_feature_dim)
                The matrix containing feature vectors, for all the bonds, across the entire batch.
                * bond_feature_dim = 5 + 6

            atom_adjacency_graph: torch.tensor (shape: num_atoms x MAX_NUM_NEIGHBORS(=6))
                For each atom, across the entire batch, the idxs of neighboring atoms.

            atom_bond_adjacency_graph: torch.tensor(shape: num_atoms x MAX_NUM_NEIGHBORS(=6))
                For each atom, across the entire batch, the idxs of all the bonds, in which it is the initial atom.

            bond_atom_adjacency_graph: torch.tensor (shape: num_bonds x 2)
                For each bond, across the entire batch, the idxs of the 2 atoms, of which the bond is composed of.

            scope: List[Tuple(int, int)]
                The list to store tuples (total_bonds, num_bonds), to keep track of all the bond feature vectors,
                belonging to a particular molecule.

        Returns:
            mol_vecs: torch.tensor (shape: batch_size x hidden_size)
                The hidden vector representation of each molecular graph, across the entire batch
        """
        # create PyTorch variables
        atom_feature_matrix = create_var(atom_feature_matrix)
        bond_feature_matrix = create_var(bond_feature_matrix)
        atom_adjacency_graph = create_var(atom_adjacency_graph)
        atom_bond_adjacency_graph = create_var(atom_bond_adjacency_graph)
        bond_atom_adjacency_graph = create_var(bond_atom_adjacency_graph)

        # implement convolution
        atom_layer_input = atom_feature_matrix
        bond_layer_input = bond_feature_matrix

        for conv_layer in self.conv_layers:
            # implement forward pass for this convolutional layer
            atom_layer_output, bond_layer_output = conv_layer(atom_layer_input, bond_layer_input, atom_adjacency_graph,
                                                              atom_bond_adjacency_graph, bond_atom_adjacency_graph)

            # set the input features for the next convolutional layer
            atom_layer_input, bond_layer_input = atom_layer_output, bond_layer_output

        # for each molecular graph, pool all the edge feature vectors
        mol_vecs = self.pool_bond_features_for_mols(atom_layer_output, bond_layer_output, bond_atom_adjacency_graph, scope)

        return mol_vecs
Example #19
0
    def reconstruct(self, smiles, prob_decode=False):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        _,tree_vec,mol_vec = self.encode([mol_tree])
        
        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al.

        epsilon = create_var(torch.randn(1, self.latent_size / 2), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(torch.randn(1, self.latent_size / 2), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
        return self.decode(tree_vec, mol_vec, prob_decode)
Example #20
0
 def rsample(self, z_vecs, W_mean, W_var):
     z_mean = W_mean(z_vecs)
     z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al.
     kl_loss = -0.5 * torch.mean(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var))
     epsilon = create_var(torch.randn_like(z_mean))
     z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon
     return z_vecs, kl_loss
Example #21
0
    def assm(self, mol_batch, mol_vec, tree_mess):
        cands = []
        batch_idx = []
        for i, mol_tree in enumerate(mol_batch):
            for node in mol_tree.nodes:
                # Leaf node's attachment is determined by neighboring node's
                # attachment
                if node.is_leaf() or len(node.cands) == 1:
                    continue
                cands.extend([(cand, mol_tree.nodes, node)
                              for cand in node.cand_mols])
                batch_idx.extend([i] * len(node.cands))

        cand_vec = self.jtmpn(cands, tree_mess)
        cand_vec = self.G_mean(cand_vec)

        batch_idx = create_var(torch.LongTensor(batch_idx))
        mol_vec = mol_vec.index_select(0, batch_idx)

        mol_vec = mol_vec.view(-1, 1, self.latent_size / 2)
        cand_vec = cand_vec.view(-1, self.latent_size / 2, 1)
        scores = torch.bmm(mol_vec, cand_vec).squeeze()

        cnt, tot, acc = 0, 0, 0
        all_loss = []
        for i, mol_tree in enumerate(mol_batch):
            comp_nodes = [
                node for node in mol_tree.nodes
                if len(node.cands) > 1 and not node.is_leaf()
            ]
            cnt += len(comp_nodes)
            for node in comp_nodes:
                label = node.cands.index(node.label)
                ncand = len(node.cands)
                cur_score = scores.narrow(0, tot, ncand)
                tot += ncand

                if cur_score.data[label] >= cur_score.max().data[0]:
                    acc += 1

                label = create_var(torch.LongTensor([label]))
                all_loss.append(self.assm_loss(cur_score.view(1, -1), label))

        # all_loss = torch.stack(all_loss).sum() / len(mol_batch)
        all_loss = sum(all_loss) / len(mol_batch)
        return all_loss, acc * 1.0 / cnt
Example #22
0
    def reconstruct1(self, smiles, prob_decode=False):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        # print("tree olusturuldu")
        _, tree_vec, mol_vec = self.encode([mol_tree])
        # print("encode edildi")
        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(self.T_var(tree_vec))  # Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(self.G_var(mol_vec))  # Following Mueller et al.

        epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon

        return tree_vec,mol_vec,prob_decode
Example #23
0
    def fuse_pair(self, x_tree_vecs, x_mol_vecs, y_tree_vecs, y_mol_vecs, jtenc_scope, mpn_scope):
        diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1)
        size = create_var(torch.Tensor([le for _,le in jtenc_scope]))
        diff_tree_vecs = diff_tree_vecs / size.unsqueeze(-1)

        diff_mol_vecs = y_mol_vecs.sum(dim=1) - x_mol_vecs.sum(dim=1)
        size = create_var(torch.Tensor([le for _,le in mpn_scope]))
        diff_mol_vecs = diff_mol_vecs / size.unsqueeze(-1)

        diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var)
        diff_mol_vecs, mol_kl = self.rsample(diff_mol_vecs, self.G_mean, self.G_var)

        diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(-1, x_tree_vecs.size(1), -1)
        diff_mol_vecs = diff_mol_vecs.unsqueeze(1).expand(-1, x_mol_vecs.size(1), -1)
        x_tree_vecs = torch.cat([x_tree_vecs,diff_tree_vecs], dim=-1)
        x_mol_vecs = torch.cat([x_mol_vecs,diff_mol_vecs], dim=-1)

        return self.B_t(x_tree_vecs), self.B_g(x_mol_vecs), tree_kl + mol_kl
Example #24
0
def node_aggregate(nodes, h, embedding, W):
    x_idx = []
    h_nei = []
    hidden_size = embedding.embedding_dim
    padding = create_var(torch.zeros(hidden_size), False)

    for node_x in nodes:
        x_idx.append(node_x.wid)
        nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors]
        pad_len = MAX_NB - len(nei)
        nei.extend([padding] * pad_len)
        h_nei.extend(nei)

    h_nei = torch.cat(h_nei, dim=0).view(-1, MAX_NB, hidden_size)
    sum_h_nei = h_nei.sum(dim=1)
    x_vec = create_var(torch.LongTensor(x_idx))
    x_vec = embedding(x_vec)
    node_vec = torch.cat([x_vec, sum_h_nei], dim=1)
    return nn.ReLU()(W(node_vec))
Example #25
0
 def recon_eval(self, smiles):
     mol_tree = MolTree(smiles)
     mol_tree.recover()
     _,tree_vec,mol_vec = self.encode([mol_tree])
     
     tree_mean = self.T_mean(tree_vec)
     tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al.
     mol_mean = self.G_mean(mol_vec)
     mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al.
     
     all_smiles = []
     for i in range(10):
         epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False)
         tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
         epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False)
         mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
         for j in range(10):
             new_smiles = self.decode(tree_vec, mol_vec, prob_decode=True)
             all_smiles.append(new_smiles)
     return all_smiles
Example #26
0
    def reconstruct(self, smiles, prob_decode=False,DataFrame=None):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        #print("tree olusturuldu")
        _,tree_vec,mol_vec = self.encode([mol_tree])
        #print("encode edildi")
        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al.

        epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(torch.randn(1, (int)(self.latent_size / 2)), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
        thethird=torch.cat((tree_vec, mol_vec), 1)
        #print(thethird.to('cpu').data.numpy())
        DataFrame.loc[smiles]=thethird.to('cpu').data.numpy()[0]


        return self.decode(tree_vec, mol_vec, prob_decode)
    def rsample(self, z_vecs, W_mean, W_var):
        batch_size = z_vecs.size(0)
        z_mean = W_mean(z_vecs)
        z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al.

        # as per Kingma & Welling
        kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size

        # reparameterization trick
        epsilon = create_var(torch.randn_like(z_mean))
        z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon

        return z_vecs, kl_loss
    def forward(self, node_wid_list, node_child_adjacency_graph,
                node_edge_adjacency_graph, edge_node_adjacency_graph, scope,
                root_scope):
        # list to store embedding vectors for junction-tree nodes
        node_feature_vecs = []

        # padding vector for node features
        node_feature_padding = create_var(torch.zeros(self.hidden_size))

        node_feature_vecs.append(node_feature_padding)

        # put this tensor on the GPU
        node_wid_list = create_var(node_wid_list)

        # obtain embedding vectors for all the junction-tree nodes
        node_embeddings = self.embedding(node_wid_list)

        node_feature_vecs.extend(list(node_embeddings))

        node_feature_matrix = torch.stack(node_feature_vecs, dim=0)

        total_num_edges = edge_node_adjacency_graph.shape[0]

        edge_feature_matrix = torch.zeros(total_num_edges, self.hidden_size)

        # create PyTorch variables
        node_feature_matrix = create_var(node_feature_matrix)
        edge_feature_matrix = create_var(edge_feature_matrix)
        node_child_adjacency_graph = create_var(node_child_adjacency_graph)
        node_edge_adjacency_graph = create_var(node_edge_adjacency_graph)
        edge_node_adjacency_graph = create_var(edge_node_adjacency_graph)

        # implement convolution
        node_layer_input = node_feature_matrix
        edge_layer_input = edge_feature_matrix

        for conv_layer in self.conv_layers:
            # implement forward pass for this convolutional layer
            node_layer_output, edge_layer_output = conv_layer(
                node_layer_input, edge_layer_input, node_child_adjacency_graph,
                node_edge_adjacency_graph, edge_node_adjacency_graph)

            # set the input features for the next convolutional layer
            node_layer_input, edge_layer_input = node_layer_output, edge_layer_output

        # for each molecular graph, pool all the edge feature vectors
        # tree_vecs = self.pool_edge_features_for_junc_trees(node_layer_output, edge_layer_output, edge_node_adjacency_graph, scope)

        tree_vecs = node_layer_output[root_scope]
        return tree_vecs
Example #29
0
    def gradient_penalty(self, real_vecs, fake_vecs):
        eps = create_var(torch.rand(real_vecs.size(0), 1))
        inter_data = eps * real_vecs + (1 - eps) * fake_vecs
        inter_data = autograd.Variable(inter_data, requires_grad=True)
        inter_score = self.netD(inter_data).squeeze(-1)

        inter_grad = autograd.grad(inter_score,
                                   inter_data,
                                   grad_outputs=torch.ones(
                                       inter_score.size()).cuda(),
                                   create_graph=True,
                                   retain_graph=True,
                                   only_inputs=True)[0]

        inter_norm = inter_grad.norm(2, dim=1)
        inter_gp = ((inter_norm - 1)**2).mean() * self.beta
        #inter_norm = (inter_grad ** 2).sum(dim=1)
        #inter_gp = torch.max(inter_norm - 1, self.zero).mean() * self.beta

        return inter_gp, inter_norm.mean().item()
Example #30
0
    def forward(self, h, x, mess_graph):
        mask = torch.ones(h.size(0), 1)
        mask[0] = 0 #first vector is padding
        mask = create_var(mask)
        for it in xrange(self.depth):
            h_nei = index_select_ND(h, 0, mess_graph)
            sum_h = h_nei.sum(dim=1)
            z_input = torch.cat([x, sum_h], dim=1)
            z = F.sigmoid(self.W_z(z_input))

            r_1 = self.W_r(x).view(-1, 1, self.hidden_size)
            r_2 = self.U_r(h_nei)
            r = F.sigmoid(r_1 + r_2)
            
            gated_h = r * h_nei
            sum_gated_h = gated_h.sum(dim=1)
            h_input = torch.cat([x, sum_gated_h], dim=1)
            pre_h = F.tanh(self.W_h(h_input))
            h = (1.0 - z) * sum_h + z * pre_h
            h = h * mask

        return h