def _compute_validation_outputs(self, actions: List[List[ProductionRuleArray]], best_final_states: Mapping[int, Sequence[GrammarBasedState]], world: List[SpiderWorld], target_list: List[List[str]]) -> None: batch_size = len(actions) action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] for i in range(batch_size): # gold sql exactly as given original_gold_sql_query = ' '.join(world[i].get_query_without_table_hints()) if i not in best_final_states: self._exact_match(0) self._action_similarity(0) self._sql_evaluator_match(0) self._acc_multi(0) self._acc_single(0) continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] predicted_sql_query = action_sequence_to_sql(action_strings, add_table_names=True) if target_list is not None: targets = target_list[i].data sequence_in_targets = self._action_history_match(best_action_indices, targets) self._exact_match(sequence_in_targets) sql_evaluator_match = self._evaluate_func(original_gold_sql_query, predicted_sql_query, world[i].db_id) self._sql_evaluator_match(sql_evaluator_match) similarity = difflib.SequenceMatcher(None, best_action_indices, targets) self._action_similarity(similarity.ratio()) difficulty = self._query_difficulty(targets, action_mapping, i) if difficulty: self._acc_multi(sql_evaluator_match) else: self._acc_single(sql_evaluator_match) beam_hit = False for pos, final_state in enumerate(best_final_states[i]): action_indices = final_state.action_history[0] action_strings = [action_mapping[(i, action_index)] for action_index in action_indices] candidate_sql_query = action_sequence_to_sql(action_strings, add_table_names=True) if target_list is not None: correct = self._evaluate_func(original_gold_sql_query, candidate_sql_query, world[i].db_id) if correct: beam_hit = True self._beam_hit(beam_hit)
def _get_sql(self, initial_state, actions): num_steps = self._max_decoding_steps best_final_states = self._beam_search.search(num_steps, initial_state, self._transition_function, keep_final_unfinished_states=True) batch_size = len(actions) predicted_sql_query_list = [] best_action_indices_list = [] action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] for i in range(batch_size): if i not in best_final_states: self._exact_match(0) self._action_similarity(0) self._sql_evaluator_match(0) self._acc_multi(0) self._acc_single(0) predicted_sql_query_list.append('') continue best_action_indices = best_final_states[i][0].action_history[0] best_action_indices_list.append(best_action_indices) action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] predicted_sql_query = action_sequence_to_sql(action_strings, add_table_names=True) predicted_sql_query_list.append(predicted_sql_query) return predicted_sql_query_list, best_action_indices_list