def get_sketch_prod(examples: List, table_dict: Dict) -> List: """ If it contains all three types of columns, then the grammar is complete Also return sketch action list and their slots """ for example in examples: table_id = example["context"] processed_table = table_dict[table_id] context = WikiSQLContext.read_from_json(example, processed_table) context.take_features(example) world = WikiSQLLanguage(context) if len(context.column_types) >= 2 and len(context._num2id) > 0 and \ len(context._entity2id) > 0: actions = world.get_nonterminal_productions() sketch_actions = world._get_sketch_productions(actions) # index all the possible actions action_set = set() for k, v in sketch_actions.items(): action_set = action_set.union(set(v)) id2prod = list(action_set) prod2id = {v: k for k, v in enumerate(id2prod)} return id2prod, prod2id
def get_sketch_prod_and_slot(examples: List, table_dict: Dict, sketch_list: List, sketch_action_list: List): """ If it contains all three types of columns, then the grammar is complete Also return sketch action list and their slots Used for pruned version """ for example in examples: table_id = example["context"] processed_table = table_dict[table_id] context = WikiSQLContext.read_from_json(example, processed_table) context.take_features(example) world = WikiSQLLanguage(context) if len(context.column_types) >= 2 and len(context._num2id) > 0 and \ len(context._entity2id) > 0: actions = world.get_nonterminal_productions() sketch_actions = world._get_sketch_productions(actions) # index all the possible actions action_set = set() for k, v in sketch_actions.items(): action_set = action_set.union(set(v)) id2prod = list(action_set) prod2id = {v: k for k, v in enumerate(id2prod)} # lf to actions sketch_lf2actions = dict() for sketch_actions in sketch_action_list: lf = world.action_sequence_to_logical_form(sketch_actions) sketch_lf2actions[lf] = sketch_actions # sort by length in decreasing order slot_dict = defaultdict(dict) sketch_action_seqs = [] for sketch in sketch_list: sketch_actions = sketch_lf2actions[sketch] sketch_actions = tuple(sketch_actions) sketch_action_seqs.append(sketch_actions) for action_ind, action in enumerate(sketch_actions): assert action in prod2id lhs, rhs = action.split(" -> ") if lhs in [ "Column", "StringColumn", "NumberColumn", "ComparableColumn", "DateColumn", "str", "Number", "Date" ] and rhs == "#PH#": slot_dict[sketch_actions][action_ind] = lhs elif lhs == "List[Row]" and rhs == "#PH#": slot_dict[sketch_actions][action_ind] = lhs return id2prod, prod2id, sketch_action_seqs, slot_dict
def decode(self, world: WikiSQLLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor): """ Input: a sequence of sketch actions Output: the most probable sequence """ action_dict = world._get_sketch_productions( world.get_nonterminal_productions()) initial_rnn_state = self._get_initial_state(token_rep) stack = [START_SYMBOL] history = [] rnn_state = initial_rnn_state for i in range(self._max_decoding_steps): if len(stack) == 0: break cur_non_terminal = stack.pop() if cur_non_terminal not in action_dict: continue candidates = action_dict[cur_non_terminal] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze() filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) prob_v = F.softmax(filter_score_v, dim=0) _, pred_id = torch.max(prob_v, dim=0) pred_id = pred_id.cpu().item() next_action_embed = self.sketch_embed.weight[ candidate_ids[pred_id]].unsqueeze(0) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) prod = candidates[pred_id] history.append(prod) non_terminals = self._get_right_side_parts(prod) stack += list(reversed(non_terminals)) return tuple(history)
def forward(self, world: WikiSQLLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor, sketch_actions: List): """ Input: a sequence of sketch actions """ action_dict = world._get_sketch_productions( world.get_nonterminal_productions()) initial_rnn_state = self._get_initial_state(token_rep) seq_likeli = [] rnn_state = initial_rnn_state for i, prod in enumerate(sketch_actions): left_side, _ = prod.split(" -> ") candidates = action_dict[left_side] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze() filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) log_likeli = F.log_softmax(filter_score_v, dim=0) gold_id = candidate_ids.index(self.sketch_prod2id[prod]) seq_likeli.append(log_likeli[gold_id]) next_action_embed = self.sketch_embed.weight[ self.sketch_prod2id[prod]].unsqueeze(0) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) return sum(seq_likeli)
def check_multi_col(world: WikiSQLLanguage, sketch_actions: List, program_actions: List) -> bool: prod_dic = world.get_nonterminal_productions() slot_dic = gen_slot2action_dic(world, prod_dic, sketch_actions, program_actions) row_slot_acs = [] col_slot_acs = [] for idx in slot_dic: slot_type = get_left_side_part(sketch_actions[idx]) if slot_type == "List[Row]": row_slot_acs.append(slot_dic[idx]) else: col_slot_acs.append(slot_dic[idx]) if len(row_slot_acs) == 0 or len(col_slot_acs) == 0: return False for col_slot_ac in col_slot_acs: col_name = get_right_side_parts(col_slot_ac)[0] for row_slot_ac in row_slot_acs: if col_name not in "_".join(row_slot_ac): return True return False
def beam_decode(self, world: WikiSQLLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor, beam_size: int): """ Input: a sequence of sketch actions Output: output top-k most probable sequence """ action_dict = world._get_sketch_productions( world.get_nonterminal_productions()) initial_rnn_state = self._get_initial_state(token_rep) incomplete = [([START_SYMBOL], [], initial_rnn_state, None) ] # stack,history,rnn_state completed = [] for i in range(self._max_decoding_steps): next_paths = [] for stack, history, rnn_state, seq_score in incomplete: cur_non_terminal = stack.pop() if cur_non_terminal not in action_dict: continue candidates = action_dict[cur_non_terminal] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze() filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) prob_v = F.log_softmax(filter_score_v, dim=0) pred_logits, pred_ids = torch.topk(prob_v, min(beam_size, prob_v.size()[0]), dim=0) for _logits, _idx in zip(pred_logits, pred_ids): next_action_embed = self.sketch_embed.weight[ candidate_ids[_idx]].unsqueeze(0) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) prod = candidates[_idx] _history = history[:] _history.append(prod) non_terminals = self._get_right_side_parts(prod) _stack = stack[:] for ac in reversed(non_terminals): if ac in action_dict: _stack.append(ac) if seq_score is None: _score = _logits else: _score = _logits + seq_score next_paths.append((_stack, _history, rnn_state, _score)) incomplete = [] for stack, history, rnn_state, seq_score in next_paths: if len(stack) == 0: if world.action_sequence_to_logical_form( history) != "#PH#": completed.append((history, seq_score)) else: incomplete.append((stack, history, rnn_state, seq_score)) if len(completed) > beam_size: completed = sorted(completed, key=lambda x: -x[1]) completed = completed[:beam_size] break if len(incomplete) > beam_size: incomplete = sorted(incomplete, key=lambda x: -x[3]) incomplete = incomplete[:beam_size] return completed