コード例 #1
0
    def forward(self, fatoms, fbonds, agraph, bgraph, scope,
                tree_message):  #tree_message[0] == vec(0)
        fatoms = create_var(fatoms)
        fbonds = create_var(fbonds)
        agraph = create_var(agraph)
        bgraph = create_var(bgraph)

        binput = self.W_i(fbonds)
        graph_message = F.relu(binput)

        for i in range(self.depth - 1):
            message = torch.cat([tree_message, graph_message], dim=0)
            nei_message = index_select_ND(message, 0, bgraph)
            nei_message = nei_message.sum(
                dim=1)  #assuming tree_message[0] == vec(0)
            nei_message = self.W_h(nei_message)
            graph_message = F.relu(binput + nei_message)

        message = torch.cat([tree_message, graph_message], dim=0)
        nei_message = index_select_ND(message, 0, agraph)
        nei_message = nei_message.sum(dim=1)
        ainput = torch.cat([fatoms, nei_message], dim=1)
        atom_hiddens = F.relu(self.W_o(ainput))

        mol_vecs = []
        for st, le in scope:
            mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
            mol_vecs.append(mol_vec)

        mol_vecs = torch.stack(mol_vecs, dim=0)
        return mol_vecs
コード例 #2
0
    def forward(self, fnode, fmess, node_graph, mess_graph, scope):
        fnode = create_var(fnode)
        fmess = create_var(fmess)
        node_graph = create_var(node_graph)
        mess_graph = create_var(mess_graph)
        messages = create_var(torch.zeros(mess_graph.size(0),
                                          self.hidden_size))

        fnode = self.embedding(fnode)
        fmess = index_select_ND(fnode, 0, fmess)
        messages = self.GRU(messages, fmess, mess_graph)

        mess_nei = index_select_ND(messages, 0, node_graph)
        node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        node_vecs = self.outputNN(node_vecs)

        max_len = max([x for _, x in scope])
        batch_vecs = []
        for st, le in scope:
            cur_vecs = node_vecs[st:st + le]
            cur_vecs = F.pad(cur_vecs, (0, 0, 0, max_len - le))
            batch_vecs.append(cur_vecs)

        tree_vecs = torch.stack(batch_vecs, dim=0)
        return tree_vecs, messages
コード例 #3
0
    def fuse_noise(self, tree_vecs, mol_vecs):
        tree_eps = create_var(
            torch.randn(tree_vecs.size(0), 1, self.rand_size_half))
        tree_eps = tree_eps.expand(-1, tree_vecs.size(1), -1)
        mol_eps = create_var(
            torch.randn(mol_vecs.size(0), 1, self.rand_size_half))
        mol_eps = mol_eps.expand(-1, mol_vecs.size(1), -1)

        tree_vecs = torch.cat([tree_vecs, tree_eps], dim=-1)
        mol_vecs = torch.cat([mol_vecs, mol_eps], dim=-1)
        return self.B_t(tree_vecs), self.B_g(mol_vecs)
コード例 #4
0
 def rsample(self, z_vecs, W_mean, W_var):
     z_mean = W_mean(z_vecs)
     z_log_var = -torch.abs(W_var(z_vecs))  #Following Mueller et al.
     kl_loss = -0.5 * torch.mean(1.0 + z_log_var - z_mean * z_mean -
                                 torch.exp(z_log_var))
     epsilon = create_var(torch.randn_like(z_mean))
     z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon
     return z_vecs, kl_loss
コード例 #5
0
    def fuse_pair(self, x_tree_vecs, x_mol_vecs, y_tree_vecs, y_mol_vecs,
                  jtenc_scope, mpn_scope):
        diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1)
        size = create_var(torch.Tensor([le for _, le in jtenc_scope]))
        diff_tree_vecs = diff_tree_vecs / size.unsqueeze(-1)

        diff_mol_vecs = y_mol_vecs.sum(dim=1) - x_mol_vecs.sum(dim=1)
        size = create_var(torch.Tensor([le for _, le in mpn_scope]))
        diff_mol_vecs = diff_mol_vecs / size.unsqueeze(-1)

        diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean,
                                               self.T_var)
        diff_mol_vecs, mol_kl = self.rsample(diff_mol_vecs, self.G_mean,
                                             self.G_var)

        diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(
            -1, x_tree_vecs.size(1), -1)
        diff_mol_vecs = diff_mol_vecs.unsqueeze(1).expand(
            -1, x_mol_vecs.size(1), -1)
        x_tree_vecs = torch.cat([x_tree_vecs, diff_tree_vecs], dim=-1)
        x_mol_vecs = torch.cat([x_mol_vecs, diff_mol_vecs], dim=-1)

        return self.B_t(x_tree_vecs), self.B_g(x_mol_vecs), tree_kl + mol_kl
コード例 #6
0
    def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, y_tree_mess):
        jtmpn_holder, batch_idx = jtmpn_holder
        fatoms, fbonds, agraph, bgraph, scope = jtmpn_holder
        batch_idx = create_var(batch_idx)

        cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope,
                               y_tree_mess)

        x_mol_vecs = x_mol_vecs.sum(dim=1)  #average pooling?
        x_mol_vecs = x_mol_vecs.index_select(0, batch_idx)
        x_mol_vecs = self.A_assm(x_mol_vecs)  #bilinear
        scores = torch.bmm(x_mol_vecs.unsqueeze(1),
                           cand_vecs.unsqueeze(-1)).squeeze()

        cnt, tot, acc = 0, 0, 0
        all_loss = []
        for i, mol_tree in enumerate(mol_batch):
            comp_nodes = [
                node for node in mol_tree.nodes
                if len(node.cands) > 1 and not node.is_leaf
            ]
            cnt += len(comp_nodes)
            for node in comp_nodes:
                label = node.cands.index(node.label)
                ncand = len(node.cands)
                cur_score = scores.narrow(0, tot, ncand)
                tot += ncand

                if cur_score.data[label] >= cur_score.max().item():
                    acc += 1

                label = create_var(torch.LongTensor([label]))
                all_loss.append(self.assm_loss(cur_score.view(1, -1), label))

        all_loss = sum(all_loss) / len(mol_batch)
        return all_loss, acc * 1.0 / cnt
コード例 #7
0
    def gradient_penalty(self, real_vecs, fake_vecs):
        eps = create_var(torch.rand(real_vecs.size(0), 1))
        inter_data = eps * real_vecs + (1 - eps) * fake_vecs
        inter_data = autograd.Variable(inter_data, requires_grad=True)
        inter_score = self.netD(inter_data).squeeze(-1)

        inter_grad = autograd.grad(inter_score,
                                   inter_data,
                                   grad_outputs=torch.ones(
                                       inter_score.size()).cuda(),
                                   create_graph=True,
                                   retain_graph=True,
                                   only_inputs=True)[0]

        inter_norm = inter_grad.norm(2, dim=1)
        inter_gp = ((inter_norm - 1)**2).mean() * self.beta
        #inter_norm = (inter_grad ** 2).sum(dim=1)
        #inter_gp = torch.max(inter_norm - 1, self.zero).mean() * self.beta

        return inter_gp, inter_norm.mean().item()
コード例 #8
0
    def forward(self, h, x, mess_graph):
        mask = torch.ones(h.size(0), 1)
        mask[0] = 0  #first vector is padding
        mask = create_var(mask)
        for it in range(self.depth):
            h_nei = index_select_ND(h, 0, mess_graph)
            sum_h = h_nei.sum(dim=1)
            z_input = torch.cat([x, sum_h], dim=1)
            z = torch.sigmoid(self.W_z(z_input))

            r_1 = self.W_r(x).view(-1, 1, self.hidden_size)
            r_2 = self.U_r(h_nei)
            r = torch.sigmoid(r_1 + r_2)

            gated_h = r * h_nei
            sum_gated_h = gated_h.sum(dim=1)
            h_input = torch.cat([x, sum_gated_h], dim=1)
            pre_h = torch.tanh(self.W_h(h_input))
            h = (1.0 - z) * sum_h + z * pre_h
            h = h * mask

        return h
コード例 #9
0
    def forward(self, holder, depth):
        fnode = create_var(holder[0])
        fmess = create_var(holder[1])
        node_graph = create_var(holder[2])
        mess_graph = create_var(holder[3])
        scope = holder[4]

        fnode = self.embedding(fnode)
        x = index_select_ND(fnode, 0, fmess)
        h = create_var(torch.zeros(mess_graph.size(0), self.hidden_size))

        mask = torch.ones(h.size(0), 1)
        mask[0] = 0  #first vector is padding
        mask = create_var(mask)

        for it in range(depth):
            h_nei = index_select_ND(h, 0, mess_graph)
            h = GRU(x, h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
            h = h * mask

        mess_nei = index_select_ND(h, 0, node_graph)
        node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        root_vecs = [node_vecs[st] for st, le in scope]
        return torch.stack(root_vecs, dim=0)
コード例 #10
0
    def forward(self, mol_batch, x_tree_vecs, x_mol_vecs):
        pred_hiddens,pred_contexts,pred_targets = [],[],[]
        stop_hiddens,stop_contexts,stop_targets = [],[],[]
        traces = []
        for mol_tree in mol_batch:
            s = []
            dfs(s, mol_tree.nodes[0], -1)
            traces.append(s)
            for node in mol_tree.nodes:
                node.neighbors = []

        #Predict Root
        batch_size = len(mol_batch)
        pred_hiddens.append(create_var(torch.zeros(len(mol_batch),self.hidden_size)))
        pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
        pred_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) )

        max_iter = max([len(tr) for tr in traces])
        padding = create_var(torch.zeros(self.hidden_size), False)
        h = {}

        for t in range(max_iter):
            prop_list = []
            batch_list = []
            for i,plist in enumerate(traces):
                if t < len(plist):
                    prop_list.append(plist[t])
                    batch_list.append(i)

            cur_x = []
            cur_h_nei,cur_o_nei = [],[]

            for node_x, real_y, _ in prop_list:
                #Neighbors for message passing (target not included)
                cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
                pad_len = MAX_NB - len(cur_nei)
                cur_h_nei.extend(cur_nei)
                cur_h_nei.extend([padding] * pad_len)

                #Neighbors for stop prediction (all neighbors)
                cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
                pad_len = MAX_NB - len(cur_nei)
                cur_o_nei.extend(cur_nei)
                cur_o_nei.extend([padding] * pad_len)

                #Current clique embedding
                cur_x.append(node_x.wid)

            #Clique embedding
            cur_x = create_var(torch.LongTensor(cur_x))
            cur_x = self.embedding(cur_x)

            #Message passing
            cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)

            #Node Aggregate
            cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            #Gather targets
            pred_target,pred_list = [],[]
            stop_target = []
            for i,m in enumerate(prop_list):
                node_x,node_y,direction = m
                x,y = node_x.idx,node_y.idx
                h[(x,y)] = new_h[i]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(i) 
                stop_target.append(direction)

            #Hidden states for stop prediction
            cur_batch = create_var(torch.LongTensor(batch_list))
            stop_hidden = torch.cat([cur_x,cur_o], dim=1)
            stop_hiddens.append( stop_hidden )
            stop_contexts.append( cur_batch )
            stop_targets.extend( stop_target )
            
            #Hidden states for clique prediction
            if len(pred_list) > 0:
                batch_list = [batch_list[i] for i in pred_list]
                cur_batch = create_var(torch.LongTensor(batch_list))
                pred_contexts.append( cur_batch )

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append( new_h.index_select(0, cur_pred) )
                pred_targets.extend( pred_target )

        #Last stop at root
        cur_x,cur_o_nei = [],[]
        for mol_tree in mol_batch:
            node_x = mol_tree.nodes[0]
            cur_x.append(node_x.wid)
            cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
            pad_len = MAX_NB - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = create_var(torch.LongTensor(cur_x))
        cur_x = self.embedding(cur_x)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x,cur_o], dim=1)
        stop_hiddens.append( stop_hidden )
        stop_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) )
        stop_targets.extend( [0] * len(mol_batch) )

        #Predict next clique
        pred_contexts = torch.cat(pred_contexts, dim=0)
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_scores = self.attention(pred_hiddens, pred_contexts, x_tree_vecs, x_mol_vecs, 'word')
        pred_targets = create_var(torch.LongTensor(pred_targets))

        pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
        _,preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        #Predict stop
        stop_contexts = torch.cat(stop_contexts, dim=0)
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_hiddens = F.relu( self.U_i(stop_hiddens) )
        stop_scores = self.attention(stop_hiddens, stop_contexts, x_tree_vecs, x_mol_vecs, 'stop')
        stop_scores = stop_scores.squeeze(-1)
        stop_targets = create_var(torch.Tensor(stop_targets))
        
        stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
コード例 #11
0
    def soft_decode(self, x_tree_vecs, x_mol_vecs, gumbel, slope, temp):
        assert x_tree_vecs.size(0) == 1

        soft_embedding = lambda x: x.matmul(self.embedding.weight)
        if gumbel:
            sample_softmax = lambda x: F.gumbel_softmax(x, tau=temp)
        else:
            sample_softmax = lambda x: F.softmax(x / temp, dim=1)

        stack = []
        init_hiddens = create_var( torch.zeros(1, self.hidden_size) )
        zero_pad = create_var(torch.zeros(1,1,self.hidden_size))
        contexts = create_var( torch.LongTensor(1).zero_() )

        #Root Prediction
        root_score = self.attention(init_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'word')
        root_prob = sample_softmax(root_score)

        root = MolTreeNode("")
        root.embedding = soft_embedding(root_prob)
        root.prob = root_prob
        root.idx = 0
        stack.append(root)

        all_nodes = [root]
        all_hiddens = []
        h = {}
        for step in range(MAX_SOFT_DECODE_LEN):
            node_x = stack[-1]
            cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
            else:
                cur_h_nei = zero_pad

            #Predict stop
            cur_x = node_x.embedding
            cur_h = cur_h_nei.sum(dim=1)
            stop_hiddens = torch.cat([cur_x,cur_h], dim=1)
            stop_hiddens = F.relu( self.U_i(stop_hiddens) )
            stop_score = self.attention(stop_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'stop')
            all_hiddens.append( stop_hiddens )
            
            forward = 0 if stop_score.item() < 0 else 1
            stop_prob = F.hardtanh(slope * stop_score + 0.5, min_val=0, max_val=1).unsqueeze(1)
            stop_val_ste = forward + stop_prob - stop_prob.detach()

            if forward == 1: #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
                pred_score = self.attention(new_h, contexts, x_tree_vecs, x_mol_vecs, 'word')
                pred_prob = sample_softmax(pred_score)

                node_y = MolTreeNode("")
                node_y.embedding = soft_embedding(pred_prob)
                node_y.prob = pred_prob
                node_y.idx = len(all_nodes)
                node_y.neighbors.append(node_x)

                h[(node_x.idx,node_y.idx)] = new_h[0] * stop_val_ste
                stack.append(node_y)
                all_nodes.append(node_y)
            else:
                if len(stack) == 1: #At root, terminate
                    return torch.cat([cur_x,cur_h], dim=1), all_nodes

                node_fa = stack[-2]
                cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
                h[(node_x.idx,node_fa.idx)] = new_h[0] * (1.0 - stop_val_ste)
                node_fa.neighbors.append(node_x)
                stack.pop()

        #Failure mode: decoding unfinished
        cur_h_nei = [ h[(node_y.idx,root.idx)] for node_y in root.neighbors ]
        if len(cur_h_nei) > 0:
            cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
        else:
            cur_h_nei = zero_pad
        cur_h = cur_h_nei.sum(dim=1)

        stop_hiddens = torch.cat([root.embedding, cur_h], dim=1)
        stop_hiddens = F.relu( self.U_i(stop_hiddens) )
        all_hiddens.append( stop_hiddens )

        return torch.cat([root.embedding,cur_h], dim=1), all_nodes
コード例 #12
0
    def decode(self, x_tree_vecs, x_mol_vecs):
        assert x_tree_vecs.size(0) == 1

        stack = []
        init_hiddens = create_var( torch.zeros(1, self.hidden_size) )
        zero_pad = create_var(torch.zeros(1,1,self.hidden_size))
        contexts = create_var( torch.LongTensor(1).zero_() )

        #Root Prediction
        root_score = self.attention(init_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'word')
        _,root_wid = torch.max(root_score, dim=1)
        root_wid = root_wid.item()

        root = MolTreeNode(self.vocab.get_smiles(root_wid))
        root.wid = root_wid
        root.idx = 0
        stack.append( (root, self.vocab.get_slots(root.wid)) )

        all_nodes = [root]
        h = {}
        for step in range(MAX_DECODE_LEN):
            node_x,fa_slot = stack[-1]
            cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors ]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
            else:
                cur_h_nei = zero_pad

            cur_x = create_var(torch.LongTensor([node_x.wid]))
            cur_x = self.embedding(cur_x)

            #Predict stop
            cur_h = cur_h_nei.sum(dim=1)
            stop_hiddens = torch.cat([cur_x,cur_h], dim=1)
            stop_hiddens = F.relu( self.U_i(stop_hiddens) )
            stop_score = self.attention(stop_hiddens, contexts, x_tree_vecs, x_mol_vecs, 'stop')
            
            backtrack = (stop_score.item() < 0) 

            if not backtrack: #Forward: Predict next clique
                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
                pred_score = self.attention(new_h, contexts, x_tree_vecs, x_mol_vecs, 'word')

                _,sort_wid = torch.sort(pred_score, dim=1, descending=True)
                sort_wid = sort_wid.data.squeeze()

                next_wid = None
                for wid in sort_wid[:5]:
                    slots = self.vocab.get_slots(wid)
                    node_y = MolTreeNode(self.vocab.get_smiles(wid))
                    if have_slots(fa_slot, slots) and can_assemble(node_x, node_y):
                        next_wid = wid
                        next_slots = slots
                        break

                if next_wid is None:
                    backtrack = True #No more children can be added
                else:
                    node_y = MolTreeNode(self.vocab.get_smiles(next_wid))
                    node_y.wid = next_wid
                    node_y.idx = len(all_nodes)
                    node_y.neighbors.append(node_x)
                    h[(node_x.idx,node_y.idx)] = new_h[0]
                    stack.append( (node_y,next_slots) )
                    all_nodes.append(node_y)

            if backtrack: #Backtrack, use if instead of else
                if len(stack) == 1: 
                    break #At root, terminate

                node_fa,_ = stack[-2]
                cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
                h[(node_x.idx,node_fa.idx)] = new_h[0]
                node_fa.neighbors.append(node_x)
                stack.pop()

        return root, all_nodes