def search(tables_directory: str,
           input_examples_file: str,
           output_file: str,
           max_path_length: int,
           max_num_logical_forms: int,
           use_agenda: bool) -> None:
    data = [wikitables_util.parse_example_line(example_line) for example_line in
            open(input_examples_file)]
    tokenizer = WordTokenizer()
    with open(output_file, "w") as output_file_pointer:
        for instance_data in data:
            utterance = instance_data["question"]
            question_id = instance_data["id"]
            if utterance.startswith('"') and utterance.endswith('"'):
                utterance = utterance[1:-1]
            # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
            table_file = instance_data["table_filename"].replace("csv", "tagged")
            # pylint: disable=protected-access
            target_list = [TableQuestionContext._normalize_string(value) for value in
                           instance_data["target_values"]]
            try:
                target_value_list = evaluator.to_value_list(target_list)
            except:
                print(target_list)
                target_value_list = evaluator.to_value_list(target_list)
            tokenized_question = tokenizer.tokenize(utterance)
            table_file = f"{tables_directory}/{table_file}"
            context = TableQuestionContext.read_from_file(table_file, tokenized_question)
            world = WikiTablesVariableFreeWorld(context)
            walker = ActionSpaceWalker(world, max_path_length=max_path_length)
            correct_logical_forms = []
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                agenda = world.get_agenda()
                print(f"Agenda: {agenda}", file=output_file_pointer)
                all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda,
                                                                         max_num_logical_forms=10000)
            else:
                all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000)
            for logical_form in all_logical_forms:
                try:
                    denotation = world.execute(logical_form)
                except ExecutionError:
                    print(f"Failed to execute: {logical_form}", file=sys.stderr)
                    continue
                if isinstance(denotation, list):
                    denotation_list = [str(denotation_item) for denotation_item in denotation]
                else:
                    # For numbers and dates
                    denotation_list = [str(denotation)]
                denotation_value_list = evaluator.to_value_list(denotation_list)
                if evaluator.check_denotation(target_value_list, denotation_value_list):
                    correct_logical_forms.append(logical_form)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
Exemple #2
0
def search(tables_directory: str, input_examples_file: str, output_path: str,
           max_path_length: int, max_num_logical_forms: int, use_agenda: bool,
           output_separate_files: bool) -> None:
    data = [
        wikitables_util.parse_example_line(example_line)
        for example_line in open(input_examples_file)
    ]
    tokenizer = WordTokenizer()
    if output_separate_files and not os.path.exists(output_path):
        os.makedirs(output_path)
    if not output_separate_files:
        output_file_pointer = open(output_path, "w")
    for instance_data in data:
        utterance = instance_data["question"]
        question_id = instance_data["id"]
        if utterance.startswith('"') and utterance.endswith('"'):
            utterance = utterance[1:-1]
        # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
        table_file = instance_data["table_filename"].replace("csv", "tagged")
        target_list = instance_data["target_values"]
        tokenized_question = tokenizer.tokenize(utterance)
        table_file = f"{tables_directory}/{table_file}"
        context = TableQuestionContext.read_from_file(table_file,
                                                      tokenized_question)
        world = WikiTablesVariableFreeWorld(context)
        walker = ActionSpaceWalker(world, max_path_length=max_path_length)
        correct_logical_forms = []
        if use_agenda:
            agenda = world.get_agenda()
            all_logical_forms = walker.get_logical_forms_with_agenda(
                agenda=agenda,
                max_num_logical_forms=10000,
                allow_partial_match=True)
        else:
            all_logical_forms = walker.get_all_logical_forms(
                max_num_logical_forms=10000)
        for logical_form in all_logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
        if output_separate_files and correct_logical_forms:
            with gzip.open(f"{output_path}/{question_id}.gz",
                           "wt") as output_file_pointer:
                for logical_form in correct_logical_forms:
                    print(logical_form, file=output_file_pointer)
        elif not output_separate_files:
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                print(f"Agenda: {agenda}", file=output_file_pointer)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
    if not output_separate_files:
        output_file_pointer.close()
def search(tables_directory: str,
           input_examples_file: str,
           output_path: str,
           max_path_length: int,
           max_num_logical_forms: int,
           use_agenda: bool,
           output_separate_files: bool) -> None:
    data = [wikitables_util.parse_example_line(example_line) for example_line in
            open(input_examples_file)]
    tokenizer = WordTokenizer()
    if output_separate_files and not os.path.exists(output_path):
        os.makedirs(output_path)
    if not output_separate_files:
        output_file_pointer = open(output_path, "w")
    for instance_data in data:
        utterance = instance_data["question"]
        question_id = instance_data["id"]
        if utterance.startswith('"') and utterance.endswith('"'):
            utterance = utterance[1:-1]
        # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
        table_file = instance_data["table_filename"].replace("csv", "tagged")
        target_list = instance_data["target_values"]
        tokenized_question = tokenizer.tokenize(utterance)
        table_file = f"{tables_directory}/{table_file}"
        context = TableQuestionContext.read_from_file(table_file, tokenized_question)
        world = WikiTablesVariableFreeWorld(context)
        walker = ActionSpaceWalker(world, max_path_length=max_path_length)
        correct_logical_forms = []
        if use_agenda:
            agenda = world.get_agenda()
            all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda,
                                                                     max_num_logical_forms=10000)
        else:
            all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000)
        for logical_form in all_logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
        if output_separate_files and correct_logical_forms:
            with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer:
                for logical_form in correct_logical_forms:
                    print(logical_form, file=output_file_pointer)
        elif not output_separate_files:
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                print(f"Agenda: {agenda}", file=output_file_pointer)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
    if not output_separate_files:
        output_file_pointer.close()
Exemple #4
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            table_lines: List[List[str]],
            target_values: List[str],
            offline_search_output: List[str] = None) -> Instance:
        """
        Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as
        TSV files pre-tagged using CoreNLP, which we use for training.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            The table content preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines``
            for the expected format.
        target_values : ``List[str]``
        offline_search_output : List[str], optional
            List of logical forms, produced by offline search. Not required during test.
        """
        # pylint: disable=arguments-differ
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        # TODO(pradeep): We'll need a better way to input CoreNLP processed lines.
        table_context = TableQuestionContext.read_from_lines(
            table_lines, tokenized_question)
        target_values_field = MetadataField(target_values)
        world = WikiTablesVariableFreeWorld(table_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(
            table_context.get_table_knowledge_graph(),
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_instance_specific_entity(
                rule_right_side)
            field = ProductionRuleField(production_rule,
                                        is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'table': table_field,
            'world': world_field,
            'actions': action_field,
            'target_values': target_values_field
        }

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    expression = world.parse_logical_form(logical_form)
                except ParsingError as error:
                    logger.debug(
                        f'Parsing error: {error.message}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.debug(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                if len(action_sequence_fields
                       ) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda():
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)