def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): # Based on number of selections to predict select-column B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) # [bs, col_num, hid] h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) # [bs, seq_len, hid] att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose( 1, 2)) # [bs, col_num, seq_len] for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, :, num:] = -100 att = self.softmax(att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + self.sel_out_col(e_col)).squeeze() max_col_num = max(col_num) for idx, num in enumerate(col_num): if num < max_col_num: sel_score[idx, num:] = -100 return sel_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_len) max_x_len = max(x_len) # Predict the condition relationship part # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict select number e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.where_rela_lstm) col_att_val = self.where_rela_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): col_att_val[idx, num:] = -1000000 num_col_att = self.softmax(col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) h1 = self.col2hid1(K_num_col).view(B, 4, self.N_h//2).transpose(0,1).contiguous() h2 = self.col2hid2(K_num_col).view(B, 4, self.N_h//2).transpose(0,1).contiguous() h_enc, _ = run_lstm(self.where_rela_lstm, x_emb_var, x_len, hidden=(h1, h2)) att_val = self.where_rela_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -1000000 att_val = self.softmax(att_val) where_rela_num = (h_enc * att_val.unsqueeze(2).expand_as(h_enc)).sum(1) where_rela_score = self.where_rela_out(where_rela_num) return where_rela_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_len) max_x_len = max(x_len) # Predict the number of select part # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict select number e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_num_lstm) num_col_att_val = self.sel_num_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): num_col_att_val[idx, num:] = -1000000 num_col_att = self.softmax(num_col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) sel_num_h1 = self.sel_num_col2hid1(K_num_col).view((B, 4, self.N_h//2)).transpose(0,1).contiguous() sel_num_h2 = self.sel_num_col2hid2(K_num_col).view((B, 4, self.N_h//2)).transpose(0,1).contiguous() h_num_enc, _ = run_lstm(self.sel_num_lstm, x_emb_var, x_len,hidden=(sel_num_h1, sel_num_h2)) num_att_val = self.sel_num_att(h_num_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: num_att_val[idx, num:] = -1000000 num_att = self.softmax(num_att_val) K_sel_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1) sel_num_score = self.sel_num_out(K_sel_num) return sel_num_score
def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None, col_len=None, col_num=None, gt_sel=None, gt_sel_num=None): B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc) h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len) col_emb = [] for b in range(B): cur_col_emb = torch.stack([e_col[b, x] for x in gt_sel[b]] + [e_col[b, 0]] * (4 - len(gt_sel[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) att_val = torch.matmul( self.agg_att(h_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() # .transpose(1,2)) for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val.view(B * 4, -1)).view(B, 4, -1) K_agg = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) agg_score = self.agg_out( self.agg_out_K(K_agg) + self.col_out_col(col_emb)).squeeze() return agg_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce): max_x_len = max(x_len) B = len(x_len) if reinforce: raise NotImplementedError('Our model doesn\'t have RL') # Predict the number of conditions # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict condition number. e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_num_name_enc) num_col_att_val = self.cond_num_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): num_col_att_val[idx, num:] = -100 num_col_att = self.softmax(num_col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) cond_num_h1 = self.cond_num_col2hid1(K_num_col).view( B, 4, self.N_h // 2).transpose(0, 1).contiguous() cond_num_h2 = self.cond_num_col2hid2(K_num_col).view( B, 4, self.N_h // 2).transpose(0, 1).contiguous() h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len, hidden=(cond_num_h1, cond_num_h2)) num_att_val = self.cond_num_att(h_num_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: num_att_val[idx, num:] = -100 num_att = self.softmax(num_att_val) K_cond_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1) cond_num_score = self.cond_num_out(K_cond_num) #Predict the columns of conditions e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_col_name_enc) h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len) if self.use_ca: col_att_val = torch.bmm( e_cond_col, self.cond_col_att(h_col_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: col_att_val[idx, :, num:] = -100 col_att = self.softmax(col_att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2) else: col_att_val = self.cond_col_att(h_col_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: col_att_val[idx, num:] = -100 col_att = self.softmax(col_att_val) K_cond_col = (h_col_enc * col_att_val.unsqueeze(2)).sum(1).unsqueeze(1) cond_col_score = self.cond_col_out( self.cond_col_out_K(K_cond_col) + self.cond_col_out_col(e_cond_col)).squeeze() max_col_num = max(col_num) for b, num in enumerate(col_num): if num < max_col_num: cond_col_score[b, num:] = -100 #Predict the operator of conditions chosen_col_gt = [] if gt_cond is None: cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1) col_scores = cond_col_score.data.cpu().numpy() chosen_col_gt = [ list(np.argsort(-col_scores[b])[:cond_nums[b]]) for b in range(len(cond_nums)) ] else: # print gt_cond chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond] e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_op_name_enc) h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len) col_emb = [] for b in range(B): cur_col_emb = torch.stack( [e_cond_col[b, x] for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) if self.use_ca: op_att_val = torch.matmul( self.cond_op_att(h_op_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: op_att_val[idx, :, num:] = -100 op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1) K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2) else: op_att_val = self.cond_op_att(h_op_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: op_att_val[idx, num:] = -100 op_att = self.softmax(op_att_val) K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1) cond_op_score = self.cond_op_out( self.cond_op_out_K(K_cond_op) + self.cond_op_out_col(col_emb)).squeeze() #Predict the string of conditions h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len) e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_str_name_enc) col_emb = [] for b in range(B): cur_col_emb = torch.stack( [e_cond_col[b, x] for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) if gt_where is not None: gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where) g_str_s_flat, _ = self.cond_str_decoder( gt_tok_seq.view(B * 4, -1, self.max_tok_num)) g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h) h_ext = h_str_enc.unsqueeze(1).unsqueeze(1) g_ext = g_str_s.unsqueeze(3) col_ext = col_emb.unsqueeze(2).unsqueeze(2) cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cond_str_score[b, :, :, num:] = -100 else: h_ext = h_str_enc.unsqueeze(1).unsqueeze(1) col_ext = col_emb.unsqueeze(2).unsqueeze(2) scores = [] t = 0 init_inp = np.zeros((B * 4, 1, self.max_tok_num), dtype=np.float32) init_inp[:, 0, 0] = 1 #Set the <BEG> token if self.gpu: cur_inp = Variable(torch.from_numpy(init_inp).cuda()) else: cur_inp = Variable(torch.from_numpy(init_inp)) cur_h = None while t < 50: if cur_h: g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h) else: g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp) g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h) g_ext = g_str_s.unsqueeze(3) cur_cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cur_cond_str_score[b, :, num:] = -100 scores.append(cur_cond_str_score) _, ans_tok_var = cur_cond_str_score.view(B * 4, max_x_len).max(1) ans_tok = ans_tok_var.data.cpu() data = torch.zeros(B * 4, self.max_tok_num).scatter_( 1, ans_tok.unsqueeze(1), 1) if self.gpu: #To one-hot cur_inp = Variable(data.cuda()) else: cur_inp = Variable(data) cur_inp = cur_inp.unsqueeze(1) t += 1 cond_str_score = torch.stack(scores, 2) for b, num in enumerate(x_len): if num < max_x_len: cond_str_score[b, :, :, num:] = -100 #[B, IDX, T, TOK_NUM] cond_score = (cond_num_score, cond_col_score, cond_op_score, cond_str_score) return cond_score