Exemplo n.º 1
0
    def translate(self, batch):
        q, q_len = batch.src
        batch_size = q.size(1)

        # encoding
        q_enc, q_all = self.model.q_encoder(q, lengths=q_len, ent=None)
        if self.model.opt.seprate_encoder:
            q_tgt_enc, q_tgt_all = self.model.q_tgt_encoder(q,
                                                            lengths=q_len,
                                                            ent=None)
        else:
            q_tgt_enc, q_tgt_all = q_enc, q_all

        vocab_mask = None
        layout_token_prune_list = [None for b in range(batch_size)]

        # layout decoding
        lay_dec = self.run_lay_decoder(self.model.lay_decoder,
                                       self.model.lay_classifier, q, q_all,
                                       q_enc, self.opt.max_lay_len, vocab_mask,
                                       self.fields['lay'].vocab)
        if self.opt.gold_layout:
            lay_dec = batch.lay[0].data[1:]
        # recover layout
        lay_list = []
        for b in range(batch_size):
            lay_field = 'lay'
            lay = recover_layout_token(
                [lay_dec[i, b] for i in range(lay_dec.size(0))],
                self.fields[lay_field].vocab, lay_dec.size(0))
            lay_list.append(lay)

        # layout encoding
        # lay_len = get_decode_batch_length(lay_dec, batch_size, self.opt.max_lay_len)
        lay_len = torch.LongTensor(
            [len(lay_list[b]) for b in range(batch_size)])
        # data used for layout encoding
        lay_dec = torch.LongTensor(lay_len.max(),
                                   batch_size).fill_(table.IO.PAD)
        for b in range(batch_size):
            for i in range(lay_len[b]):
                lay_dec[i, b] = self.fields['lay'].vocab.stoi[lay_list[b][i]]
        lay_dec = v_eval(lay_dec.cuda())
        # (lay_len, batch, lay_size)
        if self.model.opt.no_lay_encoder:
            lay_all = self.model.lay_encoder(lay_dec)
        else:
            lay_enc_len = lay_len.cuda().clamp(min=1)
            lay_all = encode_unsorted_batch(self.model.lay_encoder, lay_dec,
                                            lay_enc_len)
        # co-attention
        if self.model.lay_co_attention is not None:
            lay_all = self.model.lay_co_attention(lay_all, lay_enc_len, q_all,
                                                  q)

        # get lay_index and tgt_mask: (tgt_len, batch)
        lay_skip_list, tgt_mask_seq, lay_index_seq = expand_layout_with_skip(
            lay_list)

        # co-attention
        if self.model.q_co_attention is not None:
            q_tgt_enc, q_tgt_all = self.model.q_co_attention(
                q_tgt_all, q_len, lay_all, lay_dec)

        # target decoding
        tgt_dec = self.run_tgt_decoder(
            self.model.tgt_embeddings, tgt_mask_seq, lay_index_seq, lay_all,
            self.model.tgt_decoder, self.model.tgt_classifier, q, q_tgt_all,
            q_tgt_enc, self.opt.max_tgt_len, lay_skip_list,
            self.fields['tgt'].vocab, batch.copy_to_ext, batch.copy_to_tgt)
        # recover target
        tgt_list = []
        for b in range(batch_size):
            tgt = recover_target_token(
                lay_skip_list[b],
                [tgt_dec[i, b] for i in range(tgt_dec.size(0))],
                self.fields['tgt'].vocab, self.fields['copy_to_ext'].vocab,
                tgt_dec.size(0))
            tgt_list.append(tgt)

        # (3) recover output
        indices = cpu_vector(batch.indices.data)
        return [
            ParseResult(idx, lay, tgt, token_prune)
            for idx, lay, tgt, token_prune in zip(indices, lay_list, tgt_list,
                                                  layout_token_prune_list)
        ]
Exemplo n.º 2
0
    def translate(self, batch):
        q, q_len = batch.src
        tbl, tbl_len = batch.tbl
        ent, tbl_split, tbl_mask = batch.ent, batch.tbl_split, batch.tbl_mask

        # encoding

        q_enc, q_all, tbl_enc, q_ht, batch_size = self.model.enc(
            q, q_len, ent, batch.type, tbl, tbl_len, tbl_split, tbl_mask
        )  #query, query length, table, table length, table split, table mask

        BIO_op_out = self.model.BIO_op_classifier(q_all)
        tsp_q = BIO_op_out.size(0)
        bsz = BIO_op_out.size(1)
        BIO_op_out = BIO_op_out.view(-1, BIO_op_out.size(2))
        BIO_op_out = F.log_softmax(BIO_op_out, dim=-1)
        BIO_op_out_sf = torch.exp(BIO_op_out)
        BIO_op_out = BIO_op_out.view(tsp_q, bsz, -1)
        BIO_op_out_sf = BIO_op_out_sf.view(tsp_q, bsz, -1)

        # if fff == 1:
        #    print(BIO_op_out_sf.transpose(0,1)[0])
        #    print(BIO_op_out.transpose(0, 1)[0])

        BIO_out = self.model.BIO_classifier(q_all)
        BIO_out = BIO_out.view(-1, BIO_out.size(2))
        BIO_out = F.log_softmax(BIO_out, dim=-1)
        BIO_out_sf = torch.exp(BIO_out)
        BIO_out = BIO_out.view(tsp_q, bsz, -1)
        BIO_out_sf = BIO_out_sf.view(tsp_q, bsz, -1)
        # if fff == 1:
        #    print(BIO_out_sf.transpose(0,1)[0])
        #    print(BIO_out.transpose(0, 1)[0])

        BIO_col_out = self.model.label_col_match(q_all, tbl_enc, tbl_mask)
        # if fff == 1:
        #    print(BIO_col_out.size())
        #    print(BIO_col_out.transpose(0, 1)[0])
        BIO_col_out = BIO_col_out.view(-1, BIO_col_out.size(2))
        BIO_col_out = F.log_softmax(BIO_col_out, dim=-1)
        BIO_col_out_sf = torch.exp(BIO_col_out)
        BIO_col_out = BIO_col_out.view(tsp_q, bsz, -1)
        BIO_col_out_sf = BIO_col_out_sf.view(tsp_q, bsz, -1)

        BIO_pred = argmax(BIO_out_sf.data).transpose(0, 1)
        BIO_col_pred = argmax(BIO_col_out_sf.data).transpose(0, 1)
        for i in range(BIO_pred.size(0)):
            for j in range(BIO_pred.size(1)):
                if BIO_pred[i][j] == 2:
                    BIO_col_pred[i][j] = -1

        # (1) decoding
        q_self_encode = self.model.agg_self_attention(q_all, q_len)  #q_ht
        q_self_encode_layout = self.model.lay_self_attention(q_all,
                                                             q_len)  #q_ht
        agg_pred = cpu_vector(
            argmax(self.model.agg_classifier(q_self_encode).data))
        sel_out = self.model.sel_match(q_self_encode,
                                       tbl_enc,
                                       tbl_mask,
                                       select=True)  # select column
        sel_pred = cpu_vector(
            argmax(
                self.model.sel_match(q_self_encode,
                                     tbl_enc,
                                     tbl_mask,
                                     select=True).data))
        lay_pred = argmax(self.model.lay_classifier(q_self_encode_layout).data)
        # get layout op tokens
        op_batch_list = []
        op_idx_batch_list = []
        if self.opt.gold_layout:
            lay_pred = batch.lay.data
            cond_op, cond_op_len = batch.cond_op
            cond_op_len_list = cond_op_len.view(-1).tolist()
            for i, len_it in enumerate(cond_op_len_list):
                if len_it == 0:
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    idx_list = cond_op.data[0:len_it,
                                            i].contiguous().view(-1).tolist()
                    op_idx_batch_list.append([
                        int(self.fields['cond_op'].vocab.itos[it])
                        for it in idx_list
                    ])
                    op_batch_list.append(idx_list)
        else:
            lay_batch_list = lay_pred.view(-1).tolist()
            for lay_it in lay_batch_list:
                tk_list = self.fields['lay'].vocab.itos[lay_it].split(' ')
                if (len(tk_list) == 0) or (tk_list[0] == ''):
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    op_idx_batch_list.append(
                        [int(op_str) for op_str in tk_list])
                    op_batch_list.append([
                        self.fields['cond_op'].vocab.stoi[op_str]
                        for op_str in tk_list
                    ])
            # -> (num_cond, batch)
            cond_op = v_eval(
                add_pad(
                    op_batch_list,
                    self.fields['cond_op'].vocab.stoi[table.IO.PAD_WORD]).t())
            cond_op_len = torch.LongTensor([len(it) for it in op_batch_list])
        # emb_op -> (num_cond, batch, emb_size)
        if self.model.opt.layout_encode == 'rnn':
            emb_op = table.Models.encode_unsorted_batch(
                self.model.lay_encoder, cond_op, cond_op_len.clamp(min=1))
        else:
            emb_op = self.model.cond_embedding(cond_op)

        # (2) decoding
        self.model.cond_decoder.attn.applyMaskBySeqBatch(q)
        cond_state = self.model.cond_decoder.init_decoder_state(q_all, q_enc)
        cond_col_list, cond_span_l_list, cond_span_r_list = [], [], []
        for emb_op_t in emb_op:
            emb_op_t = emb_op_t.unsqueeze(0)
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_op_t, q_all, cond_state)
            #print(cond_context.size())
            #cond_context = self.model.decode_softattention(cond_context, q_all, q_len)
            #print(cond_context.size())

            # cond col -> (1, batch)
            cond_col = argmax(
                self.model.cond_col_match(cond_context, tbl_enc,
                                          tbl_mask).data)
            cond_col_list.append(cpu_vector(cond_col))
            # emb_col
            batch_index = torch.LongTensor(
                range(batch_size)).unsqueeze_(0).cuda().expand(
                    cond_col.size(0), cond_col.size(1))
            emb_col = tbl_enc[cond_col, batch_index, :]
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_col, q_all, cond_state)

            # cond span
            q_mask = v_eval(
                q.data.eq(self.model.pad_word_index).transpose(0, 1))
            cond_span_l = argmax(
                self.model.cond_span_l_match(cond_context, q_all, q_mask).data)
            cond_span_l_list.append(cpu_vector(cond_span_l))
            # emb_span_l: (1, batch, hidden_size)
            emb_span_l = q_all[cond_span_l, batch_index, :]
            cond_span_r = argmax(
                self.model.cond_span_r_match(cond_context,
                                             q_all,
                                             q_mask,
                                             emb_span_l=emb_span_l).data)
            cond_span_r_list.append(cpu_vector(cond_span_r))
            # emb_span_r: (1, batch, hidden_size)
            emb_span_r = q_all[cond_span_r, batch_index, :]

            emb_span = self.model.span_merge(
                torch.cat([emb_span_l, emb_span_r], 2))

            #            mask = torch.zeros([cond_col.size(0), q_all.size(0), q_all.size(1)])  # (num_cond,tsp,bsz)
            #            for j in range(q_all.size(1)):
            #                for i in range(cond_col.size(0)):
            #                    for k in range(cond_span_l[i][j], cond_span_r[i][j] + 1):
            #                        mask[i][k][j] = 1

            #            mask = mask.unsqueeze_(3)  # .expand(cond_col.size(0),q_all.size(0),q_all.size(1),q_all.size(2))

            #            emb_span = Variable(mask.cuda()) * torch.unsqueeze(q_all, 0)  # .expand_as(mask)  #(num_cond,tsp,bsz,hidden)
            #            emb_span = torch.mean(emb_span, dim=1)  # (num_cond,bsz,hidden)  mean pooling

            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_span, q_all, cond_state)

        # (3) recover output
        indices = cpu_vector(batch.indices.data)
        r_list = []
        for b in range(batch_size):
            idx = indices[b]
            agg = agg_pred[b]
            sel = sel_pred[b]
            BIO = BIO_pred[b]
            BIO_col = BIO_col_pred[b]
            cond = []
            for i in range(len(op_batch_list[b])):
                col = cond_col_list[i][b]
                op = op_idx_batch_list[b][i]
                span_l = cond_span_l_list[i][b]
                span_r = cond_span_r_list[i][b]
                cond.append((col, op, (span_l, span_r)))
            r_list.append(ParseResult(idx, agg, sel, cond, BIO, BIO_col))

        return r_list
Exemplo n.º 3
0
    def translate(self, batch):
        q, q_len = batch.src
        tbl, tbl_len = batch.tbl
        ent, tbl_split, tbl_mask = batch.ent, batch.tbl_split, batch.tbl_mask

        # encoding
        q_enc, q_all, tbl_enc, q_ht, batch_size = self.model.enc(
            q, q_len, ent, tbl, tbl_len, tbl_split, tbl_mask)

        # (1) decoding
        agg_pred = cpu_vector(argmax(self.model.agg_classifier(q_ht).data))
        sel_pred = cpu_vector(
            argmax(self.model.sel_match(q_ht, tbl_enc, tbl_mask).data))
        lay_pred = argmax(self.model.lay_classifier(q_ht).data)
        # get layout op tokens
        op_batch_list = []
        op_idx_batch_list = []
        if self.opt.gold_layout:
            lay_pred = batch.lay.data
            cond_op, cond_op_len = batch.cond_op
            cond_op_len_list = cond_op_len.view(-1).tolist()
            for i, len_it in enumerate(cond_op_len_list):
                if len_it == 0:
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    idx_list = cond_op.data[0:len_it,
                                            i].contiguous().view(-1).tolist()
                    op_idx_batch_list.append([
                        int(self.fields['cond_op'].vocab.itos[it])
                        for it in idx_list
                    ])
                    op_batch_list.append(idx_list)
        else:
            lay_batch_list = lay_pred.view(-1).tolist()
            for lay_it in lay_batch_list:
                tk_list = self.fields['lay'].vocab.itos[lay_it].split(' ')
                if (len(tk_list) == 0) or (tk_list[0] == ''):
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    op_idx_batch_list.append(
                        [int(op_str) for op_str in tk_list])
                    op_batch_list.append([
                        self.fields['cond_op'].vocab.stoi[op_str]
                        for op_str in tk_list
                    ])
            # -> (num_cond, batch)
            cond_op = v_eval(
                add_pad(
                    op_batch_list,
                    self.fields['cond_op'].vocab.stoi[table.IO.PAD_WORD]).t())
            cond_op_len = torch.LongTensor([len(it) for it in op_batch_list])
        # emb_op -> (num_cond, batch, emb_size)
        if self.model.opt.layout_encode == 'rnn':
            emb_op = table.Models.encode_unsorted_batch(
                self.model.lay_encoder, cond_op, cond_op_len.clamp(min=1))
        else:
            emb_op = self.model.cond_embedding(cond_op)

        # (2) decoding
        self.model.cond_decoder.attn.applyMaskBySeqBatch(q)
        cond_state = self.model.cond_decoder.init_decoder_state(q_all, q_enc)
        cond_col_list, cond_span_l_list, cond_span_r_list = [], [], []
        for emb_op_t in emb_op:
            emb_op_t = emb_op_t.unsqueeze(0)
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_op_t, q_all, cond_state)
            # cond col -> (1, batch)
            cond_col = argmax(
                self.model.cond_col_match(cond_context, tbl_enc,
                                          tbl_mask).data)
            cond_col_list.append(cpu_vector(cond_col))
            # emb_col
            batch_index = torch.LongTensor(
                range(batch_size)).unsqueeze_(0).cuda().expand(
                    cond_col.size(0), cond_col.size(1))
            emb_col = tbl_enc[cond_col, batch_index, :]
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_col, q_all, cond_state)
            # cond span
            q_mask = v_eval(
                q.data.eq(self.model.pad_word_index).transpose(0, 1))
            cond_span_l = argmax(
                self.model.cond_span_l_match(cond_context, q_all, q_mask).data)
            cond_span_l_list.append(cpu_vector(cond_span_l))
            # emb_span_l: (1, batch, hidden_size)
            emb_span_l = q_all[cond_span_l, batch_index, :]
            cond_span_r = argmax(
                self.model.cond_span_r_match(cond_context, q_all, q_mask,
                                             emb_span_l).data)
            cond_span_r_list.append(cpu_vector(cond_span_r))
            # emb_span_r: (1, batch, hidden_size)
            emb_span_r = q_all[cond_span_r, batch_index, :]
            emb_span = self.model.span_merge(
                torch.cat([emb_span_l, emb_span_r], 2))
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_span, q_all, cond_state)

        # (3) recover output
        indices = cpu_vector(batch.indices.data)
        r_list = []
        for b in range(batch_size):
            idx = indices[b]
            agg = agg_pred[b]
            sel = sel_pred[b]
            cond = []
            for i in range(len(op_batch_list[b])):
                col = cond_col_list[i][b]
                op = op_idx_batch_list[b][i]
                span_l = cond_span_l_list[i][b]
                span_r = cond_span_r_list[i][b]
                cond.append((col, op, (span_l, span_r)))
            r_list.append(ParseResult(idx, agg, sel, cond))

        return r_list
Exemplo n.º 4
0
    def translate(self, batch):
        q, q_len = batch.src
        batch_size = q.size(1)

        # encoding
        q_enc, q_all = self.model.q_encoder(q, lengths=q_len, ent=None)
        if self.model.opt.seprate_encoder:
            q_tgt_enc, q_tgt_all = self.model.q_tgt_encoder(q,
                                                            lengths=q_len,
                                                            ent=None)
        else:
            q_tgt_enc, q_tgt_all = q_enc, q_all

        if self.model.opt.layout_token_prune:
            layout_token_prune_list = []
            q_token_enc, __ = self.model.q_token_encoder(q,
                                                         lengths=q_len,
                                                         ent=None)
            # (num_layers * num_directions, batch, hidden_size)
            q_token_ht, __ = q_token_enc
            batch_size = q_token_ht.size(1)
            q_token_ht = q_token_ht[
                -1] if not self.model.opt.brnn else q_token_ht[-2:].transpose(
                    0, 1).contiguous().view(batch_size, -1)
            # without .t()
            token_out = F.sigmoid(self.model.token_pruner(q_token_ht))
            # decide prune which tokens
            vocab_mask = token_out.data.lt(0).view(1, batch_size, -1)
            for tk_idx in range(len(table.IO.special_token_list),
                                len(self.fields['lay'].vocab)):
                w = self.fields['lay'].vocab.itos[tk_idx]
                if w.startswith('(') or w in (')', table.IO.TOK_WORD):
                    idx = tk_idx - len(table.IO.special_token_list)
                    vocab_mask[:, :, idx] = 0
            # log pruned tokens for evaluation
            for b in range(batch_size):
                masked_v_list = []
                for i in range(vocab_mask.size(2)):
                    if vocab_mask[0, b, i] == 1:
                        masked_v_list.append(self.fields['lay'].vocab.itos[
                            i + len(table.IO.special_token_list)])
                layout_token_prune_list.append(masked_v_list)
        else:
            token_out = None
            vocab_mask = None
            layout_token_prune_list = [None for b in range(batch_size)]

        # layout decoding
        lay_dec = self.run_lay_decoder(self.model.lay_decoder,
                                       self.model.lay_classifier, q, q_all,
                                       q_enc, self.opt.max_lay_len, vocab_mask,
                                       self.fields['lay'].vocab)
        if self.opt.gold_layout:
            if self.model.opt.bpe:
                lay_dec = batch.lay_bpe[0].data[1:]
            else:
                lay_dec = batch.lay[0].data[1:]
        # recover layout
        lay_list = []
        for b in range(batch_size):
            if self.model.opt.bpe:
                lay_field = 'lay_bpe'
            else:
                lay_field = 'lay'
            lay = recover_layout_token(
                [lay_dec[i, b] for i in range(lay_dec.size(0))],
                self.fields[lay_field].vocab, lay_dec.size(0))
            if self.model.opt.bpe:
                lay = recover_bpe(lay)
            lay_list.append(lay)

        # layout encoding
        # lay_len = get_decode_batch_length(lay_dec, batch_size, self.opt.max_lay_len)
        lay_len = torch.LongTensor(
            [len(lay_list[b]) for b in range(batch_size)])
        # data used for layout encoding
        lay_dec = torch.LongTensor(lay_len.max(),
                                   batch_size).fill_(table.IO.PAD)
        for b in range(batch_size):
            for i in range(lay_len[b]):
                lay_dec[i, b] = self.fields['lay'].vocab.stoi[lay_list[b][i]]
        lay_dec = v_eval(lay_dec.cuda())
        # (lay_len, batch, lay_size)
        if self.model.opt.no_lay_encoder:
            lay_all = self.model.lay_encoder(lay_dec)
        else:
            lay_enc_len = lay_len.cuda().clamp(min=1)
            lay_all = encode_unsorted_batch(self.model.lay_encoder, lay_dec,
                                            lay_enc_len)
        # co-attention
        if self.model.lay_co_attention is not None:
            lay_all = self.model.lay_co_attention(lay_all, lay_enc_len, q_all,
                                                  q)

        # get lay_index and tgt_mask: (tgt_len, batch)
        lay_skip_list, tgt_mask_seq, lay_index_seq = expand_layout_with_skip(
            lay_list)

        # co-attention
        if self.model.q_co_attention is not None:
            q_tgt_enc, q_tgt_all = self.model.q_co_attention(
                q_tgt_all, q_len, lay_all, lay_dec)

        # target decoding
        tgt_dec = self.run_tgt_decoder(self.model.tgt_embeddings, tgt_mask_seq,
                                       lay_index_seq, lay_all,
                                       self.model.tgt_decoder,
                                       self.model.tgt_classifier, q, q_tgt_all,
                                       q_tgt_enc, self.opt.max_tgt_len,
                                       lay_skip_list, self.fields['tgt'].vocab)
        # recover target
        tgt_list = []
        for b in range(batch_size):
            tgt = recover_target_token(
                lay_skip_list[b],
                [tgt_dec[i, b]
                 for i in range(tgt_dec.size(0))], self.fields['tgt'].vocab,
                tgt_dec.size(0))
            tgt_list.append(tgt)

        # (3) recover output
        indices = cpu_vector(batch.indices.data)
        return [
            ParseResult(idx, lay, tgt, token_prune)
            for idx, lay, tgt, token_prune in zip(indices, lay_list, tgt_list,
                                                  layout_token_prune_list)
        ]
Exemplo n.º 5
0
    def translate(self, batch, js_list=[], sql_list=[]):
        q, q_len = batch.src
        tbl, tbl_len = batch.tbl
        ent, tbl_split, tbl_mask = batch.ent, batch.tbl_split, batch.tbl_mask
        # encoding
        q_enc, q_all, tbl_enc, q_ht, batch_size = self.model.enc(
            q, q_len, ent, tbl, tbl_len, tbl_split, tbl_mask)

        # (1) decoding
        agg_pred = cpu_vector(argmax(self.model.agg_classifier(q_ht).data))
        sel_pred = cpu_vector(
            argmax(self.model.sel_match(q_ht, tbl_enc, tbl_mask).data))
        lay_pred = argmax(self.model.lay_classifier(q_ht).data)
        engine = DBEngine(self.opt.db_file)
        indices = cpu_vector(batch.indices.data)
        # get layout op tokens
        op_batch_list = []
        op_idx_batch_list = []
        if self.opt.gold_layout:
            lay_pred = batch.lay.data
            cond_op, cond_op_len = batch.cond_op
            cond_op_len_list = cond_op_len.view(-1).tolist()
            for i, len_it in enumerate(cond_op_len_list):
                if len_it == 0:
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    idx_list = cond_op.data[0:len_it,
                                            i].contiguous().view(-1).tolist()
                    op_idx_batch_list.append([
                        int(self.fields['cond_op'].vocab.itos[it])
                        for it in idx_list
                    ])
                    op_batch_list.append(idx_list)
        else:
            lay_batch_list = lay_pred.view(-1).tolist()
            for lay_it in lay_batch_list:
                tk_list = self.fields['lay'].vocab.itos[lay_it].split(' ')
                if (len(tk_list) == 0) or (tk_list[0] == ''):
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    op_idx_batch_list.append(
                        [int(op_str) for op_str in tk_list])
                    op_batch_list.append([
                        self.fields['cond_op'].vocab.stoi[op_str]
                        for op_str in tk_list
                    ])
            # -> (num_cond, batch)
            cond_op = v_eval(
                add_pad(
                    op_batch_list,
                    self.fields['cond_op'].vocab.stoi[table.IO.PAD_WORD]).t())
            cond_op_len = torch.LongTensor([len(it) for it in op_batch_list])
        # emb_op -> (num_cond, batch, emb_size)
        if self.model.opt.layout_encode == 'rnn':
            emb_op = table.Models.encode_unsorted_batch(
                self.model.lay_encoder, cond_op, cond_op_len.clamp(min=1))
        else:
            emb_op = self.model.cond_embedding(cond_op)

        # (2) decoding
        self.model.cond_decoder.attn.applyMaskBySeqBatch(q)
        cond_state = self.model.cond_decoder.init_decoder_state(q_all, q_enc)
        cond_col_list, cond_span_l_list, cond_span_r_list = [], [], []
        t = 0
        need_back_track = [False] * batch_size
        for emb_op_t in emb_op:
            t += 1
            emb_op_t = emb_op_t.unsqueeze(0)
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_op_t, q_all, cond_state)
            # cond col -> (1, batch)
            cond_col_all = self.model.cond_col_match(cond_context, tbl_enc,
                                                     tbl_mask).data
            cond_col = argmax(cond_col_all)
            # add to this after beam search: cond_col_list.append(cpu_vector(cond_col))
            # emb_col
            batch_index = torch.LongTensor(
                range(batch_size)).unsqueeze_(0).cuda().expand(
                    cond_col.size(0), cond_col.size(1))
            emb_col = tbl_enc[cond_col, batch_index, :]
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_col, q_all, cond_state)
            # cond span
            q_mask = v_eval(
                q.data.eq(self.model.pad_word_index).transpose(0, 1))
            cond_span_l_batch_all = self.model.cond_span_l_match(
                cond_context, q_all, q_mask).data
            cond_span_l_batch = argmax(cond_span_l_batch_all)
            # add to this after beam search: cond_span_l_list.append(cpu_vector(cond_span_l))
            # emb_span_l: (1, batch, hidden_size)
            emb_span_l = q_all[cond_span_l_batch, batch_index, :]
            cond_span_r_batch = argmax(
                self.model.cond_span_r_match(cond_context, q_all, q_mask,
                                             emb_span_l).data)
            # add to this after beam search: cond_span_r_list.append(cpu_vector(cond_span_r))
            if self.opt.beam_search:
                # for now just go through the next col in cond
                k = min(self.opt.beam_size, cond_col_all.size()[2])
                top_col_idx = cond_col_all.topk(k)[1]
                for b in range(batch_size):
                    if t > len(op_idx_batch_list[b]) or need_back_track[b]:
                        continue
                    idx = indices[b]
                    agg = agg_pred[b]
                    sel = sel_pred[b]
                    cond = []
                    for i in range(t):
                        op = op_idx_batch_list[b][i]
                        if i < t - 1:
                            col = cond_col_list[i][b]
                            span_l = cond_span_l_list[i][b]
                            span_r = cond_span_r_list[i][b]
                        else:
                            col = cond_col[0, b]
                            span_l = cond_span_l_batch[0, b]
                            span_r = cond_span_r_batch[0, b]
                        cond.append((col, op, (span_l, span_r)))
                    pred = ParseResult(idx, agg, sel, cond)
                    pred.eval(js_list[idx], sql_list[idx], engine)
                    n_test = 0
                    while pred.exception_raised and n_test < top_col_idx.size(
                    )[2] - 1:
                        n_test += 1
                        if n_test > self.opt.beam_size:
                            need_back_track[b] = True
                            break
                        cond_col[0, b] = top_col_idx[0, b, n_test]
                        emb_col = tbl_enc[cond_col, batch_index, :]
                        cond_context, cond_state, _ = self.model.cond_decoder(
                            emb_col, q_all, cond_state)
                        # cond span
                        q_mask = v_eval(
                            q.data.eq(self.model.pad_word_index).transpose(
                                0, 1))
                        cond_span_l_batch_all = self.model.cond_span_l_match(
                            cond_context, q_all, q_mask).data
                        cond_span_l_batch = argmax(cond_span_l_batch_all)
                        # emb_span_l: (1, batch, hidden_size)
                        emb_span_l = q_all[cond_span_l_batch, batch_index, :]
                        cond_span_r_batch = argmax(
                            self.model.cond_span_r_match(
                                cond_context, q_all, q_mask, emb_span_l).data)
                        # run the new query over database
                        col = cond_col[0, b]
                        span_l = cond_span_l_batch[0, b]
                        span_r = cond_span_r_batch[0, b]
                        cond.pop()
                        cond.append((col, op, (span_l, span_r)))
                        pred = ParseResult(idx, agg, sel, cond)
                        pred.eval(js_list[idx], sql_list[idx], engine)
            cond_col_list.append(cpu_vector(cond_col))
            cond_span_l_list.append(cpu_vector(cond_span_l_batch))
            cond_span_r_list.append(cpu_vector(cond_span_r_batch))
            # emb_span_r: (1, batch, hidden_size)
            emb_span_r = q_all[cond_span_r_batch, batch_index, :]
            emb_span = self.model.span_merge(
                torch.cat([emb_span_l, emb_span_r], 2))
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_span, q_all, cond_state)

        # (3) recover output
        r_list = []
        for b in range(batch_size):
            idx = indices[b]
            agg = agg_pred[b]
            sel = sel_pred[b]
            cond = []
            for i in range(len(op_batch_list[b])):
                col = cond_col_list[i][b]
                op = op_idx_batch_list[b][i]
                span_l = cond_span_l_list[i][b]
                span_r = cond_span_r_list[i][b]
                cond.append((col, op, (span_l, span_r)))
            r_list.append(ParseResult(idx, agg, sel, cond))

        return r_list