def forward(self, q_emb_var, q_len, hs_emb_var, hs_len): max_q_len = max(q_len) max_hs_len = max(hs_len) B = len(q_len) q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) att_np_q = np.ones((B, max_q_len)) att_val_q = torch.from_numpy(att_np_q).float() att_val_q = Variable(att_val_q.cuda()) for idx, num in enumerate(q_len): if num < max_q_len: att_val_q[idx, num:] = -100 att_prob_q = self.softmax(att_val_q) q_weighted = (q_enc * att_prob_q.unsqueeze(2)).sum(1) # Same as the above, compute SQL history embedding weighted by column attentions att_np_h = np.ones((B, max_hs_len)) att_val_h = torch.from_numpy(att_np_h).float() att_val_h = Variable(att_val_h.cuda()) for idx, num in enumerate(hs_len): if num < max_hs_len: att_val_h[idx, num:] = -100 att_prob_h = self.softmax(att_val_h) hs_weighted = (hs_enc * att_prob_h.unsqueeze(2)).sum(1) # ao_score: (B, 2) ao_score = self.ao_out( self.ao_out_q(q_weighted) + int(self.use_hs) * self.ao_out_hs(hs_weighted)) # 06/14/2019: add softmax layer ao_score = F.softmax(ao_score) return ao_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) if self.use_ca: h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, :, num:] = -100 att = self.softmax(att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) else: h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) att_val = self.sel_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val) K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) K_sel_expand = K_sel.unsqueeze(1) sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \ self.sel_out_col(e_col) ).squeeze() max_col_num = max(col_num) for idx, num in enumerate(col_num): if num < max_col_num: sel_score[idx, num:] = -100 return sel_score
def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): max_q_len = max(q_len) max_hs_len = max(hs_len) max_col_len = max(col_len) B = len(q_len) q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) # get target/predicted column's embedding # col_emb: (B, hid_dim) col_emb = [] for b in range(B): col_emb.append(col_enc[b, gt_col[b]]) col_emb = torch.stack(col_emb) att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B,-1) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc[idx, num:] = -100 att_prob_qc = self.softmax(att_val_qc) q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) # Same as the above, compute SQL history embedding weighted by column attentions att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B,-1) for idx, num in enumerate(hs_len): if num < max_hs_len: att_val_hc[idx, num:] = -100 att_prob_hc = self.softmax(att_val_hc) hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) # rt_score: (B, 2) rt_score = self.rt_out(self.rt_out_q(q_weighted) + int(self.use_hs)* self.rt_out_hs(hs_weighted) + self.rt_out_c(col_emb)) return rt_score
def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): max_q_len = max(q_len) max_hs_len = max(hs_len) max_col_len = max(col_len) B = len(q_len) q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) col_emb = [] for b in range(B): col_emb.append(col_enc[b, gt_col[b]]) col_emb = torch.stack(col_emb) # Predict agg number att_val_qc_num = torch.bmm(col_emb.unsqueeze(1), self.q_num_att(q_enc).transpose(1, 2)).view(B, -1) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc_num[idx, num:] = -100 att_prob_qc_num = self.softmax(att_val_qc_num) q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1) # Same as the above, compute SQL history embedding weighted by column attentions att_val_hc_num = torch.bmm(col_emb.unsqueeze(1), self.hs_num_att(hs_enc).transpose(1, 2)).view(B, -1) for idx, num in enumerate(hs_len): if num < max_hs_len: att_val_hc_num[idx, num:] = -100 att_prob_hc_num = self.softmax(att_val_hc_num) hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1) # agg_num_score: (B, 4) agg_num_score = self.agg_num_out(self.agg_num_out_q(q_weighted_num) + int(self.use_hs)* self.agg_num_out_hs(hs_weighted_num) + self.agg_num_out_c(col_emb)) / self.T1 # Predict aggregators att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B, -1) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc[idx, num:] = -100 att_prob_qc = self.softmax(att_val_qc) q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) # Same as the above, compute SQL history embedding weighted by column attentions att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B, -1) for idx, num in enumerate(hs_len): if num < max_hs_len: att_val_hc[idx, num:] = -100 att_prob_hc = self.softmax(att_val_hc) hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) # agg_score: (B, 5) agg_score = self.agg_out(self.agg_out_q(q_weighted) + int(self.use_hs)* self.agg_out_hs(hs_weighted) + self.agg_out_c(col_emb)) / self.T2 score = (agg_num_score, agg_score) return score
def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var, mkw_len): # print("q_emb_shape:{} hs_emb_shape:{}".format(q_emb_var.size(), hs_emb_var.size())) max_q_len = max(q_len) max_hs_len = max(hs_len) B = len(q_len) # q_enc: (B, max_q_len, hid_dim) # hs_enc: (B, max_hs_len, hid_dim) # mkw: (B, 4, hid_dim) q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) mkw_enc, _ = run_lstm(self.mkw_lstm, mkw_emb_var, mkw_len) # Compute attention values between multi SQL key words and question tokens. # qmkw_att(q_enc).transpose(1, 2): (B, hid_dim, max_q_len) # att_val_qmkw: (B, 4, max_q_len) # print("mkw_enc {} q_enc {}".format(mkw_enc.size(), self.q_att(q_enc).transpose(1, 2).size())) att_val_qmkw = torch.bmm(mkw_enc, self.q_att(q_enc).transpose(1, 2)) # assign appended positions values -100 for idx, num in enumerate(q_len): if num < max_q_len: att_val_qmkw[idx, :, num:] = -100 # att_prob_qmkw: (B, 4, max_q_len) att_prob_qmkw = self.softmax(att_val_qmkw.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_enc.unsqueeze(1): (B, 1, max_q_len, hid_dim) # att_prob_qmkw.unsqueeze(3): (B, 4, max_q_len, 1) # q_weighted: (B, 4, hid_dim) q_weighted = (q_enc.unsqueeze(1) * att_prob_qmkw.unsqueeze(3)).sum(2) # Same as the above, compute SQL history embedding weighted by key words attentions att_val_hsmkw = torch.bmm(mkw_enc, self.hs_att(hs_enc).transpose(1, 2)) for idx, num in enumerate(hs_len): if num < max_hs_len: att_val_hsmkw[idx, :, num:] = -100 att_prob_hsmkw = self.softmax(att_val_hsmkw.view( (-1, max_hs_len))).view(B, -1, max_hs_len) hs_weighted = (hs_enc.unsqueeze(1) * att_prob_hsmkw.unsqueeze(3)).sum(2) # Compute prediction scores # self.multi_out.squeeze(): (B, 4, 1) -> (B, 4) mulit_score = self.multi_out( self.multi_out_q(q_weighted) + int(self.use_hs) * self.multi_out_hs(hs_weighted) + self.multi_out_c(mkw_enc)).view(B, -1) # 06/14/2019: add softmax layer mulit_score = F.softmax(mulit_score) return mulit_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_emb_var) max_x_len = max(x_len) # compute the E_col e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) if self.use_ca: # compute the hidden states output of lstm corresponding to question h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) # compute v # using multiplicative attention att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) h_ext = h_enc.unsqueeze(1).unsqueeze(1) col_ext = e_col.unsqueeze(2).unsqueeze(2) # additive attention att_val = self.sel_att_out( self.sel_W1(h_ext) + self.sel_W2(col_ext)).squeeze() # print() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, :, num:] = -100 # L-dimension attention weight att = self.softmax(att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) # compute E_Q|col K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) else: # normal method without column attention h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) att_val = self.sel_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val) K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) K_sel_expand = K_sel.unsqueeze(1) # compute P_selcol(i|Q) sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \ self.sel_out_col(e_col) ).squeeze() max_col_num = max(col_num) for idx, num in enumerate(col_num): if num < max_col_num: sel_score[idx, num:] = -100 return sel_score
def forward(self, x_emb_var, x_len, agg_emb_var, col_inp_var=None, col_len=None): B = len(x_emb_var) max_x_len = max(x_len) h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len) agg_enc = self.agg_out_agg(agg_emb_var) #agg_enc: (B, 6, hid_dim) #self.sel_att(h_enc) -> (B, max_x_len, hid_dim) .transpose(1, 2) -> (B, hid_dim, max_x_len) #att_val_agg: (B, 6, max_x_len) att_val_agg = torch.bmm(agg_enc, self.sel_att(h_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: att_val_agg[idx, :, num:] = -100 #att_agg: (B, 6, max_x_len) att_agg = self.softmax(att_val_agg.view( (-1, max_x_len))).view(B, -1, max_x_len) #h_enc.unsqueeze(1) -> (B, 1, max_x_len, hid_dim) #att_agg.unsqueeze(3) -> (B, 6, max_x_len, 1) #K_agg_expand -> (B, 6, hid_dim) K_agg_expand = (h_enc.unsqueeze(1) * att_agg.unsqueeze(3)).sum(2) #agg_score = self.agg_out(K_agg) agg_score = self.agg_out_f( self.agg_out_se(agg_emb_var) + self.agg_out_K(K_agg_expand)).squeeze() return agg_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): # Based on number of selections to predict select-column B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) # [bs, col_num, hid] h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) # [bs, seq_len, hid] att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose( 1, 2)) # [bs, col_num, seq_len] for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, :, num:] = -100 att = self.softmax(att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + self.sel_out_col(e_col)).squeeze() max_col_num = max(col_num) for idx, num in enumerate(col_num): if num < max_col_num: sel_score[idx, num:] = -100 return sel_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_len) max_x_len = max(x_len) # Predict the condition relationship part # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict select number e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.where_rela_lstm) col_att_val = self.where_rela_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): col_att_val[idx, num:] = -1000000 num_col_att = self.softmax(col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) h1 = self.col2hid1(K_num_col).view(B, 4, self.N_h // 2).transpose( 0, 1).contiguous() h2 = self.col2hid2(K_num_col).view(B, 4, self.N_h // 2).transpose( 0, 1).contiguous() h_enc, _ = run_lstm(self.where_rela_lstm, x_emb_var, x_len, hidden=(h1, h2)) att_val = self.where_rela_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -1000000 att_val = self.softmax(att_val) where_rela_num = (h_enc * att_val.unsqueeze(2).expand_as(h_enc)).sum(1) where_rela_score = self.where_rela_out(where_rela_num) return where_rela_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_len) max_x_len = max(x_len) # Predict the number of select part # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict select number e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_num_lstm) num_col_att_val = self.sel_num_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): num_col_att_val[idx, num:] = -1000000 num_col_att = self.softmax(num_col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) sel_num_h1 = self.sel_num_col2hid1(K_num_col).view(B, 4, self.N_h/2).transpose(0,1).contiguous() sel_num_h2 = self.sel_num_col2hid2(K_num_col).view(B, 4, self.N_h/2).transpose(0,1).contiguous() h_num_enc, _ = run_lstm(self.sel_num_lstm, x_emb_var, x_len, hidden=(sel_num_h1, sel_num_h2)) num_att_val = self.sel_num_att(h_num_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: num_att_val[idx, num:] = -1000000 num_att = self.softmax(num_att_val) K_sel_num = (h_num_enc * num_att.unsqueeze(2).expand_as( h_num_enc)).sum(1) sel_num_score = self.sel_num_out(K_sel_num) return sel_num_score
def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None, col_len=None, col_num=None, gt_sel=None): B = len(x_emb_var) max_x_len = max(x_len) h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len) if self.use_ca: e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc) chosen_sel_idx = torch.LongTensor(gt_sel) aux_range = torch.LongTensor(range(len(gt_sel))) if x_emb_var.is_cuda: chosen_sel_idx = chosen_sel_idx.cuda() aux_range = aux_range.cuda() chosen_e_col = e_col[aux_range, chosen_sel_idx] att_val = torch.bmm(self.agg_att(h_enc), chosen_e_col.unsqueeze(2)).squeeze() else: att_val = self.agg_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val) K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) agg_score = self.agg_out(K_agg) return agg_score
def forward(self, x_emb_var, x_len, col_inp_var=None, col_len=None, col_name_len=None, x_type_emb_var=None, gt_sel=None, sel_cond_score=None): B = len(x_emb_var) max_x_len = max(x_len) x_emb_concat = torch.cat((x_emb_var, x_type_emb_var), 2) chosen_sel_col_gt = [] if gt_sel is None: if sel_cond_score is None: raise Exception( """In the test mode, sel_num_score and sel_col_score should be passed in order to predict aggregation!""" ) sel_num_score, _, sel_score, _ = sel_cond_score sel_nums = np.argmax(sel_num_score.data.cpu().numpy(), axis=1) sel_col_scores = sel_score.data.cpu().numpy() chosen_sel_col_gt = [ list(np.argsort(-sel_col_scores[b])[:sel_nums[b]]) for b in range(len(sel_nums)) ] else: chosen_sel_col_gt = [[x for x in one_gt_sel] for one_gt_sel in gt_sel] h_enc, _ = run_lstm(self.agg_lstm, x_emb_concat, x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc) #e_col, _ = run_lstm(self.agg_col_name_enc, col_inp_var, col_len) sel_col_emb = [] for b in range(B): cur_sel_col_emb = torch.stack( [e_col[b, x] for x in chosen_sel_col_gt[b]] + [e_col[b, 0]] * (4 - len(chosen_sel_col_gt[b])) ) # Pad the columns to maximum (4) sel_col_emb.append(cur_sel_col_emb) sel_col_emb = torch.stack(sel_col_emb) agg_att_val = torch.matmul( self.agg_att(h_enc).unsqueeze(1), sel_col_emb.unsqueeze(3)).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: agg_att_val[idx, :, num:] = -100 agg_att = self.softmax(agg_att_val.view(B * 4, -1)).view(B, 4, -1) K_agg = (h_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2) agg_score = self.agg_out( self.agg_out_K(K_agg) + self.col_out_col(sel_col_emb)).squeeze() return agg_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, dropout_rate=0.): B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc, dropout_rate=dropout_rate) if self.use_ca: h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len, dropout_rate=dropout_rate) att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, :, num:] = -100 att = self.softmax(att_val.view((-1, max_x_len))).view( B, -1, max_x_len) K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) else: h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len, dropout_rate=dropout_rate) att_val = self.sel_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val) K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) K_sel_expand=K_sel.unsqueeze(1) if dropout_rate > 0.: # K_sel_expand: [batch_size, col_size, hid_size] K_sel_expand_mask = torch.FloatTensor(K_sel_expand.size()[-1]).view(1, 1, -1)\ .fill_(1. - dropout_rate).bernoulli().div_(1. - dropout_rate) if K_sel_expand.is_cuda: K_sel_expand_mask = K_sel_expand_mask.cuda() K_sel_expand.data = K_sel_expand.data * K_sel_expand_mask sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \ self.sel_out_col(e_col) ).squeeze() / self.T max_col_num = max(col_num) for idx, num in enumerate(col_num): if num < max_col_num: sel_score[idx, num:] = -100 return sel_score
def forward(self,x_emb,x_len,col_input,col_token_num,col_len,hidden=None): batch_size=len(x_emb) max_x_len=max(x_len) emb_col , _ = column_encode(self.select_colname_enc,col_input,col_token_num,col_len) hidden = None if not hidden: h_enc, _ =run_lstm(self.select_lstm, x_emb,x_len ) else: h_enc , _ = run_lstm(self.select_lstm,x_emb,x_len,hidden) #to compute the attention score attn_value=self.select_att(h_enc).squeeze(2) for idx,num in enumerate(x_len): if num<max_x_len: attn_value[idx,num:]=-100 attention=F.softmax(attn_value,1) K_select=(h_enc*attention.unsqueeze(2).expand_as(h_enc)).sum(1) K_select_expand=K_select.unsqueeze(1) select_score = self.select_out(self.select_out_K(K_select_expand) + self.select_out_col(emb_col)).squeeze(2) max_col_num=max(col_len) for idx,num in enumerate(col_len): if num<max_col_num: select_score[idx,num:]= -100 return select_score
def forward(self, q, q_len, hidden): max_q_len = max( q_len ) # For the purpose of padding upto length of the largest question output, hidden = run_lstm(self.rnn, q, q_len, hidden) att_val = self.attn(output).squeeze(2) for idx, num in enumerate(q_len): if num < max_q_len: att_val[ idx, num:] = -100 # Give attention value -100 to words that do not belong to question att = F.softmax(att_val, dim=1) k_agg = (output * att.unsqueeze(2).expand_as(output)).sum(1) agg_score = self.agg_out(k_agg) return agg_score
def op_forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, chosen_col_gt, dropout_rate=0.): B = len(x_len) max_x_len = max(x_len) e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_op_name_enc, dropout_rate=dropout_rate) col_emb = [] for b in range(B): cur_col_emb = torch.stack([e_cond_col[b, x] for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len, dropout_rate=dropout_rate) if self.use_ca: op_att_val = torch.matmul(self.cond_op_att(h_op_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze(-1) for idx, num in enumerate(x_len): if num < max_x_len: op_att_val[idx, :, num:] = -100 op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1) K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2) else: op_att_val = self.cond_op_att(h_op_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: op_att_val[idx, num:] = -100 op_att = self.softmax(op_att_val) K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1) if dropout_rate > 0.: # K_cond_op: [batch_size, op_size, hid_size] K_cond_op_mask = torch.FloatTensor(K_cond_op.size()[-1]).view(1, 1, -1)\ .fill_(1. - dropout_rate).bernoulli().div_(1. - dropout_rate) if K_cond_op.is_cuda: K_cond_op_mask = K_cond_op_mask.cuda() K_cond_op.data = K_cond_op.data * K_cond_op_mask cond_op_score = self.cond_op_out(self.cond_op_out_K(K_cond_op) + self.cond_op_out_col(col_emb)).squeeze(-1) / self.T3 return cond_op_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_len) max_x_len = max(x_len) '''预测所选零件的数量,首先使用列嵌入来计算初始隐藏单元,然后运行LSTM并预测select number''' # Predict the number of select part # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict select number e_num_col, col_num = col_name_encode( col_inp_var, col_name_len, col_len, self.sel_num_lstm ) #e_numcol:对列进行编码[batch_size,max_len,embedding_size] num_col_att_val = self.sel_num_col_att(e_num_col).squeeze( -1) #【16,14】batch=19,每个列有一个得分 for idx, num in enumerate(col_num): if num < max(col_num): num_col_att_val[idx, num:] = -1000000 #对于那些小与最大值的补齐,给一个特小的值 num_col_att = self.softmax(num_col_att_val) #softmax层【16,19】 K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum( 1) #[16,100]#对列编码后乘以它的softmax得分 sel_num_h1 = self.sel_num_col2hid1(K_num_col).view( B, 4, self.N_h // 2).transpose(0, 1).contiguous() #【4,16,50】 sel_num_h2 = self.sel_num_col2hid2(K_num_col).view( B, 4, self.N_h // 2).transpose(0, 1).contiguous() #对问题进行编码 h_num_enc, _ = run_lstm(self.sel_num_lstm, x_emb_var, x_len, hidden=(sel_num_h1, sel_num_h2)) #【16,49,100】 num_att_val = self.sel_num_att(h_num_enc).squeeze( -1) #[batch.size, max_len]=【16,49】 for idx, num in enumerate(x_len): if num < max_x_len: num_att_val[idx, num:] = -1000000 num_att = self.softmax(num_att_val) K_sel_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum( 1) #【16,100】 sel_num_score = self.sel_num_out(K_sel_num) #【16,4】 return sel_num_score #对问题的编码
def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None, col_len=None, col_num=None, gt_sel=None): B = len(x_emb_var) max_x_len = max(x_len) # compute the hidden states output of lstm corresponding to question h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len) if self.use_ca: # compute the E_col e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc) chosen_sel_idx = torch.LongTensor(gt_sel) aux_range = torch.LongTensor(range(len(gt_sel))) if x_emb_var.is_cuda: chosen_sel_idx = chosen_sel_idx.cuda() aux_range = aux_range.cuda() # chosen_e_col is v chosen_e_col = e_col[aux_range, chosen_sel_idx] att_val = torch.bmm(self.agg_att(h_enc), chosen_e_col.unsqueeze(2)).squeeze() else: att_val = self.agg_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 # att is attention weight att = self.softmax(att_val) # K_agg is E_Q|col K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) agg_score = self.agg_out(K_agg) return agg_score
def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None, col_len=None, col_num=None, gt_sel=None): B = len(x_emb_var) max_x_len = max(x_len) h_enc, _ = run_lstm( self.agg_lstm, x_emb_var, x_len) # this is the hidden state for each token in the question if self.use_ca: e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc) # the rest is the decoder portion correct?? chosen_sel_idx = torch.LongTensor(gt_sel) aux_range = torch.LongTensor(range(len(gt_sel))) if x_emb_var.is_cuda: chosen_sel_idx = chosen_sel_idx.cuda() aux_range = aux_range.cuda() chosen_e_col = e_col[aux_range, chosen_sel_idx] att_val = torch.bmm(self.agg_att(h_enc), chosen_e_col.unsqueeze(2)).squeeze() else: att_val = self.agg_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[ idx, num:] = -100 # make sure the padded numbers have softmax ~= 0 att = self.softmax(att_val) K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) agg_item_score = self.agg_out(K_agg) return (agg_num_score, agg_item_score)
def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None, col_len=None, col_num=None, gt_sel=None, gt_sel_num=None): B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc) #【16,19,100】 h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len) #[16,49,100] col_emb = [] for b in range(B): cur_col_emb = torch.stack([e_col[b, x] for x in gt_sel[b]] + [e_col[b, 0]] * (4 - len(gt_sel[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) #[16,4,100] att_val = torch.matmul( self.agg_att(h_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() #[16,4,49] for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val.view(B * 4, -1)).view(B, 4, -1) #[16,4,49] K_agg = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) agg_score = self.agg_out( self.agg_out_K(K_agg) + self.col_out_col(col_emb)) #.squeeze()为了单个样本注释掉后面的.squeeze() return agg_score #[16,4,6]
def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None, col_len=None, col_num=None, gt_sel=None, dropout_rate=0.): B = len(x_emb_var) max_x_len = max(x_len) h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len, dropout_rate=dropout_rate) if self.use_ca: e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc, dropout_rate=dropout_rate) chosen_sel_idx = torch.LongTensor(gt_sel) aux_range = torch.LongTensor(range(len(gt_sel))) if x_emb_var.is_cuda: chosen_sel_idx = chosen_sel_idx.cuda() aux_range = aux_range.cuda() chosen_e_col = e_col[aux_range, chosen_sel_idx] att_val = torch.bmm(self.agg_att(h_enc), chosen_e_col.unsqueeze(2)).squeeze(-1) #squeeze dim=-1 else: att_val = self.agg_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val) K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) if dropout_rate > 0.: # K_agg: [batch_size, hid_size] K_agg_mask = torch.FloatTensor(K_agg.size()[-1]).view(1, -1)\ .fill_(1. - dropout_rate).bernoulli().div_(1. - dropout_rate) if K_agg.is_cuda: K_agg_mask = K_agg_mask.cuda() K_agg.data = K_agg.data * K_agg_mask agg_score = self.agg_out(K_agg) / self.T return agg_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): ''' Based on number of selections to predict select-column input: x_emb_var: embedding of each question col_inp_var: embedding of each header col_name_len: length of each header col_len: number of headers in each table, array type col_num: number of headers in each table, list type ''' B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) # [bs, col_num, hid] h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) # [bs, seq_len, hid] h_enc is output, used for attention object # e_col: [batch_size(16), max_num_of_col_in_train_tab, hidden_size(100)] # h_enc: [batch_size(16), max_len_of_question, hidden_size(100)] # att_val: [bs[16], max_num_of_col_in_train_tab, max_len_of_question] att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) # [bs, col_num, seq_len] for idx, num in enumerate(x_len): if num < max_x_len: # column hidden status will have new value when attention on the question,while the some part of # question is of no use on attention calculate. att_val[idx, :, num:] = -100 att = self.softmax(att_val.view((-1, max_x_len))).view(B, -1, max_x_len) K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + self.sel_out_col(e_col) ).squeeze() max_col_num = max(col_num) for idx, num in enumerate(col_num): if num < max_col_num: sel_score[idx, num:] = -100 return sel_score
def forward_mult(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): # Predict the number of conditions # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict condition number. # exit(1) # debug_print('col_inp_var', col_inp_var) # debug_print('col_name_len', col_name_len) # debug_print('col_len', col_len) # debug_print('col_num', col_num) B = len(x_len) max_x_len = max(x_len) e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_num_name_enc) num_col_att_val = self.sel_num_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): num_col_att_val[idx, num:] = -100 # get a probability distribution of how many columns likely to be selected num_col_att = self.softmax(num_col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum( 1) # not really sure what this is doing sel_num_h1 = self.sel_num_col2hid1(K_num_col).view( B, -1, self.N_h / 2 ).transpose(0, 1).contiguous( ) # not really sure what the second dimension should be - previously was 4 sel_num_h2 = self.sel_num_col2hid2(K_num_col).view( B, -1, self.N_h / 2).transpose(0, 1).contiguous() h_num_enc, _ = run_lstm(self.col_num_lstm, x_emb_var, x_len, hidden=(sel_num_h1, sel_num_h2)) num_att_val = self.sel_num_att(h_num_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: num_att_val[idx, num:] = -100 num_att = self.softmax(num_att_val) K_sel_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1) sel_num_score = self.sel_num_out(K_sel_num) #Predict the columns of conditions e_sel_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) h_col_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) if self.use_ca: h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) sel_att_val = torch.bmm(e_sel_col, self.sel_att(h_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) else: # col_att_val = self.cond_col_att(h_col_enc).squeeze() # for idx, num in enumerate(x_len): # if num < max_x_len: # col_att_val[idx, num:] = -100 # col_att = self.softmax(col_att_val) # K_cond_col = (h_col_enc * # col_att_val.unsqueeze(2)).sum(1).unsqueeze(1) sel_att_val = self.sel_col_att(h_col_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: sel_att_val[idx, num:] = -100 sel_att = self.softmax(sel_att_val) # print 'att_probabilities', sel_att K_sel_col = (h_col_enc * sel_att_val.unsqueeze(2)).sum(1).unsqueeze(1) sel_col_score = self.sel_col_out( self.sel_out_K(K_sel_col) + self.sel_out_col(e_sel_col)).squeeze() # print 'sel_col_score', sel_col_score max_col_num = max(col_num) for b, num in enumerate(col_num): if num < max_col_num: sel_col_score[b, num:] = -100 sel_score = (sel_num_score, sel_col_score) return sel_score
def forward(self, x_emb_var, x_len, col_inp_var, col_len, x_type_emb_var, gt_where, gt_cond, sel_cond_score=None, x_pos_emb_var=None): max_x_len = max(x_len) max_col_len = max(col_len) B = len(x_len) #Predict the operator of conditions chosen_col_gt = [] if gt_cond is None: if sel_cond_score is None: raise Exception( """In the test mode, cond_num_score and cond_col_score should be passed in order to predict condition op and str!""" ) cond_num_score, _, cond_col_score = sel_cond_score cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1) col_scores = cond_col_score.data.cpu().numpy() chosen_col_gt = [ list(np.argsort(-col_scores[b])[:cond_nums[b]]) for b in range(len(cond_nums)) ] else: chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond] if self.types: # with type embeddings concatentation if self.POS: x_emb_concat = torch.cat( (x_emb_var, x_type_emb_var, x_pos_emb_var), 2) else: x_emb_concat = torch.cat((x_emb_var, x_type_emb_var), 2) else: if self.POS: x_emb_concat = torch.cat((x_emb_var, x_pos_emb_var), 2) else: x_emb_concat = x_emb_var h_enc, _ = run_lstm(self.cond_opstr_lstm, x_emb_concat, x_len) e_col, _ = run_lstm(self.cond_name_enc, col_inp_var, col_len) col_emb = [] for b in range(B): cur_col_emb = torch.stack( [e_col[b, x] for x in chosen_col_gt[b]] + [e_col[b, 0]] * (4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) op_att_val = torch.matmul( self.cond_op_att(h_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: op_att_val[idx, :, num:] = -100 op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1) K_cond_op = (h_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2) cond_op_score = self.cond_op_out( self.cond_op_out_K(K_cond_op) + self.cond_op_out_col(col_emb)).squeeze() #Predict the string of conditions if self.types: xt_str_enc = self.cond_str_x_type(x_type_emb_var) if self.POS: xpos_str_enc = self.cond_str_x_pos(x_pos_emb_var) col_emb = [] for b in range(B): cur_col_emb = torch.stack([e_col[b, x] for x in chosen_col_gt[b]] + [e_col[b, 0]] * (4 - len(chosen_col_gt[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) if gt_where is not None: gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where) g_str_s_flat, _ = self.cond_str_decoder( gt_tok_seq.view(B * 4, -1, self.max_tok_num)) g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h) h_ext = h_enc.unsqueeze(1).unsqueeze(1) ## CHANGES ## - ## FOR BERT IMPLEMENTATION ## # for BERT, don't use hidden representation of type embeddings # and comment out line below # g_ext = g_str_s.unsqueeze(3) col_ext = col_emb.unsqueeze(2).unsqueeze(2) if self.types: # with type embeddings concatenation ht_ext = xt_str_enc.unsqueeze(1).unsqueeze(1) if self.POS: hpos_ext = xpos_str_enc.unsqueeze(1).unsqueeze(1) cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext) + self.cond_str_out_ht(ht_ext) + self.cond_str_out_pos(hpos_ext)).squeeze() else: cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext) + self.cond_str_out_ht(ht_ext)).squeeze() else: # without type embeddings concatenation if self.POS: hpos_ext = xpos_str_enc.unsqueeze(1).unsqueeze(1) cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext) + self.cond_str_out_pos(hpos_ext)).squeeze() else: cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cond_str_score[b, :, :, num:] = -100 else: h_ext = h_enc.unsqueeze(1).unsqueeze(1) col_ext = col_emb.unsqueeze(2).unsqueeze(2) scores = [] t = 0 #TODO: maybe we should store BERT's [CLS] and [SEP] tokens somewhere? init_inp = np.zeros((B * 4, 1, self.max_tok_num), dtype=np.float32) init_inp[:, 0, 0] = 1 #Set the <BEG> token #TODO: for BERT rather [CLS] token? if self.gpu: cur_inp = Variable(torch.from_numpy(init_inp).cuda()) else: cur_inp = Variable(torch.from_numpy(init_inp)) cur_h = None while t < 50: if cur_h: g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h) else: g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp) g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h) g_ext = g_str_s.unsqueeze(3) if self.types: # with type embeddings concatenation ht_ext = xt_str_enc.unsqueeze(1).unsqueeze(1) if self.POS: hpos_ext = xpos_str_enc.unsqueeze(1).unsqueeze(1) cur_cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext) + self.cond_str_out_ht(ht_ext) + self.cond_str_out_pos(hpos_ext)).squeeze() else: cur_cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext) + self.cond_str_out_ht(ht_ext)).squeeze() else: # without type embeddings concatenation if self.POS: hpos_ext = xpos_str_enc.unsqueeze(1).unsqueeze(1) cur_cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext) + self.cond_str_out_pos(hpos_ext)).squeeze() else: cur_cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cur_cond_str_score[b, :, num:] = -100 scores.append(cur_cond_str_score) _, ans_tok_var = cur_cond_str_score.view(B * 4, max_x_len).max(1) ans_tok = ans_tok_var.data.cpu() data = torch.zeros(B * 4, self.max_tok_num).scatter_( 1, ans_tok.unsqueeze(1), 1) if self.gpu: #To one-hot cur_inp = Variable(data.cuda()) else: cur_inp = Variable(data) cur_inp = cur_inp.unsqueeze(1) t += 1 cond_str_score = torch.stack(scores, 2) for b, num in enumerate(x_len): if num < max_x_len: cond_str_score[b, :, :, num:] = -100 #[B, IDX, T, TOK_NUM] cond_op_str_score = (cond_op_score, cond_str_score) return cond_op_str_score
def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): max_q_len = max(q_len) max_hs_len = max(hs_len) max_col_len = max(col_len) B = len(q_len) q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) # get target/predicted column's embedding # col_emb: (B, hid_dim) col_emb = [] for b in range(B): col_emb.append(col_enc[b, gt_col[b]]) col_emb = torch.stack(col_emb) # Predict op number att_val_qc_num = torch.bmm(col_emb.unsqueeze(1), self.q_num_att(q_enc).transpose(1, 2)).view( B, -1) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc_num[idx, num:] = -100 att_prob_qc_num = self.softmax(att_val_qc_num) q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1) # Same as the above, compute SQL history embedding weighted by column attentions att_val_hc_num = torch.bmm(col_emb.unsqueeze(1), self.hs_num_att(hs_enc).transpose(1, 2)).view( B, -1) for idx, num in enumerate(hs_len): if num < max_hs_len: att_val_hc_num[idx, num:] = -100 att_prob_hc_num = self.softmax(att_val_hc_num) hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1) # op_num_score: (B, 2) op_num_score = self.op_num_out( self.op_num_out_q(q_weighted_num) + int(self.use_hs) * self.op_num_out_hs(hs_weighted_num) + self.op_num_out_c(col_emb)) # Compute attention values between selected column and question tokens. # q_enc.transpose(1, 2): (B, hid_dim, max_q_len) # col_emb.unsqueeze(1): (B, 1, hid_dim) # att_val_qc: (B, max_q_len) # print("col_emb {} q_enc {}".format(col_emb.unsqueeze(1).size(),self.q_att(q_enc).transpose(1, 2).size())) att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B, -1) # assign appended positions values -100 for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc[idx, num:] = -100 # att_prob_qc: (B, max_q_len) att_prob_qc = self.softmax(att_val_qc) # q_enc: (B, max_q_len, hid_dim) # att_prob_qc.unsqueeze(2): (B, max_q_len, 1) # q_weighted: (B, hid_dim) q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) # Same as the above, compute SQL history embedding weighted by column attentions att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B, -1) for idx, num in enumerate(hs_len): if num < max_hs_len: att_val_hc[idx, num:] = -100 att_prob_hc = self.softmax(att_val_hc) hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) # Compute prediction scores # op_score: (B, 10) op_score = self.op_out( self.op_out_q(q_weighted) + int(self.use_hs) * self.op_out_hs(hs_weighted) + self.op_out_c(col_emb)) score = (op_num_score, op_score) return score
def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len): max_q_len = max(q_len) max_hs_len = max(hs_len) max_col_len = max(col_len) B = len(q_len) q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) # Predict column number: 1-3 # att_val_qc_num: (B, max_col_len, max_q_len) att_val_qc_num = torch.bmm(col_enc, self.q_num_att(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: att_val_qc_num[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc_num[idx, :, num:] = -100 att_prob_qc_num = self.softmax(att_val_qc_num.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted_num: (B, hid_dim) q_weighted_num = (q_enc.unsqueeze(1) * att_prob_qc_num.unsqueeze(3)).sum(2).sum(1) # Same as the above, compute SQL history embedding weighted by column attentions # att_val_hc_num: (B, max_col_len, max_hs_len) att_val_hc_num = torch.bmm(col_enc, self.hs_num_att(hs_enc).transpose(1, 2)) for idx, num in enumerate(hs_len): if num < max_hs_len: att_val_hc_num[idx, :, num:] = -100 for idx, num in enumerate(col_len): if num < max_col_len: att_val_hc_num[idx, num:, :] = -100 att_prob_hc_num = self.softmax(att_val_hc_num.view( (-1, max_hs_len))).view(B, -1, max_hs_len) hs_weighted_num = (hs_enc.unsqueeze(1) * att_prob_hc_num.unsqueeze(3)).sum(2).sum(1) # self.col_num_out: (B, 3) col_num_score = self.col_num_out( self.col_num_out_q(q_weighted_num) + int(self.use_hs) * self.col_num_out_hs(hs_weighted_num)) for idx, num in enumerate(col_len): if num < 5: col_num_score[idx, num:] = -100 # Predict columns. att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2)) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc[idx, :, num:] = -100 att_prob_qc = self.softmax(att_val_qc.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted: (B, max_col_len, hid_dim) q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2) # Same as the above, compute SQL history embedding weighted by column attentions att_val_hc = torch.bmm(col_enc, self.hs_att(hs_enc).transpose(1, 2)) for idx, num in enumerate(hs_len): if num < max_hs_len: att_val_hc[idx, :, num:] = -100 att_prob_hc = self.softmax(att_val_hc.view( (-1, max_hs_len))).view(B, -1, max_hs_len) hs_weighted = (hs_enc.unsqueeze(1) * att_prob_hc.unsqueeze(3)).sum(2) # Compute prediction scores # self.col_out.squeeze(): (B, max_col_len) col_score = self.col_out( self.col_out_q(q_weighted) + int(self.use_hs) * self.col_out_hs(hs_weighted) + self.col_out_c(col_enc)).view(B, -1) for idx, num in enumerate(col_len): if num < max_col_len: col_score[idx, num:] = -100 score = (col_num_score, col_score) return score
def forward(self, q_emb_var, q_len, col_emb_var, col_len, x_type_emb_var): max_q_len = max(q_len) max_col_len = max(col_len) B = len(q_len) x_emb_concat = torch.cat((q_emb_var, x_type_emb_var), 2) q_enc, _ = run_lstm(self.q_lstm, x_emb_concat, q_len) col_enc, _ = run_lstm(self.col_lstm, col_emb_var, col_len) # Predict number gby_num_att = torch.bmm(col_enc, self.gby_num_h(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: gby_num_att[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: gby_num_att[idx, :, num:] = -100 gby_num_att_val = self.softmax(gby_num_att.view( (-1, max_q_len))).view(B, -1, max_q_len) gby_num_K = (q_enc.unsqueeze(1) * gby_num_att_val.unsqueeze(3)).sum(2).sum(1) ody_num_score = self.gby_num_out(self.gby_num_l(gby_num_K)) # Predict columns. att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2)) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc[idx, :, num:] = -100 att_prob_qc = self.softmax(att_val_qc.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted: (B, max_col_len, hid_dim) q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2) # Compute prediction scores # self.col_out.squeeze(): (B, max_col_len) col_score = self.col_out( self.col_out_q(q_weighted) + self.col_out_c(col_enc)).squeeze() for idx, num in enumerate(col_len): if num < max_col_len: col_score[idx, num:] = -100 # Predict aggregation agg_att_val = torch.bmm(col_enc, self.agg_att(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: agg_att_val[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: agg_att_val[idx, :, num:] = -100 agg_att = self.softmax(agg_att_val.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted_num: (B, hid_dim) q_weighted_agg = (q_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2).sum(1) # self.col_num_out: (B, 4) agg_score = self.agg_out(self.agg_out_q(q_weighted_agg)) # Predict desc asc limit dat_att_val = torch.bmm(col_enc, self.dat_att(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: dat_att_val[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: dat_att_val[idx, :, num:] = -100 dat_att = self.softmax(dat_att_val.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted_num: (B, hid_dim) q_weighted_dat = (q_enc.unsqueeze(1) * dat_att.unsqueeze(3)).sum(2).sum(1) # self.col_num_out: (B, 4) dat_score = self.dat_out(self.dat_out_q(q_weighted_dat)) score = (ody_num_score, col_score, agg_score, dat_score) return score
def forward(self, q_emb_var, q_len, col_emb_var, col_len, col_num, col_name_len, gt_sel): max_q_len = max(q_len) max_col_len = max(col_len) B = len(q_len) q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) #col_enc, _ = run_lstm(self.col_lstm, col_emb_var, col_len) # Predict column number: 1-3 # att_val_qc_num: (B, max_col_len, max_q_len) att_val_qc_num = torch.bmm(col_enc, self.q_num_att(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: att_val_qc_num[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc_num[idx, :, num:] = -100 att_prob_qc_num = self.softmax(att_val_qc_num.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted_num: (B, hid_dim) q_weighted_num = (q_enc.unsqueeze(1) * att_prob_qc_num.unsqueeze(3)).sum(2).sum(1) # self.col_num_out: (B, 4) col_num_score = self.col_num_out(self.col_num_out_q(q_weighted_num)) # Predict columns. att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2)) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc[idx, :, num:] = -100 att_prob_qc = self.softmax(att_val_qc.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted: (B, max_col_len, hid_dim) q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2) # Compute prediction scores # self.col_out.squeeze(): (B, max_col_len) col_score = self.col_out( self.col_out_q(q_weighted) + self.col_out_c(col_enc)).squeeze() for idx, num in enumerate(col_len): if num < max_col_len: col_score[idx, num:] = -100 # get select columns for agg prediction chosen_sel_gt = [] if gt_sel is None: sel_nums = [ x + 1 for x in list( np.argmax(col_num_score.data.cpu().numpy(), axis=1)) ] sel_col_scores = col_score.data.cpu().numpy() chosen_sel_gt = [ list(np.argsort(-sel_col_scores[b])[:sel_nums[b]]) for b in range(len(sel_nums)) ] else: for x in gt_sel: curr = x[0] curr_sel = [curr] for col in x: if col != curr: curr_sel.append(col) chosen_sel_gt.append(curr_sel) col_emb = [] for b in range(B): cur_col_emb = torch.stack( [col_enc[b, x] for x in chosen_sel_gt[b]] + [col_enc[b, 0]] * (5 - len(chosen_sel_gt[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) # (B, 4, hd) # Predict aggregation # q_enc.unsqueeze(1): (B, 1, max_x_len, hd) # col_emb.unsqueeze(3): (B, 4, hd, 1) # agg_num_att_val.squeeze: (B, 4, max_x_len) agg_num_att_val = torch.matmul( self.agg_num_att(q_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() for idx, num in enumerate(q_len): if num < max_q_len: agg_num_att_val[idx, :, num:] = -100 agg_num_att = self.softmax(agg_num_att_val.view(-1, max_q_len)).view( B, -1, max_q_len) q_weighted_agg_num = (q_enc.unsqueeze(1) * agg_num_att.unsqueeze(3)).sum(2) # (B, 4, 4) agg_num_score = self.agg_num_out( self.agg_num_out_q(q_weighted_agg_num) + self.agg_num_out_c(col_emb)).squeeze() agg_att_val = torch.matmul( self.agg_att(q_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() for idx, num in enumerate(q_len): if num < max_q_len: agg_att_val[idx, :, num:] = -100 agg_att = self.softmax(agg_att_val.view(-1, max_q_len)).view( B, -1, max_q_len) q_weighted_agg = (q_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2) agg_score = self.agg_out( self.agg_out_q(q_weighted_agg) + self.agg_out_c(col_emb)).squeeze() score = (col_num_score, col_score, agg_num_score, agg_score) return score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce): max_x_len = max(x_len) B = len(x_len) h_enc, hidden = run_lstm(self.cond_lstm, x_emb_var, x_len) decoder_hidden = tuple(torch.cat((hid[:2], hid[2:]),dim=2) for hid in hidden) if gt_where is not None: gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where, gen_inp=True) g_s, _ = run_lstm(self.cond_decoder, gt_tok_seq, gt_tok_len, decoder_hidden) h_enc_expand = h_enc.unsqueeze(1) g_s_expand = g_s.unsqueeze(2) cond_score = self.cond_out( self.cond_out_h(h_enc_expand) + self.cond_out_g(g_s_expand) ).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: cond_score[idx, :, num:] = -100 else: h_enc_expand = h_enc.unsqueeze(1) scores = [] choices = [] done_set = set() t = 0 init_inp = np.zeros((B, 1, self.max_tok_num), dtype=np.float32) init_inp[:,0,7] = 1 #Set the <BEG> token if self.gpu: cur_inp = Variable(torch.from_numpy(init_inp).cuda()) else: cur_inp = Variable(torch.from_numpy(init_inp)) cur_h = decoder_hidden while len(done_set) < B and t < 100: g_s, cur_h = self.cond_decoder(cur_inp, cur_h) g_s_expand = g_s.unsqueeze(2) cur_cond_score = self.cond_out(self.cond_out_h(h_enc_expand) + self.cond_out_g(g_s_expand)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cur_cond_score[b, num:] = -100 scores.append(cur_cond_score) if not reinforce: _, ans_tok_var = cur_cond_score.view(B, max_x_len).max(1) ans_tok_var = ans_tok_var.unsqueeze(1) else: ans_tok_var = self.softmax(cur_cond_score).multinomial() choices.append(ans_tok_var) ans_tok = ans_tok_var.data.cpu() if self.gpu: #To one-hot cur_inp = Variable(torch.zeros( B, self.max_tok_num).scatter_(1, ans_tok, 1).cuda()) else: cur_inp = Variable(torch.zeros( B, self.max_tok_num).scatter_(1, ans_tok, 1)) cur_inp = cur_inp.unsqueeze(1) for idx, tok in enumerate(ans_tok.squeeze()): if tok == 1: #Find the <END> token done_set.add(idx) t += 1 cond_score = torch.stack(scores, 1) if reinforce: return cond_score, choices else: return cond_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce): max_x_len = max(x_len) B = len(x_len) if reinforce: raise NotImplementedError('Our model doesn\'t have RL') # Predict the number of conditions # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict condition number. e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_num_name_enc) num_col_att_val = self.cond_num_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): num_col_att_val[idx, num:] = -100 num_col_att = self.softmax(num_col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) cond_num_h1 = self.cond_num_col2hid1(K_num_col).view( B, 4, self.N_h//2).transpose(0, 1).contiguous() cond_num_h2 = self.cond_num_col2hid2(K_num_col).view( B, 4, self.N_h//2).transpose(0, 1).contiguous() h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len, hidden=(cond_num_h1, cond_num_h2)) num_att_val = self.cond_num_att(h_num_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: num_att_val[idx, num:] = -100 num_att = self.softmax(num_att_val) K_cond_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1) cond_num_score = self.cond_num_out(K_cond_num) #Predict the columns of conditions e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_col_name_enc) h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len) if self.use_ca: col_att_val = torch.bmm(e_cond_col, self.cond_col_att(h_col_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: col_att_val[idx, :, num:] = -100 col_att = self.softmax(col_att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2) else: col_att_val = self.cond_col_att(h_col_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: col_att_val[idx, num:] = -100 col_att = self.softmax(col_att_val) K_cond_col = (h_col_enc * col_att_val.unsqueeze(2)).sum(1).unsqueeze(1) cond_col_score = self.cond_col_out(self.cond_col_out_K(K_cond_col) + self.cond_col_out_col(e_cond_col)).squeeze() max_col_num = max(col_num) for b, num in enumerate(col_num): if num < max_col_num: cond_col_score[b, num:] = -100 #Predict the operator of conditions chosen_col_gt = [] if gt_cond is None: cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1) col_scores = cond_col_score.data.cpu().numpy() chosen_col_gt = [list(np.argsort(-col_scores[b])[:cond_nums[b]]) for b in range(len(cond_nums))] else: # print gt_cond chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond] e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_op_name_enc) h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len) col_emb = [] for b in range(B): cur_col_emb = torch.stack([e_cond_col[b, x] for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) if self.use_ca: op_att_val = torch.matmul(self.cond_op_att(h_op_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: op_att_val[idx, :, num:] = -100 op_att = self.softmax(op_att_val.view(B*4, -1)).view(B, 4, -1) K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2) else: op_att_val = self.cond_op_att(h_op_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: op_att_val[idx, num:] = -100 op_att = self.softmax(op_att_val) K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1) cond_op_score = self.cond_op_out(self.cond_op_out_K(K_cond_op) + self.cond_op_out_col(col_emb)).squeeze() #Predict the string of conditions h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len) e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_str_name_enc) col_emb = [] for b in range(B): cur_col_emb = torch.stack([e_cond_col[b, x] for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) if gt_where is not None: gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where) g_str_s_flat, _ = self.cond_str_decoder( gt_tok_seq.view(B*4, -1, self.max_tok_num)) g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h) h_ext = h_str_enc.unsqueeze(1).unsqueeze(1) g_ext = g_str_s.unsqueeze(3) col_ext = col_emb.unsqueeze(2).unsqueeze(2) cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cond_str_score[b, :, :, num:] = -100 else: h_ext = h_str_enc.unsqueeze(1).unsqueeze(1) col_ext = col_emb.unsqueeze(2).unsqueeze(2) scores = [] t = 0 init_inp = np.zeros((B*4, 1, self.max_tok_num), dtype=np.float32) init_inp[:,0,0] = 1 #Set the <BEG> token if self.gpu: cur_inp = Variable(torch.from_numpy(init_inp).cuda()) else: cur_inp = Variable(torch.from_numpy(init_inp)) cur_h = None while t < 50: if cur_h: g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h) else: g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp) g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h) g_ext = g_str_s.unsqueeze(3) cur_cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cur_cond_str_score[b, :, num:] = -100 scores.append(cur_cond_str_score) _, ans_tok_var = cur_cond_str_score.view(B*4, max_x_len).max(1) ans_tok = ans_tok_var.data.cpu() data = torch.zeros(B*4, self.max_tok_num).scatter_( 1, ans_tok.unsqueeze(1), 1) if self.gpu: #To one-hot cur_inp = Variable(data.cuda()) else: cur_inp = Variable(data) cur_inp = cur_inp.unsqueeze(1) t += 1 cond_str_score = torch.stack(scores, 2) for b, num in enumerate(x_len): if num < max_x_len: cond_str_score[b, :, :, num:] = -100 #[B, IDX, T, TOK_NUM] cond_score = (cond_num_score, cond_col_score, cond_op_score, cond_str_score) return cond_score