def evaluate_logical_form(self, logical_form: str,
                           target_list: List[str]) -> bool:
     """
     Takes a logical form, and the list of target values as strings from the original lisp
     string, and returns True iff the logical form executes to the target list.
     """
     normalized_target_list = [
         TableQuestionContext.normalize_string(value)
         for value in target_list
     ]
     target_value_list = evaluator.to_value_list(normalized_target_list)
     try:
         denotation = self.execute(logical_form)
     except ExecutionError:
         logger.warning(f'Failed to execute: {logical_form}')
         return False
     if isinstance(denotation, list):
         denotation_list = [
             str(denotation_item) for denotation_item in denotation
         ]
     else:
         denotation_list = [str(denotation)]
     denotation_value_list = evaluator.to_value_list(denotation_list)
     return evaluator.check_denotation(target_value_list,
                                       denotation_value_list)
def search(tables_directory: str,
           input_examples_file: str,
           output_file: str,
           max_path_length: int,
           max_num_logical_forms: int,
           use_agenda: bool) -> None:
    data = [wikitables_util.parse_example_line(example_line) for example_line in
            open(input_examples_file)]
    tokenizer = WordTokenizer()
    with open(output_file, "w") as output_file_pointer:
        for instance_data in data:
            utterance = instance_data["question"]
            question_id = instance_data["id"]
            if utterance.startswith('"') and utterance.endswith('"'):
                utterance = utterance[1:-1]
            # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
            table_file = instance_data["table_filename"].replace("csv", "tagged")
            # pylint: disable=protected-access
            target_list = [TableQuestionContext._normalize_string(value) for value in
                           instance_data["target_values"]]
            try:
                target_value_list = evaluator.to_value_list(target_list)
            except:
                print(target_list)
                target_value_list = evaluator.to_value_list(target_list)
            tokenized_question = tokenizer.tokenize(utterance)
            table_file = f"{tables_directory}/{table_file}"
            context = TableQuestionContext.read_from_file(table_file, tokenized_question)
            world = WikiTablesVariableFreeWorld(context)
            walker = ActionSpaceWalker(world, max_path_length=max_path_length)
            correct_logical_forms = []
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                agenda = world.get_agenda()
                print(f"Agenda: {agenda}", file=output_file_pointer)
                all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda,
                                                                         max_num_logical_forms=10000)
            else:
                all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000)
            for logical_form in all_logical_forms:
                try:
                    denotation = world.execute(logical_form)
                except ExecutionError:
                    print(f"Failed to execute: {logical_form}", file=sys.stderr)
                    continue
                if isinstance(denotation, list):
                    denotation_list = [str(denotation_item) for denotation_item in denotation]
                else:
                    # For numbers and dates
                    denotation_list = [str(denotation)]
                denotation_value_list = evaluator.to_value_list(denotation_list)
                if evaluator.check_denotation(target_value_list, denotation_value_list):
                    correct_logical_forms.append(logical_form)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
 def evaluate_logical_form(self, logical_form: str, target_list: List[str]) -> bool:
     """
     Takes a logical form, and the list of target values as strings from the original lisp
     string, and returns True iff the logical form executes to the target list.
     """
     normalized_target_list = [TableQuestionContext.normalize_string(value) for value in
                               target_list]
     target_value_list = evaluator.to_value_list(normalized_target_list)
     try:
         denotation = self.execute(logical_form)
     except ExecutionError:
         logger.warning(f'Failed to execute: {logical_form}')
         return False
     if isinstance(denotation, list):
         denotation_list = [str(denotation_item) for denotation_item in denotation]
     else:
         denotation_list = [str(denotation)]
     denotation_value_list = evaluator.to_value_list(denotation_list)
     return evaluator.check_denotation(target_value_list, denotation_value_list)
Ejemplo n.º 4
0
 def evaluate_denotation(self, denotation: Any,
                         target_list: List[str]) -> bool:
     """
     Compares denotation with a target list and returns whether they are both the same according to the official
     evaluator.
     """
     normalized_target_list = [
         TableQuestionContext.normalize_string(value)
         for value in target_list
     ]
     target_value_list = evaluator.to_value_list(normalized_target_list)
     if isinstance(denotation, list):
         denotation_list = [
             str(denotation_item) for denotation_item in denotation
         ]
     else:
         denotation_list = [str(denotation)]
     denotation_value_list = evaluator.to_value_list(denotation_list)
     return evaluator.check_denotation(target_value_list,
                                       denotation_value_list)