def get_sketch_prod(examples: List, table_dict: Dict) -> List: """ If it contains all three types of columns, then the grammar is complete Also return sketch action list and their slots """ for example in examples: table_id = example["context"] processed_table = table_dict[table_id] context = WikiSQLContext.read_from_json(example, processed_table) context.take_features(example) world = WikiSQLLanguage(context) if len(context.column_types) >= 2 and len(context._num2id) > 0 and \ len(context._entity2id) > 0: actions = world.get_nonterminal_productions() sketch_actions = world._get_sketch_productions(actions) # index all the possible actions action_set = set() for k, v in sketch_actions.items(): action_set = action_set.union(set(v)) id2prod = list(action_set) prod2id = {v: k for k, v in enumerate(id2prod)} return id2prod, prod2id
def evaluate(self, context: WikiSQLContext, sketch2program: Dict) -> bool: world = WikiSQLLanguage(context) ret_dic = defaultdict(int) # encode question and offline sketches token_in_table_feat = context.question_in_table_feat token_encodes, token_reps = self.encode_question( context.question_tokens, token_in_table_feat) candidate_rep_dic = self.construct_candidates(world, token_encodes) sketch_actions, program_actions = self.seq2seq.decode(world, token_reps, \ token_encodes, candidate_rep_dic) sketch_lf = world.action_sequence_to_logical_form(sketch_actions) program_lf = world.action_sequence_to_logical_form(program_actions) if sketch_lf in sketch2program: sketch_triggered = True if program_lf in sketch2program[sketch_lf]: lf_triggered = True else: lf_triggered = False else: sketch_triggered = False lf_triggered = False ret_dic["best_program_lf"] = program_lf ret_dic["best_program_actions"] = program_actions ret_dic["best_sketch_lf"] = sketch_lf ret_dic["best_sketch_actions"] = sketch_actions return ret_dic
def filter_program_by_execution(self, world: WikiSQLLanguage, actions: List): try: world.execute_action_sequence(actions) return True except: return False
def forward(self, context: WikiSQLContext, sketch2program: Dict) -> torch.Tensor: world = WikiSQLLanguage(context) # encode questions token_in_table_feat = context.question_in_table_feat token_encodes, token_reps = self.encode_question( context.question_tokens, token_in_table_feat) sketch_lf2actions = self.sketch_lf2actions(world) consistent_scores = [] candidate_rep_dic = self.construct_candidates(world, token_encodes) for sketch_lf in sketch2program: sketch_actions = sketch_lf2actions[sketch_lf] if len(sketch2program[sketch_lf]) > self.CONSISTENT_INST_NUM_BOUND: continue for program_lf in sketch2program[sketch_lf]: program_actions = world.logical_form_to_action_sequence( program_lf) seq_log_likeli = self.seq2seq(world, token_reps, token_encodes, candidate_rep_dic, sketch_actions, program_actions) if seq_log_likeli: consistent_scores.append(seq_log_likeli) if len(consistent_scores) > 0: return -1 * log_sum_exp(consistent_scores) else: return None
def get_sketch_prod_and_slot(examples: List, table_dict: Dict, sketch_list: List, sketch_action_list: List): """ If it contains all three types of columns, then the grammar is complete Also return sketch action list and their slots Used for pruned version """ for example in examples: table_id = example["context"] processed_table = table_dict[table_id] context = WikiSQLContext.read_from_json(example, processed_table) context.take_features(example) world = WikiSQLLanguage(context) if len(context.column_types) >= 2 and len(context._num2id) > 0 and \ len(context._entity2id) > 0: actions = world.get_nonterminal_productions() sketch_actions = world._get_sketch_productions(actions) # index all the possible actions action_set = set() for k, v in sketch_actions.items(): action_set = action_set.union(set(v)) id2prod = list(action_set) prod2id = {v: k for k, v in enumerate(id2prod)} # lf to actions sketch_lf2actions = dict() for sketch_actions in sketch_action_list: lf = world.action_sequence_to_logical_form(sketch_actions) sketch_lf2actions[lf] = sketch_actions # sort by length in decreasing order slot_dict = defaultdict(dict) sketch_action_seqs = [] for sketch in sketch_list: sketch_actions = sketch_lf2actions[sketch] sketch_actions = tuple(sketch_actions) sketch_action_seqs.append(sketch_actions) for action_ind, action in enumerate(sketch_actions): assert action in prod2id lhs, rhs = action.split(" -> ") if lhs in [ "Column", "StringColumn", "NumberColumn", "ComparableColumn", "DateColumn", "str", "Number", "Date" ] and rhs == "#PH#": slot_dict[sketch_actions][action_ind] = lhs elif lhs == "List[Row]" and rhs == "#PH#": slot_dict[sketch_actions][action_ind] = lhs return id2prod, prod2id, sketch_action_seqs, slot_dict
def decode(self, world: WikiSQLLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor): """ Input: a sequence of sketch actions Output: the most probable sequence """ action_dict = world._get_sketch_productions( world.get_nonterminal_productions()) initial_rnn_state = self._get_initial_state(token_rep) stack = [START_SYMBOL] history = [] rnn_state = initial_rnn_state for i in range(self._max_decoding_steps): if len(stack) == 0: break cur_non_terminal = stack.pop() if cur_non_terminal not in action_dict: continue candidates = action_dict[cur_non_terminal] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze() filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) prob_v = F.softmax(filter_score_v, dim=0) _, pred_id = torch.max(prob_v, dim=0) pred_id = pred_id.cpu().item() next_action_embed = self.sketch_embed.weight[ candidate_ids[pred_id]].unsqueeze(0) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) prod = candidates[pred_id] history.append(prod) non_terminals = self._get_right_side_parts(prod) stack += list(reversed(non_terminals)) return tuple(history)
def forward(self, context: WikiSQLContext, sketch2program: Dict) -> torch.Tensor: world = WikiSQLLanguage(context) # encode questions token_in_table_feat = context.question_in_table_feat token_encodes, token_reps, last_state = self.encode_question( context.question_tokens, token_in_table_feat) sketch_lf2actions = self.sketch_lf2actions(world) consistent_scores = [] candidate_rep_dic = self.construct_candidates(world, token_encodes) for sketch_lf in sketch2program: if len(sketch2program[sketch_lf]) > self.CONSISTENT_INST_NUM_BOUND: continue sketch_actions = sketch_lf2actions[sketch_lf] seq_log_likeli = self.seq2seq(world, token_reps, token_encodes, sketch_actions) _paths, _log_scores = self.slot_filling(world, token_encodes, last_state, candidate_rep_dic, sketch_actions) # only one path if len(_paths) == 1: consistent_scores.append(seq_log_likeli) continue _gold_scores = [] for _path, _score in zip(_paths, _log_scores): assert _score is not None _path_lf = world.action_sequence_to_logical_form(_path) # logger.info(_path_lf) if _path_lf in sketch2program[sketch_lf]: _gold_scores.append(_score) # aggregate consistent instantiations if len(_gold_scores) > 0: _score = seq_log_likeli + log_sum_exp(_gold_scores) if torch.isnan(_score) == 0: consistent_scores.append(_score) else: logger.warning("Nan loss founded!") if len(consistent_scores) > 0: return -1 * log_sum_exp(consistent_scores) else: return None
def forward(self, world: WikiSQLLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor, sketch_actions: List): """ Input: a sequence of sketch actions """ action_dict = world._get_sketch_productions( world.get_nonterminal_productions()) initial_rnn_state = self._get_initial_state(token_rep) seq_likeli = [] rnn_state = initial_rnn_state for i, prod in enumerate(sketch_actions): left_side, _ = prod.split(" -> ") candidates = action_dict[left_side] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze() filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) log_likeli = F.log_softmax(filter_score_v, dim=0) gold_id = candidate_ids.index(self.sketch_prod2id[prod]) seq_likeli.append(log_likeli[gold_id]) next_action_embed = self.sketch_embed.weight[ self.sketch_prod2id[prod]].unsqueeze(0) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) return sum(seq_likeli)
def check_multi_col(world: WikiSQLLanguage, sketch_actions: List, program_actions: List) -> bool: prod_dic = world.get_nonterminal_productions() slot_dic = gen_slot2action_dic(world, prod_dic, sketch_actions, program_actions) row_slot_acs = [] col_slot_acs = [] for idx in slot_dic: slot_type = get_left_side_part(sketch_actions[idx]) if slot_type == "List[Row]": row_slot_acs.append(slot_dic[idx]) else: col_slot_acs.append(slot_dic[idx]) if len(row_slot_acs) == 0 or len(col_slot_acs) == 0: return False for col_slot_ac in col_slot_acs: col_name = get_right_side_parts(col_slot_ac)[0] for row_slot_ac in row_slot_acs: if col_name not in "_".join(row_slot_ac): return True return False
def beam_decode(self, world: WikiSQLLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor, beam_size: int): """ Input: a sequence of sketch actions Output: output top-k most probable sequence """ action_dict = world._get_sketch_productions( world.get_nonterminal_productions()) initial_rnn_state = self._get_initial_state(token_rep) incomplete = [([START_SYMBOL], [], initial_rnn_state, None) ] # stack,history,rnn_state completed = [] for i in range(self._max_decoding_steps): next_paths = [] for stack, history, rnn_state, seq_score in incomplete: cur_non_terminal = stack.pop() if cur_non_terminal not in action_dict: continue candidates = action_dict[cur_non_terminal] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze() filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) prob_v = F.log_softmax(filter_score_v, dim=0) pred_logits, pred_ids = torch.topk(prob_v, min(beam_size, prob_v.size()[0]), dim=0) for _logits, _idx in zip(pred_logits, pred_ids): next_action_embed = self.sketch_embed.weight[ candidate_ids[_idx]].unsqueeze(0) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) prod = candidates[_idx] _history = history[:] _history.append(prod) non_terminals = self._get_right_side_parts(prod) _stack = stack[:] for ac in reversed(non_terminals): if ac in action_dict: _stack.append(ac) if seq_score is None: _score = _logits else: _score = _logits + seq_score next_paths.append((_stack, _history, rnn_state, _score)) incomplete = [] for stack, history, rnn_state, seq_score in next_paths: if len(stack) == 0: if world.action_sequence_to_logical_form( history) != "#PH#": completed.append((history, seq_score)) else: incomplete.append((stack, history, rnn_state, seq_score)) if len(completed) > beam_size: completed = sorted(completed, key=lambda x: -x[1]) completed = completed[:beam_size] break if len(incomplete) > beam_size: incomplete = sorted(incomplete, key=lambda x: -x[3]) incomplete = incomplete[:beam_size] return completed
def compute_entropy(self, context: WikiSQLContext, sketch2program: Dict, keywords: List) -> Dict: """ Return a dictionary for different analysis """ world = WikiSQLLanguage(context) ret_dic = defaultdict(int) # encode question and offline sketches token_in_table_feat = context.question_in_table_feat token_encodes, token_reps, last_state = self.encode_question( context.question_tokens, token_in_table_feat) consist_prob_logs = [] # consist_sketch_logs = [] entropy = [] candidate_rep_dic = self.construct_candidates(world, token_encodes) sketch_lf2actions = self.sketch_lf2actions(world) for sketch_lf in sketch2program: sketch_actions = sketch_lf2actions[sketch_lf] sketch_log_score = self.seq2seq(world, token_reps, token_encodes, sketch_actions) _paths, _log_scores = self.slot_filling(world, token_encodes, last_state, candidate_rep_dic, sketch_actions) # sketch entropy # entropy.append(-1 * sketch_log_score * torch.exp(sketch_log_score)) # consist_sketch_logs.append(sketch_log_score) # only one path if len(_paths) == 1: if not self.filter_program_by_execution(world, _paths[0]): continue _path_lf = world.action_sequence_to_logical_form(_paths[0]) _seq_score = sketch_log_score if _path_lf in sketch2program[sketch_lf]: if self.filter_by_keywords(_path_lf, keywords): entropy.append(_seq_score) # consist_prob_logs.append(_seq_score) consist_prob_logs.append(_seq_score) continue # multiple path for _path, _score in zip(_paths, _log_scores): if not self.filter_program_by_execution(world, _path): continue _path_lf = world.action_sequence_to_logical_form(_path) _seq_score = _score + sketch_log_score if _path_lf in sketch2program[sketch_lf]: if self.filter_by_keywords(_path_lf, keywords): entropy.append(_seq_score) consist_prob_logs.append(_seq_score) if len(entropy) > 0: # print(f"length {len(entropy)}") ret_dic["entropy"] = log_sum_exp(consist_prob_logs) - log_sum_exp( entropy) ret_dic["triggered"] = True if max(consist_prob_logs) > max(entropy): ret_dic["is_correct"] = False else: ret_dic["is_correct"] = True p = torch.exp(max(entropy)) / sum( torch.exp(l_) for l_ in consist_prob_logs) ret_dic["proportion"] = p else: ret_dic["triggered"] = False # print("Not matched") return ret_dic
def evaluate(self, context: WikiSQLContext, sketch2program: Dict) -> Dict: """ Return a dictionary for different analysis """ world = WikiSQLLanguage(context) ret_dic = defaultdict(int) # encode question and offline sketches token_in_table_feat = context.question_in_table_feat token_encodes, token_reps, last_state = self.encode_question( context.question_tokens, token_in_table_feat) sketch_actions_and_scores = self.seq2seq.beam_decode( world, token_reps, token_encodes, self.EVAL_NUM_SKETCH_BOUND) max_score = None best_sketch_actions = None best_sketch_lf = None best_program_actions = None best_program_lf = None candidate_rep_dic = self.construct_candidates(world, token_encodes) for sketch_actions, sketch_log_score in sketch_actions_and_scores: sketch_lf = world.action_sequence_to_logical_form(sketch_actions) _paths, _log_scores = self.slot_filling(world, token_encodes, last_state, candidate_rep_dic, sketch_actions) if self.__class__.__name__ == "ConcreteProgrammer": assert self._ConcreteProgrammer__cur_align_prob_log is not None align_prob_log = self._ConcreteProgrammer__cur_align_prob_log.squeeze( ) sketch_log_score = sketch_log_score + align_prob_log self._ConcreteProgrammer__cur_align_prob_log = None # only one path if len(_paths) == 1: if not self.filter_program_by_execution(world, _paths[0]): continue _path_lf = world.action_sequence_to_logical_form(_paths[0]) _seq_score = sketch_log_score if max_score is None or _seq_score > max_score: max_score = _seq_score best_sketch_lf = sketch_lf best_sketch_actions = sketch_actions best_program_lf = _path_lf best_program_actions = _paths[0] continue # multiple path for _path, _score in zip(_paths, _log_scores): if not self.filter_program_by_execution(world, _path): continue assert _score is not None _path_lf = world.action_sequence_to_logical_form(_path) _seq_score = _score + sketch_log_score if max_score is None or _seq_score > max_score: max_score = _seq_score best_sketch_lf = sketch_lf best_sketch_actions = sketch_actions best_program_lf = _path_lf best_program_actions = _path assert max_score is not None ret_dic["best_program_lf"] = best_program_lf ret_dic["best_program_actions"] = best_program_actions ret_dic["best_sketch_lf"] = best_sketch_lf ret_dic["best_sketch_actions"] = best_sketch_actions ret_dic["best_score"] = torch.exp(max_score) ret_dic["is_multi_col"] = check_multi_col(world, best_sketch_actions, best_program_actions) if best_sketch_lf in sketch2program: ret_dic["sketch_triggered"] = True if best_program_lf in sketch2program[best_sketch_lf]: ret_dic["lf_triggered"] = True else: ret_dic["lf_triggered"] = False else: ret_dic["sketch_triggered"] = False ret_dic["lf_triggered"] = False return ret_dic
def slot_filling(self, world: WikiSQLLanguage, token_encodes: torch.Tensor, token_state: torch.Tensor, candidate_rep_dic: Dict, sketch_actions: List): """ 1) collect scores for each individual slot 2) find all the paths recursively """ slot_dict = world.get_slot_dict(sketch_actions) sketch_encodes, sketch_rep = self.encode_sketch( sketch_actions, token_state) candidate_score_dic = self.collect_candidate_scores( world, token_encodes, candidate_rep_dic, sketch_encodes, slot_dict) possible_paths = [] path_scores = [] def recur_compute(prefix, score, i): if i == len(sketch_actions): possible_paths.append(prefix) path_scores.append(score) return if i in slot_dict: _slot_type = slot_dict[i] if _slot_type not in candidate_rep_dic: return # this sketch does not apply here slot_rep = sketch_encodes[i] candidate_v, candidiate_actions = candidate_rep_dic[_slot_type] if len(candidiate_actions) == 1: action = candidiate_actions[0] new_prefix = prefix[:] if isinstance(action, list): new_prefix += action else: new_prefix.append(action) recur_compute(new_prefix, score, i + 1) return if len(candidiate_actions) > self.CANDIDATE_ACTION_NUM_BOUND: _, top_k = torch.topk(candidate_score_dic[i], self.CANDIDATE_ACTION_NUM_BOUND, dim=0) ac_idxs = top_k.cpu().numpy() else: ac_idxs = range(len(candidiate_actions)) # for ac_ind, action in enumerate(candidiate_actions): for ac_ind in ac_idxs: action = candidiate_actions[ac_ind] new_prefix = prefix[:] if score: new_score = score + candidate_score_dic[i][ac_ind] else: new_score = candidate_score_dic[i][ac_ind] if isinstance(action, list): new_prefix += action else: new_prefix.append(action) recur_compute(new_prefix, new_score, i + 1) else: new_prefix = prefix[:] new_prefix.append(sketch_actions[i]) recur_compute(new_prefix, score, i + 1) recur_compute([], None, 0) return possible_paths, path_scores
def sketch_lf2actions(self, world: WikiSQLLanguage): lf2actions = dict() for actions in self.sketch_actions_cache: lf = world.action_sequence_to_logical_form(actions) lf2actions[lf] = actions return lf2actions