Exemplo n.º 1
0
 def setUp(self):
     super().setUp()
     self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
     question = "did the redskins score in the final two minutes of the game?"
     question_tokens = self.tokenizer.tokenize(question)
     self.test_file = 'fixtures/data/tables/sample_paragraph.tagged'
     self.context = ParagraphQuestionContext.read_from_file(self.test_file, question_tokens)
Exemplo n.º 2
0
 def test_get_knowledge_graph_with_empty_paragraph(self):
     question = "did the redskins score in the final two minutes of the game?"
     question_tokens = self.tokenizer.tokenize(question)
     empty_test_file = 'fixtures/data/tables/empty_paragraph.tagged'
     empty_context = ParagraphQuestionContext.read_from_file(empty_test_file, question_tokens)
     knowledge_graph = empty_context.get_knowledge_graph()
     assert knowledge_graph.entities == []
     assert knowledge_graph.neighbors == {}
     assert knowledge_graph.entity_text == {}
 def setUp(self):
     super().setUp()
     self.tokenizer = WordTokenizer()
     self.tokens = self.tokenizer.tokenize(
         """how many points did the redskins score in the final
                                           two minutes of the game?""")
     context = ParagraphQuestionContext.read_from_file(
         "fixtures/data/tables/sample_paragraph.tagged", self.tokens)
     self.world = DropWorld(context)
 def setUp(self):
     super().setUp()
     self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
     question = "how many touchdowns did the redskins score??"
     question_tokens = self.tokenizer.tokenize(question)
     self.test_file = 'fixtures/data/tables/sample_paragraph.tagged'
     context = ParagraphQuestionContext.read_from_file(
         self.test_file, question_tokens)
     self.executor = DropExecutor(context.paragraph_data)
Exemplo n.º 5
0
 def test_context_with_embedding_to_select_entities(self):
     question = "what resulted in the redskins not scoring in the final two minutes of the game?"
     question_tokens = self.tokenizer.tokenize(question)
     embedding_file = "fixtures/data/glove_100d_sample.txt.gz"
     embedding = context_util.read_pretrained_embedding(embedding_file)
     context = ParagraphQuestionContext.read_from_file(self.test_file,
                                                       question_tokens,
                                                       embedding)
     entities = context.paragraph_tokens_to_keep
     expected_entities = [('first', ['relation:arg1']),
                          ('four', ['relation:arg1']),
                          ('six', ['relation:arg1'])]
     assert all([entity in entities for entity in expected_entities])
Exemplo n.º 6
0
 def _get_context_with_question(self, question):
     question_tokens = self.tokenizer.tokenize(question)
     context = ParagraphQuestionContext.read_from_file(self.test_file, question_tokens)
     return context
 def test_world_with_empty_paragraph(self):
     context = ParagraphQuestionContext.read_from_file(
         "fixtures/data/tables/empty_paragraph.tagged", self.tokens)
     # We're just confirming that creating a world wit empty context does not throw an error.
     DropWorld(context)
Exemplo n.º 8
0
    def text_to_instance(self,  # type: ignore
                         question: str,
                         table_lines: List[List[str]],
                         answer_json: JsonDict,
                         offline_search_output: List[str] = None) -> Instance:
        """
        Reads text inputs and makes an instance. We assume we have access to DROP paragraphs parsed
        and tagged in a format similar to the tagged tables in WikiTableQuestions.
        # TODO(pradeep): Explain the format.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            Preprocessed paragraph content. See ``ParagraphQuestionContext.read_from_lines``
            for the expected format.
        answer_json : ``JsonDict``
            The "answer" dict from the original data file.
        offline_search_output : List[str], optional
            List of logical forms, produced by offline search. Not required during test.
        """
        # pylint: disable=arguments-differ
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question, self._question_token_indexers)
        # TODO(pradeep): We'll need a better way to input processed lines.
        paragraph_context = ParagraphQuestionContext.read_from_lines(table_lines,
                                                                     tokenized_question,
                                                                     self._entity_extraction_embedding,
                                                                     self._entity_extraction_distance_threshold)
        target_values_field = MetadataField(answer_json)
        world = DropWorld(paragraph_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(paragraph_context.get_knowledge_graph(),
                                          tokenized_question,
                                          self._table_token_indexers,
                                          tokenizer=self._tokenizer,
                                          include_in_vocab=self._use_table_for_vocab,
                                          max_table_tokens=self._max_table_tokens)
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_instance_specific_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {'question': question_field,
                  'table': table_field,
                  'world': world_field,
                  'actions': action_field,
                  'target_values': target_values_field}

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {action.rule: i for i, action in enumerate(action_field.field_list)}  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    expression = world.parse_logical_form(logical_form)
                except ParsingError as error:
                    logger.debug(f'Parsing error: {error.message}, skipping logical form')
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(IndexField(action_map[production_rule], action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.debug(f'Missing production rule: {error.args}, skipping logical form')
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                if len(action_sequence_fields) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda():
                agenda_index_fields.append(IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)