Beispiel #1
0
    def __init__(self,
                 input_size=None,
                 hidden_size=None,
                 output_size=None,
                 optimizer=optim.Adam,
                 criterion=nn.NLLLoss,
                 lr=0.0001,
                 src_vocab=None,
                 tar_vocab=None,
                 sess='',
                 device='cpu'):
        self.encoder = Encoder(input_size, hidden_size, device)
        self.decoder = TreeDecoder(hidden_size, output_size, device)

        if torch.cuda.is_available():
            self.encoder.cuda()
            self.decoder.cuda()

        self.encoder_opt = optimizer(self.encoder.parameters(), lr=lr)
        self.decoder_opt = optimizer(self.decoder.parameters(), lr=lr)
        self.criterion = criterion()

        self.src_vocab = src_vocab
        self.tar_vocab = tar_vocab

        self.sess = sess
        self.device = device
Beispiel #2
0
    def __init__(self, vsize, esize, hsize, asize, buckets, **kwargs):
        super(PointerNet, self).__init__()

        self.name = kwargs.get('name', self.__class__.__name__)
        self.scope = kwargs.get('scope', self.name)

        self.enc_vsize = vsize
        self.enc_esize = esize
        self.enc_hsize = hsize

        self.dec_msize = self.enc_hsize * 2  # concatenation of bidirectional RNN states
        self.dec_isize = self.enc_hsize * 2  # concatenation of bidirectional RNN states
        self.dec_hsize = hsize
        self.dec_asize = asize

        self.buckets = buckets
        self.max_len = self.buckets[-1]

        self.max_grad_norm = kwargs.get('max_grad_norm', 100)
        self.optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
        # self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-2)

        self.num_layer = kwargs.get('num_layer', 1)
        self.rnn_class = kwargs.get('rnn_class', tf.nn.rnn_cell.BasicLSTMCell)
        # self.rnn_class = kwargs.get('rnn_class', tf.nn.rnn_cell.GRUCell)

        self.encoder = Encoder(self.enc_vsize,
                               self.enc_esize,
                               self.enc_hsize,
                               rnn_class=self.rnn_class,
                               num_layer=self.num_layer)

        if kwargs.get('tree_decoder', False):
            self.decoder = TreeDecoder(self.dec_isize,
                                       self.dec_hsize,
                                       self.dec_msize,
                                       self.dec_asize,
                                       self.max_len,
                                       rnn_class=self.rnn_class,
                                       num_layer=self.num_layer,
                                       epsilon=1.0)
        else:
            self.decoder = Decoder(self.dec_isize,
                                   self.dec_hsize,
                                   self.dec_msize,
                                   self.dec_asize,
                                   self.max_len,
                                   rnn_class=self.rnn_class,
                                   num_layer=self.num_layer,
                                   epsilon=1.0)

        self.baselines = []
        self.bl_ratio = kwargs.get('bl_ratio', 0.95)
        for i in range(self.max_len):
            self.baselines.append(tf.Variable(0.0, trainable=False))
Beispiel #3
0
    def __init__(self, vocab, hidden_size, latent_size, depth, stereo=True):
        super(TreeVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth

        self.embedding = nn.Embedding(vocab.size(), hidden_size)
        self.jtnn = TreeEncoder(vocab, hidden_size, self.embedding)
        self.jtmpn = JTMPN(hidden_size, depth)
        self.mpn = MPN(hidden_size, depth)
        self.decoder = TreeDecoder(vocab, hidden_size, latent_size / 2,
                                   self.embedding)

        self.T_mean = nn.Linear(hidden_size, int(latent_size / 2))
        self.T_var = nn.Linear(hidden_size, int(latent_size / 2))
        self.G_mean = nn.Linear(hidden_size, int(latent_size / 2))
        self.G_var = nn.Linear(hidden_size, int(latent_size / 2))

        self.assm_loss = nn.CrossEntropyLoss(size_average=False)
        self.use_stereo = stereo
        if stereo:
            self.stereo_loss = nn.CrossEntropyLoss(size_average=False)
Beispiel #4
0
class TreeVAE(nn.Module):
    def __init__(self, vocab, hidden_size, latent_size, depth, stereo=True):
        super(TreeVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth

        self.embedding = nn.Embedding(vocab.size(), hidden_size)
        self.jtnn = TreeEncoder(vocab, hidden_size, self.embedding)
        self.jtmpn = JTMPN(hidden_size, depth)
        self.mpn = MPN(hidden_size, depth)
        self.decoder = TreeDecoder(vocab, hidden_size, latent_size / 2,
                                   self.embedding)

        self.T_mean = nn.Linear(hidden_size, int(latent_size / 2))
        self.T_var = nn.Linear(hidden_size, int(latent_size / 2))
        self.G_mean = nn.Linear(hidden_size, int(latent_size / 2))
        self.G_var = nn.Linear(hidden_size, int(latent_size / 2))

        self.assm_loss = nn.CrossEntropyLoss(size_average=False)
        self.use_stereo = stereo
        if stereo:
            self.stereo_loss = nn.CrossEntropyLoss(size_average=False)

    def encode(self, mol_batch):
        set_batch_nodeID(mol_batch, self.vocab)
        root_batch = [mol_tree.nodes[0] for mol_tree in mol_batch]
        tree_mess, tree_vec = self.jtnn(root_batch)

        smiles_batch = [mol_tree.smiles for mol_tree in mol_batch]
        mol_vec = self.mpn(mol2graph(smiles_batch))
        return tree_mess, tree_vec, mol_vec

    def encode_latent_mean(self, smiles_list):
        mol_batch = [MolTree(s) for s in smiles_list]
        for mol_tree in mol_batch:
            mol_tree.recover()

        _, tree_vec, mol_vec = self.encode(mol_batch)
        tree_mean = self.T_mean(tree_vec)
        mol_mean = self.G_mean(mol_vec)
        return torch.cat([tree_mean, mol_mean], dim=1)

    def forward(self, mol_batch, beta=0):
        batch_size = len(mol_batch)

        tree_mess, tree_vec, mol_vec = self.encode(mol_batch)

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(
            self.T_var(tree_vec))  # Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(
            self.G_var(mol_vec))  # Following Mueller et al.

        z_mean = torch.cat([tree_mean, mol_mean], dim=1)
        z_log_var = torch.cat([tree_log_var, mol_log_var], dim=1)
        kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean -
                                   torch.exp(z_log_var)) / batch_size

        epsilon = create_var(
            torch.randn(batch_size, int(self.latent_size / 2)), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(
            torch.randn(batch_size, int(self.latent_size / 2)), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon

        word_loss, topo_loss, word_acc, topo_acc = self.decoder(
            mol_batch, tree_vec)
        assm_loss, assm_acc = self.assm(mol_batch, mol_vec, tree_mess)
        if self.use_stereo:
            stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
        else:
            stereo_loss, stereo_acc = 0, 0

        all_vec = torch.cat([tree_vec, mol_vec], dim=1)
        loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss

        return loss, kl_loss.item(), word_acc, topo_acc, assm_acc, stereo_acc

    def assm(self, mol_batch, mol_vec, tree_mess):
        cands = []
        batch_idx = []
        for i, mol_tree in enumerate(mol_batch):
            for node in mol_tree.nodes:
                # Leaf node's attachment is determined by neighboring node's attachment
                if node.is_leaf or len(node.cands) == 1: continue
                cands.extend([(cand, mol_tree.nodes, node)
                              for cand in node.cand_mols])
                batch_idx.extend([i] * len(node.cands))

        cand_vec = self.jtmpn(cands, tree_mess)
        cand_vec = self.G_mean(cand_vec)

        batch_idx = create_var(torch.LongTensor(batch_idx))
        mol_vec = mol_vec.index_select(0, batch_idx)

        mol_vec = mol_vec.view(-1, 1, int(self.latent_size / 2))
        cand_vec = cand_vec.view(-1, int(self.latent_size / 2), 1)
        scores = torch.bmm(mol_vec, cand_vec).squeeze()

        cnt, tot, acc = 0, 0, 0
        all_loss = []
        for i, mol_tree in enumerate(mol_batch):
            comp_nodes = [
                node for node in mol_tree.nodes
                if len(node.cands) > 1 and not node.is_leaf
            ]
            cnt += len(comp_nodes)
            for node in comp_nodes:
                label = node.cands.index(node.label)
                ncand = len(node.cands)
                cur_score = scores.narrow(0, tot, ncand)
                tot += ncand

                if cur_score[label].item() >= cur_score.max().item():
                    acc += 1

                label = create_var(torch.LongTensor([label]))
                all_loss.append(self.assm_loss(cur_score.view(1, -1), label))

        # all_loss = torch.stack(all_loss).sum() / len(mol_batch)
        all_loss = sum(all_loss) / len(mol_batch)
        return all_loss, acc * 1.0 / cnt

    def stereo(self, mol_batch, mol_vec):
        stereo_cands, batch_idx = [], []
        labels = []
        for i, mol_tree in enumerate(mol_batch):
            cands = mol_tree.stereo_cands
            if len(cands) == 1: continue
            if mol_tree.smiles3D not in cands:
                cands.append(mol_tree.smiles3D)
            stereo_cands.extend(cands)
            batch_idx.extend([i] * len(cands))
            labels.append((cands.index(mol_tree.smiles3D), len(cands)))

        if len(labels) == 0:
            return create_var(torch.zeros(1)), 1.0

        batch_idx = create_var(torch.LongTensor(batch_idx))
        stereo_cands = self.mpn(mol2graph(stereo_cands))
        stereo_cands = self.G_mean(stereo_cands)
        stereo_labels = mol_vec.index_select(0, batch_idx)
        scores = torch.nn.CosineSimilarity()(stereo_cands, stereo_labels)

        st, acc = 0, 0
        all_loss = []
        for label, le in labels:
            cur_scores = scores.narrow(0, st, le)
            if cur_scores.data[label] >= cur_scores.max().data:
                acc += 1
            label = create_var(torch.LongTensor([label]))
            all_loss.append(self.stereo_loss(cur_scores.view(1, -1), label))
            st += le
        # all_loss = torch.cat(all_loss).sum() / len(labels)
        all_loss = sum(all_loss) / len(labels)
        return all_loss, acc * 1.0 / len(labels)

    def reconstruct(self, smiles, prob_decode=False):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        _, tree_vec, mol_vec = self.encode([mol_tree])

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(
            self.T_var(tree_vec))  # Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(
            self.G_var(mol_vec))  # Following Mueller et al.

        epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False)
        tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
        epsilon = create_var(torch.randn(1, int(self.latent_size / 2)), False)
        mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
        return self.decode(tree_vec, mol_vec, prob_decode)

    def recon_eval(self, smiles):
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        _, tree_vec, mol_vec = self.encode([mol_tree])

        tree_mean = self.T_mean(tree_vec)
        tree_log_var = -torch.abs(
            self.T_var(tree_vec))  # Following Mueller et al.
        mol_mean = self.G_mean(mol_vec)
        mol_log_var = -torch.abs(
            self.G_var(mol_vec))  # Following Mueller et al.

        all_smiles = []
        for i in range(10):
            epsilon = create_var(torch.randn(1, int(self.latent_size / 2)),
                                 False)
            tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
            epsilon = create_var(torch.randn(1, int(self.latent_size / 2)),
                                 False)
            mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
            for j in range(10):
                new_smiles = self.decode(tree_vec, mol_vec, prob_decode=True)
                all_smiles.append(new_smiles)
        return all_smiles

    def sample_prior(self, prob_decode=False):
        tree_vec = create_var(torch.randn(1, int(self.latent_size / 2)), False)
        mol_vec = create_var(torch.randn(1, int(self.latent_size / 2)), False)
        return self.decode(tree_vec, mol_vec, prob_decode)

    def sample_eval(self):
        tree_vec = create_var(torch.randn(1, int(self.latent_size / 2)), False)
        mol_vec = create_var(torch.randn(1, int(self.latent_size / 2)), False)
        all_smiles = []
        for i in range(100):
            s = self.decode(tree_vec, mol_vec, prob_decode=True)
            all_smiles.append(s)
        return all_smiles

    def decode(self, tree_vec, mol_vec, prob_decode):
        pred_root, pred_nodes = self.decoder.decode(tree_vec, prob_decode)

        # Mark nid & is_leaf & atommap
        for i, node in enumerate(pred_nodes):
            node.nid = i + 1
            node.is_leaf = (len(node.neighbors) == 1)
            if len(node.neighbors) > 1:
                set_atommap(node.mol, node.nid)

        tree_mess = self.jtnn([pred_root])[0]

        cur_mol = copy_edit_mol(pred_root.mol)
        global_amap = [{}] + [{} for node in pred_nodes]
        global_amap[1] = {
            atom.GetIdx(): atom.GetIdx()
            for atom in cur_mol.GetAtoms()
        }

        cur_mol = self.dfs_assemble(tree_mess, mol_vec, pred_nodes, cur_mol,
                                    global_amap, [], pred_root, None,
                                    prob_decode)
        if cur_mol is None:
            return None

        cur_mol = cur_mol.GetMol()
        set_atommap(cur_mol)
        cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
        if cur_mol is None: return None
        if self.use_stereo == False:
            return Chem.MolToSmiles(cur_mol)

        smiles2D = Chem.MolToSmiles(cur_mol)
        stereo_cands = decode_stereo(smiles2D)
        if len(stereo_cands) == 1:
            return stereo_cands[0]
        stereo_vecs = self.mpn(mol2graph(stereo_cands))
        stereo_vecs = self.G_mean(stereo_vecs)
        scores = nn.CosineSimilarity()(stereo_vecs, mol_vec)
        _, max_id = scores.max(dim=0)
        return stereo_cands[max_id.data[0]]

    def dfs_assemble(self, tree_mess, mol_vec, all_nodes, cur_mol, global_amap,
                     fa_amap, cur_node, fa_node, prob_decode):
        fa_nid = fa_node.nid if fa_node is not None else -1
        prev_nodes = [fa_node] if fa_node is not None else []

        children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
        neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
        neighbors = sorted(neighbors,
                           key=lambda x: x.mol.GetNumAtoms(),
                           reverse=True)
        singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
        neighbors = singletons + neighbors

        cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap
                    if nid == cur_node.nid]
        cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
        if len(cands) == 0:
            return None
        cand_smiles, cand_mols, cand_amap = zip(*cands)

        cands = [(candmol, all_nodes, cur_node) for candmol in cand_mols]

        cand_vecs = self.jtmpn(cands, tree_mess)
        cand_vecs = self.G_mean(cand_vecs)
        mol_vec = mol_vec.squeeze()
        scores = torch.mv(cand_vecs, mol_vec) * 20

        if prob_decode:
            probs = nn.Softmax()(scores.view(
                1, -1)).squeeze() + 1e-5  # prevent prob = 0
            cand_idx = torch.multinomial(probs, probs.numel())
        else:
            _, cand_idx = torch.sort(scores, descending=True)

        backup_mol = Chem.RWMol(cur_mol)
        for i in range(cand_idx.numel()):
            cur_mol = Chem.RWMol(backup_mol)
            pred_amap = cand_amap[cand_idx[i].item()]
            new_global_amap = copy.deepcopy(global_amap)

            for nei_id, ctr_atom, nei_atom in pred_amap:
                if nei_id == fa_nid:
                    continue
                new_global_amap[nei_id][nei_atom] = new_global_amap[
                    cur_node.nid][ctr_atom]

            cur_mol = attach_mols(
                cur_mol, children, [],
                new_global_amap)  # father is already attached
            new_mol = cur_mol.GetMol()
            new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))

            if new_mol is None: continue

            result = True
            for nei_node in children:
                if nei_node.is_leaf: continue
                cur_mol = self.dfs_assemble(tree_mess, mol_vec, all_nodes,
                                            cur_mol, new_global_amap,
                                            pred_amap, nei_node, cur_node,
                                            prob_decode)
                if cur_mol is None:
                    result = False
                    break
            if result: return cur_mol

        return None
Beispiel #5
0
class Seq2Tree:
    def __init__(self,
                 input_size=None,
                 hidden_size=None,
                 output_size=None,
                 optimizer=optim.Adam,
                 criterion=nn.NLLLoss,
                 lr=0.0001,
                 src_vocab=None,
                 tar_vocab=None,
                 sess='',
                 device='cpu'):
        self.encoder = Encoder(input_size, hidden_size, device)
        self.decoder = TreeDecoder(hidden_size, output_size, device)

        if torch.cuda.is_available():
            self.encoder.cuda()
            self.decoder.cuda()

        self.encoder_opt = optimizer(self.encoder.parameters(), lr=lr)
        self.decoder_opt = optimizer(self.decoder.parameters(), lr=lr)
        self.criterion = criterion()

        self.src_vocab = src_vocab
        self.tar_vocab = tar_vocab

        self.sess = sess
        self.device = device

    def get_idx(self, decoder_output):
        topv, topi = decoder_output.data.topk(1)
        idx = topi.item()
        decoder_input = topi.squeeze().detach()
        return idx, decoder_input

    def run_epoch(self, src_inputs, tar_outputs, batch_size=1):
        """
        one training epoch
        """
        self.encoder_opt.zero_grad()
        self.decoder_opt.zero_grad()

        # encode the source input
        _, (encoder_h, encoder_c) = self.encoder(src_inputs,
                                                 batch_size=batch_size)

        # print encoder_h.shape # (1, 20, 10)
        # print encoder_c.shape # (1, 20, 10)

        SOS_token = self.tar_vocab.word_to_index['<S>']
        EOS_token = self.tar_vocab.word_to_index['</S>']
        NON_token = self.tar_vocab.word_to_index['<N>']

        loss = 0
        decoder_h = encoder_h.view(batch_size, 1, 1, -1)
        decoder_c = encoder_c.view(batch_size, 1, 1, -1)

        tar_count = 0

        for batch in range(batch_size):
            decoder_hidden = decoder_h[batch], decoder_c[batch]
            # (1, 1, hidden), (1, 1, hidden)

            # see Dong et al. (2016) [Algorithm 1]
            root = {
                'parent': decoder_hidden[0],  # (1, 1, 200)
                'hidden': decoder_hidden,  # (1, 1, 200) * 2
            }

            tar_idx = 0

            queue = [root]

            while queue:
                # until no more nonterminals
                subtree = queue.pop(0)
                # get the next subtree in tar_output
                tar_seq = tar_outputs[batch][tar_idx]

                # count items in sequence (for averaging loss)
                tar_count += len(tar_seq)

                # initialize the sequence
                # NOTE: batch_size is 1
                decoder_input = torch.tensor([[SOS_token]],
                                             dtype=torch.long,
                                             device=self.device)

                # get the parent-feeding vector
                parent_input = subtree['parent']

                idx = SOS_token

                # Teacher forcing with trees
                for i in range(1, len(tar_seq)):
                    # decode the input sequence
                    decoder_output, decoder_hidden = self.decoder(
                        decoder_input,
                        hidden=decoder_hidden,
                        parent=parent_input)
                    # interpret the output
                    idx, decoder_input = self.get_idx(decoder_output)

                    # get the desired output
                    target_output = torch.tensor([tar_seq[i]],
                                                 dtype=torch.long,
                                                 device=self.device)

                    # calculate loss
                    loss += self.criterion(decoder_output, target_output)

                    # if we have a non-terminal token
                    if tar_seq[i] == NON_token:
                        # add a subtree to the queue
                        ### parent: the previous state for <n>
                        ### hidden: the hidden state for <n>
                        ### children: subtrees
                        nonterminal = {
                            'parent': decoder_hidden[0],
                            'hidden': decoder_hidden,
                            # 'children': []
                        }

                        queue.append(nonterminal)

                    decoder_input = target_output  # Teacher forcing

                # next subtree in tar_output
                tar_idx += 1

        loss.backward()
        self.encoder_opt.step()
        self.decoder_opt.step()

        return loss.item() / tar_count

    def train(self,
              X_train,
              y_train,
              epochs=10,
              retrain=0,
              batch_size=10,
              loss_update=10):
        cum_loss = 0
        history = {}
        losses = []

        def get_progress(num, den, length):
            """
            simple progress bar
            """
            if num == den - 1:
                return '=' * length
            arrow = int(float(num) / den * length)
            return '=' * (arrow - 1) + '>' + '.' * (20 - arrow)

        for epoch in range(retrain, epochs):
            epoch_loss = 0

            print 'Epoch %d/%d' % (epoch, epochs)

            for i in range(0, len(X_train), batch_size):
                X_batch = X_train[i:i + batch_size]

                if len(X_batch) < batch_size:
                    continue

                if epoch == 0:
                    # initialize training data to trees
                    for j in range(i, i + batch_size):
                        root = Tree(formula=y_train[j])
                        y_train[j] = [
                            self.tar_vocab.sent_to_idx(formula)
                            for formula in root.inorder()
                        ]

                X_batch = torch.tensor(X_batch,
                                       dtype=torch.long,
                                       device=self.device)
                y_batch = y_train[i:i + batch_size]

                loss = self.run_epoch(X_batch, y_batch, batch_size=batch_size)

                cum_loss += loss
                epoch_loss += loss

                progress = get_progress(i, len(X_train), length=20)
                out = '%d/%d [%s] loss: %f' % (i, len(X_train), progress,
                                               epoch_loss / (i + 1))
                sys.stdout.write('{0}\r'.format(out))
                sys.stdout.flush()
            print

            losses.append(epoch_loss)
            if epoch % loss_update == 0:
                self.save('%s_epoch_%d.json' % (self.sess, epoch))
        history['losses'] = losses
        return history

    def flatten(self, root):
        decoded_seq = []
        for child in root['children']:
            if isinstance(child, int):
                decoded_seq.append(child)
            else:
                flattened = self.flatten(child)
                decoded_seq += flattened
        return decoded_seq

    def predict(self, X_test):
        with torch.no_grad():
            decoded_text = []

            for i in range(len(X_test)):
                src_input = torch.tensor(X_test[i],
                                         dtype=torch.long,
                                         device=self.device).view(1, -1)

                # encode the source input
                encoder_output, encoder_hidden = self.encoder(src_input)

                SOS_token = self.tar_vocab.word_to_index['<S>']
                EOS_token = self.tar_vocab.word_to_index['</S>']
                NON_token = self.tar_vocab.word_to_index['<N>']

                decoder_hidden = encoder_hidden

                # see Dong et al. (2016) [Algorithm 1]
                root = {
                    'parent': decoder_hidden[0],  # (1, 1, 200)
                    'hidden': decoder_hidden,  # (1, 1, 200) * 2
                    'children': []
                }

                queue = [root]
                depth = 0
                while queue:
                    # until no more nonterminals
                    subtree = queue.pop(0)

                    # initialize the sequence
                    # NOTE: batch_size is 1
                    decoder_input = torch.tensor([[SOS_token]],
                                                 dtype=torch.long,
                                                 device=self.device)

                    # get the parent-feeding vector
                    parent_input = subtree['parent']

                    idx = SOS_token

                    while idx != EOS_token:
                        # decode the input sequence
                        decoder_output, decoder_hidden = self.decoder(
                            decoder_input,
                            hidden=decoder_hidden,
                            parent=parent_input)

                        # interpret the output
                        idx, decoder_input = self.get_idx(decoder_output)

                        # if exceeds max length
                        if len(subtree['children']) == 20:
                            idx = EOS_token
                        # if exceed max depth
                        if depth > 3:
                            idx = EOS_token

                        # if we have a non-terminal token
                        if idx == NON_token:

                            # add a subtree to the queue
                            ### parent: the previous state for <n>
                            ### hidden: the hidden state for <n>
                            ### children: subtrees
                            nonterminal = {
                                'parent': decoder_hidden[0],
                                'hidden': decoder_hidden,
                                'children': []
                            }

                            queue.append(nonterminal)
                            subtree['children'].append(nonterminal)

                            depth += 1
                        else:
                            subtree['children'].append(idx)

                decoded_seq = self.flatten(root)
                decoded_text.append(decoded_seq)

        return decoded_text

    def evaluate(self, X_test, y_test, preds, out=None):
        """
        for seq2tree models, X_test and preds are going to be
        lists of lists of indexes => [[idx]]

        y_test is going to be
        list of sents => [sent]
        """
        if out:
            outfile = out
            errfile = 'err_' + out
        else:
            outfile = 'logs/sessions/%s.out' % self.sess
            errfile = 'logs/sessions/err_%s.out' % self.sess

        print 'logging to %s...' % outfile

        num_correct = 0

        def preprocess(fol_pred_idx):
            fol_pred = self.tar_vocab.reverse(fol_pred_idx)
            return fol_pred.replace('<S>', '').replace('</S>', '')

        with open(outfile, 'w') as w:
            with open(errfile, 'w') as err:
                for nl_idx, fol_gold, fol_pred_idx in zip(
                        X_test, y_test, preds):
                    nl_sent = self.src_vocab.reverse(nl_idx)

                    fol_pred = preprocess(fol_pred_idx)

                    if fol_gold != fol_pred:
                        err.write('input:  ' + nl_sent + '\n')
                        err.write('gold:   ' + fol_gold + '\n')
                        err.write('output: ' + fol_pred + '\n')
                        err.write('\n')
                    else:
                        num_correct += 1

                    w.write('%s\t%s\t%s\t\n' % (nl_sent, fol_gold, fol_pred))

        print '########################'
        print '# Evaluation:'
        print '# %d out of %d correct' % (num_correct, len(preds))
        print '# %0.3f accuracy' % (float(num_correct) / len(preds))
        print '########################'

    def save(self, filename):
        torch.save(self.encoder.state_dict(),
                   'logs/sessions/enc_%s' % filename)
        torch.save(self.decoder.state_dict(),
                   'logs/sessions/dec_%s' % filename)

    def load(self, filename):
        self.encoder.load_state_dict(
            torch.load('logs/sessions/enc_%s' % filename))
        self.decoder.load_state_dict(
            torch.load('logs/sessions/dec_%s' % filename))