def count_where_accuracy(score_span_l, score_span_r, score_col, gold_span_ls, gold_span_rs, gold_cols): pred_span_ls = argmax(score_span_l) pred_span_rs = argmax(score_span_r) pred_cols = argmax(score_col) pred_span_ls =pred_span_ls.transpose(0,1) pred_span_rs = pred_span_rs.transpose(0, 1) pred_cols = pred_cols.transpose(0,1) gold_span_ls = gold_span_ls.transpose(0, 1) gold_span_rs = gold_span_rs.transpose(0, 1) gold_cols = gold_cols.transpose(0, 1) exact_matched = 0 for sample_id in range(pred_span_ls.size(0)): pred_span_l = pred_span_ls[sample_id] pred_span_r = pred_span_rs[sample_id] pred_col = pred_cols[sample_id] gold_span_l = gold_span_ls[sample_id] gold_span_r = gold_span_rs[sample_id] gold_col = gold_cols[sample_id] exact_matched+=1 for i in range(pred_span_l.size(0)): if gold_col[i]!=-1: if gold_col[i]!=pred_col[i] or gold_span_l[i]!=pred_span_l[i] or gold_span_r[i]!= pred_span_r[i]: exact_matched-=1 break return exact_matched,pred_span_ls.size(0)
def run_tgt_decoder(self, embeddings, tgt_mask_seq, lay_index_seq, lay_all, decoder, classifier, q, q_all, q_enc, max_dec_len, lay_skip_list, vocab, copy_to_ext, copy_to_tgt): batch_size = q.size(1) decoder.attn.applyMaskBySeqBatch(q) dec_list = [] dec_state = decoder.init_decoder_state(q_all, q_enc) inp = torch.LongTensor(1, batch_size).fill_(table.IO.BOS).cuda() batch_index = torch.LongTensor(range(batch_size)).unsqueeze_(0).cuda() for i in range(min(max_dec_len, lay_index_seq.size(0))): # (1, batch) lay_index = lay_index_seq[i].unsqueeze(0) lay_select = lay_all[lay_index, batch_index, :] inp.masked_fill_(inp.ge(len(vocab)), table.IO.UNK) tgt_inp_emb = embeddings(v_eval(inp)) tgt_mask_expand = v_eval(tgt_mask_seq[i].unsqueeze(0).unsqueeze( 2).expand_as(tgt_inp_emb)) inp = tgt_inp_emb.mul(tgt_mask_expand) + \ lay_select.mul(1 - tgt_mask_expand) parent_index = None dec_all, dec_state, attn_scores, dec_rnn_output, concat_c = decoder( inp, q_all, dec_state, parent_index) dec_out = classifier(dec_all, dec_rnn_output, concat_c, attn_scores, copy_to_ext, copy_to_tgt) # no <unk> in decoding dec_out.data[:, :, table.IO.UNK] = -float('inf') inp = argmax(dec_out.data) topk_cpu = topk(dec_out.data.view(batch_size, -1), 20).cpu() dec_list.append(topk_cpu) return torch.stack(dec_list, 0)
def run_lay_decoder(self, decoder, classifier, q, q_all, q_enc, max_dec_len, vocab_mask, vocab): batch_size = q.size(1) decoder.attn.applyMaskBySeqBatch(q) dec_list = [] dec_state = decoder.init_decoder_state(q_all, q_enc) inp = torch.LongTensor(1, batch_size).fill_(table.IO.BOS).cuda() for i in range(max_dec_len): inp = v_eval(inp) parent_index = None dec_all, dec_state, _, _, _ = decoder(inp, q_all, dec_state, parent_index) dec_all = dec_all.view(batch_size, -1) dec_out = classifier(dec_all) dec_out = dec_out.data.view(1, batch_size, -1) if vocab_mask is not None: dec_out_part = dec_out[:, :, len(table.IO.special_token_list):] dec_out_part.masked_fill_(vocab_mask, -float('inf')) # dec_out_part.masked_scatter_(vocab_mask, dec_out_part[vocab_mask].add(-math.log(1000))) inp = argmax(dec_out) # topk = [vocab.itos[idx] for idx in dec_out[0, 0, :].topk(10, dim=0)[1]] # print(topk) inp_cpu = cpu_vector(inp) dec_list.append(inp_cpu) return torch.stack(dec_list, 0)
def recover_tgt(self, tgt): def recover_target_token(pred_list, vocab_tgt, vocab_copy_ext, max_sent_length): r_list = [] for i in range(max_sent_length): # filter topk results using layout information if pred_list[i] < len(vocab_tgt): tk = vocab_tgt.itos[pred_list[i]] else: tk = vocab_copy_ext.itos[pred_list[i] - len(vocab_tgt)] if tk == table.IO.EOS_WORD: break r_list.append(tk) return " ".join(r_list) if len(tgt.size()) > 2: tgt_dec = argmax(tgt).cpu() else: tgt_dec = tgt batch_size = tgt_dec.size(1) tgt_list = [] for b in range(batch_size): tgt = recover_target_token( [tgt_dec[i, b] for i in range(tgt_dec.size(0))], self.fields['tgt'].vocab, self.fields['copy_to_ext'].vocab, tgt_dec.size(0)) tgt_list.append(tgt) return tgt_list
def count_accuracy(scores, target, mask=None, row=False): pred = argmax(scores) if mask is None: m_correct = pred.eq(target) num_all = m_correct.numel() elif row: m_correct = pred.eq(target).masked_fill_(mask, 1).prod(0, keepdim=False) num_all = m_correct.numel() else: non_mask = mask.ne(1) m_correct = pred.eq(target).masked_select(non_mask) num_all = non_mask.sum() return (m_correct, num_all)
def run_tgt_decoder(self, embeddings, tgt_mask_seq, lay_index_seq, lay_all, decoder, classifier, q, q_all, q_enc, max_dec_len, lay_skip_list, vocab): batch_size = q.size(1) decoder.attn.applyMaskBySeqBatch(q) dec_list = [] dec_state = decoder.init_decoder_state(q_all, q_enc) inp = torch.LongTensor(1, batch_size).fill_(table.IO.BOS).cuda() batch_index = torch.LongTensor(range(batch_size)).unsqueeze_(0).cuda() if self.model.opt.parent_feed in ('input', 'output'): parent_list = self._init_parent_list(decoder, q_enc, batch_size) for i in range(min(max_dec_len, lay_index_seq.size(0))): # (1, batch) lay_index = lay_index_seq[i].unsqueeze(0) lay_select = lay_all[lay_index, batch_index, :] tgt_inp_emb = embeddings(v_eval(inp)) tgt_mask_expand = v_eval(tgt_mask_seq[i].unsqueeze(0).unsqueeze( 2).expand_as(tgt_inp_emb)) inp = tgt_inp_emb.mul(tgt_mask_expand) + \ lay_select.mul(1 - tgt_mask_expand) if self.model.opt.parent_feed == 'input': parent_index = self._cat_parent_feed_input( parent_list, batch_size) else: parent_index = None dec_all, dec_state, _, dec_rnn_output = decoder( inp, q_all, dec_state, parent_index) if self.model.opt.parent_feed == 'output': dec_all = self._cat_parent_feed_output(dec_all, parent_list, batch_size) dec_all = dec_all.view(batch_size, -1) dec_out = classifier(dec_all) dec_out = dec_out.view(1, batch_size, -1) inp = argmax(dec_out.data) # RIG_WORD -> ')' rig_mask = [] for b in range(batch_size): tk = lay_skip_list[b][i] if i < len(lay_skip_list[b]) else None rig_mask.append(1 if tk in (table.IO.RIG_WORD, ) else 0) inp.masked_fill_( torch.ByteTensor(rig_mask).unsqueeze_(0).cuda(), vocab.stoi[')']) inp_cpu = cpu_vector(inp) dec_list.append(inp_cpu) if self.model.opt.parent_feed in ('input', 'output'): self._update_parent_list(i, parent_list, dec_rnn_output, inp_cpu, lay_skip_list, vocab, batch_size) return torch.stack(dec_list, 0)
def count_accuracy(scores, target, mask=None, row=False): pred = argmax(scores) if mask is None: m_correct = pred.eq(target) num_all = m_correct.numel() elif row: m_correct = pred.eq(target).masked_fill_( mask.type(torch.bool), 1).prod(0, keepdim=False) #print('m_correct_row', m_correct.type()) num_all = m_correct.numel() else: non_mask = mask.ne(1).type(torch.bool) m_correct = pred.eq(target).masked_select(non_mask) num_all = non_mask.sum().item() m_correct = m_correct.type(torch.LongTensor) if torch.cuda.is_available(): m_correct=m_correct.cuda() return (m_correct, num_all)
def recover_lay(self, l): def recover_layout_token(pred_list, vocab, max_sent_length): r_list = [] for i in range(max_sent_length): r_list.append(vocab.itos[pred_list[i]]) if r_list[-1] == table.IO.EOS_WORD: r_list = r_list[:-1] break return " ".join(r_list) lay_list = [] if len(l.size()) > 2: lay_dec = argmax(l).cpu() else: lay_dec = l batch_size = lay_dec.size(1) lay_field = 'lay' for b in range(batch_size): lay = recover_layout_token( [lay_dec[i, b] for i in range(lay_dec.size(0))], self.fields[lay_field].vocab, lay_dec.size(0)) lay_list.append(lay) return lay_list
def count_condition_value_F1(scores1,golden_scores1, target_ls1,target_rs1): preds = argmax(scores1) preds = preds.transpose(0,1) target_ls=target_ls1.transpose(0,1) target_rs=target_rs1.transpose(0,1) golden_scores=golden_scores1.transpose(0,1) # print(type(preds),preds.size(),target_ls.size()) total_p=0 total_r=0 matched = 0 exact_matched=0 for sample_id in range(preds.size(0)): pred=preds[sample_id] #print(pred) golden_score=golden_scores[sample_id] #print(g) target_l=target_ls[sample_id] #print(target_l) target_r=target_rs[sample_id] #print(target_r) exact_matched+=1 for i in range(pred.size(0)): if pred[i]!=golden_score[i] and golden_score[i]!=-1: exact_matched-=1 break cond_span_lr = [] l=0 r=0 for i in range(pred.size(0)): if pred[i]==0: if l!=0: cond_span_lr.append((l,r)) l=i r=i elif pred[i]==1: r=i else: if l!=0: cond_span_lr.append((l,r)) l=0 r=0 if l != 0: cond_span_lr.append((l, r)) for l,r in cond_span_lr: for i in range(target_l.size(0)): if l==target_l[i] and r==target_r[i]: matched+=1 for i in range(target_l.size(0)): if target_l[i]!=-1: total_r+=1 total_p+=len(cond_span_lr) #print(matched,total_p,total_r) #if random.random()<0.01: # print(preds[:10]) # print(golden_scores[:10]) recall=1.0*matched/(total_r+1e-10) precision=1.0*matched/(total_p+1e-10) return (exact_matched,preds.size(0)),(precision,1),(recall,1),(recall*precision*2,(recall+precision+1e-10))
def count_condition_value_EM_column_op(scores1,scores_col1,scores_op1, golden_scores1, golden_scores_col1,golden_scores_op1, target_ls1, target_rs1, target_cols1): preds = argmax(scores1) preds_col = argmax(scores_col1) preds_op = argmax(scores_op1) preds = preds.transpose(0,1) preds_col = preds_col.transpose(0, 1) preds_op = preds_op.transpose(0, 1) target_ls = target_ls1.transpose(0,1) target_rs = target_rs1.transpose(0,1) golden_scores = golden_scores1.transpose(0,1) golden_scores_col = golden_scores_col1.transpose(0, 1) golden_scores_op = golden_scores_op1.transpose(0, 1) target_cols = target_cols1.transpose(0,1) total_p=0 total_r=0 exact_matched=0 exact_matched_op=0 exact_matched_col=0 for sample_id in range(preds.size(0)): pred=preds[sample_id] pred_col=preds_col[sample_id] pred_op = preds_op[sample_id] golden_score=golden_scores[sample_id] golden_score_col = golden_scores_col[sample_id] golden_score_op = golden_scores_op[sample_id] exact_matched += 1 exact_matched_op += 1 exact_matched_col += 1 BIO_not_match = False for i in range(pred.size(0)): if pred[i] != golden_score[i] and golden_score[i] != -1: exact_matched-=1 exact_matched_op-=1 exact_matched_col-=1 BIO_not_match = True break if BIO_not_match == False: col_not_match = False for i in range(pred.size(0)): if pred[i]==0: column_cnt = [] for j in range(torch.max(pred_col) + 2): column_cnt.append(0) column_cnt[pred_col[i]] = 1 for j in range(i+1,pred.size(0)): if pred[j]!=1: break column_cnt[pred_col[j]] += 1 max_cnt=0 argmax1=pred_col[i] for j in range(torch.max(pred_col)+2): if column_cnt[j]>max_cnt: max_cnt=column_cnt[j] argmax1=j if argmax1!=golden_score_col[i]: exact_matched_col-=1 col_not_match=True break op_not_match = False for i in range(pred.size(0)): if pred[i]==0: op_cnt = [] for j in range(3): op_cnt.append(0) op_cnt[pred_op[i]] = 1 for j in range(i+1,pred.size(0)): if pred[j]!=1: break op_cnt[pred_op[j]] += 1 max_cnt=0 argmax1=pred_op[i] for j in range(3): if op_cnt[j]>max_cnt: max_cnt=op_cnt[j] argmax1=j if argmax1!=golden_score_op[i]: exact_matched_op-=1 op_not_match=True break if op_not_match or col_not_match: exact_matched-=1 return (exact_matched,preds.size(0)),(exact_matched_col,preds.size(0)),(exact_matched_op,preds.size(0)),0
def translate(self, batch): q, q_len = batch.src tbl, tbl_len = batch.tbl ent, tbl_split, tbl_mask = batch.ent, batch.tbl_split, batch.tbl_mask # encoding q_enc, q_all, tbl_enc, q_ht, batch_size = self.model.enc( q, q_len, ent, batch.type, tbl, tbl_len, tbl_split, tbl_mask ) #query, query length, table, table length, table split, table mask BIO_op_out = self.model.BIO_op_classifier(q_all) tsp_q = BIO_op_out.size(0) bsz = BIO_op_out.size(1) BIO_op_out = BIO_op_out.view(-1, BIO_op_out.size(2)) BIO_op_out = F.log_softmax(BIO_op_out, dim=-1) BIO_op_out_sf = torch.exp(BIO_op_out) BIO_op_out = BIO_op_out.view(tsp_q, bsz, -1) BIO_op_out_sf = BIO_op_out_sf.view(tsp_q, bsz, -1) # if fff == 1: # print(BIO_op_out_sf.transpose(0,1)[0]) # print(BIO_op_out.transpose(0, 1)[0]) BIO_out = self.model.BIO_classifier(q_all) BIO_out = BIO_out.view(-1, BIO_out.size(2)) BIO_out = F.log_softmax(BIO_out, dim=-1) BIO_out_sf = torch.exp(BIO_out) BIO_out = BIO_out.view(tsp_q, bsz, -1) BIO_out_sf = BIO_out_sf.view(tsp_q, bsz, -1) # if fff == 1: # print(BIO_out_sf.transpose(0,1)[0]) # print(BIO_out.transpose(0, 1)[0]) BIO_col_out = self.model.label_col_match(q_all, tbl_enc, tbl_mask) # if fff == 1: # print(BIO_col_out.size()) # print(BIO_col_out.transpose(0, 1)[0]) BIO_col_out = BIO_col_out.view(-1, BIO_col_out.size(2)) BIO_col_out = F.log_softmax(BIO_col_out, dim=-1) BIO_col_out_sf = torch.exp(BIO_col_out) BIO_col_out = BIO_col_out.view(tsp_q, bsz, -1) BIO_col_out_sf = BIO_col_out_sf.view(tsp_q, bsz, -1) BIO_pred = argmax(BIO_out_sf.data).transpose(0, 1) BIO_col_pred = argmax(BIO_col_out_sf.data).transpose(0, 1) for i in range(BIO_pred.size(0)): for j in range(BIO_pred.size(1)): if BIO_pred[i][j] == 2: BIO_col_pred[i][j] = -1 # (1) decoding q_self_encode = self.model.agg_self_attention(q_all, q_len) #q_ht q_self_encode_layout = self.model.lay_self_attention(q_all, q_len) #q_ht agg_pred = cpu_vector( argmax(self.model.agg_classifier(q_self_encode).data)) sel_out = self.model.sel_match(q_self_encode, tbl_enc, tbl_mask, select=True) # select column sel_pred = cpu_vector( argmax( self.model.sel_match(q_self_encode, tbl_enc, tbl_mask, select=True).data)) lay_pred = argmax(self.model.lay_classifier(q_self_encode_layout).data) # get layout op tokens op_batch_list = [] op_idx_batch_list = [] if self.opt.gold_layout: lay_pred = batch.lay.data cond_op, cond_op_len = batch.cond_op cond_op_len_list = cond_op_len.view(-1).tolist() for i, len_it in enumerate(cond_op_len_list): if len_it == 0: op_idx_batch_list.append([]) op_batch_list.append([]) else: idx_list = cond_op.data[0:len_it, i].contiguous().view(-1).tolist() op_idx_batch_list.append([ int(self.fields['cond_op'].vocab.itos[it]) for it in idx_list ]) op_batch_list.append(idx_list) else: lay_batch_list = lay_pred.view(-1).tolist() for lay_it in lay_batch_list: tk_list = self.fields['lay'].vocab.itos[lay_it].split(' ') if (len(tk_list) == 0) or (tk_list[0] == ''): op_idx_batch_list.append([]) op_batch_list.append([]) else: op_idx_batch_list.append( [int(op_str) for op_str in tk_list]) op_batch_list.append([ self.fields['cond_op'].vocab.stoi[op_str] for op_str in tk_list ]) # -> (num_cond, batch) cond_op = v_eval( add_pad( op_batch_list, self.fields['cond_op'].vocab.stoi[table.IO.PAD_WORD]).t()) cond_op_len = torch.LongTensor([len(it) for it in op_batch_list]) # emb_op -> (num_cond, batch, emb_size) if self.model.opt.layout_encode == 'rnn': emb_op = table.Models.encode_unsorted_batch( self.model.lay_encoder, cond_op, cond_op_len.clamp(min=1)) else: emb_op = self.model.cond_embedding(cond_op) # (2) decoding self.model.cond_decoder.attn.applyMaskBySeqBatch(q) cond_state = self.model.cond_decoder.init_decoder_state(q_all, q_enc) cond_col_list, cond_span_l_list, cond_span_r_list = [], [], [] for emb_op_t in emb_op: emb_op_t = emb_op_t.unsqueeze(0) cond_context, cond_state, _ = self.model.cond_decoder( emb_op_t, q_all, cond_state) #print(cond_context.size()) #cond_context = self.model.decode_softattention(cond_context, q_all, q_len) #print(cond_context.size()) # cond col -> (1, batch) cond_col = argmax( self.model.cond_col_match(cond_context, tbl_enc, tbl_mask).data) cond_col_list.append(cpu_vector(cond_col)) # emb_col batch_index = torch.LongTensor( range(batch_size)).unsqueeze_(0).cuda().expand( cond_col.size(0), cond_col.size(1)) emb_col = tbl_enc[cond_col, batch_index, :] cond_context, cond_state, _ = self.model.cond_decoder( emb_col, q_all, cond_state) # cond span q_mask = v_eval( q.data.eq(self.model.pad_word_index).transpose(0, 1)) cond_span_l = argmax( self.model.cond_span_l_match(cond_context, q_all, q_mask).data) cond_span_l_list.append(cpu_vector(cond_span_l)) # emb_span_l: (1, batch, hidden_size) emb_span_l = q_all[cond_span_l, batch_index, :] cond_span_r = argmax( self.model.cond_span_r_match(cond_context, q_all, q_mask, emb_span_l=emb_span_l).data) cond_span_r_list.append(cpu_vector(cond_span_r)) # emb_span_r: (1, batch, hidden_size) emb_span_r = q_all[cond_span_r, batch_index, :] emb_span = self.model.span_merge( torch.cat([emb_span_l, emb_span_r], 2)) # mask = torch.zeros([cond_col.size(0), q_all.size(0), q_all.size(1)]) # (num_cond,tsp,bsz) # for j in range(q_all.size(1)): # for i in range(cond_col.size(0)): # for k in range(cond_span_l[i][j], cond_span_r[i][j] + 1): # mask[i][k][j] = 1 # mask = mask.unsqueeze_(3) # .expand(cond_col.size(0),q_all.size(0),q_all.size(1),q_all.size(2)) # emb_span = Variable(mask.cuda()) * torch.unsqueeze(q_all, 0) # .expand_as(mask) #(num_cond,tsp,bsz,hidden) # emb_span = torch.mean(emb_span, dim=1) # (num_cond,bsz,hidden) mean pooling cond_context, cond_state, _ = self.model.cond_decoder( emb_span, q_all, cond_state) # (3) recover output indices = cpu_vector(batch.indices.data) r_list = [] for b in range(batch_size): idx = indices[b] agg = agg_pred[b] sel = sel_pred[b] BIO = BIO_pred[b] BIO_col = BIO_col_pred[b] cond = [] for i in range(len(op_batch_list[b])): col = cond_col_list[i][b] op = op_idx_batch_list[b][i] span_l = cond_span_l_list[i][b] span_r = cond_span_r_list[i][b] cond.append((col, op, (span_l, span_r))) r_list.append(ParseResult(idx, agg, sel, cond, BIO, BIO_col)) return r_list
def translate(self, batch): q, q_len = batch.src tbl, tbl_len = batch.tbl ent, tbl_split, tbl_mask = batch.ent, batch.tbl_split, batch.tbl_mask # encoding q_enc, q_all, tbl_enc, q_ht, batch_size = self.model.enc( q, q_len, ent, tbl, tbl_len, tbl_split, tbl_mask) # (1) decoding agg_pred = cpu_vector(argmax(self.model.agg_classifier(q_ht).data)) sel_pred = cpu_vector( argmax(self.model.sel_match(q_ht, tbl_enc, tbl_mask).data)) lay_pred = argmax(self.model.lay_classifier(q_ht).data) # get layout op tokens op_batch_list = [] op_idx_batch_list = [] if self.opt.gold_layout: lay_pred = batch.lay.data cond_op, cond_op_len = batch.cond_op cond_op_len_list = cond_op_len.view(-1).tolist() for i, len_it in enumerate(cond_op_len_list): if len_it == 0: op_idx_batch_list.append([]) op_batch_list.append([]) else: idx_list = cond_op.data[0:len_it, i].contiguous().view(-1).tolist() op_idx_batch_list.append([ int(self.fields['cond_op'].vocab.itos[it]) for it in idx_list ]) op_batch_list.append(idx_list) else: lay_batch_list = lay_pred.view(-1).tolist() for lay_it in lay_batch_list: tk_list = self.fields['lay'].vocab.itos[lay_it].split(' ') if (len(tk_list) == 0) or (tk_list[0] == ''): op_idx_batch_list.append([]) op_batch_list.append([]) else: op_idx_batch_list.append( [int(op_str) for op_str in tk_list]) op_batch_list.append([ self.fields['cond_op'].vocab.stoi[op_str] for op_str in tk_list ]) # -> (num_cond, batch) cond_op = v_eval( add_pad( op_batch_list, self.fields['cond_op'].vocab.stoi[table.IO.PAD_WORD]).t()) cond_op_len = torch.LongTensor([len(it) for it in op_batch_list]) # emb_op -> (num_cond, batch, emb_size) if self.model.opt.layout_encode == 'rnn': emb_op = table.Models.encode_unsorted_batch( self.model.lay_encoder, cond_op, cond_op_len.clamp(min=1)) else: emb_op = self.model.cond_embedding(cond_op) # (2) decoding self.model.cond_decoder.attn.applyMaskBySeqBatch(q) cond_state = self.model.cond_decoder.init_decoder_state(q_all, q_enc) cond_col_list, cond_span_l_list, cond_span_r_list = [], [], [] for emb_op_t in emb_op: emb_op_t = emb_op_t.unsqueeze(0) cond_context, cond_state, _ = self.model.cond_decoder( emb_op_t, q_all, cond_state) # cond col -> (1, batch) cond_col = argmax( self.model.cond_col_match(cond_context, tbl_enc, tbl_mask).data) cond_col_list.append(cpu_vector(cond_col)) # emb_col batch_index = torch.LongTensor( range(batch_size)).unsqueeze_(0).cuda().expand( cond_col.size(0), cond_col.size(1)) emb_col = tbl_enc[cond_col, batch_index, :] cond_context, cond_state, _ = self.model.cond_decoder( emb_col, q_all, cond_state) # cond span q_mask = v_eval( q.data.eq(self.model.pad_word_index).transpose(0, 1)) cond_span_l = argmax( self.model.cond_span_l_match(cond_context, q_all, q_mask).data) cond_span_l_list.append(cpu_vector(cond_span_l)) # emb_span_l: (1, batch, hidden_size) emb_span_l = q_all[cond_span_l, batch_index, :] cond_span_r = argmax( self.model.cond_span_r_match(cond_context, q_all, q_mask, emb_span_l).data) cond_span_r_list.append(cpu_vector(cond_span_r)) # emb_span_r: (1, batch, hidden_size) emb_span_r = q_all[cond_span_r, batch_index, :] emb_span = self.model.span_merge( torch.cat([emb_span_l, emb_span_r], 2)) cond_context, cond_state, _ = self.model.cond_decoder( emb_span, q_all, cond_state) # (3) recover output indices = cpu_vector(batch.indices.data) r_list = [] for b in range(batch_size): idx = indices[b] agg = agg_pred[b] sel = sel_pred[b] cond = [] for i in range(len(op_batch_list[b])): col = cond_col_list[i][b] op = op_idx_batch_list[b][i] span_l = cond_span_l_list[i][b] span_r = cond_span_r_list[i][b] cond.append((col, op, (span_l, span_r))) r_list.append(ParseResult(idx, agg, sel, cond)) return r_list
def translate(self, batch, js_list=[], sql_list=[]): q, q_len = batch.src tbl, tbl_len = batch.tbl ent, tbl_split, tbl_mask = batch.ent, batch.tbl_split, batch.tbl_mask # encoding q_enc, q_all, tbl_enc, q_ht, batch_size = self.model.enc( q, q_len, ent, tbl, tbl_len, tbl_split, tbl_mask) # (1) decoding agg_pred = cpu_vector(argmax(self.model.agg_classifier(q_ht).data)) sel_pred = cpu_vector( argmax(self.model.sel_match(q_ht, tbl_enc, tbl_mask).data)) lay_pred = argmax(self.model.lay_classifier(q_ht).data) engine = DBEngine(self.opt.db_file) indices = cpu_vector(batch.indices.data) # get layout op tokens op_batch_list = [] op_idx_batch_list = [] if self.opt.gold_layout: lay_pred = batch.lay.data cond_op, cond_op_len = batch.cond_op cond_op_len_list = cond_op_len.view(-1).tolist() for i, len_it in enumerate(cond_op_len_list): if len_it == 0: op_idx_batch_list.append([]) op_batch_list.append([]) else: idx_list = cond_op.data[0:len_it, i].contiguous().view(-1).tolist() op_idx_batch_list.append([ int(self.fields['cond_op'].vocab.itos[it]) for it in idx_list ]) op_batch_list.append(idx_list) else: lay_batch_list = lay_pred.view(-1).tolist() for lay_it in lay_batch_list: tk_list = self.fields['lay'].vocab.itos[lay_it].split(' ') if (len(tk_list) == 0) or (tk_list[0] == ''): op_idx_batch_list.append([]) op_batch_list.append([]) else: op_idx_batch_list.append( [int(op_str) for op_str in tk_list]) op_batch_list.append([ self.fields['cond_op'].vocab.stoi[op_str] for op_str in tk_list ]) # -> (num_cond, batch) cond_op = v_eval( add_pad( op_batch_list, self.fields['cond_op'].vocab.stoi[table.IO.PAD_WORD]).t()) cond_op_len = torch.LongTensor([len(it) for it in op_batch_list]) # emb_op -> (num_cond, batch, emb_size) if self.model.opt.layout_encode == 'rnn': emb_op = table.Models.encode_unsorted_batch( self.model.lay_encoder, cond_op, cond_op_len.clamp(min=1)) else: emb_op = self.model.cond_embedding(cond_op) # (2) decoding self.model.cond_decoder.attn.applyMaskBySeqBatch(q) cond_state = self.model.cond_decoder.init_decoder_state(q_all, q_enc) cond_col_list, cond_span_l_list, cond_span_r_list = [], [], [] t = 0 need_back_track = [False] * batch_size for emb_op_t in emb_op: t += 1 emb_op_t = emb_op_t.unsqueeze(0) cond_context, cond_state, _ = self.model.cond_decoder( emb_op_t, q_all, cond_state) # cond col -> (1, batch) cond_col_all = self.model.cond_col_match(cond_context, tbl_enc, tbl_mask).data cond_col = argmax(cond_col_all) # add to this after beam search: cond_col_list.append(cpu_vector(cond_col)) # emb_col batch_index = torch.LongTensor( range(batch_size)).unsqueeze_(0).cuda().expand( cond_col.size(0), cond_col.size(1)) emb_col = tbl_enc[cond_col, batch_index, :] cond_context, cond_state, _ = self.model.cond_decoder( emb_col, q_all, cond_state) # cond span q_mask = v_eval( q.data.eq(self.model.pad_word_index).transpose(0, 1)) cond_span_l_batch_all = self.model.cond_span_l_match( cond_context, q_all, q_mask).data cond_span_l_batch = argmax(cond_span_l_batch_all) # add to this after beam search: cond_span_l_list.append(cpu_vector(cond_span_l)) # emb_span_l: (1, batch, hidden_size) emb_span_l = q_all[cond_span_l_batch, batch_index, :] cond_span_r_batch = argmax( self.model.cond_span_r_match(cond_context, q_all, q_mask, emb_span_l).data) # add to this after beam search: cond_span_r_list.append(cpu_vector(cond_span_r)) if self.opt.beam_search: # for now just go through the next col in cond k = min(self.opt.beam_size, cond_col_all.size()[2]) top_col_idx = cond_col_all.topk(k)[1] for b in range(batch_size): if t > len(op_idx_batch_list[b]) or need_back_track[b]: continue idx = indices[b] agg = agg_pred[b] sel = sel_pred[b] cond = [] for i in range(t): op = op_idx_batch_list[b][i] if i < t - 1: col = cond_col_list[i][b] span_l = cond_span_l_list[i][b] span_r = cond_span_r_list[i][b] else: col = cond_col[0, b] span_l = cond_span_l_batch[0, b] span_r = cond_span_r_batch[0, b] cond.append((col, op, (span_l, span_r))) pred = ParseResult(idx, agg, sel, cond) pred.eval(js_list[idx], sql_list[idx], engine) n_test = 0 while pred.exception_raised and n_test < top_col_idx.size( )[2] - 1: n_test += 1 if n_test > self.opt.beam_size: need_back_track[b] = True break cond_col[0, b] = top_col_idx[0, b, n_test] emb_col = tbl_enc[cond_col, batch_index, :] cond_context, cond_state, _ = self.model.cond_decoder( emb_col, q_all, cond_state) # cond span q_mask = v_eval( q.data.eq(self.model.pad_word_index).transpose( 0, 1)) cond_span_l_batch_all = self.model.cond_span_l_match( cond_context, q_all, q_mask).data cond_span_l_batch = argmax(cond_span_l_batch_all) # emb_span_l: (1, batch, hidden_size) emb_span_l = q_all[cond_span_l_batch, batch_index, :] cond_span_r_batch = argmax( self.model.cond_span_r_match( cond_context, q_all, q_mask, emb_span_l).data) # run the new query over database col = cond_col[0, b] span_l = cond_span_l_batch[0, b] span_r = cond_span_r_batch[0, b] cond.pop() cond.append((col, op, (span_l, span_r))) pred = ParseResult(idx, agg, sel, cond) pred.eval(js_list[idx], sql_list[idx], engine) cond_col_list.append(cpu_vector(cond_col)) cond_span_l_list.append(cpu_vector(cond_span_l_batch)) cond_span_r_list.append(cpu_vector(cond_span_r_batch)) # emb_span_r: (1, batch, hidden_size) emb_span_r = q_all[cond_span_r_batch, batch_index, :] emb_span = self.model.span_merge( torch.cat([emb_span_l, emb_span_r], 2)) cond_context, cond_state, _ = self.model.cond_decoder( emb_span, q_all, cond_state) # (3) recover output r_list = [] for b in range(batch_size): idx = indices[b] agg = agg_pred[b] sel = sel_pred[b] cond = [] for i in range(len(op_batch_list[b])): col = cond_col_list[i][b] op = op_idx_batch_list[b][i] span_l = cond_span_l_list[i][b] span_r = cond_span_r_list[i][b] cond.append((col, op, (span_l, span_r))) r_list.append(ParseResult(idx, agg, sel, cond)) return r_list