def evaluate(self, context: WikiSQLContext, sketch2program: Dict) -> bool: world = WikiSQLLanguage(context) ret_dic = defaultdict(int) # encode question and offline sketches token_in_table_feat = context.question_in_table_feat token_encodes, token_reps = self.encode_question( context.question_tokens, token_in_table_feat) candidate_rep_dic = self.construct_candidates(world, token_encodes) sketch_actions, program_actions = self.seq2seq.decode(world, token_reps, \ token_encodes, candidate_rep_dic) sketch_lf = world.action_sequence_to_logical_form(sketch_actions) program_lf = world.action_sequence_to_logical_form(program_actions) if sketch_lf in sketch2program: sketch_triggered = True if program_lf in sketch2program[sketch_lf]: lf_triggered = True else: lf_triggered = False else: sketch_triggered = False lf_triggered = False ret_dic["best_program_lf"] = program_lf ret_dic["best_program_actions"] = program_actions ret_dic["best_sketch_lf"] = sketch_lf ret_dic["best_sketch_actions"] = sketch_actions return ret_dic
def get_sketch_prod_and_slot(examples: List, table_dict: Dict, sketch_list: List, sketch_action_list: List): """ If it contains all three types of columns, then the grammar is complete Also return sketch action list and their slots Used for pruned version """ for example in examples: table_id = example["context"] processed_table = table_dict[table_id] context = WikiSQLContext.read_from_json(example, processed_table) context.take_features(example) world = WikiSQLLanguage(context) if len(context.column_types) >= 2 and len(context._num2id) > 0 and \ len(context._entity2id) > 0: actions = world.get_nonterminal_productions() sketch_actions = world._get_sketch_productions(actions) # index all the possible actions action_set = set() for k, v in sketch_actions.items(): action_set = action_set.union(set(v)) id2prod = list(action_set) prod2id = {v: k for k, v in enumerate(id2prod)} # lf to actions sketch_lf2actions = dict() for sketch_actions in sketch_action_list: lf = world.action_sequence_to_logical_form(sketch_actions) sketch_lf2actions[lf] = sketch_actions # sort by length in decreasing order slot_dict = defaultdict(dict) sketch_action_seqs = [] for sketch in sketch_list: sketch_actions = sketch_lf2actions[sketch] sketch_actions = tuple(sketch_actions) sketch_action_seqs.append(sketch_actions) for action_ind, action in enumerate(sketch_actions): assert action in prod2id lhs, rhs = action.split(" -> ") if lhs in [ "Column", "StringColumn", "NumberColumn", "ComparableColumn", "DateColumn", "str", "Number", "Date" ] and rhs == "#PH#": slot_dict[sketch_actions][action_ind] = lhs elif lhs == "List[Row]" and rhs == "#PH#": slot_dict[sketch_actions][action_ind] = lhs return id2prod, prod2id, sketch_action_seqs, slot_dict
def forward(self, context: WikiSQLContext, sketch2program: Dict) -> torch.Tensor: world = WikiSQLLanguage(context) # encode questions token_in_table_feat = context.question_in_table_feat token_encodes, token_reps, last_state = self.encode_question( context.question_tokens, token_in_table_feat) sketch_lf2actions = self.sketch_lf2actions(world) consistent_scores = [] candidate_rep_dic = self.construct_candidates(world, token_encodes) for sketch_lf in sketch2program: if len(sketch2program[sketch_lf]) > self.CONSISTENT_INST_NUM_BOUND: continue sketch_actions = sketch_lf2actions[sketch_lf] seq_log_likeli = self.seq2seq(world, token_reps, token_encodes, sketch_actions) _paths, _log_scores = self.slot_filling(world, token_encodes, last_state, candidate_rep_dic, sketch_actions) # only one path if len(_paths) == 1: consistent_scores.append(seq_log_likeli) continue _gold_scores = [] for _path, _score in zip(_paths, _log_scores): assert _score is not None _path_lf = world.action_sequence_to_logical_form(_path) # logger.info(_path_lf) if _path_lf in sketch2program[sketch_lf]: _gold_scores.append(_score) # aggregate consistent instantiations if len(_gold_scores) > 0: _score = seq_log_likeli + log_sum_exp(_gold_scores) if torch.isnan(_score) == 0: consistent_scores.append(_score) else: logger.warning("Nan loss founded!") if len(consistent_scores) > 0: return -1 * log_sum_exp(consistent_scores) else: return None
def beam_decode(self, world: WikiSQLLanguage, token_rep: torch.Tensor, token_encodes: torch.Tensor, beam_size: int): """ Input: a sequence of sketch actions Output: output top-k most probable sequence """ action_dict = world._get_sketch_productions( world.get_nonterminal_productions()) initial_rnn_state = self._get_initial_state(token_rep) incomplete = [([START_SYMBOL], [], initial_rnn_state, None) ] # stack,history,rnn_state completed = [] for i in range(self._max_decoding_steps): next_paths = [] for stack, history, rnn_state, seq_score in incomplete: cur_non_terminal = stack.pop() if cur_non_terminal not in action_dict: continue candidates = action_dict[cur_non_terminal] candidate_ids = [self.sketch_prod2id[ac] for ac in candidates] cur_hidden, cur_memory = rnn_state.hidden_state, rnn_state.memory_cell next_hidden, next_memory = self.decoder_lstm( rnn_state.previous_action_embedding, (cur_hidden, cur_memory)) hidden_tran = next_hidden.transpose(0, 1) att_feat_v = torch.mm(token_encodes, hidden_tran) # sent_len * 1 att_v = F.softmax(att_feat_v, dim=0) att_ret_v = torch.mm(att_v.transpose(0, 1), token_encodes) score_feat_v = torch.cat([next_hidden, att_ret_v], 1) score_v = self.score_action(score_feat_v).squeeze() filter_score_v_list = [score_v[_id] for _id in candidate_ids] filter_score_v = torch.stack(filter_score_v_list, 0) prob_v = F.log_softmax(filter_score_v, dim=0) pred_logits, pred_ids = torch.topk(prob_v, min(beam_size, prob_v.size()[0]), dim=0) for _logits, _idx in zip(pred_logits, pred_ids): next_action_embed = self.sketch_embed.weight[ candidate_ids[_idx]].unsqueeze(0) rnn_state = RnnStatelet(next_hidden, next_memory, next_action_embed, None, None, None) prod = candidates[_idx] _history = history[:] _history.append(prod) non_terminals = self._get_right_side_parts(prod) _stack = stack[:] for ac in reversed(non_terminals): if ac in action_dict: _stack.append(ac) if seq_score is None: _score = _logits else: _score = _logits + seq_score next_paths.append((_stack, _history, rnn_state, _score)) incomplete = [] for stack, history, rnn_state, seq_score in next_paths: if len(stack) == 0: if world.action_sequence_to_logical_form( history) != "#PH#": completed.append((history, seq_score)) else: incomplete.append((stack, history, rnn_state, seq_score)) if len(completed) > beam_size: completed = sorted(completed, key=lambda x: -x[1]) completed = completed[:beam_size] break if len(incomplete) > beam_size: incomplete = sorted(incomplete, key=lambda x: -x[3]) incomplete = incomplete[:beam_size] return completed
def compute_entropy(self, context: WikiSQLContext, sketch2program: Dict, keywords: List) -> Dict: """ Return a dictionary for different analysis """ world = WikiSQLLanguage(context) ret_dic = defaultdict(int) # encode question and offline sketches token_in_table_feat = context.question_in_table_feat token_encodes, token_reps, last_state = self.encode_question( context.question_tokens, token_in_table_feat) consist_prob_logs = [] # consist_sketch_logs = [] entropy = [] candidate_rep_dic = self.construct_candidates(world, token_encodes) sketch_lf2actions = self.sketch_lf2actions(world) for sketch_lf in sketch2program: sketch_actions = sketch_lf2actions[sketch_lf] sketch_log_score = self.seq2seq(world, token_reps, token_encodes, sketch_actions) _paths, _log_scores = self.slot_filling(world, token_encodes, last_state, candidate_rep_dic, sketch_actions) # sketch entropy # entropy.append(-1 * sketch_log_score * torch.exp(sketch_log_score)) # consist_sketch_logs.append(sketch_log_score) # only one path if len(_paths) == 1: if not self.filter_program_by_execution(world, _paths[0]): continue _path_lf = world.action_sequence_to_logical_form(_paths[0]) _seq_score = sketch_log_score if _path_lf in sketch2program[sketch_lf]: if self.filter_by_keywords(_path_lf, keywords): entropy.append(_seq_score) # consist_prob_logs.append(_seq_score) consist_prob_logs.append(_seq_score) continue # multiple path for _path, _score in zip(_paths, _log_scores): if not self.filter_program_by_execution(world, _path): continue _path_lf = world.action_sequence_to_logical_form(_path) _seq_score = _score + sketch_log_score if _path_lf in sketch2program[sketch_lf]: if self.filter_by_keywords(_path_lf, keywords): entropy.append(_seq_score) consist_prob_logs.append(_seq_score) if len(entropy) > 0: # print(f"length {len(entropy)}") ret_dic["entropy"] = log_sum_exp(consist_prob_logs) - log_sum_exp( entropy) ret_dic["triggered"] = True if max(consist_prob_logs) > max(entropy): ret_dic["is_correct"] = False else: ret_dic["is_correct"] = True p = torch.exp(max(entropy)) / sum( torch.exp(l_) for l_ in consist_prob_logs) ret_dic["proportion"] = p else: ret_dic["triggered"] = False # print("Not matched") return ret_dic
def evaluate(self, context: WikiSQLContext, sketch2program: Dict) -> Dict: """ Return a dictionary for different analysis """ world = WikiSQLLanguage(context) ret_dic = defaultdict(int) # encode question and offline sketches token_in_table_feat = context.question_in_table_feat token_encodes, token_reps, last_state = self.encode_question( context.question_tokens, token_in_table_feat) sketch_actions_and_scores = self.seq2seq.beam_decode( world, token_reps, token_encodes, self.EVAL_NUM_SKETCH_BOUND) max_score = None best_sketch_actions = None best_sketch_lf = None best_program_actions = None best_program_lf = None candidate_rep_dic = self.construct_candidates(world, token_encodes) for sketch_actions, sketch_log_score in sketch_actions_and_scores: sketch_lf = world.action_sequence_to_logical_form(sketch_actions) _paths, _log_scores = self.slot_filling(world, token_encodes, last_state, candidate_rep_dic, sketch_actions) if self.__class__.__name__ == "ConcreteProgrammer": assert self._ConcreteProgrammer__cur_align_prob_log is not None align_prob_log = self._ConcreteProgrammer__cur_align_prob_log.squeeze( ) sketch_log_score = sketch_log_score + align_prob_log self._ConcreteProgrammer__cur_align_prob_log = None # only one path if len(_paths) == 1: if not self.filter_program_by_execution(world, _paths[0]): continue _path_lf = world.action_sequence_to_logical_form(_paths[0]) _seq_score = sketch_log_score if max_score is None or _seq_score > max_score: max_score = _seq_score best_sketch_lf = sketch_lf best_sketch_actions = sketch_actions best_program_lf = _path_lf best_program_actions = _paths[0] continue # multiple path for _path, _score in zip(_paths, _log_scores): if not self.filter_program_by_execution(world, _path): continue assert _score is not None _path_lf = world.action_sequence_to_logical_form(_path) _seq_score = _score + sketch_log_score if max_score is None or _seq_score > max_score: max_score = _seq_score best_sketch_lf = sketch_lf best_sketch_actions = sketch_actions best_program_lf = _path_lf best_program_actions = _path assert max_score is not None ret_dic["best_program_lf"] = best_program_lf ret_dic["best_program_actions"] = best_program_actions ret_dic["best_sketch_lf"] = best_sketch_lf ret_dic["best_sketch_actions"] = best_sketch_actions ret_dic["best_score"] = torch.exp(max_score) ret_dic["is_multi_col"] = check_multi_col(world, best_sketch_actions, best_program_actions) if best_sketch_lf in sketch2program: ret_dic["sketch_triggered"] = True if best_program_lf in sketch2program[best_sketch_lf]: ret_dic["lf_triggered"] = True else: ret_dic["lf_triggered"] = False else: ret_dic["sketch_triggered"] = False ret_dic["lf_triggered"] = False return ret_dic
def sketch_lf2actions(self, world: WikiSQLLanguage): lf2actions = dict() for actions in self.sketch_actions_cache: lf = world.action_sequence_to_logical_form(actions) lf2actions[lf] = actions return lf2actions