def count_where_accuracy(score_span_l, score_span_r, score_col, gold_span_ls, gold_span_rs, gold_cols):
    pred_span_ls = argmax(score_span_l)
    pred_span_rs = argmax(score_span_r)
    pred_cols = argmax(score_col)
    pred_span_ls =pred_span_ls.transpose(0,1)
    pred_span_rs = pred_span_rs.transpose(0, 1)
    pred_cols = pred_cols.transpose(0,1)
    gold_span_ls = gold_span_ls.transpose(0, 1)
    gold_span_rs = gold_span_rs.transpose(0, 1)
    gold_cols = gold_cols.transpose(0, 1)
    exact_matched = 0
    for sample_id in range(pred_span_ls.size(0)):
        pred_span_l = pred_span_ls[sample_id]
        pred_span_r = pred_span_rs[sample_id]
        pred_col = pred_cols[sample_id]

        gold_span_l = gold_span_ls[sample_id]
        gold_span_r = gold_span_rs[sample_id]
        gold_col = gold_cols[sample_id]

        exact_matched+=1
        for i in range(pred_span_l.size(0)):
            if gold_col[i]!=-1:
                if gold_col[i]!=pred_col[i] or gold_span_l[i]!=pred_span_l[i] or gold_span_r[i]!= pred_span_r[i]:
                    exact_matched-=1
                    break
    return exact_matched,pred_span_ls.size(0)
Esempio n. 2
0
 def run_tgt_decoder(self, embeddings, tgt_mask_seq, lay_index_seq, lay_all,
                     decoder, classifier, q, q_all, q_enc, max_dec_len,
                     lay_skip_list, vocab, copy_to_ext, copy_to_tgt):
     batch_size = q.size(1)
     decoder.attn.applyMaskBySeqBatch(q)
     dec_list = []
     dec_state = decoder.init_decoder_state(q_all, q_enc)
     inp = torch.LongTensor(1, batch_size).fill_(table.IO.BOS).cuda()
     batch_index = torch.LongTensor(range(batch_size)).unsqueeze_(0).cuda()
     for i in range(min(max_dec_len, lay_index_seq.size(0))):
         # (1, batch)
         lay_index = lay_index_seq[i].unsqueeze(0)
         lay_select = lay_all[lay_index, batch_index, :]
         inp.masked_fill_(inp.ge(len(vocab)), table.IO.UNK)
         tgt_inp_emb = embeddings(v_eval(inp))
         tgt_mask_expand = v_eval(tgt_mask_seq[i].unsqueeze(0).unsqueeze(
             2).expand_as(tgt_inp_emb))
         inp = tgt_inp_emb.mul(tgt_mask_expand) + \
             lay_select.mul(1 - tgt_mask_expand)
         parent_index = None
         dec_all, dec_state, attn_scores, dec_rnn_output, concat_c = decoder(
             inp, q_all, dec_state, parent_index)
         dec_out = classifier(dec_all, dec_rnn_output, concat_c,
                              attn_scores, copy_to_ext, copy_to_tgt)
         # no <unk> in decoding
         dec_out.data[:, :, table.IO.UNK] = -float('inf')
         inp = argmax(dec_out.data)
         topk_cpu = topk(dec_out.data.view(batch_size, -1), 20).cpu()
         dec_list.append(topk_cpu)
     return torch.stack(dec_list, 0)
Esempio n. 3
0
 def run_lay_decoder(self, decoder, classifier, q, q_all, q_enc,
                     max_dec_len, vocab_mask, vocab):
     batch_size = q.size(1)
     decoder.attn.applyMaskBySeqBatch(q)
     dec_list = []
     dec_state = decoder.init_decoder_state(q_all, q_enc)
     inp = torch.LongTensor(1, batch_size).fill_(table.IO.BOS).cuda()
     for i in range(max_dec_len):
         inp = v_eval(inp)
         parent_index = None
         dec_all, dec_state, _, _, _ = decoder(inp, q_all, dec_state,
                                               parent_index)
         dec_all = dec_all.view(batch_size, -1)
         dec_out = classifier(dec_all)
         dec_out = dec_out.data.view(1, batch_size, -1)
         if vocab_mask is not None:
             dec_out_part = dec_out[:, :, len(table.IO.special_token_list):]
             dec_out_part.masked_fill_(vocab_mask, -float('inf'))
             # dec_out_part.masked_scatter_(vocab_mask, dec_out_part[vocab_mask].add(-math.log(1000)))
         inp = argmax(dec_out)
         # topk = [vocab.itos[idx] for idx in dec_out[0, 0, :].topk(10, dim=0)[1]]
         # print(topk)
         inp_cpu = cpu_vector(inp)
         dec_list.append(inp_cpu)
     return torch.stack(dec_list, 0)
Esempio n. 4
0
    def recover_tgt(self, tgt):
        def recover_target_token(pred_list, vocab_tgt, vocab_copy_ext,
                                 max_sent_length):
            r_list = []
            for i in range(max_sent_length):
                # filter topk results using layout information
                if pred_list[i] < len(vocab_tgt):
                    tk = vocab_tgt.itos[pred_list[i]]
                else:
                    tk = vocab_copy_ext.itos[pred_list[i] - len(vocab_tgt)]
                if tk == table.IO.EOS_WORD:
                    break
                r_list.append(tk)
            return " ".join(r_list)

        if len(tgt.size()) > 2:
            tgt_dec = argmax(tgt).cpu()
        else:
            tgt_dec = tgt
        batch_size = tgt_dec.size(1)
        tgt_list = []
        for b in range(batch_size):
            tgt = recover_target_token(
                [tgt_dec[i, b] for i in range(tgt_dec.size(0))],
                self.fields['tgt'].vocab, self.fields['copy_to_ext'].vocab,
                tgt_dec.size(0))
            tgt_list.append(tgt)
        return tgt_list
Esempio n. 5
0
def count_accuracy(scores, target, mask=None, row=False):
    pred = argmax(scores)
    if mask is None:
        m_correct = pred.eq(target)
        num_all = m_correct.numel()
    elif row:
        m_correct = pred.eq(target).masked_fill_(mask, 1).prod(0,
                                                               keepdim=False)
        num_all = m_correct.numel()
    else:
        non_mask = mask.ne(1)
        m_correct = pred.eq(target).masked_select(non_mask)
        num_all = non_mask.sum()
    return (m_correct, num_all)
 def run_tgt_decoder(self, embeddings, tgt_mask_seq, lay_index_seq, lay_all,
                     decoder, classifier, q, q_all, q_enc, max_dec_len,
                     lay_skip_list, vocab):
     batch_size = q.size(1)
     decoder.attn.applyMaskBySeqBatch(q)
     dec_list = []
     dec_state = decoder.init_decoder_state(q_all, q_enc)
     inp = torch.LongTensor(1, batch_size).fill_(table.IO.BOS).cuda()
     batch_index = torch.LongTensor(range(batch_size)).unsqueeze_(0).cuda()
     if self.model.opt.parent_feed in ('input', 'output'):
         parent_list = self._init_parent_list(decoder, q_enc, batch_size)
     for i in range(min(max_dec_len, lay_index_seq.size(0))):
         # (1, batch)
         lay_index = lay_index_seq[i].unsqueeze(0)
         lay_select = lay_all[lay_index, batch_index, :]
         tgt_inp_emb = embeddings(v_eval(inp))
         tgt_mask_expand = v_eval(tgt_mask_seq[i].unsqueeze(0).unsqueeze(
             2).expand_as(tgt_inp_emb))
         inp = tgt_inp_emb.mul(tgt_mask_expand) + \
             lay_select.mul(1 - tgt_mask_expand)
         if self.model.opt.parent_feed == 'input':
             parent_index = self._cat_parent_feed_input(
                 parent_list, batch_size)
         else:
             parent_index = None
         dec_all, dec_state, _, dec_rnn_output = decoder(
             inp, q_all, dec_state, parent_index)
         if self.model.opt.parent_feed == 'output':
             dec_all = self._cat_parent_feed_output(dec_all, parent_list,
                                                    batch_size)
         dec_all = dec_all.view(batch_size, -1)
         dec_out = classifier(dec_all)
         dec_out = dec_out.view(1, batch_size, -1)
         inp = argmax(dec_out.data)
         # RIG_WORD -> ')'
         rig_mask = []
         for b in range(batch_size):
             tk = lay_skip_list[b][i] if i < len(lay_skip_list[b]) else None
             rig_mask.append(1 if tk in (table.IO.RIG_WORD, ) else 0)
         inp.masked_fill_(
             torch.ByteTensor(rig_mask).unsqueeze_(0).cuda(),
             vocab.stoi[')'])
         inp_cpu = cpu_vector(inp)
         dec_list.append(inp_cpu)
         if self.model.opt.parent_feed in ('input', 'output'):
             self._update_parent_list(i, parent_list, dec_rnn_output,
                                      inp_cpu, lay_skip_list, vocab,
                                      batch_size)
     return torch.stack(dec_list, 0)
def count_accuracy(scores, target, mask=None, row=False):
    pred = argmax(scores)
    if mask is None:
        m_correct = pred.eq(target)
        num_all = m_correct.numel()
    elif row:
        m_correct = pred.eq(target).masked_fill_(
            mask.type(torch.bool), 1).prod(0, keepdim=False)

        #print('m_correct_row', m_correct.type())
        num_all = m_correct.numel()
    else:
        non_mask = mask.ne(1).type(torch.bool)
        m_correct = pred.eq(target).masked_select(non_mask)
        num_all = non_mask.sum().item()

    m_correct = m_correct.type(torch.LongTensor)
    if torch.cuda.is_available():
        m_correct=m_correct.cuda()
    return (m_correct, num_all)
Esempio n. 8
0
    def recover_lay(self, l):
        def recover_layout_token(pred_list, vocab, max_sent_length):
            r_list = []
            for i in range(max_sent_length):
                r_list.append(vocab.itos[pred_list[i]])
                if r_list[-1] == table.IO.EOS_WORD:
                    r_list = r_list[:-1]
                    break
            return " ".join(r_list)

        lay_list = []

        if len(l.size()) > 2:
            lay_dec = argmax(l).cpu()
        else:
            lay_dec = l
        batch_size = lay_dec.size(1)
        lay_field = 'lay'
        for b in range(batch_size):
            lay = recover_layout_token(
                [lay_dec[i, b] for i in range(lay_dec.size(0))],
                self.fields[lay_field].vocab, lay_dec.size(0))
            lay_list.append(lay)
        return lay_list
def count_condition_value_F1(scores1,golden_scores1, target_ls1,target_rs1):
    preds = argmax(scores1)

    preds = preds.transpose(0,1)
    target_ls=target_ls1.transpose(0,1)
    target_rs=target_rs1.transpose(0,1)
    golden_scores=golden_scores1.transpose(0,1)

   # print(type(preds),preds.size(),target_ls.size())

    total_p=0
    total_r=0
    matched = 0
    exact_matched=0
    for sample_id in range(preds.size(0)):
        pred=preds[sample_id]
        #print(pred)
        golden_score=golden_scores[sample_id]
        #print(g)
        target_l=target_ls[sample_id]
        #print(target_l)
        target_r=target_rs[sample_id]
        #print(target_r)
        exact_matched+=1
        for i in range(pred.size(0)):
            if pred[i]!=golden_score[i] and golden_score[i]!=-1:
                exact_matched-=1
                break
        cond_span_lr = []
        l=0
        r=0
        for i in range(pred.size(0)):
            if pred[i]==0:
                if l!=0:
                    cond_span_lr.append((l,r))
                l=i
                r=i
            elif pred[i]==1:
                r=i
            else:
                if l!=0:
                    cond_span_lr.append((l,r))
                l=0
                r=0
        if l != 0:
            cond_span_lr.append((l, r))

        for l,r in cond_span_lr:
            for i in range(target_l.size(0)):
                if l==target_l[i] and r==target_r[i]:
                    matched+=1
        for i in range(target_l.size(0)):
            if target_l[i]!=-1:
                total_r+=1
        total_p+=len(cond_span_lr)

    #print(matched,total_p,total_r)
    #if random.random()<0.01:
    #    print(preds[:10])
    #    print(golden_scores[:10])
    recall=1.0*matched/(total_r+1e-10)
    precision=1.0*matched/(total_p+1e-10)
    return (exact_matched,preds.size(0)),(precision,1),(recall,1),(recall*precision*2,(recall+precision+1e-10))
def count_condition_value_EM_column_op(scores1,scores_col1,scores_op1, golden_scores1, golden_scores_col1,golden_scores_op1, target_ls1, target_rs1, target_cols1):
    preds = argmax(scores1)
    preds_col = argmax(scores_col1)
    preds_op = argmax(scores_op1)

    preds = preds.transpose(0,1)
    preds_col = preds_col.transpose(0, 1)
    preds_op = preds_op.transpose(0, 1)

    target_ls = target_ls1.transpose(0,1)
    target_rs = target_rs1.transpose(0,1)
    golden_scores = golden_scores1.transpose(0,1)
    golden_scores_col = golden_scores_col1.transpose(0, 1)
    golden_scores_op = golden_scores_op1.transpose(0, 1)
    target_cols = target_cols1.transpose(0,1)

    total_p=0
    total_r=0

    exact_matched=0
    exact_matched_op=0
    exact_matched_col=0
    for sample_id in range(preds.size(0)):
        pred=preds[sample_id]
        pred_col=preds_col[sample_id]
        pred_op = preds_op[sample_id]
        golden_score=golden_scores[sample_id]
        golden_score_col = golden_scores_col[sample_id]
        golden_score_op = golden_scores_op[sample_id]

        exact_matched += 1
        exact_matched_op += 1
        exact_matched_col += 1
        BIO_not_match = False
        for i in range(pred.size(0)):
            if pred[i] != golden_score[i] and golden_score[i] != -1:
                exact_matched-=1
                exact_matched_op-=1
                exact_matched_col-=1
                BIO_not_match = True
                break

        if BIO_not_match == False:
            col_not_match = False
            for i in range(pred.size(0)):
                if pred[i]==0:
                    column_cnt = []
                    for j in range(torch.max(pred_col) + 2):
                        column_cnt.append(0)
                    column_cnt[pred_col[i]] = 1
                    for j in range(i+1,pred.size(0)):
                        if pred[j]!=1:
                            break
                        column_cnt[pred_col[j]] += 1
                    max_cnt=0
                    argmax1=pred_col[i]
                    for j in range(torch.max(pred_col)+2):
                        if column_cnt[j]>max_cnt:
                            max_cnt=column_cnt[j]
                            argmax1=j
                    if argmax1!=golden_score_col[i]:
                        exact_matched_col-=1
                        col_not_match=True
                        break

            op_not_match = False
            for i in range(pred.size(0)):
                if pred[i]==0:
                    op_cnt = []
                    for j in range(3):
                        op_cnt.append(0)
                    op_cnt[pred_op[i]] = 1
                    for j in range(i+1,pred.size(0)):
                        if pred[j]!=1:
                            break
                        op_cnt[pred_op[j]] += 1
                    max_cnt=0
                    argmax1=pred_op[i]
                    for j in range(3):
                        if op_cnt[j]>max_cnt:
                            max_cnt=op_cnt[j]
                            argmax1=j
                    if argmax1!=golden_score_op[i]:
                        exact_matched_op-=1
                        op_not_match=True
                        break
            if op_not_match or col_not_match:
                exact_matched-=1

    return (exact_matched,preds.size(0)),(exact_matched_col,preds.size(0)),(exact_matched_op,preds.size(0)),0
Esempio n. 11
0
    def translate(self, batch):
        q, q_len = batch.src
        tbl, tbl_len = batch.tbl
        ent, tbl_split, tbl_mask = batch.ent, batch.tbl_split, batch.tbl_mask

        # encoding

        q_enc, q_all, tbl_enc, q_ht, batch_size = self.model.enc(
            q, q_len, ent, batch.type, tbl, tbl_len, tbl_split, tbl_mask
        )  #query, query length, table, table length, table split, table mask

        BIO_op_out = self.model.BIO_op_classifier(q_all)
        tsp_q = BIO_op_out.size(0)
        bsz = BIO_op_out.size(1)
        BIO_op_out = BIO_op_out.view(-1, BIO_op_out.size(2))
        BIO_op_out = F.log_softmax(BIO_op_out, dim=-1)
        BIO_op_out_sf = torch.exp(BIO_op_out)
        BIO_op_out = BIO_op_out.view(tsp_q, bsz, -1)
        BIO_op_out_sf = BIO_op_out_sf.view(tsp_q, bsz, -1)

        # if fff == 1:
        #    print(BIO_op_out_sf.transpose(0,1)[0])
        #    print(BIO_op_out.transpose(0, 1)[0])

        BIO_out = self.model.BIO_classifier(q_all)
        BIO_out = BIO_out.view(-1, BIO_out.size(2))
        BIO_out = F.log_softmax(BIO_out, dim=-1)
        BIO_out_sf = torch.exp(BIO_out)
        BIO_out = BIO_out.view(tsp_q, bsz, -1)
        BIO_out_sf = BIO_out_sf.view(tsp_q, bsz, -1)
        # if fff == 1:
        #    print(BIO_out_sf.transpose(0,1)[0])
        #    print(BIO_out.transpose(0, 1)[0])

        BIO_col_out = self.model.label_col_match(q_all, tbl_enc, tbl_mask)
        # if fff == 1:
        #    print(BIO_col_out.size())
        #    print(BIO_col_out.transpose(0, 1)[0])
        BIO_col_out = BIO_col_out.view(-1, BIO_col_out.size(2))
        BIO_col_out = F.log_softmax(BIO_col_out, dim=-1)
        BIO_col_out_sf = torch.exp(BIO_col_out)
        BIO_col_out = BIO_col_out.view(tsp_q, bsz, -1)
        BIO_col_out_sf = BIO_col_out_sf.view(tsp_q, bsz, -1)

        BIO_pred = argmax(BIO_out_sf.data).transpose(0, 1)
        BIO_col_pred = argmax(BIO_col_out_sf.data).transpose(0, 1)
        for i in range(BIO_pred.size(0)):
            for j in range(BIO_pred.size(1)):
                if BIO_pred[i][j] == 2:
                    BIO_col_pred[i][j] = -1

        # (1) decoding
        q_self_encode = self.model.agg_self_attention(q_all, q_len)  #q_ht
        q_self_encode_layout = self.model.lay_self_attention(q_all,
                                                             q_len)  #q_ht
        agg_pred = cpu_vector(
            argmax(self.model.agg_classifier(q_self_encode).data))
        sel_out = self.model.sel_match(q_self_encode,
                                       tbl_enc,
                                       tbl_mask,
                                       select=True)  # select column
        sel_pred = cpu_vector(
            argmax(
                self.model.sel_match(q_self_encode,
                                     tbl_enc,
                                     tbl_mask,
                                     select=True).data))
        lay_pred = argmax(self.model.lay_classifier(q_self_encode_layout).data)
        # get layout op tokens
        op_batch_list = []
        op_idx_batch_list = []
        if self.opt.gold_layout:
            lay_pred = batch.lay.data
            cond_op, cond_op_len = batch.cond_op
            cond_op_len_list = cond_op_len.view(-1).tolist()
            for i, len_it in enumerate(cond_op_len_list):
                if len_it == 0:
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    idx_list = cond_op.data[0:len_it,
                                            i].contiguous().view(-1).tolist()
                    op_idx_batch_list.append([
                        int(self.fields['cond_op'].vocab.itos[it])
                        for it in idx_list
                    ])
                    op_batch_list.append(idx_list)
        else:
            lay_batch_list = lay_pred.view(-1).tolist()
            for lay_it in lay_batch_list:
                tk_list = self.fields['lay'].vocab.itos[lay_it].split(' ')
                if (len(tk_list) == 0) or (tk_list[0] == ''):
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    op_idx_batch_list.append(
                        [int(op_str) for op_str in tk_list])
                    op_batch_list.append([
                        self.fields['cond_op'].vocab.stoi[op_str]
                        for op_str in tk_list
                    ])
            # -> (num_cond, batch)
            cond_op = v_eval(
                add_pad(
                    op_batch_list,
                    self.fields['cond_op'].vocab.stoi[table.IO.PAD_WORD]).t())
            cond_op_len = torch.LongTensor([len(it) for it in op_batch_list])
        # emb_op -> (num_cond, batch, emb_size)
        if self.model.opt.layout_encode == 'rnn':
            emb_op = table.Models.encode_unsorted_batch(
                self.model.lay_encoder, cond_op, cond_op_len.clamp(min=1))
        else:
            emb_op = self.model.cond_embedding(cond_op)

        # (2) decoding
        self.model.cond_decoder.attn.applyMaskBySeqBatch(q)
        cond_state = self.model.cond_decoder.init_decoder_state(q_all, q_enc)
        cond_col_list, cond_span_l_list, cond_span_r_list = [], [], []
        for emb_op_t in emb_op:
            emb_op_t = emb_op_t.unsqueeze(0)
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_op_t, q_all, cond_state)
            #print(cond_context.size())
            #cond_context = self.model.decode_softattention(cond_context, q_all, q_len)
            #print(cond_context.size())

            # cond col -> (1, batch)
            cond_col = argmax(
                self.model.cond_col_match(cond_context, tbl_enc,
                                          tbl_mask).data)
            cond_col_list.append(cpu_vector(cond_col))
            # emb_col
            batch_index = torch.LongTensor(
                range(batch_size)).unsqueeze_(0).cuda().expand(
                    cond_col.size(0), cond_col.size(1))
            emb_col = tbl_enc[cond_col, batch_index, :]
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_col, q_all, cond_state)

            # cond span
            q_mask = v_eval(
                q.data.eq(self.model.pad_word_index).transpose(0, 1))
            cond_span_l = argmax(
                self.model.cond_span_l_match(cond_context, q_all, q_mask).data)
            cond_span_l_list.append(cpu_vector(cond_span_l))
            # emb_span_l: (1, batch, hidden_size)
            emb_span_l = q_all[cond_span_l, batch_index, :]
            cond_span_r = argmax(
                self.model.cond_span_r_match(cond_context,
                                             q_all,
                                             q_mask,
                                             emb_span_l=emb_span_l).data)
            cond_span_r_list.append(cpu_vector(cond_span_r))
            # emb_span_r: (1, batch, hidden_size)
            emb_span_r = q_all[cond_span_r, batch_index, :]

            emb_span = self.model.span_merge(
                torch.cat([emb_span_l, emb_span_r], 2))

            #            mask = torch.zeros([cond_col.size(0), q_all.size(0), q_all.size(1)])  # (num_cond,tsp,bsz)
            #            for j in range(q_all.size(1)):
            #                for i in range(cond_col.size(0)):
            #                    for k in range(cond_span_l[i][j], cond_span_r[i][j] + 1):
            #                        mask[i][k][j] = 1

            #            mask = mask.unsqueeze_(3)  # .expand(cond_col.size(0),q_all.size(0),q_all.size(1),q_all.size(2))

            #            emb_span = Variable(mask.cuda()) * torch.unsqueeze(q_all, 0)  # .expand_as(mask)  #(num_cond,tsp,bsz,hidden)
            #            emb_span = torch.mean(emb_span, dim=1)  # (num_cond,bsz,hidden)  mean pooling

            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_span, q_all, cond_state)

        # (3) recover output
        indices = cpu_vector(batch.indices.data)
        r_list = []
        for b in range(batch_size):
            idx = indices[b]
            agg = agg_pred[b]
            sel = sel_pred[b]
            BIO = BIO_pred[b]
            BIO_col = BIO_col_pred[b]
            cond = []
            for i in range(len(op_batch_list[b])):
                col = cond_col_list[i][b]
                op = op_idx_batch_list[b][i]
                span_l = cond_span_l_list[i][b]
                span_r = cond_span_r_list[i][b]
                cond.append((col, op, (span_l, span_r)))
            r_list.append(ParseResult(idx, agg, sel, cond, BIO, BIO_col))

        return r_list
Esempio n. 12
0
    def translate(self, batch):
        q, q_len = batch.src
        tbl, tbl_len = batch.tbl
        ent, tbl_split, tbl_mask = batch.ent, batch.tbl_split, batch.tbl_mask

        # encoding
        q_enc, q_all, tbl_enc, q_ht, batch_size = self.model.enc(
            q, q_len, ent, tbl, tbl_len, tbl_split, tbl_mask)

        # (1) decoding
        agg_pred = cpu_vector(argmax(self.model.agg_classifier(q_ht).data))
        sel_pred = cpu_vector(
            argmax(self.model.sel_match(q_ht, tbl_enc, tbl_mask).data))
        lay_pred = argmax(self.model.lay_classifier(q_ht).data)
        # get layout op tokens
        op_batch_list = []
        op_idx_batch_list = []
        if self.opt.gold_layout:
            lay_pred = batch.lay.data
            cond_op, cond_op_len = batch.cond_op
            cond_op_len_list = cond_op_len.view(-1).tolist()
            for i, len_it in enumerate(cond_op_len_list):
                if len_it == 0:
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    idx_list = cond_op.data[0:len_it,
                                            i].contiguous().view(-1).tolist()
                    op_idx_batch_list.append([
                        int(self.fields['cond_op'].vocab.itos[it])
                        for it in idx_list
                    ])
                    op_batch_list.append(idx_list)
        else:
            lay_batch_list = lay_pred.view(-1).tolist()
            for lay_it in lay_batch_list:
                tk_list = self.fields['lay'].vocab.itos[lay_it].split(' ')
                if (len(tk_list) == 0) or (tk_list[0] == ''):
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    op_idx_batch_list.append(
                        [int(op_str) for op_str in tk_list])
                    op_batch_list.append([
                        self.fields['cond_op'].vocab.stoi[op_str]
                        for op_str in tk_list
                    ])
            # -> (num_cond, batch)
            cond_op = v_eval(
                add_pad(
                    op_batch_list,
                    self.fields['cond_op'].vocab.stoi[table.IO.PAD_WORD]).t())
            cond_op_len = torch.LongTensor([len(it) for it in op_batch_list])
        # emb_op -> (num_cond, batch, emb_size)
        if self.model.opt.layout_encode == 'rnn':
            emb_op = table.Models.encode_unsorted_batch(
                self.model.lay_encoder, cond_op, cond_op_len.clamp(min=1))
        else:
            emb_op = self.model.cond_embedding(cond_op)

        # (2) decoding
        self.model.cond_decoder.attn.applyMaskBySeqBatch(q)
        cond_state = self.model.cond_decoder.init_decoder_state(q_all, q_enc)
        cond_col_list, cond_span_l_list, cond_span_r_list = [], [], []
        for emb_op_t in emb_op:
            emb_op_t = emb_op_t.unsqueeze(0)
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_op_t, q_all, cond_state)
            # cond col -> (1, batch)
            cond_col = argmax(
                self.model.cond_col_match(cond_context, tbl_enc,
                                          tbl_mask).data)
            cond_col_list.append(cpu_vector(cond_col))
            # emb_col
            batch_index = torch.LongTensor(
                range(batch_size)).unsqueeze_(0).cuda().expand(
                    cond_col.size(0), cond_col.size(1))
            emb_col = tbl_enc[cond_col, batch_index, :]
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_col, q_all, cond_state)
            # cond span
            q_mask = v_eval(
                q.data.eq(self.model.pad_word_index).transpose(0, 1))
            cond_span_l = argmax(
                self.model.cond_span_l_match(cond_context, q_all, q_mask).data)
            cond_span_l_list.append(cpu_vector(cond_span_l))
            # emb_span_l: (1, batch, hidden_size)
            emb_span_l = q_all[cond_span_l, batch_index, :]
            cond_span_r = argmax(
                self.model.cond_span_r_match(cond_context, q_all, q_mask,
                                             emb_span_l).data)
            cond_span_r_list.append(cpu_vector(cond_span_r))
            # emb_span_r: (1, batch, hidden_size)
            emb_span_r = q_all[cond_span_r, batch_index, :]
            emb_span = self.model.span_merge(
                torch.cat([emb_span_l, emb_span_r], 2))
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_span, q_all, cond_state)

        # (3) recover output
        indices = cpu_vector(batch.indices.data)
        r_list = []
        for b in range(batch_size):
            idx = indices[b]
            agg = agg_pred[b]
            sel = sel_pred[b]
            cond = []
            for i in range(len(op_batch_list[b])):
                col = cond_col_list[i][b]
                op = op_idx_batch_list[b][i]
                span_l = cond_span_l_list[i][b]
                span_r = cond_span_r_list[i][b]
                cond.append((col, op, (span_l, span_r)))
            r_list.append(ParseResult(idx, agg, sel, cond))

        return r_list
Esempio n. 13
0
    def translate(self, batch, js_list=[], sql_list=[]):
        q, q_len = batch.src
        tbl, tbl_len = batch.tbl
        ent, tbl_split, tbl_mask = batch.ent, batch.tbl_split, batch.tbl_mask
        # encoding
        q_enc, q_all, tbl_enc, q_ht, batch_size = self.model.enc(
            q, q_len, ent, tbl, tbl_len, tbl_split, tbl_mask)

        # (1) decoding
        agg_pred = cpu_vector(argmax(self.model.agg_classifier(q_ht).data))
        sel_pred = cpu_vector(
            argmax(self.model.sel_match(q_ht, tbl_enc, tbl_mask).data))
        lay_pred = argmax(self.model.lay_classifier(q_ht).data)
        engine = DBEngine(self.opt.db_file)
        indices = cpu_vector(batch.indices.data)
        # get layout op tokens
        op_batch_list = []
        op_idx_batch_list = []
        if self.opt.gold_layout:
            lay_pred = batch.lay.data
            cond_op, cond_op_len = batch.cond_op
            cond_op_len_list = cond_op_len.view(-1).tolist()
            for i, len_it in enumerate(cond_op_len_list):
                if len_it == 0:
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    idx_list = cond_op.data[0:len_it,
                                            i].contiguous().view(-1).tolist()
                    op_idx_batch_list.append([
                        int(self.fields['cond_op'].vocab.itos[it])
                        for it in idx_list
                    ])
                    op_batch_list.append(idx_list)
        else:
            lay_batch_list = lay_pred.view(-1).tolist()
            for lay_it in lay_batch_list:
                tk_list = self.fields['lay'].vocab.itos[lay_it].split(' ')
                if (len(tk_list) == 0) or (tk_list[0] == ''):
                    op_idx_batch_list.append([])
                    op_batch_list.append([])
                else:
                    op_idx_batch_list.append(
                        [int(op_str) for op_str in tk_list])
                    op_batch_list.append([
                        self.fields['cond_op'].vocab.stoi[op_str]
                        for op_str in tk_list
                    ])
            # -> (num_cond, batch)
            cond_op = v_eval(
                add_pad(
                    op_batch_list,
                    self.fields['cond_op'].vocab.stoi[table.IO.PAD_WORD]).t())
            cond_op_len = torch.LongTensor([len(it) for it in op_batch_list])
        # emb_op -> (num_cond, batch, emb_size)
        if self.model.opt.layout_encode == 'rnn':
            emb_op = table.Models.encode_unsorted_batch(
                self.model.lay_encoder, cond_op, cond_op_len.clamp(min=1))
        else:
            emb_op = self.model.cond_embedding(cond_op)

        # (2) decoding
        self.model.cond_decoder.attn.applyMaskBySeqBatch(q)
        cond_state = self.model.cond_decoder.init_decoder_state(q_all, q_enc)
        cond_col_list, cond_span_l_list, cond_span_r_list = [], [], []
        t = 0
        need_back_track = [False] * batch_size
        for emb_op_t in emb_op:
            t += 1
            emb_op_t = emb_op_t.unsqueeze(0)
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_op_t, q_all, cond_state)
            # cond col -> (1, batch)
            cond_col_all = self.model.cond_col_match(cond_context, tbl_enc,
                                                     tbl_mask).data
            cond_col = argmax(cond_col_all)
            # add to this after beam search: cond_col_list.append(cpu_vector(cond_col))
            # emb_col
            batch_index = torch.LongTensor(
                range(batch_size)).unsqueeze_(0).cuda().expand(
                    cond_col.size(0), cond_col.size(1))
            emb_col = tbl_enc[cond_col, batch_index, :]
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_col, q_all, cond_state)
            # cond span
            q_mask = v_eval(
                q.data.eq(self.model.pad_word_index).transpose(0, 1))
            cond_span_l_batch_all = self.model.cond_span_l_match(
                cond_context, q_all, q_mask).data
            cond_span_l_batch = argmax(cond_span_l_batch_all)
            # add to this after beam search: cond_span_l_list.append(cpu_vector(cond_span_l))
            # emb_span_l: (1, batch, hidden_size)
            emb_span_l = q_all[cond_span_l_batch, batch_index, :]
            cond_span_r_batch = argmax(
                self.model.cond_span_r_match(cond_context, q_all, q_mask,
                                             emb_span_l).data)
            # add to this after beam search: cond_span_r_list.append(cpu_vector(cond_span_r))
            if self.opt.beam_search:
                # for now just go through the next col in cond
                k = min(self.opt.beam_size, cond_col_all.size()[2])
                top_col_idx = cond_col_all.topk(k)[1]
                for b in range(batch_size):
                    if t > len(op_idx_batch_list[b]) or need_back_track[b]:
                        continue
                    idx = indices[b]
                    agg = agg_pred[b]
                    sel = sel_pred[b]
                    cond = []
                    for i in range(t):
                        op = op_idx_batch_list[b][i]
                        if i < t - 1:
                            col = cond_col_list[i][b]
                            span_l = cond_span_l_list[i][b]
                            span_r = cond_span_r_list[i][b]
                        else:
                            col = cond_col[0, b]
                            span_l = cond_span_l_batch[0, b]
                            span_r = cond_span_r_batch[0, b]
                        cond.append((col, op, (span_l, span_r)))
                    pred = ParseResult(idx, agg, sel, cond)
                    pred.eval(js_list[idx], sql_list[idx], engine)
                    n_test = 0
                    while pred.exception_raised and n_test < top_col_idx.size(
                    )[2] - 1:
                        n_test += 1
                        if n_test > self.opt.beam_size:
                            need_back_track[b] = True
                            break
                        cond_col[0, b] = top_col_idx[0, b, n_test]
                        emb_col = tbl_enc[cond_col, batch_index, :]
                        cond_context, cond_state, _ = self.model.cond_decoder(
                            emb_col, q_all, cond_state)
                        # cond span
                        q_mask = v_eval(
                            q.data.eq(self.model.pad_word_index).transpose(
                                0, 1))
                        cond_span_l_batch_all = self.model.cond_span_l_match(
                            cond_context, q_all, q_mask).data
                        cond_span_l_batch = argmax(cond_span_l_batch_all)
                        # emb_span_l: (1, batch, hidden_size)
                        emb_span_l = q_all[cond_span_l_batch, batch_index, :]
                        cond_span_r_batch = argmax(
                            self.model.cond_span_r_match(
                                cond_context, q_all, q_mask, emb_span_l).data)
                        # run the new query over database
                        col = cond_col[0, b]
                        span_l = cond_span_l_batch[0, b]
                        span_r = cond_span_r_batch[0, b]
                        cond.pop()
                        cond.append((col, op, (span_l, span_r)))
                        pred = ParseResult(idx, agg, sel, cond)
                        pred.eval(js_list[idx], sql_list[idx], engine)
            cond_col_list.append(cpu_vector(cond_col))
            cond_span_l_list.append(cpu_vector(cond_span_l_batch))
            cond_span_r_list.append(cpu_vector(cond_span_r_batch))
            # emb_span_r: (1, batch, hidden_size)
            emb_span_r = q_all[cond_span_r_batch, batch_index, :]
            emb_span = self.model.span_merge(
                torch.cat([emb_span_l, emb_span_r], 2))
            cond_context, cond_state, _ = self.model.cond_decoder(
                emb_span, q_all, cond_state)

        # (3) recover output
        r_list = []
        for b in range(batch_size):
            idx = indices[b]
            agg = agg_pred[b]
            sel = sel_pred[b]
            cond = []
            for i in range(len(op_batch_list[b])):
                col = cond_col_list[i][b]
                op = op_idx_batch_list[b][i]
                span_l = cond_span_l_list[i][b]
                span_r = cond_span_r_list[i][b]
                cond.append((col, op, (span_l, span_r)))
            r_list.append(ParseResult(idx, agg, sel, cond))

        return r_list