Exemplo n.º 1
0
    def test_read_from_json_handles_parentheses_correctly(self):
        json = {
            'question': [],
            'columns': ['Urban settlements'],
            'cells': [['Dzhebariki-Khaya\\n(Джебарики-Хая)'],
                      ['South Korea (KOR)'], ['Area (km²)']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.urban_settlements'])
        assert neighbors == {
            'fb:cell.dzhebariki_khaya', 'fb:cell.south_korea_kor',
            'fb:cell.area_km'
        }

        json = {
            'question': [],
            'columns': ['Margin\\nof victory'],
            'cells': [['−9 (67-67-68-69=271)']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.margin_of_victory'])
        assert neighbors == {'fb:cell._9_67_67_68_69_271'}

        json = {
            'question': [],
            'columns': ['Record'],
            'cells': [['4.08 m (13 ft 41⁄2 in)']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.record'])
        assert neighbors == {'fb:cell.4_08_m_13_ft_41_2_in'}
Exemplo n.º 2
0
 def test_get_cell_parts_returns_cell_text_on_simple_cells(self):
     assert TableQuestionKnowledgeGraph._get_cell_parts('Team') == [
         ('fb:part.team', 'Team')
     ]
     assert TableQuestionKnowledgeGraph._get_cell_parts('2006') == [
         ('fb:part.2006', '2006')
     ]
     assert TableQuestionKnowledgeGraph._get_cell_parts('Wolfe Tones') == [
         ('fb:part.wolfe_tones', 'Wolfe Tones')
     ]
Exemplo n.º 3
0
    def test_read_from_json_handles_simple_cases(self):
        json = {
            'question': [Token(x) for x in ['where', 'is', 'mersin', '?']],
            'columns': ['Name in English', 'Location'],
            'cells': [['Paradeniz', 'Mersin'], ['Lake Gala', 'Edirne']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:cell.mersin'])
        assert graph.entities == [
            '-1', '0', '1', 'fb:cell.edirne', 'fb:cell.lake_gala',
            'fb:cell.mersin', 'fb:cell.paradeniz', 'fb:row.row.location',
            'fb:row.row.name_in_english'
        ]
        assert neighbors == {'fb:row.row.location'}
        neighbors = set(graph.neighbors['fb:row.row.name_in_english'])
        assert neighbors == {'fb:cell.paradeniz', 'fb:cell.lake_gala'}
        assert graph.entity_text['fb:cell.edirne'] == 'Edirne'
        assert graph.entity_text['fb:cell.lake_gala'] == 'Lake Gala'
        assert graph.entity_text['fb:cell.mersin'] == 'Mersin'
        assert graph.entity_text['fb:cell.paradeniz'] == 'Paradeniz'
        assert graph.entity_text['fb:row.row.location'] == 'Location'
        assert graph.entity_text[
            'fb:row.row.name_in_english'] == 'Name in English'

        # These are default numbers that should always be in the graph.
        assert graph.neighbors['-1'] == []
        assert graph.neighbors['0'] == []
        assert graph.neighbors['1'] == []
        assert graph.entity_text['-1'] == '-1'
        assert graph.entity_text['0'] == '0'
        assert graph.entity_text['1'] == '1'
Exemplo n.º 4
0
 def test_world_parses_logical_forms_with_decimals(self):
     question_tokens = [Token(x) for x in ['0.2']]
     table_kg = TableQuestionKnowledgeGraph.read_from_file(
             self.FIXTURES_ROOT / "data" / "wikitables" / "sample_table.tsv", question_tokens)
     world = WikiTablesWorld(table_kg)
     sempre_form = "(fb:cell.cell.number (number 0.200))"
     expression = world.parse_logical_form(sempre_form)
     assert str(expression) == "I1(I(num:0_200))"
Exemplo n.º 5
0
    def test_read_from_json_handles_crazy_unicode(self):
        json = {
            'question': [],
            'columns': ['Town'],
            'cells': [['Viðareiði'], ['Funningsfjørður'], ['Froðba']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.town'])
        assert neighbors == {
            'fb:cell.funningsfj_r_ur',
            'fb:cell.vi_arei_i',
            'fb:cell.fro_ba',
        }

        json = {
            'question': [],
            'columns': ['Fate'],
            'cells': [['Sunk at 45°00′N 11°21′W / 45.000°N 11.350°W'],
                      ['66°22′32″N 29°20′19″E / 66.37556°N 29.33861°E']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.fate'])
        assert neighbors == {
            'fb:cell.sunk_at_45_00_n_11_21_w_45_000_n_11_350_w',
            'fb:cell.66_22_32_n_29_20_19_e_66_37556_n_29_33861_e'
        }

        json = {
            'question': [],
            'columns': ['€0.01', 'Σ Points'],
            'cells': [['6,000', '9.5']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row._0_01'])
        assert neighbors == {'fb:cell.6_000'}
        neighbors = set(graph.neighbors['fb:row.row._points'])
        assert neighbors == {'fb:cell.9_5'}

        json = {
            'question': [],
            'columns': ['Division'],
            'cells': [['1ª Aut. Pref.']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.division'])
        assert neighbors == {'fb:cell.1_aut_pref'}
Exemplo n.º 6
0
 def test_get_linked_agenda_items(self):
     json = {
         'question': [Token(x) for x in ['where', 'is', 'mersin', '?']],
         'columns': ['Name in English', 'Location'],
         'cells': [['Paradeniz', 'Mersin'], ['Lake Gala', 'Edirne']]
     }
     graph = TableQuestionKnowledgeGraph.read_from_json(json)
     assert graph.get_linked_agenda_items() == [
         'fb:cell.mersin', 'fb:row.row.location'
     ]
Exemplo n.º 7
0
 def test_read_from_json_handles_diacritics_and_newlines(self):
     json = {
         'question': [],
         'columns': ['Notes'],
         'cells': [['8 districts\nFormed from Orūzgān Province in 2004']]
     }
     graph = TableQuestionKnowledgeGraph.read_from_json(json)
     neighbors = set(graph.neighbors['fb:row.row.notes'])
     assert neighbors == {
         'fb:cell.8_districts_formed_from_oruzgan_province_in_2004'
     }
Exemplo n.º 8
0
 def test_get_longest_span_matching_entities(self):
     json = {
         'question':
         [Token(x) for x in ['where', 'is', 'lake', 'big', 'gala', '?']],
         'columns': ['Name in English', 'Location'],
         'cells': [['Paradeniz', 'Lake Big'], ['Lake Big Gala', 'Edirne']]
     }
     graph = TableQuestionKnowledgeGraph.read_from_json(json)
     assert graph._get_longest_span_matching_entities() == [
         'fb:cell.lake_big_gala'
     ]
Exemplo n.º 9
0
    def test_read_from_json_handles_diacritics(self):
        json = {
            'question': [],
            'columns': ['Name in English', 'Name in Turkish', 'Location'],
            'cells': [['Lake Van', 'Van Gölü', 'Mersin'],
                      ['Lake Gala', 'Gala Gölü', 'Edirne']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.name_in_turkish'])
        assert neighbors == {'fb:cell.van_golu', 'fb:cell.gala_golu'}

        json = {
            'question': [],
            'columns': ['Notes'],
            'cells': [['Ordained as a priest at\nReșița on March, 29th 1936']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.notes'])
        assert neighbors == {
            'fb:cell.ordained_as_a_priest_at_resita_on_march_29th_1936'
        }

        json = {
            'question': [],
            'columns': ['Player'],
            'cells': [['Mateja Kežman']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.player'])
        assert neighbors == {'fb:cell.mateja_kezman'}

        json = {
            'question': [],
            'columns': ['Venue'],
            'cells': [['Arena Națională, Bucharest, Romania']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.venue'])
        assert neighbors == {'fb:cell.arena_nationala_bucharest_romania'}
Exemplo n.º 10
0
 def test_read_from_json_handles_numbers_in_question(self):
     # The TSV file we use has newlines converted to "\n", not actual escape characters.  We
     # need to be sure we catch this.
     json = {
         'question': [Token(x) for x in ['one', '4']],
         'columns': [],
         'cells': []
     }
     graph = TableQuestionKnowledgeGraph.read_from_json(json)
     assert graph.neighbors['1'] == []
     assert graph.neighbors['4'] == []
     assert graph.entity_text['1'] == 'one'
     assert graph.entity_text['4'] == '4'
Exemplo n.º 11
0
    def test_read_from_json_handles_cells_with_duplicate_normalizations(self):
        json = {
            'question': [],
            'columns': ['answer'],
            'cells': [['yes'], ['yes*'], ['yes'], ['yes '], ['yes*']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)

        # There are three unique text strings that all normalize to "yes", so there are three
        # fb:cell.yes entities.  Hopefully we produce them in the same order as SEMPRE does...
        assert graph.entities == [
            '-1', '0', '1', 'fb:cell.yes', 'fb:cell.yes_2', 'fb:cell.yes_3',
            'fb:row.row.answer'
        ]
Exemplo n.º 12
0
 def test_read_from_json_handles_columns_with_duplicate_normalizations(
         self):
     json = {
         'question': [],
         'columns': ['# of votes', '% of votes'],
         'cells': [['1', '2'], ['3', '4']]
     }
     graph = TableQuestionKnowledgeGraph.read_from_json(json)
     neighbors = set(graph.neighbors['fb:row.row._of_votes'])
     assert neighbors == {'fb:cell.1', 'fb:cell.3'}
     neighbors = set(graph.neighbors['fb:row.row._of_votes_2'])
     assert neighbors == {'fb:cell.2', 'fb:cell.4'}
     neighbors = set(graph.neighbors['fb:cell.1'])
     assert neighbors == {'fb:row.row._of_votes'}
Exemplo n.º 13
0
    def test_read_from_json_handles_newlines_in_columns(self):
        # The TSV file we use has newlines converted to "\n", not actual escape characters.  We
        # need to be sure we catch this.
        json = {
            'question': [],
            'columns': ['Peak\\nAUS', 'Peak\\nNZ'],
            'cells': [['1', '2'], ['3', '4']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.peak_aus'])
        assert neighbors == {'fb:cell.1', 'fb:cell.3'}
        neighbors = set(graph.neighbors['fb:row.row.peak_nz'])
        assert neighbors == {'fb:cell.2', 'fb:cell.4'}
        neighbors = set(graph.neighbors['fb:cell.1'])
        assert neighbors == {'fb:row.row.peak_aus'}

        json = {
            'question': [],
            'columns': ['Title'],
            'cells': [['Dance of the\\nSeven Veils']]
        }
        graph = TableQuestionKnowledgeGraph.read_from_json(json)
        neighbors = set(graph.neighbors['fb:row.row.title'])
        assert neighbors == {'fb:cell.dance_of_the_seven_veils'}
Exemplo n.º 14
0
 def test_with_deeply_nested_logical_form(self):
     question_tokens = [Token(x) for x in ['what', 'was', 'the', 'district', '?']]
     table_filename = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'table' / '109.tsv'
     table_kg = TableQuestionKnowledgeGraph.read_from_file(table_filename, question_tokens)
     world = WikiTablesWorld(table_kg)
     logical_form = ("(count ((reverse fb:cell.cell.number) (or (or (or (or (or (or (or (or "
                     "(or (or (or (or (or (or (or (or (or (or (or (or (or fb:cell.virginia_1 "
                     "fb:cell.virginia_10) fb:cell.virginia_11) fb:cell.virginia_12) "
                     "fb:cell.virginia_13) fb:cell.virginia_14) fb:cell.virginia_15) "
                     "fb:cell.virginia_16) fb:cell.virginia_17) fb:cell.virginia_18) "
                     "fb:cell.virginia_19) fb:cell.virginia_2) fb:cell.virginia_20) "
                     "fb:cell.virginia_21) fb:cell.virginia_22) fb:cell.virginia_3) "
                     "fb:cell.virginia_4) fb:cell.virginia_5) fb:cell.virginia_6) "
                     "fb:cell.virginia_7) fb:cell.virginia_8) fb:cell.virginia_9)))")
     print("Parsing...")
     world.parse_logical_form(logical_form)
Exemplo n.º 15
0
    def test_world_adds_numbers_from_question(self):
        question_tokens = [Token(x) for x in ['what', '2007', '2,107', '0.2', '1800s', '1950s', '?']]
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
                self.FIXTURES_ROOT / "data" / "wikitables" / "sample_table.tsv", question_tokens)
        world = WikiTablesWorld(table_kg)
        valid_actions = world.get_valid_actions()
        assert 'n -> 2007' in valid_actions['n']
        assert 'n -> 2107' in valid_actions['n']

        # It appears that sempre normalizes floating point numbers.
        assert 'n -> 0.200' in valid_actions['n']

        # We want to add the end-points to things like "1800s": 1800 and 1900.
        assert 'n -> 1800' in valid_actions['n']
        assert 'n -> 1900' in valid_actions['n']
        assert 'n -> 1950' in valid_actions['n']
        assert 'n -> 1960' in valid_actions['n']
Exemplo n.º 16
0
 def test_read_from_json_splits_columns_when_necessary(self):
     json = {
         'question': [Token(x) for x in ['where', 'is', 'mersin', '?']],
         'columns': ['Name in English', 'Location'],
         'cells': [['Paradeniz', 'Mersin with spaces'],
                   ['Lake, Gala', 'Edirne']]
     }
     graph = TableQuestionKnowledgeGraph.read_from_json(json)
     assert graph.entities == [
         '-1', '0', '1', 'fb:cell.edirne', 'fb:cell.lake_gala',
         'fb:cell.mersin_with_spaces', 'fb:cell.paradeniz', 'fb:part.gala',
         'fb:part.lake', 'fb:part.paradeniz', 'fb:row.row.location',
         'fb:row.row.name_in_english'
     ]
     assert graph.neighbors['fb:part.lake'] == []
     assert graph.neighbors['fb:part.gala'] == []
     assert graph.neighbors['fb:part.paradeniz'] == []
Exemplo n.º 17
0
 def test_read_from_json_replaces_newlines(self):
     # The csv -> tsv conversion renders '\n' as r'\n' (with a literal slash character), that
     # gets read in a two characters instead of one.  We need to make sure we convert it back to
     # one newline character, so our splitting and other processing works correctly.
     json = {
         'question': [Token(x) for x in ['where', 'is', 'mersin', '?']],
         'columns': ['Name\\nin English', 'Location'],
         'cells': [['Paradeniz', 'Mersin'], ['Lake\\nGala', 'Edirne']]
     }
     graph = TableQuestionKnowledgeGraph.read_from_json(json)
     assert graph.entities == [
         '-1', '0', '1', 'fb:cell.edirne', 'fb:cell.lake_gala',
         'fb:cell.mersin', 'fb:cell.paradeniz', 'fb:part.gala',
         'fb:part.lake', 'fb:part.paradeniz', 'fb:row.row.location',
         'fb:row.row.name_in_english'
     ]
     assert graph.entity_text[
         'fb:row.row.name_in_english'] == 'Name\nin English'
Exemplo n.º 18
0
    def setUp(self):
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        json = {
            'question': self.utterance,
            'columns': ['Name in English', 'Location in English'],
            'cells': [['Paradeniz', 'Mersin'], ['Lake Gala', 'Edirne']]
        }
        self.graph = TableQuestionKnowledgeGraph.read_from_json(json)
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name",
                                                            namespace='tokens')
        self.in_index = self.vocab.add_token_to_namespace("in",
                                                          namespace='tokens')
        self.english_index = self.vocab.add_token_to_namespace(
            "english", namespace='tokens')
        self.location_index = self.vocab.add_token_to_namespace(
            "location", namespace='tokens')
        self.paradeniz_index = self.vocab.add_token_to_namespace(
            "paradeniz", namespace='tokens')
        self.mersin_index = self.vocab.add_token_to_namespace(
            "mersin", namespace='tokens')
        self.lake_index = self.vocab.add_token_to_namespace("lake",
                                                            namespace='tokens')
        self.gala_index = self.vocab.add_token_to_namespace("gala",
                                                            namespace='tokens')
        self.negative_one_index = self.vocab.add_token_to_namespace(
            "-1", namespace='tokens')
        self.zero_index = self.vocab.add_token_to_namespace("0",
                                                            namespace='tokens')
        self.one_index = self.vocab.add_token_to_namespace("1",
                                                           namespace='tokens')

        self.oov_index = self.vocab.get_token_index('random OOV string',
                                                    namespace='tokens')
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance,
                                         self.token_indexers, self.tokenizer)

        super(KnowledgeGraphFieldTest, self).setUp()
Exemplo n.º 19
0
    def setUp(self):
        super().setUp()
        self.world_without_recursion = FakeWorldWithoutRecursion()
        self.world_with_recursion = FakeWorldWithRecursion()

        test_filename = self.FIXTURES_ROOT / "data" / "nlvr" / "sample_ungrouped_data.jsonl"
        data = [
            json.loads(line)["structured_rep"]
            for line in open(test_filename).readlines()
        ]
        self.nlvr_world = NlvrWorld(data[0])

        question_tokens = [
            Token(x)
            for x in ['what', 'was', 'the', 'last', 'year', '2004', '?']
        ]
        table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tsv'
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
            table_file, question_tokens)
        self.wikitables_world = WikiTablesWorld(table_kg)
Exemplo n.º 20
0
 def test_get_numbers_from_tokens_works_for_arabic_numerals(self):
     tokens = [Token(x) for x in ['7', '1.0', '-20']]
     numbers = TableQuestionKnowledgeGraph._get_numbers_from_tokens(tokens)
     assert numbers == [('7', '7'), ('1.000', '1.0'), ('-20', '-20')]
Exemplo n.º 21
0
 def test_should_split_column_returns_false_when_all_text_is_simple(self):
     assert TableQuestionKnowledgeGraph._should_split_column_cells(
         ['Team', '2006', 'Wolfe Tones']) is False
Exemplo n.º 22
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            table_lines: List[str],
            example_lisp_string: str = None,
            dpd_output: List[str] = None,
            tokenized_question: List[Token] = None) -> Instance:
        """
        Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as TSV
        files, which we use for training.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[str]``
            The table content itself, as a list of rows. See
            ``TableQuestionKnowledgeGraph.read_from_lines`` for the expected format.
        example_lisp_string : ``str``, optional
            The original (lisp-formatted) example string in the WikiTableQuestions dataset.  This
            comes directly from the ``.examples`` file provided with the dataset.  We pass this to
            SEMPRE for evaluating logical forms during training.  It isn't otherwise used for
            anything.
        dpd_output : List[str], optional
            List of logical forms, produced by dynamic programming on denotations. Not required
            during test.
        tokenized_question : ``List[Token]``, optional
            If you have already tokenized the question, you can pass that in here, so we don't
            duplicate that work.  You might, for example, do batch processing on the questions in
            the whole dataset, then pass the result in here.
        """
        # pylint: disable=arguments-differ
        tokenized_question = tokenized_question or self._tokenizer.tokenize(
            question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        metadata: Dict[str, Any] = {
            "question_tokens": [x.text for x in tokenized_question]
        }
        metadata["original_table"] = "".join(table_lines)
        table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines(
            table_lines, tokenized_question)
        table_metadata = MetadataField(table_lines)
        table_field = KnowledgeGraphField(
            table_knowledge_graph,
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            feature_extractors=self._linking_feature_extractors,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        world = WikiTablesWorld(table_knowledge_graph)
        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'metadata': MetadataField(metadata),
            'table': table_field,
            'world': world_field,
            'actions': action_field
        }
        if self._include_table_metadata:
            fields['table_metadata'] = table_metadata
        if example_lisp_string:
            fields['example_lisp_string'] = MetadataField(example_lisp_string)

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if dpd_output:
            action_sequence_fields: List[Field] = []
            for logical_form in dpd_output:
                if not self._should_keep_logical_form(logical_form):
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                try:
                    expression = world.parse_logical_form(logical_form)
                except ParsingError as error:
                    logger.debug(
                        f'Parsing error: {error.message}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.debug(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                if len(action_sequence_fields) >= self._max_dpd_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda():
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)
Exemplo n.º 23
0
    def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance:
        question_tokens = self._read_tokens_from_json_list(
            json_obj['question_tokens'])
        question_field = TextField(question_tokens,
                                   self._question_token_indexers)
        question_metadata = MetadataField(
            {"question_tokens": [x.text for x in question_tokens]})
        table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines(
            json_obj['table_lines'], question_tokens)
        entity_tokens = [
            self._read_tokens_from_json_list(token_list)
            for token_list in json_obj['entity_texts']
        ]
        table_field = KnowledgeGraphField(
            table_knowledge_graph,
            question_tokens,
            tokenizer=None,
            token_indexers=self._table_token_indexers,
            entity_tokens=entity_tokens,
            linking_features=json_obj['linking_features'],
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        world = WikiTablesWorld(table_knowledge_graph)
        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        example_string_field = MetadataField(json_obj['example_lisp_string'])

        fields = {
            'question': question_field,
            'metadata': question_metadata,
            'table': table_field,
            'world': world_field,
            'actions': action_field,
            'example_lisp_string': example_string_field
        }

        if 'target_action_sequences' in json_obj or 'agenda' in json_obj:
            action_map = {
                action.rule: i
                for i, action in enumerate(action_field.field_list)
            }  # type: ignore
        if 'target_action_sequences' in json_obj:
            action_sequence_fields: List[Field] = []
            for sequence in json_obj['target_action_sequences']:
                index_fields: List[Field] = []
                for production_rule in sequence:
                    index_fields.append(
                        IndexField(action_map[production_rule], action_field))
                action_sequence_fields.append(ListField(index_fields))
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        if 'agenda' in json_obj:
            agenda_index_fields: List[Field] = []
            for agenda_action in json_obj['agenda']:
                agenda_index_fields.append(
                    IndexField(action_map[agenda_action], action_field))
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)
Exemplo n.º 24
0
 def test_get_numbers_from_tokens_works_for_ordinal_and_cardinal_numbers(
         self):
     tokens = [Token(x) for x in ['one', 'five', 'Seventh']]
     numbers = TableQuestionKnowledgeGraph._get_numbers_from_tokens(tokens)
     assert numbers == [('1', 'one'), ('5', 'five'), ('7', 'Seventh')]
Exemplo n.º 25
0
 def _get_world_with_question_tokens(self, tokens: List[Token]) -> WikiTablesWorld:
     table_kg = TableQuestionKnowledgeGraph.read_from_file(self.table_file, tokens)
     world = WikiTablesWorld(table_kg)
     return world
Exemplo n.º 26
0
 def test_get_numbers_from_tokens_works_for_months(self):
     tokens = [Token(x) for x in ['January', 'March', 'october']]
     numbers = TableQuestionKnowledgeGraph._get_numbers_from_tokens(tokens)
     assert numbers == [('1', 'January'), ('3', 'March'), ('10', 'october')]
Exemplo n.º 27
0
 def test_get_numbers_from_tokens_works_for_units(self):
     tokens = [Token(x) for x in ['1ghz', '3.5mm', '-2m/s']]
     numbers = TableQuestionKnowledgeGraph._get_numbers_from_tokens(tokens)
     assert numbers == [('1', '1ghz'), ('3.500', '3.5mm'), ('-2', '-2m/s')]
Exemplo n.º 28
0
 def setUp(self):
     super().setUp()
     question_tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2000', '?']]
     self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tsv'
     self.table_kg = TableQuestionKnowledgeGraph.read_from_file(self.table_file, question_tokens)
     self.world = WikiTablesWorld(self.table_kg)
Exemplo n.º 29
0
 def test_get_numbers_from_tokens_works_with_magnitude_words(self):
     tokens = [Token(x) for x in ['one', 'million', '7', 'thousand']]
     numbers = TableQuestionKnowledgeGraph._get_numbers_from_tokens(tokens)
     assert numbers == [('1000000', 'one million'), ('7000', '7 thousand')]
Exemplo n.º 30
0
 def test_should_split_column_returns_true_when_one_input_is_splitable(
         self):
     assert TableQuestionKnowledgeGraph._should_split_column_cells(
         ['Team, 2006', 'Wolfe Tones']) is True