def forward(self, x_emb_var, x_len, col_inp_var=None, col_len=None, col_name_len=None, x_type_emb_var=None, gt_sel=None, sel_cond_score=None): B = len(x_emb_var) max_x_len = max(x_len) x_emb_concat = torch.cat((x_emb_var, x_type_emb_var), 2) chosen_sel_col_gt = [] if gt_sel is None: if sel_cond_score is None: raise Exception( """In the test mode, sel_num_score and sel_col_score should be passed in order to predict aggregation!""" ) sel_num_score, _, sel_score, _ = sel_cond_score sel_nums = np.argmax(sel_num_score.data.cpu().numpy(), axis=1) sel_col_scores = sel_score.data.cpu().numpy() chosen_sel_col_gt = [ list(np.argsort(-sel_col_scores[b])[:sel_nums[b]]) for b in range(len(sel_nums)) ] else: chosen_sel_col_gt = [[x for x in one_gt_sel] for one_gt_sel in gt_sel] h_enc, _ = run_lstm(self.agg_lstm, x_emb_concat, x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc) #e_col, _ = run_lstm(self.agg_col_name_enc, col_inp_var, col_len) sel_col_emb = [] for b in range(B): cur_sel_col_emb = torch.stack( [e_col[b, x] for x in chosen_sel_col_gt[b]] + [e_col[b, 0]] * (4 - len(chosen_sel_col_gt[b])) ) # Pad the columns to maximum (4) sel_col_emb.append(cur_sel_col_emb) sel_col_emb = torch.stack(sel_col_emb) agg_att_val = torch.matmul( self.agg_att(h_enc).unsqueeze(1), sel_col_emb.unsqueeze(3)).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: agg_att_val[idx, :, num:] = -100 agg_att = self.softmax(agg_att_val.view(B * 4, -1)).view(B, 4, -1) K_agg = (h_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2) agg_score = self.agg_out( self.agg_out_K(K_agg) + self.col_out_col(sel_col_emb)).squeeze() return agg_score
def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, gt_col): max_q_len = max(q_len) max_hs_len = max(hs_len) max_col_len = max(col_len) B = len(q_len) q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) hs_enc, _ = run_lstm(self.hs_lstm, hs_emb_var, hs_len) col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) col_emb = [] for b in range(B): col_emb.append(col_enc[b, gt_col[b]]) col_emb = torch.stack(col_emb) # Predict agg number att_val_qc_num = torch.bmm(col_emb.unsqueeze(1), self.q_num_att(q_enc).transpose(1, 2)).view(B, -1) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc_num[idx, num:] = -100 att_prob_qc_num = self.softmax(att_val_qc_num) q_weighted_num = (q_enc * att_prob_qc_num.unsqueeze(2)).sum(1) # Same as the above, compute SQL history embedding weighted by column attentions att_val_hc_num = torch.bmm(col_emb.unsqueeze(1), self.hs_num_att(hs_enc).transpose(1, 2)).view(B, -1) for idx, num in enumerate(hs_len): if num < max_hs_len: att_val_hc_num[idx, num:] = -100 att_prob_hc_num = self.softmax(att_val_hc_num) hs_weighted_num = (hs_enc * att_prob_hc_num.unsqueeze(2)).sum(1) # agg_num_score: (B, 4) agg_num_score = self.agg_num_out(self.agg_num_out_q(q_weighted_num) + int(self.use_hs)* self.agg_num_out_hs(hs_weighted_num) + self.agg_num_out_c(col_emb)) / self.T1 # Predict aggregators att_val_qc = torch.bmm(col_emb.unsqueeze(1), self.q_att(q_enc).transpose(1, 2)).view(B, -1) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc[idx, num:] = -100 att_prob_qc = self.softmax(att_val_qc) q_weighted = (q_enc * att_prob_qc.unsqueeze(2)).sum(1) # Same as the above, compute SQL history embedding weighted by column attentions att_val_hc = torch.bmm(col_emb.unsqueeze(1), self.hs_att(hs_enc).transpose(1, 2)).view(B, -1) for idx, num in enumerate(hs_len): if num < max_hs_len: att_val_hc[idx, num:] = -100 att_prob_hc = self.softmax(att_val_hc) hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1) # agg_score: (B, 5) agg_score = self.agg_out(self.agg_out_q(q_weighted) + int(self.use_hs)* self.agg_out_hs(hs_weighted) + self.agg_out_c(col_emb)) / self.T2 score = (agg_num_score, agg_score) return score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_emb_var) max_x_len = max(x_len) # compute the E_col e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) if self.use_ca: # compute the hidden states output of lstm corresponding to question h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) # compute v # using multiplicative attention att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) h_ext = h_enc.unsqueeze(1).unsqueeze(1) col_ext = e_col.unsqueeze(2).unsqueeze(2) # additive attention att_val = self.sel_att_out( self.sel_W1(h_ext) + self.sel_W2(col_ext)).squeeze() # print() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, :, num:] = -100 # L-dimension attention weight att = self.softmax(att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) # compute E_Q|col K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) else: # normal method without column attention h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) att_val = self.sel_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val) K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) K_sel_expand = K_sel.unsqueeze(1) # compute P_selcol(i|Q) sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \ self.sel_out_col(e_col) ).squeeze() max_col_num = max(col_num) for idx, num in enumerate(col_num): if num < max_col_num: sel_score[idx, num:] = -100 return sel_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, dropout_rate=0.): B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc, dropout_rate=dropout_rate) if self.use_ca: h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len, dropout_rate=dropout_rate) att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, :, num:] = -100 att = self.softmax(att_val.view((-1, max_x_len))).view( B, -1, max_x_len) K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) else: h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len, dropout_rate=dropout_rate) att_val = self.sel_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val) K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) K_sel_expand=K_sel.unsqueeze(1) if dropout_rate > 0.: # K_sel_expand: [batch_size, col_size, hid_size] K_sel_expand_mask = torch.FloatTensor(K_sel_expand.size()[-1]).view(1, 1, -1)\ .fill_(1. - dropout_rate).bernoulli().div_(1. - dropout_rate) if K_sel_expand.is_cuda: K_sel_expand_mask = K_sel_expand_mask.cuda() K_sel_expand.data = K_sel_expand.data * K_sel_expand_mask sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \ self.sel_out_col(e_col) ).squeeze() / self.T max_col_num = max(col_num) for idx, num in enumerate(col_num): if num < max_col_num: sel_score[idx, num:] = -100 return sel_score
def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None, col_len=None, col_num=None, gt_sel=None): B = len(x_emb_var) max_x_len = max(x_len) h_enc, _ = run_lstm( self.agg_lstm, x_emb_var, x_len) # this is the hidden state for each token in the question if self.use_ca: e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc) # the rest is the decoder portion correct?? chosen_sel_idx = torch.LongTensor(gt_sel) aux_range = torch.LongTensor(range(len(gt_sel))) if x_emb_var.is_cuda: chosen_sel_idx = chosen_sel_idx.cuda() aux_range = aux_range.cuda() chosen_e_col = e_col[aux_range, chosen_sel_idx] att_val = torch.bmm(self.agg_att(h_enc), chosen_e_col.unsqueeze(2)).squeeze() else: att_val = self.agg_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[ idx, num:] = -100 # make sure the padded numbers have softmax ~= 0 att = self.softmax(att_val) K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) agg_item_score = self.agg_out(K_agg) return (agg_num_score, agg_item_score)
def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None, col_len=None, col_num=None, gt_sel=None): B = len(x_emb_var) max_x_len = max(x_len) # compute the hidden states output of lstm corresponding to question h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len) if self.use_ca: # compute the E_col e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc) chosen_sel_idx = torch.LongTensor(gt_sel) aux_range = torch.LongTensor(range(len(gt_sel))) if x_emb_var.is_cuda: chosen_sel_idx = chosen_sel_idx.cuda() aux_range = aux_range.cuda() # chosen_e_col is v chosen_e_col = e_col[aux_range, chosen_sel_idx] att_val = torch.bmm(self.agg_att(h_enc), chosen_e_col.unsqueeze(2)).squeeze() else: att_val = self.agg_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 # att is attention weight att = self.softmax(att_val) # K_agg is E_Q|col K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) agg_score = self.agg_out(K_agg) return agg_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): ''' Based on number of selections to predict select-column input: x_emb_var: embedding of each question col_inp_var: embedding of each header col_name_len: length of each header col_len: number of headers in each table, array type col_num: number of headers in each table, list type ''' B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) # [bs, col_num, hid] h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) # [bs, seq_len, hid] h_enc is output, used for attention object # e_col: [batch_size(16), max_num_of_col_in_train_tab, hidden_size(100)] # h_enc: [batch_size(16), max_len_of_question, hidden_size(100)] # att_val: [bs[16], max_num_of_col_in_train_tab, max_len_of_question] att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) # [bs, col_num, seq_len] for idx, num in enumerate(x_len): if num < max_x_len: # column hidden status will have new value when attention on the question,while the some part of # question is of no use on attention calculate. att_val[idx, :, num:] = -100 att = self.softmax(att_val.view((-1, max_x_len))).view(B, -1, max_x_len) K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + self.sel_out_col(e_col) ).squeeze() max_col_num = max(col_num) for idx, num in enumerate(col_num): if num < max_col_num: sel_score[idx, num:] = -100 return sel_score
def forward(self, q_emb_var, q_len, col_emb_var, col_len, col_num, col_name_len, gt_cond): max_q_len = max(q_len) max_col_len = max(col_len) B = len(q_len) q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) # Predict column number: 0-4 # att_val_qc_num: (B, max_col_len, max_q_len) att_val_qc_num = torch.bmm(col_enc, self.q_num_att(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: att_val_qc_num[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc_num[idx, :, num:] = -100 att_prob_qc_num = self.softmax(att_val_qc_num.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted_num: (B, hid_dim) q_weighted_num = (q_enc.unsqueeze(1) * att_prob_qc_num.unsqueeze(3)).sum(2).sum(1) # self.col_num_out: (B, 4) col_num_score = self.col_num_out(self.col_num_out_q(q_weighted_num)) # Predict columns. att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2)) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc[idx, :, num:] = -100 att_prob_qc = self.softmax(att_val_qc.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted: (B, max_col_len, hid_dim) q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2) # Compute prediction scores # self.col_out.squeeze(): (B, max_col_len) col_score = self.col_out( self.col_out_q(q_weighted) + self.col_out_c(col_enc)).squeeze() for idx, num in enumerate(col_len): if num < max_col_len: col_score[idx, num:] = -100 # get select columns for op prediction chosen_col_gt = [] if gt_cond is None: cond_nums = np.argmax(col_num_score.data.cpu().numpy(), axis=1) col_scores = col_score.data.cpu().numpy() chosen_col_gt = [ list(np.argsort(-col_scores[b])[:cond_nums[b]]) for b in range(len(cond_nums)) ] else: chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond] col_emb = [] for b in range(B): cur_col_emb = torch.stack( [col_enc[b, x] for x in chosen_col_gt[b]] + [col_enc[b, 0]] * (5 - len(chosen_col_gt[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) # Predict op op_att_val = torch.matmul( self.op_att(q_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() for idx, num in enumerate(q_len): if num < max_q_len: op_att_val[idx, :, num:] = -100 op_att = self.softmax(op_att_val.view(-1, max_q_len)).view( B, -1, max_q_len) q_weighted_op = (q_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2) op_score = self.op_out( self.op_out_q(q_weighted_op) + self.op_out_c(col_emb)).squeeze() score = (col_num_score, col_score, op_score) return score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, x_type_emb_var, gt_sel): max_x_len = max(x_len) max_col_len = max(col_len) B = len(x_len) print("x_size:{},x_type_size:{}".format(x_emb_var.size(), x_type_emb_var.size())) x_emb_concat = torch.cat((x_emb_var, x_type_emb_var), 2) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.selcond_name_enc) # e_col, _ = run_lstm(self.selcond_name_enc, col_inp_var, col_len) h_enc, _ = run_lstm(self.selcond_lstm, x_emb_concat, x_len) # Predict the number of selected columns # att_sel_num_type_val:(B, max_col_len, max_x_len) att_sel_num_type_val = torch.bmm( e_col, self.sel_num_type_att(h_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: att_sel_num_type_val[idx, num:, :] = -100 for idx, num in enumerate(x_len): if num < max_x_len: att_sel_num_type_val[idx, :, num:] = -100 # att_sel_num_type: (B, max_col_len, max_x_len) att_sel_num_type = self.softmax( att_sel_num_type_val.view((-1, max_x_len))).view(B, -1, max_x_len) # h_enc.unsqueeze(1): (B, 1, max_x_len, hid_dim) # att_sel_num_type.unsqueeze(3): (B, max_col_len, max_x_len, 1) # K_num_type (B, max_col_len, hid_dim) K_sel_num_type = (h_enc.unsqueeze(1) * att_sel_num_type.unsqueeze(3)).sum(2).sum(1) # K_sel_num: (B, hid_dim) # K_sel_num_type (B, hid_dim) sel_num_score = self.sel_num_out(self.ty_sel_num_out(K_sel_num_type)) #Predict the selection condition #att_val: (B, max_col_len, max_x_len) sel_att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: sel_att_val[idx, :, num:] = -100 sel_att = self.softmax(sel_att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) #print(sel_att.size()) #K_sel_expand -> (B, max_number of col names in batch tables, hid_dim) K_sel_expand = (h_enc.unsqueeze(1) * sel_att.unsqueeze(3)).sum(2) sel_score = self.sel_out(self.sel_out_K(K_sel_expand) + \ self.sel_out_col(e_col)).squeeze() for idx, num in enumerate(col_len): if num < max_col_len: sel_score[idx, num:] = -100 # Predict the number of conditions #att_cond_num_type_val:(B, max_col_len, max_x_len) att_cond_num_type_val = torch.bmm( e_col, self.cond_num_type_att(h_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: att_cond_num_type_val[idx, num:, :] = -100 for idx, num in enumerate(x_len): if num < max_x_len: att_cond_num_type_val[idx, :, num:] = -100 #att_cond_num_type: (B, max_col_len, max_x_len) att_cond_num_type = self.softmax( att_cond_num_type_val.view( (-1, max_x_len))).view(B, -1, max_x_len) #h_enc.unsqueeze(1): (B, 1, max_x_len, hid_dim) #att_cond_num_type.unsqueeze(3): (B, max_col_len, max_x_len, 1) #K_num_type (B, max_col_len, hid_dim) K_cond_num_type = (h_enc.unsqueeze(1) * att_cond_num_type.unsqueeze(3)).sum(2).sum(1) #K_cond_num: (B, hid_dim) #K_cond_num_type (B, hid_dim) cond_num_score = self.cond_num_out( self.ty_cond_num_out(K_cond_num_type)) #Predict the columns of conditions if gt_sel is None: num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1) + 1 sel = sel_score.data.cpu().numpy() # gt_sel = np.argmax(sel_score.data.cpu().numpy(), axis=1) chosen_sel_col_gt = [ list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num)) ] else: chosen_sel_col_gt = [[x for x in one_gt_sel] for one_gt_sel in gt_sel] sel_col_emb = [] for b in range(B): cur_sel_col_emb = torch.stack( [e_col[b, x] for x in chosen_sel_col_gt[b]] + [e_col[b, 0]] * (max_col_len - len(chosen_sel_col_gt[b])) ) # Pad the columns to maximum (4) sel_col_emb.append(cur_sel_col_emb) sel_col_emb = torch.stack(sel_col_emb) # chosen_sel_idx = torch.LongTensor(gt_sel) #aux_range (B) (0,1,...) # aux_range = torch.LongTensor(range(len(gt_sel))) # if x_emb_var.is_cuda: # chosen_sel_idx = chosen_sel_idx.cuda() # aux_range = aux_range.cuda() #chosen_e_col: (B, hid_dim) # chosen_e_col = e_col[aux_range, chosen_sel_idx] #chosen_e_col.unsqueeze(2): (B, hid_dim, 1) #self.col_att(h_enc): (B, max_x_len, hid_dim) #att_sel_val: (B, max_x_len) # K_agg = (h_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2) # # agg_score = self.agg_out(self.agg_out_K(K_agg) + self.col_out_col(sel_col_emb)).squeeze() # # # att_sel_val = torch.bmm(self.col_att(h_enc), chosen_e_col.unsqueeze(2)).squeeze() att_sel_val = torch.matmul( self.col_att(h_enc).unsqueeze(1), sel_col_emb.unsqueeze(3)).squeeze() col_att_val = torch.bmm(e_col, self.cond_col_att(h_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: col_att_val[idx, :, num:] = -100 att_sel_val[idx, num:] = -100 sel_att = self.softmax(att_sel_val.view(B * max_col_len, -1)).view(B, max_col_len, -1) #K_sel_agg = (h_enc * sel_att.unsqueeze(2).expand_as(h_enc)).sum(1) K_sel_agg = (h_enc.unsqueeze(1) * sel_att.unsqueeze(3)).sum(2) col_att = self.softmax(col_att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_cond_col = (h_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2) cond_col_score = self.cond_col_out( self.cond_col_out_K(K_cond_col) + self.cond_col_out_col(e_col) + self.cond_col_out_sel(K_sel_agg.expand_as(K_cond_col))).squeeze() for b, num in enumerate(col_len): if num < max_col_len: cond_col_score[b, num:] = -100 sel_cond_score = (sel_num_score, cond_num_score, sel_score, cond_col_score) return sel_cond_score
def forward(self, q_emb_var, q_len, col_emb_var, col_len, col_num, col_name_len): max_q_len = max(q_len) max_col_len = max(col_len) B = len(q_len) q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) # Predict number gby_num_att = torch.bmm(col_enc, self.gby_num_h(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: gby_num_att[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: gby_num_att[idx, :, num:] = -100 gby_num_att_val = self.softmax(gby_num_att.view((-1, max_q_len))).view(B, -1, max_q_len) gby_num_K = (q_enc.unsqueeze(1) * gby_num_att_val.unsqueeze(3)).sum(2).sum(1) ody_num_score = self.gby_num_out(self.gby_num_l(gby_num_K)) # Predict columns. att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2)) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc[idx, :, num:] = -100 att_prob_qc = self.softmax(att_val_qc.view((-1, max_q_len))).view(B, -1, max_q_len) # q_weighted: (B, max_col_len, hid_dim) q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2) # Compute prediction scores # self.col_out.squeeze(): (B, max_col_len) col_score = self.col_out(self.col_out_q(q_weighted) + self.col_out_c(col_enc)).squeeze() for idx, num in enumerate(col_len): if num < max_col_len: col_score[idx, num:] = -100 # Predict aggregation agg_att_val = torch.bmm(col_enc, self.agg_att(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: agg_att_val[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: agg_att_val[idx, :, num:] = -100 agg_att = self.softmax(agg_att_val.view((-1, max_q_len))).view(B, -1, max_q_len) # q_weighted_num: (B, hid_dim) q_weighted_agg = (q_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2).sum(1) # self.col_num_out: (B, 4) agg_score = self.agg_out(self.agg_out_q(q_weighted_agg)) # Predict desc asc limit dat_att_val = torch.bmm(col_enc, self.dat_att(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: dat_att_val[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: dat_att_val[idx, :, num:] = -100 dat_att = self.softmax(dat_att_val.view((-1, max_q_len))).view(B, -1, max_q_len) # q_weighted_num: (B, hid_dim) q_weighted_dat = (q_enc.unsqueeze(1) * dat_att.unsqueeze(3)).sum(2).sum(1) # self.col_num_out: (B, 4) dat_score = self.dat_out(self.dat_out_q(q_weighted_dat)) score = (ody_num_score, col_score, agg_score, dat_score) return score
def forward(self, perm, st, ed, q_emb_var, q_len, col_emb_var, col_len, col_num, col_name_len, q_seq, col_seq, emb_layer, train=True): max_q_len = max(q_len) max_col_len = max(col_len) B = len(q_len) q_enc, _ = run_lstm(self.q_lstm, q_emb_var, q_len) col_enc, _ = col_name_encode(col_emb_var, col_name_len, col_len, self.col_lstm) # Predict number gby_num_att = torch.bmm(col_enc, self.gby_num_h(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: gby_num_att[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: gby_num_att[idx, :, num:] = -100 gby_num_att_val = self.softmax(gby_num_att.view( (-1, max_q_len))).view(B, -1, max_q_len) gby_num_K = (q_enc.unsqueeze(1) * gby_num_att_val.unsqueeze(3)).sum(2).sum(1) ody_num_score = self.gby_num_out(self.gby_num_l(gby_num_K)) # Predict columns. att_val_qc = torch.bmm(col_enc, self.q_att(q_enc).transpose(1, 2)) for idx, num in enumerate(q_len): if num < max_q_len: att_val_qc[idx, :, num:] = -100 att_prob_qc = self.softmax(att_val_qc.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted: (B, max_col_len, hid_dim) q_weighted = (q_enc.unsqueeze(1) * att_prob_qc.unsqueeze(3)).sum(2) # Compute prediction scores # self.col_out.squeeze(): (B, max_col_len) col_score = self.col_out( self.col_out_q(q_weighted) + self.col_out_c(col_enc)).squeeze() for idx, num in enumerate(col_len): if num < max_col_len: col_score[idx, num:] = -100 # Predict aggregation agg_att_val = torch.bmm(col_enc, self.agg_att(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: agg_att_val[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: agg_att_val[idx, :, num:] = -100 agg_att = self.softmax(agg_att_val.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted_num: (B, hid_dim) q_weighted_agg = (q_enc.unsqueeze(1) * agg_att.unsqueeze(3)).sum(2).sum(1) # self.col_num_out: (B, 4) agg_score = self.agg_out(self.agg_out_q(q_weighted_agg)) # Predict desc asc limit dat_att_val = torch.bmm(col_enc, self.dat_att(q_enc).transpose(1, 2)) for idx, num in enumerate(col_len): if num < max_col_len: dat_att_val[idx, num:, :] = -100 for idx, num in enumerate(q_len): if num < max_q_len: dat_att_val[idx, :, num:] = -100 dat_att = self.softmax(dat_att_val.view( (-1, max_q_len))).view(B, -1, max_q_len) # q_weighted_num: (B, hid_dim) q_weighted_dat = (q_enc.unsqueeze(1) * dat_att.unsqueeze(3)).sum(2).sum(1) # self.col_num_out: (B, 4) col_scores = col_score.data.cpu().numpy() chosen_col_gt = [np.argmax(col_scores[b]) for b in range(B)] assert B == ed - st dirc_vecs = torch.zeros([B, 50]) zero_feats = torch.zeros([50]) if self.gpu: dirc_vecs = dirc_vecs.cuda() zero_feats = zero_feats.cuda() dirc_vecs = Variable(dirc_vecs, requires_grad=False) for b in range(st, ed): idx = perm[b] gt_col = chosen_col_gt[b - st] dirc_feat = emb_layer.get_direction_feature( max_q_len, idx, gt_col, train) if self.feats_format == 'direct': # [max_len] (-1/0/1) mask = (att_prob_qc[b - st, gt_col] * dirc_feat[0]) mask_i = mask.cpu().data.numpy()[0] if mask_i > 0: dirc_vec = dirc_feat[1] elif mask_i < 0: dirc_vec = dirc_feat[2] else: dirc_vec = zero_feats dirc_vec = Variable(dirc_vec, requires_grad=False) else: # [max_len, len(feats)] dirc_vec = torch.matmul( att_prob_qc[b - st, gt_col].unsqueeze(0), dirc_feat).squeeze() dirc_vecs[b - st] = dirc_vec dat_score = self.dat_out(self.dat_out_q(q_weighted_dat)) + \ self.dat_out_dirc_out(self.dat_out_dirc(dirc_vecs)) score = (ody_num_score, col_score, agg_score, dat_score) return score