def test_table_data_from_untagged_file(self):
     question = "what was the attendance when usl a league played?"
     question_tokens = self.tokenizer.tokenize(question)
     test_file = f"{self.FIXTURES_ROOT}/data/wikitables/sample_table.tsv"
     table_lines = [line.strip() for line in open(test_file).readlines()]
     table_question_context = TableQuestionContext.read_from_lines(
         table_lines, question_tokens)
     # The content in the table represented by the untagged file we are reading here is the same as the one we
     # had in the tagged file above, except that we have a "Score" column instead of "Avg. Attendance" column,
     # which is changed to test the num2 extraction logic. I've shown the values not being extracted here as
     # well and commented them out.
     assert table_question_context.table_data == [
         {
             "number_column:year": 2001.0,
             # The value extraction logic we have for untagged lines does
             # not extract this value as a date.
             # 'date_column:year': Date(2001, -1, -1),
             "string_column:year": "2001",
             "number_column:division": 2.0,
             "string_column:division": "2",
             "string_column:league": "usl_a_league",
             "string_column:regular_season": "4th_western",
             # We only check for strings that are entirely numbers. So 4.0
             # will not be extracted.
             # 'number_column:regular_season': 4.0,
             "string_column:playoffs": "quarterfinals",
             "string_column:open_cup": "did_not_qualify",
             # 'number_column:open_cup': None,
             "number_column:score": 20.0,
             "num2_column:score": 30.0,
             "string_column:score": "20_30",
         },
         {
             "number_column:year": 2005.0,
             # 'date_column:year': Date(2005, -1, -1),
             "string_column:year": "2005",
             "number_column:division": 2.0,
             "string_column:division": "2",
             "string_column:league": "usl_first_division",
             "string_column:regular_season": "5th",
             # Same here as in the "division" column for the first row.
             # 5.0 will not be extracted from "5th".
             # 'number_column:regular_season': 5.0,
             "string_column:playoffs": "quarterfinals",
             "string_column:open_cup": "4th_round",
             # 'number_column:open_cup': 4.0,
             "number_column:score": 50.0,
             "num2_column:score": 40.0,
             "string_column:score": "50_40",
         },
     ]
예제 #2
0
    def text_to_instance(
        self,  # type: ignore
        question: str,
        table_lines: List[List[str]],
        target_values: List[str] = None,
        offline_search_output: List[str] = None,
    ) -> Instance:
        """
        Reads text inputs and makes an instance. We pass the ``table_lines`` to ``TableQuestionContext``, and that
        method accepts this field either as lines from CoreNLP processed tagged files that come with the dataset,
        or simply in a tsv format where each line corresponds to a row and the cells are tab-separated.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            The table content optionally preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines``
            for the expected format.
        target_values : ``List[str]``, optional
            Target values for the denotations the logical forms should execute to. Not required for testing.
        offline_search_output : ``List[str]``, optional
            List of logical forms, produced by offline search. Not required during test.
        """
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        metadata: Dict[str, Any] = {
            "question_tokens": [x.text for x in tokenized_question]
        }
        table_context = TableQuestionContext.read_from_lines(
            table_lines, tokenized_question)
        world = WikiTablesLanguage(table_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(
            table_context.get_table_knowledge_graph(),
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens,
        )
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_productions():
            _, rule_right_side = production_rule.split(" -> ")
            is_global_rule = not world.is_instance_specific_entity(
                rule_right_side)
            field = ProductionRuleField(production_rule,
                                        is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            "question": question_field,
            "metadata": MetadataField(metadata),
            "table": table_field,
            "world": world_field,
            "actions": action_field,
        }

        if target_values is not None:
            target_values_field = MetadataField(target_values)
            fields["target_values"] = target_values_field

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    action_sequence = world.logical_form_to_action_sequence(
                        logical_form)
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except ParsingError as error:
                    logger.debug(
                        f"Parsing error: {error.message}, skipping logical form"
                    )
                    logger.debug(f"Question was: {question}")
                    logger.debug(f"Logical form was: {logical_form}")
                    logger.debug(f"Table info was: {table_lines}")
                    continue
                except KeyError as error:
                    logger.debug(
                        f"Missing production rule: {error.args}, skipping logical form"
                    )
                    logger.debug(f"Question was: {question}")
                    logger.debug(f"Table info was: {table_lines}")
                    logger.debug(f"Logical form was: {logical_form}")
                    continue
                except:  # noqa
                    logger.error(logical_form)
                    raise
                if len(action_sequence_fields
                       ) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields["target_action_sequences"] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda(conservative=True):
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields["agenda"] = ListField(agenda_index_fields)
        return Instance(fields)