def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) if self.use_ca: h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, :, num:] = -100 att = self.softmax(att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) else: h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) att_val = self.sel_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val) K_sel = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) K_sel_expand = K_sel.unsqueeze(1) sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + \ self.sel_out_col(e_col) ).squeeze() max_col_num = max(col_num) for idx, num in enumerate(col_num): if num < max_col_num: sel_score[idx, num:] = -100 return sel_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): # Based on number of grpections to predict grpect-column B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.grp_col_name_enc) # [bs, col_num, hid] h_enc, _ = run_lstm(self.grp_lstm, x_emb_var, x_len) # [bs, seq_len, hid] att_val = torch.bmm(e_col, self.grp_att(h_enc).transpose( 1, 2)) # [bs, col_num, seq_len] for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, :, num:] = -100 att = self.softmax(att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_grp_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) grp_score = self.grp_out( self.grp_out_K(K_grp_expand) + self.grp_out_col(e_col)).squeeze() max_col_num = max(col_num) for idx, num in enumerate(col_num): if num < max_col_num: grp_score[idx, num:] = -100 return grp_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_len) max_x_len = max(x_len) # Predict the condition relationship part # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict select number e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.where_rela_lstm) col_att_val = self.where_rela_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): col_att_val[idx, num:] = -1000000 num_col_att = self.softmax(col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) h1 = self.col2hid1(K_num_col).view(B, 4, self.N_h // 2).transpose( 0, 1).contiguous() h2 = self.col2hid2(K_num_col).view(B, 4, self.N_h // 2).transpose( 0, 1).contiguous() h_enc, _ = run_lstm(self.where_rela_lstm, x_emb_var, x_len, hidden=(h1, h2)) att_val = self.where_rela_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -1000000 att_val = self.softmax(att_val) where_rela_num = (h_enc * att_val.unsqueeze(2).expand_as(h_enc)).sum(1) where_rela_score = self.where_rela_out(where_rela_num) return where_rela_score
def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None, col_len=None, col_num=None, gt_sel=None): B = len(x_emb_var) max_x_len = max(x_len) h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len) if self.use_ca: e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc) chosen_sel_idx = torch.LongTensor(gt_sel) aux_range = torch.LongTensor(range(len(gt_sel))) if x_emb_var.is_cuda: chosen_sel_idx = chosen_sel_idx.cuda() aux_range = aux_range.cuda() chosen_e_col = e_col[aux_range, chosen_sel_idx] att_val = torch.bmm(self.agg_att(h_enc), chosen_e_col.unsqueeze(2)).squeeze() else: att_val = self.agg_att(h_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val) K_agg = (h_enc * att.unsqueeze(2).expand_as(h_enc)).sum(1) agg_score = self.agg_out(K_agg) return agg_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_len) max_x_len = max(x_len) # Predict the number of select part # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict select number e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_num_lstm) num_col_att_val = self.sel_num_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): num_col_att_val[idx, num:] = -1000000 num_col_att = self.softmax(num_col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) sel_num_h1 = self.sel_num_col2hid1(K_num_col).view((B, 4, self.N_h//2)).transpose(0,1).contiguous() sel_num_h2 = self.sel_num_col2hid2(K_num_col).view((B, 4, self.N_h//2)).transpose(0,1).contiguous() h_num_enc, _ = run_lstm(self.sel_num_lstm, x_emb_var, x_len,hidden=(sel_num_h1, sel_num_h2)) num_att_val = self.sel_num_att(h_num_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: num_att_val[idx, num:] = -1000000 num_att = self.softmax(num_att_val) K_sel_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1) sel_num_score = self.sel_num_out(K_sel_num) return sel_num_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num): B = len(x_len) max_x_len = max(x_len) # Predict the number of select part # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict select number # 过LSTM获取列名的embedding e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_num_lstm) # 降低一个维度 num_col_att_val = self.sel_num_col_att(e_num_col).squeeze() # 把序列长度小于最长的置为一个很大的负数,softmax的值会很小 for idx, num in enumerate(col_num): if num < max(col_num): num_col_att_val[idx, num:] = -1000000 num_col_att = self.softmax(num_col_att_val) # 对应位置置为0 K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) # 过了一个全连接层然后转变了下维度,把attention后的结果作为新的state添加到问题的bilstm中作为hidden state sel_num_h1 = self.sel_num_col2hid1(K_num_col).view( (B, 4, self.N_h // 2)).transpose(0, 1).contiguous() sel_num_h2 = self.sel_num_col2hid2(K_num_col).view( (B, 4, self.N_h // 2)).transpose(0, 1).contiguous() # [bs, max_x_len, N_h] h_num_enc, _ = run_lstm(self.sel_num_lstm, x_emb_var, x_len, hidden=(sel_num_h1, sel_num_h2)) # [bs, max_x_len] num_att_val = self.sel_num_att(h_num_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: num_att_val[idx, num:] = -1000000 num_att = self.softmax(num_att_val) K_sel_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1) sel_num_score = self.sel_num_out(K_sel_num) return sel_num_score
def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None, col_len=None, col_num=None, gt_sel=None, gt_sel_num=None): B = len(x_emb_var) max_x_len = max(x_len) e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc) h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len) col_emb = [] for b in range(B): cur_col_emb = torch.stack([e_col[b, x] for x in gt_sel[b]] + [e_col[b, 0]] * (4 - len(gt_sel[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) #tmp=self.agg_att(h_enc).unsqueeze(1) att_val = torch.matmul( self.agg_att(h_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() # .transpose(1,2)) for idx, num in enumerate(x_len): if num < max_x_len: att_val[idx, num:] = -100 att = self.softmax(att_val.view(B * 4, -1)).view(B, 4, -1) K_agg = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2) agg_score = self.agg_out( self.agg_out_K(K_agg) + self.col_out_col(col_emb)).squeeze() return agg_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce): max_x_len = max(x_len) B = len(x_len) if reinforce: raise NotImplementedError('Our model doesn\'t have RL') # Predict the number of conditions # First use column embeddings to calculate the initial hidden unit # Then run the LSTM and predict condition number. e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_num_name_enc) num_col_att_val = self.cond_num_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): num_col_att_val[idx, num:] = -100 num_col_att = self.softmax(num_col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) cond_num_h1 = self.cond_num_col2hid1(K_num_col).view( B, 4, self.N_h // 2).transpose(0, 1).contiguous() cond_num_h2 = self.cond_num_col2hid2(K_num_col).view( B, 4, self.N_h // 2).transpose(0, 1).contiguous() h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len, hidden=(cond_num_h1, cond_num_h2)) num_att_val = self.cond_num_att(h_num_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: num_att_val[idx, num:] = -100 num_att = self.softmax(num_att_val) K_cond_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1) cond_num_score = self.cond_num_out(K_cond_num) #Predict the columns of conditions e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_col_name_enc) h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len) if self.use_ca: col_att_val = torch.bmm( e_cond_col, self.cond_col_att(h_col_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: col_att_val[idx, :, num:] = -100 col_att = self.softmax(col_att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2) else: col_att_val = self.cond_col_att(h_col_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: col_att_val[idx, num:] = -100 col_att = self.softmax(col_att_val) K_cond_col = (h_col_enc * col_att_val.unsqueeze(2)).sum(1).unsqueeze(1) cond_col_score = self.cond_col_out( self.cond_col_out_K(K_cond_col) + self.cond_col_out_col(e_cond_col)).squeeze() max_col_num = max(col_num) for b, num in enumerate(col_num): if num < max_col_num: cond_col_score[b, num:] = -100 #Predict the operator of conditions chosen_col_gt = [] if gt_cond is None: cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1) col_scores = cond_col_score.data.cpu().numpy() chosen_col_gt = [ list(np.argsort(-col_scores[b])[:cond_nums[b]]) for b in range(len(cond_nums)) ] else: # print gt_cond chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond] e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_op_name_enc) h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len) col_emb = [] for b in range(B): cur_col_emb = torch.stack( [e_cond_col[b, x] for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) if self.use_ca: op_att_val = torch.matmul( self.cond_op_att(h_op_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: op_att_val[idx, :, num:] = -100 op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1) K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2) else: op_att_val = self.cond_op_att(h_op_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: op_att_val[idx, num:] = -100 op_att = self.softmax(op_att_val) K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1) cond_op_score = self.cond_op_out( self.cond_op_out_K(K_cond_op) + self.cond_op_out_col(col_emb)).squeeze() #Predict the string of conditions h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len) e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_str_name_enc) col_emb = [] for b in range(B): cur_col_emb = torch.stack( [e_cond_col[b, x] for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) if gt_where is not None: gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where) g_str_s_flat, _ = self.cond_str_decoder( gt_tok_seq.view(B * 4, -1, self.max_tok_num)) g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h) h_ext = h_str_enc.unsqueeze(1).unsqueeze( 1) #question的encoding matrix g_ext = g_str_s.unsqueeze(3) #value string encoding col_ext = col_emb.unsqueeze(2).unsqueeze(2) #col encoding cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cond_str_score[b, :, :, num:] = -100 else: h_ext = h_str_enc.unsqueeze(1).unsqueeze(1) col_ext = col_emb.unsqueeze(2).unsqueeze(2) scores = [] t = 0 init_inp = np.zeros((B * 4, 1, self.max_tok_num), dtype=np.float32) init_inp[:, 0, 0] = 1 #Set the <BEG> token if self.gpu: cur_inp = Variable(torch.from_numpy(init_inp).cuda()) else: cur_inp = Variable(torch.from_numpy(init_inp)) cur_h = None while t < 50: if cur_h: g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h) else: g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp) g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h) g_ext = g_str_s.unsqueeze(3) cur_cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cur_cond_str_score[b, :, num:] = -100 scores.append(cur_cond_str_score) _, ans_tok_var = cur_cond_str_score.view(B * 4, max_x_len).max(1) ans_tok = ans_tok_var.data.cpu() data = torch.zeros(B * 4, self.max_tok_num).scatter_( 1, ans_tok.unsqueeze(1), 1) if self.gpu: #To one-hot cur_inp = Variable(data.cuda()) else: cur_inp = Variable(data) cur_inp = cur_inp.unsqueeze(1) t += 1 cond_str_score = torch.stack(scores, 2) for b, num in enumerate(x_len): if num < max_x_len: cond_str_score[b, :, :, num:] = -100 #[B, IDX, T, TOK_NUM] cond_score = (cond_num_score, cond_col_score, cond_op_score, cond_str_score) return cond_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, table_content, q_seq, gt_where, gt_cond): max_x_len = max(x_len) # batch size, x_len is a list for token nums of all queries B = len(x_len) # Predict the number of conditions # use column embeddings to calculate the initial hidden unit e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_num_name_enc) num_col_att_val = self.cond_num_col_att(e_num_col).squeeze() for idx, num in enumerate(col_num): if num < max(col_num): num_col_att_val[idx, num:] = -100 num_col_att = self.softmax(num_col_att_val) K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1) cond_num_h1 = self.cond_num_col2hid1(K_num_col).view( B, 2 * self.N_depth, self.N_h // 2).transpose(0, 1).contiguous() cond_num_h2 = self.cond_num_col2hid2(K_num_col).view( B, 2 * self.N_depth, self.N_h // 2).transpose(0, 1).contiguous() # Then run the LSTM and predict condition number. h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len, hidden=(cond_num_h1, cond_num_h2)) num_att_val = self.cond_num_att(h_num_enc).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: num_att_val[idx, num:] = -100 num_att = self.softmax(num_att_val) K_cond_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1) cond_num_score = self.cond_num_out(K_cond_num) cond_num = torch.argmax(cond_num_score, dim=1) assert cond_num_score.shape == (B, 5) assert cond_num.shape[-1] == B # Predict the columns of conditions e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_col_name_enc) h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len) col_att_val = torch.bmm(e_cond_col, self.cond_col_att(h_col_enc).transpose(1, 2)) for idx, num in enumerate(x_len): if num < max_x_len: col_att_val[idx, :, num:] = -100 col_att = self.softmax(col_att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2) cond_col_score = self.cond_col_out( self.cond_col_out_K(K_cond_col) + self.cond_col_out_col(e_cond_col)).squeeze() # padding -100 for different sequence length max_col_num = max(col_num) for b, num in enumerate(col_num): if num < max_col_num: cond_col_score[b, num:] = -100 # use gt_cond when training, use predicted gt_cond when testing # chosen_col_gt = [[col_1, col_2, ...], ..., col chosen for last query] if gt_cond is None: # select num of condition from 0 to 4 cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1) col_scores = cond_col_score.data.cpu().numpy() # select the first cond_nums columns as prediction chosen_col_gt = [ list(np.argsort(-col_scores[b])[:cond_nums[b]]) for b in range(len(cond_nums)) ] else: chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond] # force max column index is under the table length assert len(chosen_col_gt) == B for b in range(B): table_length = len(table_content[b]) for c_index, x in enumerate(chosen_col_gt[b]): if x > table_length - 1: chosen_col_gt[b][c_index] = table_length - 1 # Predict the operator of conditions e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_op_name_enc) h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len) col_emb = [] # Pad the columns to 4, stack([col_1_emb, col_2_emb, pad_0_emb, pad_0_emb]) for b in range(B): cur_col_emb = torch.stack( [e_cond_col[b, x] for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) op_att_val = torch.matmul( self.cond_op_att(h_op_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() # add pad for query embedding for idx, num in enumerate(x_len): if num < max_x_len: op_att_val[idx, :, num:] = -100 op_att = self.softmax(op_att_val.view(B * 4, -1)).view(B, 4, -1) K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2) cond_op_score = self.cond_op_out( self.cond_op_out_K(K_cond_op) + self.cond_op_out_col(col_emb)).squeeze() # cond_op = [] # for b in range(B): # cond_op.extend([cond_op_score[b, num] for num in range(cond_num[b])]) # cond_op = torch.stack(cond_op) # print('cond_op.shape', cond_op.shape) # print('cond_num', cond_num) # cond_op = torch.argmax(cond_op, dim=-1) # cond_op_score=[batch size, max condition num for one query, operation num] assert cond_op_score.shape == (B, 4, 4) # Predict the value of conditions h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len) e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_str_name_enc) col_emb = [] for b in range(B): cur_col_emb = torch.stack( [e_cond_col[b, x] for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) col_emb.append(cur_col_emb) col_emb = torch.stack(col_emb) if self.use_table: # column_content = [] # for b in range(B): # column_content.append([table_content[b][x] for x in chosen_col_gt[b]]) # content_emb, table_config = self.emb_layer.gen_table_batch(column_content) if gt_where is not None: assert len(gt_where) == 4 gt_index, gt_value, condition_num, max_value_length = gt_where # condition num for one batch queries assert len(gt_value) == gt_index.shape[0] # gt_value_embed=[condition num, max value num, hidden state] gt_value_embd = self.emb_layer.condition_value_batch( gt_value, max_value_length) # print('gt_value_embd.shape', gt_value_embd.shape) query_embed = torch.mean(x_emb_var, dim=1) queries_embed = torch.zeros([len(gt_index), 768]) one_cond = 0 for c, cond in enumerate(condition_num): for _ in range(cond): queries_embed[one_cond, :] = query_embed[c] one_cond += 1 assert one_cond == gt_index.shape[0] # print('queries_embed.shape', queries_embed.shape) gt_value_embd = gt_value_embd.transpose(0, 1) value_score = torch.matmul(gt_value_embd, queries_embed.transpose(0, 1)) value_score = torch.diagonal(value_score, offset=0, dim1=-2, dim2=-1).transpose(0, 1) # print('value_score.shape', value_score.shape) # value_score=[condition num, max value length] for l, length in enumerate([len(x) for x in gt_value]): if length < max_value_length: value_score[l, length:] = -100 value_utils = self.softmax(value_score) else: cond_value, max_value_length = self.cond_value_batch( table_content, q_seq, cond_num, chosen_col_gt, cond_op_score) cond_value_embd = self.emb_layer.condition_value_batch( cond_value, max_value_length) query_embed = torch.mean(x_emb_var, dim=1) queries_embed = torch.zeros([len(cond_value), 768]) one_cond = 0 for c, cond in enumerate(cond_num): for _ in range(cond): queries_embed[one_cond, :] = query_embed[c] one_cond += 1 assert one_cond == len(cond_value) cond_value_embd = cond_value_embd.transpose(0, 1) value_score = torch.matmul(cond_value_embd, queries_embed.transpose(0, 1)) value_score = torch.diagonal(value_score, offset=0, dim1=-2, dim2=-1).transpose(0, 1) for l, length in enumerate([len(x) for x in cond_value]): if length < max_value_length: value_score[l, length:] = -100 softmax_score = self.softmax(value_score) value_utils = (softmax_score, cond_value, cond_num) else: # for training if gt_where is not None: # gt_where = [[[index range for value in query], column 2, ...],table 2, ...] gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where) g_str_s_flat, _ = self.cond_str_decoder( gt_tok_seq.view(B * 4, -1, self.max_tok_num)) g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h) h_ext = h_str_enc.unsqueeze(1).unsqueeze(1) g_ext = g_str_s.unsqueeze(3) col_ext = col_emb.unsqueeze(2).unsqueeze(2) cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: # [batch_size, condition num, T, TOK_NUM] cond_str_score[b, :, :, num:] = -100 # for validation and inference else: h_ext = h_str_enc.unsqueeze(1).unsqueeze(1) col_ext = col_emb.unsqueeze(2).unsqueeze(2) scores = [] t = 0 init_inp = np.zeros((B * 4, 1, self.max_tok_num), dtype=np.float32) # Set the <BEG> token init_inp[:, 0, 0] = 1 if self.gpu: cur_inp = torch.from_numpy(init_inp).to('cuda') else: cur_inp = torch.from_numpy(init_inp) cur_h = None while t < 50: if cur_h: g_str_s_flat, cur_h = self.cond_str_decoder( cur_inp, cur_h) else: g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp) g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h) g_ext = g_str_s.unsqueeze(3) cur_cond_str_score = self.cond_str_out( self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) + self.cond_str_out_col(col_ext)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cur_cond_str_score[b, :, num:] = -100 scores.append(cur_cond_str_score) _, ans_tok_var = cur_cond_str_score.view(B * 4, max_x_len).max(1) ans_tok = ans_tok_var.data.cpu() cur_inp = torch.zeros(B * 4, self.max_tok_num).scatter_( 1, ans_tok.unsqueeze(1), 1) if self.gpu: cur_inp = cur_inp.to('cuda') cur_inp = cur_inp.unsqueeze(1) t += 1 cond_str_score = torch.stack(scores, 2) for b, num in enumerate(x_len): if num < max_x_len: cond_str_score[b, :, :, num:] = -100 if self.use_table: cond_score = (cond_num_score, cond_col_score, cond_op_score, value_utils) else: cond_score = (cond_num_score, cond_col_score, cond_op_score, cond_str_score) return cond_score
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce): max_x_len = max(x_len) B = len(x_len) h_enc, hidden = run_lstm(self.cond_lstm, x_emb_var, x_len) decoder_hidden = tuple(torch.cat((hid[:2], hid[2:]),dim=2) for hid in hidden) if gt_where is not None: gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where, gen_inp=True) g_s, _ = run_lstm(self.cond_decoder, gt_tok_seq, gt_tok_len, decoder_hidden) h_enc_expand = h_enc.unsqueeze(1) g_s_expand = g_s.unsqueeze(2) cond_score = self.cond_out( self.cond_out_h(h_enc_expand) + self.cond_out_g(g_s_expand) ).squeeze() for idx, num in enumerate(x_len): if num < max_x_len: cond_score[idx, :, num:] = -100 else: h_enc_expand = h_enc.unsqueeze(1) scores = [] choices = [] done_set = set() t = 0 init_inp = np.zeros((B, 1, self.max_tok_num), dtype=np.float32) init_inp[:,0,7] = 1 #Set the <BEG> token if self.gpu: cur_inp = Variable(torch.from_numpy(init_inp).cuda()) else: cur_inp = Variable(torch.from_numpy(init_inp)) cur_h = decoder_hidden while len(done_set) < B and t < 100: g_s, cur_h = self.cond_decoder(cur_inp, cur_h) g_s_expand = g_s.unsqueeze(2) cur_cond_score = self.cond_out(self.cond_out_h(h_enc_expand) + self.cond_out_g(g_s_expand)).squeeze() for b, num in enumerate(x_len): if num < max_x_len: cur_cond_score[b, num:] = -100 scores.append(cur_cond_score) if not reinforce: _, ans_tok_var = cur_cond_score.view(B, max_x_len).max(1) ans_tok_var = ans_tok_var.unsqueeze(1) else: ans_tok_var = self.softmax(cur_cond_score).multinomial() choices.append(ans_tok_var) ans_tok = ans_tok_var.data.cpu() if self.gpu: #To one-hot cur_inp = Variable(torch.zeros( B, self.max_tok_num).scatter_(1, ans_tok, 1).cuda()) else: cur_inp = Variable(torch.zeros( B, self.max_tok_num).scatter_(1, ans_tok, 1)) cur_inp = cur_inp.unsqueeze(1) for idx, tok in enumerate(ans_tok.squeeze()): if tok == 1: #Find the <END> token done_set.add(idx) t += 1 cond_score = torch.stack(scores, 1) if reinforce: return cond_score, choices else: return cond_score