Beispiel #1
0
    def stereo(self, mol_batch, mol_vec):
        stereo_cands = mol_batch['stereo_cand_graph_batch']
        batch_idx = mol_batch['stereo_cand_batch_idx']
        labels = mol_batch['stereo_cand_labels']
        lengths = mol_batch['stereo_cand_lengths']

        if len(labels) == 0:
            # Only one stereoisomer exists; do nothing
            return cuda(torch.tensor(0.)), 1.

        batch_idx = cuda(torch.LongTensor(batch_idx))
        stereo_cands = self.mpn(stereo_cands)
        stereo_cands = self.G_mean(stereo_cands)
        stereo_labels = mol_vec[batch_idx]
        scores = F.cosine_similarity(stereo_cands, stereo_labels)

        st, acc = 0, 0
        all_loss = []
        for label, le in zip(labels, lengths):
            cur_scores = scores[st:st+le]
            if cur_scores.data[label].item() >= cur_scores.max().item():
                acc += 1
            label = cuda(torch.LongTensor([label]))
            all_loss.append(
                F.cross_entropy(cur_scores.view(1, -1), label, size_average=False))
            st += le

        all_loss = sum(all_loss) / len(labels)
        return all_loss, acc / len(labels)
Beispiel #2
0
 def forward(self, nodes):
     #print(nodes.data.keys() )
     x = nodes.data['x']
     try:
         m = nodes.data['m']
     except:
         m = torch.cuda.FloatTensor(1, self.hidden_size).fill_(0)
     return {
         'h': cuda(torch.relu(self.W(cuda(torch.cat([x, m], 1))))),
     }
Beispiel #3
0
    def decode(self, tree_vec, mol_vec):
        mol_tree, nodes_dict, effective_nodes = self.decoder.decode(tree_vec)
        effective_nodes_list = effective_nodes.tolist()
        nodes_dict = [nodes_dict[v] for v in effective_nodes_list]

        for i, (node_id, node) in enumerate(zip(effective_nodes_list, nodes_dict)):
            node['idx'] = i
            node['nid'] = i + 1
            node['is_leaf'] = True
            if mol_tree.in_degree(node_id) > 1:
                node['is_leaf'] = False
                set_atommap(node['mol'], node['nid'])

        mol_tree_sg = mol_tree.subgraph(effective_nodes)
        mol_tree_sg.copy_from_parent()
        mol_tree_msg, _ = self.jtnn([mol_tree_sg])
        mol_tree_msg = unbatch(mol_tree_msg)[0]
        mol_tree_msg.nodes_dict = nodes_dict

        cur_mol = copy_edit_mol(nodes_dict[0]['mol'])
        global_amap = [{}] + [{} for node in nodes_dict]
        global_amap[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()}

        cur_mol = self.dfs_assemble(mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None)
        if cur_mol is None:
            return None

        cur_mol = cur_mol.GetMol()
        set_atommap(cur_mol)
        cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
        if cur_mol is None:
            return None

        smiles2D = Chem.MolToSmiles(cur_mol)
        stereo_cands = decode_stereo(smiles2D)
        if len(stereo_cands) == 1:
            return stereo_cands[0]
        stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands]
        stereo_cand_graphs, atom_x, bond_x = \
                zip(*stereo_graphs)
        stereo_cand_graphs = batch(stereo_cand_graphs)
        atom_x = cuda(torch.cat(atom_x))
        bond_x = cuda(torch.cat(bond_x))
        stereo_cand_graphs.ndata['x'] = atom_x
        stereo_cand_graphs.edata['x'] = bond_x
        stereo_cand_graphs.edata['src_x'] = atom_x.new(
                bond_x.shape[0], atom_x.shape[1]).zero_()
        stereo_vecs = self.mpn(stereo_cand_graphs)
        stereo_vecs = self.G_mean(stereo_vecs)
        scores = F.cosine_similarity(stereo_vecs, mol_vec)
        _, max_id = scores.max(0)
        return stereo_cands[max_id.item()]
Beispiel #4
0
    def sample(self, tree_vec, mol_vec, e1=None, e2=None):
        tree_mean = cuda(self.T_mean(tree_vec))
        tree_log_var = -torch.abs(self.T_var(tree_vec))
        
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(self.G_var(mol_vec))

        epsilon = cuda(torch.randn(*tree_mean.shape)) if e1 is None else e1
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = cuda(torch.randn(*mol_mean.shape)) if e2 is None else e2
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon

        z_mean = torch.cat([tree_mean, mol_mean], 1)
        z_log_var = torch.cat([tree_log_var, mol_log_var], 1)

        return tree_vec, mol_vec, z_mean, z_log_var
Beispiel #5
0
    def run(self, cand_graphs, cand_line_graph, tree_mess_src_edges,
            tree_mess_tgt_edges, tree_mess_tgt_nodes, mol_tree_batch):
        n_nodes = cand_graphs.number_of_nodes()

        cand_graphs.apply_edges(func=lambda edges: {'src_x': edges.src['x']}, )

        get_bond_features = cand_line_graph.ndata['x']
        source_features = cand_line_graph.ndata['src_x']
        features = torch.cat([source_features, get_bond_features], 1)
        msg_input = self.W_i(features)
        cand_line_graph.ndata.update({
            'msg_input': msg_input,
            'msg': torch.relu(msg_input),
            'accum_msg': torch.zeros_like(msg_input),
        })
        zero_node_state = get_bond_features.new(n_nodes,
                                                self.hidden_size).zero_()
        cand_graphs.ndata.update({
            'm': zero_node_state.clone(),
            'h': zero_node_state.clone(),
        })

        cand_graphs.edata['alpha'] = \
                cuda(torch.zeros(cand_graphs.number_of_edges(), self.hidden_size))
        cand_graphs.ndata['alpha'] = zero_node_state
        if tree_mess_src_edges.shape[0] > 0:
            if PAPER:
                src_u, src_v = tree_mess_src_edges.unbind(1)
                tgt_u, tgt_v = tree_mess_tgt_edges.unbind(1)
                alpha = mol_tree_batch.edges[src_u, src_v].data['m']
                cand_graphs.edges[tgt_u, tgt_v].data['alpha'] = alpha
            else:
                src_u, src_v = tree_mess_src_edges.unbind(1)
                alpha = mol_tree_batch.edges[src_u, src_v].data['m']
                node_idx = (tree_mess_tgt_nodes.to(
                    device=zero_node_state.device)[:, None].expand_as(alpha))
                node_alpha = zero_node_state.clone().scatter_add(
                    0, node_idx, alpha)
                cand_graphs.ndata['alpha'] = node_alpha
                cand_graphs.apply_edges(
                    func=lambda edges: {'alpha': edges.src['alpha']}, )

        for i in range(self.depth - 1):
            cand_line_graph.update_all(
                mpn_loopy_bp_msg,
                mpn_loopy_bp_reduce,
                self.loopy_bp_updater,
            )

        cand_graphs.update_all(
            mpn_gather_msg,
            mpn_gather_reduce,
            self.gather_updater,
        )

        return cand_graphs
Beispiel #6
0
    def encode(self, mol_batch):
        mol_graphs = mol_batch['mol_graph_batch']
        mol_vec = cuda(self.mpn(mol_graphs))

        mol_tree_batch, tree_vec = self.jtnn(mol_batch['mol_trees'])

        self.n_nodes_total += mol_graphs.number_of_nodes()
        self.n_edges_total += mol_graphs.number_of_edges()
        self.n_tree_nodes_total += sum(t.number_of_nodes() for t in mol_batch['mol_trees'])
        self.n_passes += 1

        return mol_tree_batch, tree_vec, mol_vec
Beispiel #7
0
    def assm(self, mol_batch, mol_tree_batch, mol_vec):
        cands = [mol_batch['cand_graph_batch'],
                 mol_batch['tree_mess_src_e'],
                 mol_batch['tree_mess_tgt_e'],
                 mol_batch['tree_mess_tgt_n']]
        cand_vec = self.jtmpn(cands, mol_tree_batch)
        cand_vec = self.G_mean(cand_vec)

        batch_idx = cuda(torch.LongTensor(mol_batch['cand_batch_idx']))
        mol_vec = mol_vec[batch_idx]

        mol_vec = mol_vec.view(-1, 1, self.latent_size // 2)
        cand_vec = cand_vec.view(-1, self.latent_size // 2, 1)
        scores = (mol_vec @ cand_vec)[:, 0, 0]

        cnt, tot, acc = 0, 0, 0
        all_loss = []
        for i, mol_tree in enumerate(mol_batch['mol_trees']):
            comp_nodes = [node_id for node_id, node in mol_tree.nodes_dict.items()
                          if len(node['cands']) > 1 and not node['is_leaf']]
            cnt += len(comp_nodes)
            # segmented accuracy and cross entropy
            for node_id in comp_nodes:
                node = mol_tree.nodes_dict[node_id]
                label = node['cands'].index(node['label'])
                ncand = len(node['cands'])
                cur_score = scores[tot:tot+ncand]
                tot += ncand

                if cur_score[label].item() >= cur_score.max().item():
                    acc += 1

                label = cuda(torch.LongTensor([label]))
                all_loss.append(
                    F.cross_entropy(cur_score.view(1, -1), label, size_average=False))

        all_loss = sum(all_loss) / len(mol_batch['mol_trees'])
        return all_loss, acc / cnt
Beispiel #8
0
    def run(self, mol_tree_batch, mol_tree_batch_lg):
        # Since tree roots are designated to 0.  In the batched graph we can
        # simply find the corresponding node ID by looking at node_offset
        node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
        root_ids = node_offset[:-1]
        n_nodes = mol_tree_batch.number_of_nodes()
        n_edges = mol_tree_batch.number_of_edges()

        # Assign structure embeddings to tree nodes
        x = cuda(self.embedding(cuda(mol_tree_batch.ndata['wid'])))
        h = torch.cuda.FloatTensor(n_nodes, self.hidden_size).fill_(0)

        mol_tree_batch.ndata.update({
            'x': x,
            'h': h,
        })

        # Initialize the intermediate variables according to Eq (4)-(8).
        # Also initialize the src_x and dst_x fields.
        # TODO: context?
        mol_tree_batch.edata.update({
            's':
            torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'm':
            torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'r':
            torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'z':
            torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'src_x':
            torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'dst_x':
            torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'rm':
            torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'accum_rm':
            torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
        })

        # Send the source/destination node features to edges
        mol_tree_batch.apply_edges(func=lambda edges: {
            'src_x': edges.src['x'],
            'dst_x': edges.dst['x']
        }, )

        # Message passing
        # I exploited the fact that the reduce function is a sum of incoming
        # messages, and the uncomputed messages are zero vectors.  Essentially,
        # we can always compute s_ij as the sum of incoming m_ij, no matter
        # if m_ij is actually computed or not.
        for eid in level_order(mol_tree_batch, root_ids):
            #eid = mol_tree_batch.edge_ids(u, v)
            mol_tree_batch_lg.pull(
                eid,
                enc_tree_msg,
                enc_tree_reduce,
                self.enc_tree_update,
            )

        # Readout
        mol_tree_batch.update_all(
            enc_tree_gather_msg,
            enc_tree_gather_reduce,
            self.enc_tree_gather_update,
        )

        root_vecs = mol_tree_batch.nodes[root_ids].data['h']

        return mol_tree_batch, root_vecs
Beispiel #9
0
    def decode(self, mol_vec):
        assert mol_vec.shape[0] == 1

        mol_tree = MolTree(None) 

        init_hidden = torch.cuda.FloatTensor(1, self.hidden_size).fill_(0)

        root_hidden = torch.cat([init_hidden, mol_vec], 1)
        root_hidden = F.relu(self.W(root_hidden))
        root_score = self.W_o(root_hidden)
        _, root_wid = torch.max(root_score, 1)
        root_wid = root_wid.view(1)

        mol_tree.add_nodes(1)   # root
        mol_tree.nodes[0].data['wid'] = root_wid
        mol_tree.nodes[0].data['x'] = self.embedding(root_wid)
        mol_tree.nodes[0].data['h'] = init_hidden
        mol_tree.nodes[0].data['fail'] = cuda(torch.tensor([0]))
        mol_tree.nodes_dict[0] = root_node_dict = create_node_dict(
                self.vocab.get_smiles(root_wid))

        stack, trace = [], []
        stack.append((0, self.vocab.get_slots(root_wid)))

        all_nodes = {0: root_node_dict}
        h = {}
        first = True
        new_node_id = 0
        new_edge_id = 0

        for step in range(MAX_DECODE_LEN):
            u, u_slots = stack[-1]
            udata = mol_tree.nodes[u].data
            x = udata['x']
            h = udata['h']

            # Predict stop
            p_input = torch.cat([x, h, mol_vec], 1)
            p_score = torch.sigmoid(self.U_s(torch.relu(self.U(p_input))))
            backtrack = (p_score.item() < 0.5)

            if not backtrack:
                # Predict next clique.  Note that the prediction may fail due
                # to lack of assemblable components
                mol_tree.add_nodes(1)
                new_node_id += 1
                v = new_node_id
                mol_tree.add_edges(u, v)
                uv = new_edge_id
                new_edge_id += 1
                
                if first:
                    mol_tree.edata.update({
                        's': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0),
                        'm': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0),
                        'r': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0),
                        'z': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0),
                        'src_x': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0),
                        'dst_x': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0),
                        'rm': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0),
                        'accum_rm': torch.cuda.FloatTensor(1, self.hidden_size).fill_(0),
                    })
                    first = False

                mol_tree.edges[uv].data['src_x'] = mol_tree.nodes[u].data['x']
                # keeping dst_x 0 is fine as h on new edge doesn't depend on that.

                # DGL doesn't dynamically maintain a line graph.
                mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True)

                mol_tree_lg.pull(
                    uv,
                    dec_tree_edge_msg,
                    dec_tree_edge_reduce,
                    self.dec_tree_edge_update.update_zm,
                )
                mol_tree.pull(
                    v,
                    dec_tree_node_msg,
                    dec_tree_node_reduce,
                )

                vdata = mol_tree.nodes[v].data
                h_v = vdata['h']
                q_input = torch.cat([h_v, mol_vec], 1)
                q_score = torch.softmax(self.W_o(torch.relu(self.W(q_input))), -1)
                _, sort_wid = torch.sort(q_score, 1, descending=True)
                sort_wid = sort_wid.squeeze()

                next_wid = None
                for wid in sort_wid.tolist()[:5]:
                    slots = self.vocab.get_slots(wid)
                    cand_node_dict = create_node_dict(self.vocab.get_smiles(wid))
                    if (have_slots(u_slots, slots) and can_assemble(mol_tree, u, cand_node_dict)):
                        next_wid = wid
                        next_slots = slots
                        next_node_dict = cand_node_dict
                        break

                if next_wid is None:
                    # Failed adding an actual children; v is a spurious node
                    # and we mark it.
                    vdata['fail'] = cuda(torch.tensor([1]))
                    backtrack = True
                else:
                    next_wid = cuda(torch.tensor([next_wid]))
                    vdata['wid'] = next_wid
                    vdata['x'] = self.embedding(next_wid)
                    mol_tree.nodes_dict[v] = next_node_dict
                    all_nodes[v] = next_node_dict
                    stack.append((v, next_slots))
                    mol_tree.add_edge(v, u)
                    vu = new_edge_id
                    new_edge_id += 1
                    mol_tree.edges[uv].data['dst_x'] = mol_tree.nodes[v].data['x']
                    mol_tree.edges[vu].data['src_x'] = mol_tree.nodes[v].data['x']
                    mol_tree.edges[vu].data['dst_x'] = mol_tree.nodes[u].data['x']

                    # DGL doesn't dynamically maintain a line graph.
                    mol_tree_lg = mol_tree.line_graph(backtracking=False, shared=True)
                    mol_tree_lg.apply_nodes(
                        self.dec_tree_edge_update.update_r,
                        uv
                        )

            if backtrack:
                if len(stack) == 1:
                    break   # At root, terminate

                pu, _ = stack[-2]
                u_pu = mol_tree.edge_id(u, pu)

                mol_tree_lg.pull(
                    u_pu,
                    dec_tree_edge_msg,
                    dec_tree_edge_reduce,
                    self.dec_tree_edge_update,
                )
                mol_tree.pull(
                    pu,
                    dec_tree_node_msg,
                    dec_tree_node_reduce,
                )
                stack.pop()

        effective_nodes = mol_tree.filter_nodes(lambda nodes: nodes.data['fail'] != 1)
        effective_nodes, _ = torch.sort(effective_nodes)
        return mol_tree, all_nodes, effective_nodes
Beispiel #10
0
    def run(self, mol_tree_batch, mol_tree_batch_lg, n_trees, tree_vec):
        times = []
        times.append((116,time.time()))
        node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
        root_ids = node_offset[:-1]
        n_nodes = mol_tree_batch.number_of_nodes()
        n_edges = mol_tree_batch.number_of_edges()

        times.append((122,time.time()))
        mol_tree_batch.ndata.update({
            'x': self.embedding(mol_tree_batch.ndata['wid']),
            'h': torch.cuda.FloatTensor(n_nodes, self.hidden_size).fill_(0),
            'new': torch.cuda.ByteTensor(n_nodes).fill_(1)  # whether it's newly generated node
        })

        times.append((129,time.time()))
        mol_tree_batch.edata.update({
            's': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'm': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'r': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'z': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'src_x': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'dst_x': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'rm': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0),
            'accum_rm': torch.cuda.FloatTensor(n_edges, self.hidden_size).fill_(0)
        })

        times.append((141,time.time()))
        mol_tree_batch.apply_edges(
            func=lambda edges: {'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
        )

        # input tensors for stop prediction (p) and label prediction (q)
        p_inputs = []
        p_targets = []
        q_inputs = []
        q_targets = []

        times.append((152,time.time()))
        # Predict root
        mol_tree_batch.pull(
            root_ids,
            dec_tree_node_msg,
            dec_tree_node_reduce,
            dec_tree_node_update,
        )
        
        times.append((161,time.time()))
        # Extract hidden states and store them for stop/label prediction
        h = mol_tree_batch.nodes[root_ids].data['h']
        x = mol_tree_batch.nodes[root_ids].data['x']
        p_inputs.append(torch.cat([x, h, tree_vec], 1))
        # If the out degree is 0 we don't generate any edges at all
        root_out_degrees = mol_tree_batch.out_degrees(root_ids).cuda()
        q_inputs.append(torch.cat([h, tree_vec], 1))
        q_targets.append(mol_tree_batch.nodes[root_ids].data['wid'])

        times.append((171,time.time()))
        # Traverse the tree and predict on children
        for eid, p in dfs_order(mol_tree_batch, root_ids):
            u, v = mol_tree_batch.find_edges(eid)
            
            p_target_list = torch.cuda.LongTensor(root_out_degrees.shape).fill_(0)
            p_target_list[root_out_degrees > 0] = 1 - p.cuda()
            p_target_list = p_target_list[root_out_degrees >= 0]
            p_targets.append(torch.tensor(p_target_list).cuda())

            root_out_degrees -= (root_out_degrees == 0).long().cuda()
            root_out_degrees -= torch.tensor(np.isin(root_ids, v).astype('int64')).cuda()

            mol_tree_batch_lg.pull(
                eid,
                dec_tree_edge_msg,
                dec_tree_edge_reduce,
                self.dec_tree_edge_update,
            )
            is_new = mol_tree_batch.nodes[v].data['new']
            mol_tree_batch.pull(
                v,
                dec_tree_node_msg,
                dec_tree_node_reduce,
                dec_tree_node_update,
            )
            # Extract
            n_repr = mol_tree_batch.nodes[v].data
            h = n_repr['h']
            x = n_repr['x']
            tree_vec_set = tree_vec[root_out_degrees >= 0]
            wid = n_repr['wid']
            p_inputs.append(torch.cat([x, h, tree_vec_set], 1))
            # Only newly generated nodes are needed for label prediction
            # NOTE: The following works since the uncomputed messages are zeros.

            q_input = torch.cat([h, tree_vec_set], 1)[is_new]
            q_target = wid[is_new]
            if q_input.shape[0] > 0:
                q_inputs.append(q_input)
                q_targets.append(q_target)
        p_targets.append(torch.zeros((root_out_degrees == 0).sum()).long().cuda())

        times.append((214,time.time()))
        # Batch compute the stop/label prediction losses
        p_inputs = torch.cat(p_inputs, 0)
        p_targets = cuda(torch.cat(p_targets, 0))
        q_inputs = torch.cat(q_inputs, 0)
        q_targets = torch.cat(q_targets, 0)

        times.append((221,time.time()))
        q = self.W_o(torch.relu(self.W(q_inputs)))
        p = self.U_s(torch.relu(self.U(p_inputs)))[:, 0]

        times.append((225,time.time()))
        p_loss = F.binary_cross_entropy_with_logits(
            p, p_targets.float(), size_average=False
        ) / n_trees
        q_loss = F.cross_entropy(q, q_targets, size_average=False) / n_trees
        p_acc = ((p > 0).long() == p_targets).sum().float() / p_targets.shape[0]
        q_acc = (q.max(1)[1] == q_targets).float().sum() / q_targets.shape[0]

        times.append((233,time.time()))
        self.q_inputs = q_inputs
        self.q_targets = q_targets
        self.q = q
        self.p_inputs =  p_inputs
        self.p_targets = p_targets
        self.p = p
        
        #print("Dec Profile:")
        #for i in range(len(times)-1):
        #    print("\t%d: %f" % (times[i][0], 
        #                        times[i+1][1]-times[i][1]))

        return q_loss, p_loss, q_acc, p_acc
Beispiel #11
0
    def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol,
                     global_amap, fa_amap, cur_node_id, fa_node_id):
        nodes_dict = mol_tree_msg.nodes_dict
        fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None
        cur_node = nodes_dict[cur_node_id]

        fa_nid = fa_node['nid'] if fa_node is not None else -1
        prev_nodes = [fa_node] if fa_node is not None else []

        children_node_id = [v for v in mol_tree_msg.successors(cur_node_id).tolist()
                            if nodes_dict[v]['nid'] != fa_nid]
        children = [nodes_dict[v] for v in children_node_id]
        neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
        neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
        singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
        neighbors = singletons + neighbors

        cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
        cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap)
        if len(cands) == 0:
            return None
        cand_smiles, cand_mols, cand_amap = list(zip(*cands))

        cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols]
        cand_graphs, atom_x, bond_x, tree_mess_src_edges, \
                tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec(
                        cands)
        cand_graphs = batch(cand_graphs)
        atom_x = cuda(atom_x)
        bond_x = cuda(bond_x)
        cand_graphs.ndata['x'] = atom_x
        cand_graphs.edata['x'] = bond_x
        cand_graphs.edata['src_x'] = atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_()

        cand_vecs = self.jtmpn(
                (cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes),
                mol_tree_msg,
                )
        cand_vecs = self.G_mean(cand_vecs)
        mol_vec = mol_vec.squeeze()
        scores = cand_vecs @ mol_vec

        _, cand_idx = torch.sort(scores, descending=True)

        backup_mol = Chem.RWMol(cur_mol)
        for i in range(len(cand_idx)):
            cur_mol = Chem.RWMol(backup_mol)
            pred_amap = cand_amap[cand_idx[i].item()]
            new_global_amap = copy.deepcopy(global_amap)

            for nei_id, ctr_atom, nei_atom in pred_amap:
                if nei_id == fa_nid:
                    continue
                new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom]

            cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap)
            new_mol = cur_mol.GetMol()
            new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))

            if new_mol is None:
                continue

            result = True
            for nei_node_id, nei_node in zip(children_node_id, children):
                if nei_node['is_leaf']:
                    continue
                cur_mol = self.dfs_assemble(
                        mol_tree_msg, mol_vec, cur_mol, new_global_amap, pred_amap,
                        nei_node_id, cur_node_id)
                if cur_mol is None:
                    result = False
                    break

            if result:
                return cur_mol

        return None