Exemple #1
0
    def test_doubly_nested_field_works(self):
        field1 = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
        field2 = ProductionRuleField('NP -> test', is_global_rule=True)
        field3 = ProductionRuleField('VP -> eat', is_global_rule=False)
        list_field = ListField(
            [ListField([field1, field2, field3]),
             ListField([field1, field2])])
        list_field.index(self.vocab)
        padding_lengths = list_field.get_padding_lengths()
        tensors = list_field.as_tensor(padding_lengths)
        assert isinstance(tensors, list)
        assert len(tensors) == 2
        assert isinstance(tensors[0], list)
        assert len(tensors[0]) == 3
        assert isinstance(tensors[1], list)
        assert len(tensors[1]) == 3

        tensor_tuple = tensors[0][0]
        assert tensor_tuple[0] == 'S -> [NP, VP]'
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.s_rule_index])

        tensor_tuple = tensors[0][1]
        assert tensor_tuple[0] == 'NP -> test'
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.np_index])

        tensor_tuple = tensors[0][2]
        assert tensor_tuple[0] == 'VP -> eat'
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None

        tensor_tuple = tensors[1][0]
        assert tensor_tuple[0] == 'S -> [NP, VP]'
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.s_rule_index])

        tensor_tuple = tensors[1][1]
        assert tensor_tuple[0] == 'NP -> test'
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.np_index])

        # This item was just padding.
        tensor_tuple = tensors[1][2]
        assert tensor_tuple[0] == ''
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None
Exemple #2
0
    def text_to_instance(self, question_tokens: List[Token], logical_form: str,
                         question_entities: List[str],
                         question_predicates: List[str]) -> Instance:
        context = LCQuADContext(self.executor, question_tokens,
                                question_entities, question_predicates)
        language = LCQuADLanguage(context)
        print("CONSTANT:" +
              str(language._functions['http://dbpedia.org/ontology/creator']))
        target_action_sequence = language.logical_form_to_action_sequence(
            logical_form)

        production_rule_fields = [
            ProductionRuleField(rule, is_global_rule=True)
            for rule in language.all_possible_productions()
        ]
        action_field = ListField(production_rule_fields)
        action_map = {
            action.rule: i
            for i, action in enumerate(production_rule_fields)
        }
        target_action_sequence_field = ListField([
            IndexField(action_map[a], action_field)
            for a in target_action_sequence
        ])

        fields = {
            'question': TextField(question_tokens, self.token_indexers),
            'question_entities': MetadataField(question_entities),
            'question_predicates': MetadataField(question_predicates),
            'world': MetadataField(language),
            'actions': action_field,
            'target_action_sequence': target_action_sequence_field
        }

        return Instance(fields)
Exemple #3
0
    def text_to_instance(self, question_tokens: List[Token], logical_form: str,
                         question_entities: List[str],
                         question_predicates: List[str]) -> Optional[Instance]:
        try:
            if any([True for i in self.ontology_types if i in logical_form]):
                return None

            if "intersection" in logical_form or "contains" in logical_form or "count" in logical_form:
                return None

            for old, new in zip(self.original_predicates, self.predicates):
                logical_form = logical_form.replace(old, new)

            question_entities = question_entities  #+ list(self.ontology_types)

            import random
            random.shuffle(question_predicates)
            context = LCQuADContext(self.executor, question_tokens,
                                    question_entities, question_predicates)
            language = LCQuADLanguage(context)
            # print("CONSTANT:" + str(language._functions['http://dbpedia.org/ontology/creator']))
            target_action_sequence = language.logical_form_to_action_sequence(
                logical_form)
            #labelled_results = language.execute_action_sequence(target_action_sequence)
            # if isinstance(labelled_results, set) and len(labelled_results) > 1000:
            #     return None

            production_rule_fields = [
                ProductionRuleField(rule, is_global_rule=True)
                for rule in language.all_possible_productions()
            ]
            action_field = ListField(production_rule_fields)
            action_map = {
                action.rule: i
                for i, action in enumerate(production_rule_fields)
            }
            target_action_sequence_field = ListField([
                ListField([
                    IndexField(action_map[a], action_field)
                    for a in target_action_sequence
                ])
            ])

            fields = {
                'question': TextField(question_tokens, self.token_indexers),
                'question_entities': MetadataField(question_entities),
                'question_predicates': MetadataField(question_predicates),
                'world': MetadataField(language),
                'actions': action_field,
                'target_action_sequences': target_action_sequence_field,
                'logical_forms': MetadataField([logical_form])
                #'labelled_results': MetadataField(labelled_results),
            }

            return Instance(fields)
        except ParsingError:
            print(logical_form)
            return None
Exemple #4
0
    def test_field_counts_vocab_items_correctly(self):
        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        field.count_vocab_items(namespace_token_counts)
        assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 1

        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=False)
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        field.count_vocab_items(namespace_token_counts)
        assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 0
    def text_to_instance(self,  # type: ignore
                         query: List[str],
                         derived_cols: List[Tuple[str, str]],
                         derived_tables: List[str],
                         prelinked_entities: Dict[str, Dict[str, str]] = None,
                         sql: List[str] = None,
                         spans: List[Tuple[int, int]] = None) -> Instance:
        # pylint: disable=arguments-differ
        fields: Dict[str, Field] = {}
        tokens_tokenized = self._token_tokenizer.tokenize(' '.join(query))
        tokens = TextField(tokens_tokenized, self._token_indexers)
        fields["tokens"] = tokens

        spans_field: List[Field] = []
        spans = self._fix_spans_coverage(spans, len(tokens_tokenized))
        for start, end in spans:
            spans_field.append(SpanField(start, end, tokens))
        span_list_field: ListField = ListField(spans_field)
        fields["spans"] = span_list_field

        if sql is not None:
            action_sequence, all_actions = self._world.get_action_sequence_and_all_actions(query=sql,
                                                                                           derived_cols=derived_cols,
                                                                                           derived_tables=derived_tables,
                                                                                           prelinked_entities=prelinked_entities)
            if action_sequence is None and self._keep_if_unparsable:
                print("Parse error")
                action_sequence = []
            elif action_sequence is None:
                return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, _ = production_rule.split(' ->')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        self._world.is_global_rule(nonterminal),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {action.rule: i # type: ignore
                      for i, action in enumerate(valid_actions_field.field_list)}

        for production_rule in action_sequence:
            index_fields.append(IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]
        # if not action_sequence and re.findall(r"COUNT \( \* \) (?:<|>|<>|=) 0", " ".join(sql)):
        #     index_fields = [IndexField(-2, valid_actions_field)]

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        return Instance(fields)
    def test_as_tensor_produces_correct_output(self):
        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 4
        assert tensor_tuple[0] == 'S -> [NP, VP]'
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index])

        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=False)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 4
        assert tensor_tuple[0] == 'S -> [NP, VP]'
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None
Exemple #7
0
    def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance:
        question_tokens = self._read_tokens_from_json_list(
            json_obj['question_tokens'])
        question_field = TextField(question_tokens,
                                   self._question_token_indexers)
        table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines(
            json_obj['table_lines'], question_tokens)
        entity_tokens = [
            self._read_tokens_from_json_list(token_list)
            for token_list in json_obj['entity_texts']
        ]
        table_field = KnowledgeGraphField(
            table_knowledge_graph,
            question_tokens,
            tokenizer=None,
            token_indexers=self._table_token_indexers,
            entity_tokens=entity_tokens,
            linking_features=json_obj['linking_features'],
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        world = WikiTablesWorld(table_knowledge_graph)
        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        example_string_field = MetadataField(json_obj['example_lisp_string'])

        fields = {
            'question': question_field,
            'table': table_field,
            'world': world_field,
            'actions': action_field,
            'example_lisp_string': example_string_field
        }

        if 'target_action_sequences' in json_obj:
            action_map = {
                action.rule: i
                for i, action in enumerate(action_field.field_list)
            }  # type: ignore
            action_sequence_fields: List[Field] = []
            for sequence in json_obj['target_action_sequences']:
                index_fields: List[Field] = []
                for production_rule in sequence:
                    index_fields.append(
                        IndexField(action_map[production_rule], action_field))
                action_sequence_fields.append(ListField(index_fields))
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)

        return Instance(fields)
    def text_to_instance(self,  # type: ignore
                         logical_forms: List[str],
                         table_lines: List[List[str]],
                         question: str) -> Instance:
        # pylint: disable=arguments-differ
        tokenized_question = self._tokenizer.tokenize(question.lower())
        tokenized_question.insert(0, Token(START_SYMBOL))
        tokenized_question.append(Token(END_SYMBOL))
        question_field = TextField(tokenized_question, self._question_token_indexers)
        table_context = TableQuestionContext.read_from_lines(table_lines, tokenized_question)
        world = WikiTablesLanguage(table_context)

        action_sequences_list: List[List[str]] = []
        action_sequence_fields_list: List[TextField] = []
        for logical_form in logical_forms:
            try:
                action_sequence = world.logical_form_to_action_sequence(logical_form)
                action_sequence = reader_utils.make_bottom_up_action_sequence(action_sequence,
                                                                              world.is_nonterminal)
                action_sequence_field = TextField([Token(rule) for rule in  action_sequence],
                                                  self._rule_indexers)
                action_sequences_list.append(action_sequence)
                action_sequence_fields_list.append(action_sequence_field)
            except ParsingError as error:
                logger.debug(f'Parsing error: {error.message}, skipping logical form')
                logger.debug(f'Question was: {question}')
                logger.debug(f'Logical form was: {logical_form}')
                logger.debug(f'Table info was: {table_lines}')
            except:
                logger.error(logical_form)
                raise

        if not action_sequences_list:
            return None

        all_production_rule_fields: List[List[Field]] = []
        for action_sequence in action_sequences_list:
            all_production_rule_fields.append([])
            for production_rule in action_sequence:
                _, rule_right_side = production_rule.split(' -> ')
                is_global_rule = not world.is_instance_specific_entity(rule_right_side)
                field = ProductionRuleField(production_rule, is_global_rule=is_global_rule)
                all_production_rule_fields[-1].append(field)
        action_field = ListField([ListField(production_rule_fields) for production_rule_fields in
                                  all_production_rule_fields])

        fields = {'action_sequences': ListField(action_sequence_fields_list),
                  'target_tokens': question_field,
                  'world': MetadataField(world),
                  'actions': action_field}

        return Instance(fields)
Exemple #9
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}

        db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer,
                                     tables_file=self._tables_file, dataset_path=self._dataset_path)
        table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph,
                                                db_context.tokenized_utterance,
                                                self._utterance_token_indexers,
                                                entity_tokens=db_context.entity_tokens,
                                                include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                                                max_table_tokens=None)  # self._max_table_tokens)
        
        world = SpiderWorld(db_context, query=sql)
        fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers)

        action_sequence, all_actions = world.get_action_sequence_and_all_actions()
        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, rhs = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(rhs),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)
            
        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {action.rule: i  # type: ignore
                      for i, action in enumerate(valid_actions_field.field_list)}

        for production_rule in action_sequence:
            index_fields.append(IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]
        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        fields["world"] = MetadataField(world)
        fields["schema"] = table_field
        return Instance(fields)
Exemple #10
0
    def text_to_instance(
            self,  # type: ignore
            query: List[str],
            sql: text2sql_utils.SqlData = None) -> Instance:
        # pylint: disable=arguments-differ
        fields: Dict[str, Field] = {}
        tokens = TextField([Token(t) for t in query], self._token_indexers)
        fields["tokens"] = tokens

        if sql is not None:
            try:
                action_sequence, all_actions = self._world.get_action_sequence_and_all_actions(
                    sql.sql)
            except ParseError:
                return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, _ = production_rule.split(' ->')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(
                production_rule, self._world.is_global_rule(nonterminal))
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(valid_actions_field.field_list)
        }

        for production_rule in action_sequence:
            # Temporarily skipping this production to
            # make a PR smaller. The next PR will constrain
            # the strings produced to be from the table,
            # but at the moment they are blank so they
            # aren't present in the global actions.
            # TODO(Mark): fix the above.
            if production_rule.startswith("string"):
                continue
            index_fields.append(
                IndexField(action_map[production_rule], valid_actions_field))

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        return Instance(fields)
Exemple #11
0
    def text_to_instance(
        self,  # type: ignore
        query: List[str],
        prelinked_entities: Dict[str, Dict[str, str]] = None,
        sql: List[str] = None,
    ) -> Instance:

        fields: Dict[str, Field] = {}
        tokens = TextField([Token(t) for t in query], self._token_indexers)
        fields["tokens"] = tokens

        if sql is not None:
            action_sequence, all_actions = self._world.get_action_sequence_and_all_actions(
                sql, prelinked_entities)
            if action_sequence is None and self._keep_if_unparsable:
                print("Parse error")
                action_sequence = []
            elif action_sequence is None:
                return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, _ = production_rule.split(" ->")
            production_rule = " ".join(production_rule.split(" "))
            field = ProductionRuleField(
                production_rule,
                self._world.is_global_rule(nonterminal),
                nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(valid_actions_field.field_list)
        }

        for production_rule in action_sequence:
            index_fields.append(
                IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        return Instance(fields)
    def test_field_counts_vocab_items_correctly(self):
        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        field.count_vocab_items(namespace_token_counts)
        assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 1

        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=False)
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        field.count_vocab_items(namespace_token_counts)
        assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 0
Exemple #13
0
    def text_to_instance(
            self,  # type: ignore
            query: List[str],
            prelinked_entities: Dict[str, Dict[str, str]] = None,
            sql: List[str] = None) -> Instance:
        # pylint: disable=arguments-differ
        fields: Dict[str, Field] = {}
        tokens = TextField([Token(t) for t in query], self._token_indexers)
        fields["tokens"] = tokens

        if sql is not None:
            try:
                action_sequence, all_actions = self._world.get_action_sequence_and_all_actions(
                    sql, prelinked_entities)
            except ParseError:
                return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, _ = production_rule.split(' ->')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(
                production_rule,
                self._world.is_global_rule(nonterminal),
                nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(valid_actions_field.field_list)
        }

        for production_rule in action_sequence:
            index_fields.append(
                IndexField(action_map[production_rule], valid_actions_field))

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        return Instance(fields)
Exemple #14
0
    def text_to_instance(self,
                         utterances: List[str],
                         db_id: str,
                         sql: List[List[str]] = None):
        fields: Dict[str, Field] = {}

        ctxts = [
            SpiderDBContext(db_id,
                            utterance,
                            tokenizer=self._tokenizer,
                            tables_file=self._tables_file,
                            dataset_path=self._dataset_path)
            for utterance in utterances
        ]

        super_utterance = ' '.join(utterances)
        hack_ctxt = SpiderDBContext(db_id,
                                    super_utterance,
                                    tokenizer=self._tokenizer,
                                    tables_file=self._tables_file,
                                    dataset_path=self._dataset_path)

        kg = SpiderKnowledgeGraphField(
            hack_ctxt.knowledge_graph,
            hack_ctxt.tokenized_utterance,
            self._utterance_token_indexers,
            entity_tokens=hack_ctxt.entity_tokens,
            include_in_vocab=False,  # TODO: self._use_table_for_vocab,
            max_table_tokens=None)
        '''
        kgs = [SpiderKnowledgeGraphField(db_context.knowledge_graph,
                                        db_context.tokenized_utterance,
                                        self._utterance_token_indexers,
                                        entity_tokens=db_context.entity_tokens,
                                        include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                                        max_table_tokens=None)  # self._max_table_tokens)
                        for db_context in ctxts]
        '''
        worlds = []

        for i in range(len(sql)):
            sqli = sql[i]
            db_context = ctxts[i]
            world = SpiderWorld(db_context, query=sqli)
            worlds.append(world)

        fields["utterances"] = ListField([
            TextField(db_context.tokenized_utterance,
                      self._utterance_token_indexers) for db_context in ctxts
        ])

        #action_sequence, all_actions = world.get_action_sequence_and_all_actions()

        action_tups = [
            world.get_action_sequence_and_all_actions() for world in worlds
        ]
        action_sequences = [tup[0] for tup in action_tups]
        all_actions = [tup[1] for tup in action_tups]

        for i in range(len(action_sequences)):
            action_sequence = action_sequences[i]

            if action_sequence is None and self._keep_if_unparsable:
                # print("Parse error")
                action_sequence = []
            elif action_sequence is None:
                return None

            action_sequences[i] = action_sequence

        all_valid_actions_fields = []
        all_action_sequence_fields = []

        for i in range(len(all_actions)):
            index_fields: List[Field] = []
            production_rule_fields: List[Field] = []

            all_actionsi = all_actions[i]
            for production_rule in all_actionsi:
                nonterminal, rhs = production_rule.split(' -> ')
                production_rule = ' '.join(production_rule.split(' '))
                field = ProductionRuleField(production_rule,
                                            world.is_global_rule(rhs),
                                            nonterminal=nonterminal)
                production_rule_fields.append(field)

            valid_actions_field = ListField(production_rule_fields)

            all_valid_actions_fields.append(valid_actions_field)

            action_map = {
                action.rule: i  # type: ignore
                for i, action in enumerate(valid_actions_field.field_list)
            }

            index_fields: List[Field] = []

            action_sequence = action_sequences[i]

            for production_rule in action_sequence:
                index_fields.append(
                    IndexField(action_map[production_rule],
                               valid_actions_field))
            if not action_sequence:
                index_fields = [IndexField(-1, valid_actions_field)]

            action_sequence_field = ListField(index_fields)

            all_action_sequence_fields.append(action_sequence_field)

        fields["valid_actions"] = ListField(all_valid_actions_fields)
        fields["action_sequences"] = ListField(all_action_sequence_fields)
        fields["worlds"] = ListField(
            [MetadataField(world) for world in worlds])
        fields["schema"] = kg
        '''
        fields['utterances'] = ListField[TextField]
        fields['valid_actions'] = ListField[ListField[ProductionRuleField]]
        fields['action_sequences'] = ListField[ListField[IndexField]]
        fields['worlds'] = ListField[MetadataField[SpiderWorld]]
        fields['schemas'] = ListField[SpiderKnowledgeGraphField]
        '''

        return Instance(fields)
Exemple #15
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}

        db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer,
                                     tables_file=self._tables_file, dataset_path=self._dataset_path)
        table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph,
                                                db_context.tokenized_utterance,
                                                {},
                                                entity_tokens=db_context.entity_tokens,
                                                include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                                                max_table_tokens=None)  # self._max_table_tokens)


        combined_tokens = [] + db_context.tokenized_utterance
        entity_token_map = dict(zip(db_context.knowledge_graph.entities, db_context.entity_tokens))
        entity_tokens = []
        for e in db_context.knowledge_graph.entities:
            if e.startswith('column:'):
                table_name, column_name = e.split(':')[-2:]
                table_tokens = entity_token_map['table:'+table_name]
                column_tokens = entity_token_map[e]
                if column_name.startswith(table_name):
                    column_tokens = column_tokens[len(table_tokens):]
                entity_tokens.append(table_tokens + [Token(text='[unused30]')] + column_tokens)
            else:
                entity_tokens.append(entity_token_map[e])
        for e in entity_tokens:
            combined_tokens += [Token(text='[SEP]')] + e
        if len(combined_tokens) > 450:
            return None
        db_context.entity_tokens = entity_tokens
        fields["utterance"] = TextField(combined_tokens, self._utterance_token_indexers)

        world = SpiderWorld(db_context, query=sql)

        action_sequence, all_actions = world.get_action_sequence_and_all_actions()

        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, _ = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(nonterminal),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {action.rule: i  # type: ignore
                      for i, action in enumerate(valid_actions_field.field_list)}

        for production_rule in action_sequence:
            index_fields.append(IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        fields["world"] = MetadataField(world)
        fields["schema"] = table_field
        return Instance(fields)
Exemple #16
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}
        if self._is_spider:
            db_context = SpiderDBContext(db_id,
                                         utterance,
                                         tokenizer=self._tokenizer,
                                         tables_file=self._tables_file,
                                         dataset_path=self._dataset_path)
            table_field = SpiderKnowledgeGraphField(
                db_context.knowledge_graph,
                db_context.tokenized_utterance,
                self._utterance_token_indexers,
                entity_tokens=db_context.entity_tokens,
                include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                max_table_tokens=None)  # self._max_table_tokens)

            world = SpiderWorld(db_context, nl_context=None, query=sql)
            fields["utterance"] = TextField(db_context.tokenized_utterance,
                                            self._utterance_token_indexers)

            action_sequence, all_actions = world.get_action_sequence_and_all_actions(
            )

            if action_sequence is None and self._keep_if_unparsable:
                # print("Parse error")
                action_sequence = []
            elif action_sequence is None:
                return None

            index_fields: List[Field] = []
            production_rule_fields: List[Field] = []

            for production_rule in all_actions:
                nonterminal, rhs = production_rule.split(' -> ')
                production_rule = ' '.join(production_rule.split(' '))
                field = ProductionRuleField(production_rule,
                                            world.is_global_rule(rhs),
                                            nonterminal=nonterminal)
                production_rule_fields.append(field)

            valid_actions_field = ListField(production_rule_fields)
            fields["valid_actions"] = valid_actions_field

            action_map = {
                action.rule: i  # type: ignore
                for i, action in enumerate(valid_actions_field.field_list)
            }

            for production_rule in action_sequence:
                index_fields.append(
                    IndexField(action_map[production_rule],
                               valid_actions_field))
            if not action_sequence:
                index_fields = [IndexField(-1, valid_actions_field)]

            action_sequence_field = ListField(index_fields)
            fields["action_sequence"] = action_sequence_field
            fields["world"] = MetadataField(world)
            fields["schema"] = table_field
        else:
            db_context = WikiDBContext(db_id,
                                       utterance,
                                       tokenizer=self._tokenizer,
                                       tables_file=self._tables_file,
                                       dataset_path=self._dataset_path)
            #print(db_context.entity_tokens)
            #todo 这个WikiKnowledgeGraphField和对应的spider的一模一样 只是改动了类名
            table_field = WikiKnowledgeGraphField(
                db_context.knowledge_graph,
                db_context.tokenized_utterance,
                self._utterance_token_indexers,
                entity_tokens=db_context.entity_tokens,
                include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                max_table_tokens=None)  # self._max_table_tokens)

            world = WikiWorld(db_context, nl_context=None, query=sql)
            fields["utterance"] = TextField(db_context.tokenized_utterance,
                                            self._utterance_token_indexers)

            #todo 这一步会报错 应该是grammar不匹配的问题 ParseError:['select', '1-10015132-11@Position', 'where', '1-10015132-11@School/Club Team', '=', "'value'"]
            #todo action_sequence: None
            #todo all_actions: ['arg_list -> [expr, ",", arg_list]', 'arg_list -> [expr]', 'arg_list_or_star -> ["*"]', 'arg_list_or_star -> [arg_list]', 'binaryop -> ["!="]', 'binaryop -> ["*"]', 'binaryop -> ["+"]', 'binaryop -> ["-"]', 'binaryop -> ["/"]', 'binaryop -> ["<"]', 'binaryop -> ["<="]', 'binaryop -> ["<>"]', 'binaryop -> ["="]', 'binaryop -> [">"]', 'binaryop -> [">="]', 'binaryop -> ["and"]', 'binaryop -> ["like"]', 'binaryop -> ["or"]', 'boolean -> ["false"]', 'boolean -> ["true"]', 'col_ref -> ["1-10015132-11@col0"]', 'col_ref -> ["1-10015132-11@col1"]', 'col_ref -> ["1-10015132-11@col2"]', 'col_ref -> ["1-10015132-11@col3"]', 'col_ref -> ["1-10015132-11@col4"]', 'col_ref -> ["1-10015132-11@col5"]', 'column_name -> ["1-10015132-11@col0"]', 'column_name -> ["1-10015132-11@col1"]', 'column_name -> ["1-10015132-11@col2"]', 'column_name -> ["1-10015132-11@col3"]', 'column_name -> ["1-10015132-11@col4"]', 'column_name -> ["1-10015132-11@col5"]', 'expr -> [in_expr]', 'expr -> [source_subq]', 'expr -> [unaryop, expr]', 'expr -> [value, "between", value, "and", value]', 'expr -> [value, "like", string]', 'expr -> [value, binaryop, expr]', 'expr -> [value]', 'fname -> ["all"]', 'fname -> ["avg"]', 'fname -> ["count"]', 'fname -> ["max"]', 'fname -> ["min"]', 'fname -> ["sum"]', 'from_clause -> ["from", source]', 'from_clause -> ["from", table_name, join_clauses]', 'function -> [fname, "(", "distinct", arg_list_or_star, ")"]', 'function -> [fname, "(", arg_list_or_star, ")"]', 'group_clause -> [expr, ",", group_clause]', 'group_clause -> [expr]', 'groupby_clause -> ["group", "by", group_clause, "having", expr]', 'groupby_clause -> ["group", "by", group_clause]', 'in_expr -> [value, "in", expr]', 'in_expr -> [value, "in", string_set]', 'in_expr -> [value, "not", "in", expr]', 'in_expr -> [value, "not", "in", string_set]', 'iue -> ["except"]', 'iue -> ["intersect"]', 'iue -> ["union"]', 'join_clause -> ["join", table_name, "on", join_condition_clause]', 'join_clauses -> [join_clause, join_clauses]', 'join_clauses -> [join_clause]', 'join_condition -> [column_name, "=", column_name]', 'join_condition_clause -> [join_condition, "and", join_condition_clause]', 'join_condition_clause -> [join_condition]', 'limit -> ["limit", non_literal_number]', 'non_literal_number -> ["1"]', 'non_literal_number -> ["2"]', 'non_literal_number -> ["3"]', 'non_literal_number -> ["4"]', 'number -> ["value"]', 'order_clause -> [ordering_term, ",", order_clause]', 'order_clause -> [ordering_term]', 'orderby_clause -> ["order", "by", order_clause]', 'ordering -> ["asc"]', 'ordering -> ["desc"]', 'ordering_term -> [expr, ordering]', 'ordering_term -> [expr]', 'parenval -> ["(", expr, ")"]', 'query -> [select_core, groupby_clause, limit]', 'query -> [select_core, groupby_clause, orderby_clause, limit]', 'query -> [select_core, groupby_clause, orderby_clause]', 'query -> [select_core, groupby_clause]', 'query -> [select_core, orderby_clause, limit]', 'query -> [select_core, orderby_clause]', 'query -> [select_core]', 'select_core -> [select_with_distinct, select_results, from_clause, where_clause]', 'select_core -> [select_with_distinct, select_results, from_clause]', 'select_core -> [select_with_distinct, select_results, where_clause]', 'select_core -> [select_with_distinct, select_results]', 'select_result -> ["*"]', 'select_result -> [column_name]', 'select_result -> [expr]', 'select_result -> [table_name, ".*"]', 'select_results -> [select_result, ",", select_results]', 'select_results -> [select_result]', 'select_with_distinct -> ["select", "distinct"]', 'select_with_distinct -> ["select"]', 'single_source -> [source_subq]', 'single_source -> [table_name]', 'source -> [single_source, ",", source]', 'source -> [single_source]', 'source_subq -> ["(", query, ")"]', 'statement -> [query, iue, query]', 'statement -> [query]', 'string -> ["\'", "value", "\'"]', 'string_set -> ["(", string_set_vals, ")"]', 'string_set_vals -> [string, ",", string_set_vals]', 'string_set_vals -> [string]', "table_name -> ['1-10015132-11']", "table_source -> ['1-10015132-11']", 'unaryop -> ["+"]', 'unaryop -> ["-"]', 'unaryop -> ["not"]', 'value -> ["YEAR(CURDATE())"]', 'value -> [boolean]', 'value -> [column_name]', 'value -> [function]', 'value -> [number]', 'value -> [parenval]', 'value -> [string]', 'where_clause -> ["where", expr, where_conj]', 'where_clause -> ["where", expr]', 'where_conj -> ["and", expr, where_conj]', 'where_conj -> ["and", expr]']
            action_sequence, all_actions = world.get_action_sequence_and_all_actions(
            )

            if action_sequence is None and self._keep_if_unparsable:
                # print("Parse error")
                action_sequence = []
            elif action_sequence is None:
                return None

            production_rule_fields: List[Field] = []

            for production_rule in all_actions:

                nonterminal, rhs = production_rule.split(' -> ')
                production_rule = ' '.join(production_rule.split(' '))

                field = ProductionRuleField(production_rule,
                                            world.is_global_rule(rhs),
                                            nonterminal=nonterminal)
                production_rule_fields.append(field)

            valid_actions_field = ListField(production_rule_fields)
            fields["valid_actions"] = valid_actions_field

            index_fields: List[Field] = []

            # action: ProductionRuleField
            action_map = {
                action.rule: i  # type: ignore
                for i, action in enumerate(valid_actions_field.field_list)
            }

            for production_rule in action_sequence:
                index_fields.append(
                    IndexField(action_map[production_rule],
                               valid_actions_field))
            if not action_sequence:
                index_fields = [IndexField(-1, valid_actions_field)]

            action_sequence_field = ListField(index_fields)
            fields["action_sequence"] = action_sequence_field

            fields["world"] = MetadataField(world)
            fields["schema"] = table_field

            #print(fields)

        return Instance(fields)
Exemple #17
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            table_lines: List[str],
            example_lisp_string: str = None,
            dpd_output: List[str] = None,
            tokenized_question: List[Token] = None) -> Instance:
        """
        Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as TSV
        files, which we use for training.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[str]``
            The table content itself, as a list of rows. See
            ``TableQuestionKnowledgeGraph.read_from_lines`` for the expected format.
        example_lisp_string : ``str``, optional
            The original (lisp-formatted) example string in the WikiTableQuestions dataset.  This
            comes directly from the ``.examples`` file provided with the dataset.  We pass this to
            SEMPRE for evaluating logical forms during training.  It isn't otherwise used for
            anything.
        dpd_output : List[str], optional
            List of logical forms, produced by dynamic programming on denotations. Not required
            during test.
        tokenized_question : ``List[Token]``, optional
            If you have already tokenized the question, you can pass that in here, so we don't
            duplicate that work.  You might, for example, do batch processing on the questions in
            the whole dataset, then pass the result in here.
        """
        # pylint: disable=arguments-differ
        tokenized_question = tokenized_question or self._tokenizer.tokenize(
            question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        metadata: Dict[str, Any] = {
            "question_tokens": [x.text for x in tokenized_question]
        }
        metadata["original_table"] = "\n".join(table_lines)
        table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines(
            table_lines, tokenized_question)
        table_metadata = MetadataField(table_lines)
        table_field = KnowledgeGraphField(
            table_knowledge_graph,
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            feature_extractors=self._linking_feature_extractors,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        world = WikiTablesWorld(table_knowledge_graph)
        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'metadata': MetadataField(metadata),
            'table': table_field,
            'world': world_field,
            'actions': action_field
        }
        if self._include_table_metadata:
            fields['table_metadata'] = table_metadata
        if example_lisp_string:
            fields['example_lisp_string'] = MetadataField(example_lisp_string)

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if dpd_output:
            action_sequence_fields: List[Field] = []
            for logical_form in dpd_output:
                if not self._should_keep_logical_form(logical_form):
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                try:
                    expression = world.parse_logical_form(logical_form)
                except ParsingError as error:
                    logger.debug(
                        f'Parsing error: {error.message}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.debug(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                if len(action_sequence_fields) >= self._max_dpd_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda():
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)
Exemple #18
0
    def test_batch_tensors_does_not_modify_list(self):
        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict1 = field.as_tensor(padding_lengths)

        field = ProductionRuleField('NP -> test', is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict2 = field.as_tensor(padding_lengths)
        tensor_list = [tensor_dict1, tensor_dict2]
        assert field.batch_tensors(tensor_list) == tensor_list
Exemple #19
0
 def test_padding_lengths_are_computed_correctly(self):
     field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
     field.index(self.vocab)
     assert field.get_padding_lengths() == {}
Exemple #20
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}
        """KAIMARY"""
        # Contains
        # 1. db schema(Tables with corresponding columns)
        # 2. Tokenized utterance
        # 3. Knowledge graph(Table entities, column entities and "text" column type related token entities)
        # 4. Entity_tokens(Retrieved from entities_text from kg)
        db_context = SpiderDBContext(db_id,
                                     utterance,
                                     tokenizer=self._tokenizer,
                                     tables_file=self._tables_file,
                                     dataset_path=self._dataset_path)
        # https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html#knowledge-graph-field
        # *feature extractors*
        table_field = SpiderKnowledgeGraphField(
            db_context.knowledge_graph,
            db_context.tokenized_utterance,
            self._utterance_token_indexers,
            entity_tokens=db_context.entity_tokens,
            include_in_vocab=False,  # TODO: self._use_table_for_vocab,
            max_table_tokens=None)  # self._max_table_tokens)

        world = SpiderWorld(db_context, query=sql)
        fields["utterance"] = TextField(db_context.tokenized_utterance,
                                        self._utterance_token_indexers)

        action_sequence, all_actions = world.get_action_sequence_and_all_actions(
        )

        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, rhs = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(rhs),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(valid_actions_field.field_list)
        }

        for production_rule in action_sequence:
            index_fields.append(
                IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        fields["world"] = MetadataField(world)
        fields["schema"] = table_field
        return Instance(fields)
Exemple #21
0
    def text_to_instance(
            self,  # type: ignore
            utterances: List[str],
            sql_query_labels: List[str] = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        utterances: ``List[str]``, required.
            List of utterances in the interaction, the last element is the current utterance.
        sql_query_labels: ``List[str]``, optional
            The SQL queries that are given as labels during training or validation.
        """
        if self._num_turns_to_concatenate:
            utterances[-1] = f' {END_OF_UTTERANCE_TOKEN} '.join(
                utterances[-self._num_turns_to_concatenate:])

        utterance = utterances[-1]
        action_sequence: List[str] = []

        if not utterance:
            return None

        world = AtisWorld(utterances=utterances)

        if sql_query_labels:
            # If there are multiple sql queries given as labels, we use the shortest
            # one for training.
            sql_query = min(sql_query_labels, key=len)
            try:
                action_sequence = world.get_action_sequence(sql_query)
            except ParseError:
                logger.debug(f'Parsing error')

        tokenized_utterance = self._tokenizer.tokenize(utterance.lower())
        utterance_field = TextField(tokenized_utterance, self._token_indexers)

        production_rule_fields: List[Field] = []

        for production_rule in world.all_possible_actions():
            nonterminal, _ = production_rule.split(' ->')
            # The whitespaces are not semantically meaningful, so we filter them out.
            production_rule = ' '.join([
                token for token in production_rule.split(' ') if token != 'ws'
            ])
            field = ProductionRuleField(production_rule,
                                        self._is_global_rule(nonterminal))
            production_rule_fields.append(field)

        action_field = ListField(production_rule_fields)
        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(action_field.field_list)
        }
        index_fields: List[Field] = []
        world_field = MetadataField(world)
        fields = {
            'utterance': utterance_field,
            'actions': action_field,
            'world': world_field,
            'linking_scores': ArrayField(world.linking_scores)
        }

        if sql_query_labels != None:
            fields['sql_queries'] = MetadataField(sql_query_labels)
            if action_sequence and not self._keep_if_unparseable:
                for production_rule in action_sequence:
                    index_fields.append(
                        IndexField(action_map[production_rule], action_field))
                action_sequence_field = ListField(index_fields)
                fields['target_action_sequence'] = action_sequence_field
            elif not self._keep_if_unparseable:
                # If we are given a SQL query, but we are unable to parse it, and we do not specify explicitly
                # to keep it, then we will skip the it.
                return None

        return Instance(fields)
Exemple #22
0
    def text_to_instance(
        self,  # type: ignore
        question: str,
        logical_forms: List[str] = None,
        additional_metadata: Dict[str, Any] = None,
        world_extractions: Dict[str, Union[str, List[str]]] = None,
        entity_literals: Dict[str, Union[str, List[str]]] = None,
        tokenized_question: List[Token] = None,
        debug_counter: int = None,
        qr_spec_override: List[Dict[str, int]] = None,
        dynamic_entities_override: Dict[str, str] = None,
    ) -> Instance:

        tokenized_question = tokenized_question or self._tokenizer.tokenize(question.lower())
        additional_metadata = additional_metadata or dict()
        additional_metadata["question_tokens"] = [token.text for token in tokenized_question]
        if world_extractions is not None:
            additional_metadata["world_extractions"] = world_extractions
        question_field = TextField(tokenized_question, self._question_token_indexers)

        if qr_spec_override is not None or dynamic_entities_override is not None:
            # Dynamically specify theory and/or entities
            dynamic_entities = dynamic_entities_override or self._dynamic_entities
            neighbors: Dict[str, List[str]] = {key: [] for key in dynamic_entities.keys()}
            knowledge_graph = KnowledgeGraph(
                entities=set(dynamic_entities.keys()),
                neighbors=neighbors,
                entity_text=dynamic_entities,
            )
            world = QuarelWorld(knowledge_graph, self._lf_syntax, qr_coeff_sets=qr_spec_override)
        else:
            knowledge_graph = self._knowledge_graph
            world = self._world

        table_field = KnowledgeGraphField(
            knowledge_graph,
            tokenized_question,
            self._entity_token_indexers,
            tokenizer=self._tokenizer,
        )

        if self._tagger_only:
            fields: Dict[str, Field] = {"tokens": question_field}
            if entity_literals is not None:
                entity_tags = self._get_entity_tags(
                    self._all_entities, table_field, entity_literals, tokenized_question
                )
                if debug_counter > 0:
                    logger.info(f"raw entity tags = {entity_tags}")
                entity_tags_bio = self._convert_tags_bio(entity_tags)
                fields["tags"] = SequenceLabelField(entity_tags_bio, question_field)
                additional_metadata["tags_gold"] = entity_tags_bio
            additional_metadata["words"] = [x.text for x in tokenized_question]
            fields["metadata"] = MetadataField(additional_metadata)
            return Instance(fields)

        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(" -> ")
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            "question": question_field,
            "table": table_field,
            "world": world_field,
            "actions": action_field,
        }

        if self._denotation_only:
            denotation_field = LabelField(additional_metadata["answer_index"], skip_indexing=True)
            fields["denotation_target"] = denotation_field

        if self._entity_bits_mode is not None and world_extractions is not None:
            entity_bits = self._get_entity_tags(
                ["world1", "world2"], table_field, world_extractions, tokenized_question
            )
            if self._entity_bits_mode == "simple":
                entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple_collapsed":
                entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple3":
                entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag] for tag in entity_bits]

            entity_bits_field = ArrayField(np.array(entity_bits_v))
            fields["entity_bits"] = entity_bits_field

        if logical_forms:
            action_map = {
                action.rule: i for i, action in enumerate(action_field.field_list)  # type: ignore
            }
            action_sequence_fields: List[Field] = []
            for logical_form in logical_forms:
                expression = world.parse_logical_form(logical_form)
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(IndexField(action_map[production_rule], action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.info(f"Missing production rule: {error.args}, skipping logical form")
                    logger.info(f"Question was: {question}")
                    logger.info(f"Logical form was: {logical_form}")
                    continue
            fields["target_action_sequences"] = ListField(action_sequence_fields)
        fields["metadata"] = MetadataField(additional_metadata or {})
        return Instance(fields)
Exemple #23
0
    def text_to_instance(
            self,
            utterance: str,  # question
            db_id: str,
            sql: List[str] = None):
        fields: Dict[str, Field] = {}

        # db_context is db graph and its tokens. It include: utterance, db graph
        db_context = SpiderDBContext(db_id,
                                     utterance,
                                     tokenizer=self._tokenizer,
                                     tables_file=self._tables_file,
                                     dataset_path=self._dataset_path)

        # A instance contain many fields and must be filed obj in allennlp. (You can consider fields are columns in table or attribute in obj)
        # So we need to convert the db_context to a Filed obj which is table_field.
        # db_context.knowledge_graph is a graph so we need a graph field obj and SpiderKnowledgeGraphField inherit KnowledgeGraphField.
        table_field = SpiderKnowledgeGraphField(
            db_context.knowledge_graph,
            db_context.tokenized_utterance,
            self._utterance_token_indexers,
            entity_tokens=db_context.entity_tokens,
            include_in_vocab=False,  # TODO: self._use_table_for_vocab,
            max_table_tokens=None,
            conceptnet=self._concept_word)  # self._max_table_tokens)

        world = SpiderWorld(db_context, query=sql)

        fields["utterance"] = TextField(db_context.tokenized_utterance,
                                        self._utterance_token_indexers)

        # action_sequence is the parsed result by grammar. The grammar is created by certain database.
        # all_actions is the total grammar string list.
        # So you can consider action_sequence is subset of all_actions.
        # And this subset include all grammar you need for this query.
        # The grammar is defined in semparse/contexts/spider_db_grammar.py which is similar to the BNF grammar type.
        action_sequence, all_actions = world.get_action_sequence_and_all_actions(
        )

        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, rhs = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            # Help ProductionRuleField: https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html?highlight=productionrulefield#production-rule-field
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(rhs),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        # valid_actions_field is generated by all_actions that include all grammar.
        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        # give every grammar a id.
        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(valid_actions_field.field_list)
        }

        # give every grammar rule in action_sequence a total grammar.
        # So maybe you can infer this rule to others through the total grammar rules easily.
        if action_sequence:
            for production_rule in action_sequence:
                index_fields.append(
                    IndexField(action_map[production_rule],
                               valid_actions_field))
        else:  # action_sequence is None, which means: our grammar for the query is error.
            index_fields = [IndexField(-1, valid_actions_field)]
            # assert False # gan ???
        action_sequence_field = ListField(index_fields)

        # The fields["valid_actions"] include the global rule which is the same in all SQL and database specific rules.
        # For example, 'binaryop -> ["!="]' and 'query -> [select_core, groupby_clause, limit]' are global rules.
        # 'column_name -> ["department@budget_in_billions"]' and 'col_ref -> ["department@budget_in_billions"]' are not global rules.
        # So fields["valid_actions"] is case by case but will contain the same global rules. And it must include the rules appearing in fields["action_sequence"].
        # Now the attribute of _rule_id is None. But when finish loading all data, allennlp will automatically build the vocabulary and give a unique _ruled_id for every rule.
        # But when forward the fields["valid_actions"] to the model, it will become the ProductionRule List.
        # We will find that the global rules will contain a _rule_id but non-global rules will not.
        # And in ProductionRule List, the global rules will become a tuple and non-global will become a ProductionRule obj.
        # In a tuple and ProductionRule, its[0] shows its rule value, such as: 'where_clause -> ["where", expr, where_conj]' or 'col_ref -> ["department@budget_in_billions"]'.
        # In a tuple and ProductionRule, its[1] shows whether it is global rules, such as True or False.
        # In a tuple and ProductionRule, its[2] shows _rule_id. But if it is non-global rule, it will be None.
        # In a tuple and ProductionRule, its[3] shows left rule value. For example: 'where_clause' is the left rule value of 'where_clause -> ["where", expr, where_conj]'.
        # fields["valid_actions"]   # All action / All grammar

        # The information of fields["valid_actions"] is almost the same as the world.valid_actions but using different representations (world is a SpiderWorld obj)
        # There are two kinds of valid actions in the project but their information is the same.
        # The first one is a set list: (We call it as list-type-action)
        # [   #key                #value                      #key         #value   #other key and value pairs
        #    {rule: 'arg_list -> [expr, ",", arg_list]'  ,  is_global_rule: True,            ... }
        #    {rule: 'arg_list -> [expr]'  ,                 is_global_rule: True,            ... }
        #    {...}
        # ]
        # You can easily extract the all valid action to a list, such as:
        # all_actions:
        # ['arg_list -> [expr, ",", arg_list]' , 'arg_list -> [expr]' , ... ]

        # The second one is also a dict but its key is different and it will combine the same left key value together: (We call it as dict-type-action)
        # {       #key           #value-list
        #       arg_list:[
        #                   '[expr, ",", arg_list]',
        #                   '[expr]'
        #                ]
        #
        #       ...: [...]
        #       ...
        # }
        # Say it again, they are valid actions.

        # fields["utterance"]       # TextFile for utterance
        fields[
            "action_sequence"] = action_sequence_field  # grammar rules (action) of this query, and every rule contains a total grammar set which is fields["valid_actions"].
        fields["world"] = MetadataField(
            world
        )  #Maybe just for calc the metric. # A MetadataField is a Field that does not get converted into tensors. https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html?highlight=metadatafield#metadata-field
        fields["schema"] = table_field
        return Instance(fields)
Exemple #24
0
    def text_to_instance(self,  # type: ignore
                         question: str,
                         table_lines: List[List[str]],
                         answer_json: JsonDict,
                         offline_search_output: List[str] = None) -> Instance:
        """
        Reads text inputs and makes an instance. We assume we have access to DROP paragraphs parsed
        and tagged in a format similar to the tagged tables in WikiTableQuestions.
        # TODO(pradeep): Explain the format.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            Preprocessed paragraph content. See ``ParagraphQuestionContext.read_from_lines``
            for the expected format.
        answer_json : ``JsonDict``
            The "answer" dict from the original data file.
        offline_search_output : List[str], optional
            List of logical forms, produced by offline search. Not required during test.
        """
        # pylint: disable=arguments-differ
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question, self._question_token_indexers)
        # TODO(pradeep): We'll need a better way to input processed lines.
        paragraph_context = ParagraphQuestionContext.read_from_lines(table_lines,
                                                                     tokenized_question,
                                                                     self._entity_extraction_embedding,
                                                                     self._entity_extraction_distance_threshold)
        target_values_field = MetadataField(answer_json)
        world = DropWorld(paragraph_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(paragraph_context.get_knowledge_graph(),
                                          tokenized_question,
                                          self._table_token_indexers,
                                          tokenizer=self._tokenizer,
                                          include_in_vocab=self._use_table_for_vocab,
                                          max_table_tokens=self._max_table_tokens)
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_instance_specific_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {'question': question_field,
                  'table': table_field,
                  'world': world_field,
                  'actions': action_field,
                  'target_values': target_values_field}

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {action.rule: i for i, action in enumerate(action_field.field_list)}  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    expression = world.parse_logical_form(logical_form)
                except ParsingError as error:
                    logger.debug(f'Parsing error: {error.message}, skipping logical form')
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(IndexField(action_map[production_rule], action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.debug(f'Missing production rule: {error.args}, skipping logical form')
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                if len(action_sequence_fields) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda():
                agenda_index_fields.append(IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)
Exemple #25
0
    def text_to_instance(
            self,  # type: ignore
            utterances: List[str],
            sql_query: str = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        utterances: ``List[str]``, required.
            List of utterances in the interaction, the last element is the current utterance.
        sql_query: ``str``, optional
            The SQL query, given as label during training or validation.
        """
        utterance = utterances[-1]
        action_sequence: List[str] = []

        if not utterance:
            return None

        world = AtisWorld(utterances=utterances,
                          database_directory=self._database_directory)

        if sql_query:
            try:
                action_sequence = world.get_action_sequence(sql_query)
            except ParseError:
                logger.debug(f'Parsing error')

        tokenized_utterance = self._tokenizer.tokenize(utterance.lower())
        utterance_field = TextField(tokenized_utterance, self._token_indexers)

        production_rule_fields: List[Field] = []

        for production_rule in world.all_possible_actions():
            lhs, _ = production_rule.split(' ->')
            is_global_rule = not lhs in ['number', 'string']
            # The whitespaces are not semantically meaningful, so we filter them out.
            production_rule = ' '.join([
                token for token in production_rule.split(' ') if token != 'ws'
            ])
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)

        action_field = ListField(production_rule_fields)
        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(action_field.field_list)
        }
        index_fields: List[Field] = []
        world_field = MetadataField(world)
        fields = {
            'utterance': utterance_field,
            'actions': action_field,
            'world': world_field,
            'linking_scores': ArrayField(world.linking_scores)
        }

        if sql_query:
            if action_sequence:
                for production_rule in action_sequence:
                    index_fields.append(
                        IndexField(action_map[production_rule], action_field))

                action_sequence_field: List[Field] = []
                action_sequence_field.append(ListField(index_fields))
                fields['target_action_sequence'] = ListField(
                    action_sequence_field)
            else:
                # If we are given a SQL query, but we are unable to parse it, then we will skip it.
                return None

        return Instance(fields)
Exemple #26
0
    def text_to_instance(
            self,  # type: ignore
            sentence: str,
            structured_representations: List[List[List[JsonDict]]],
            labels: List[str] = None,
            target_sequences: List[List[str]] = None,
            identifier: str = None) -> Instance:
        """
        Parameters
        ----------
        sentence : ``str``
            The query sentence.
        structured_representations : ``List[List[List[JsonDict]]]``
            A list of Json representations of all the worlds. See expected format in this class' docstring.
        labels : ``List[str]`` (optional)
            List of string representations of the labels (true or false) corresponding to the
            ``structured_representations``. Not required while testing.
        target_sequences : ``List[List[str]]`` (optional)
            List of target action sequences for each element which lead to the correct denotation in
            worlds corresponding to the structured representations.
        identifier : ``str`` (optional)
            The identifier from the dataset if available.
        """
        # pylint: disable=arguments-differ
        worlds = [NlvrWorld(data) for data in structured_representations]
        tokenized_sentence = self._tokenizer.tokenize(sentence)
        sentence_field = TextField(tokenized_sentence,
                                   self._sentence_token_indexers)
        production_rule_fields: List[Field] = []
        instance_action_ids: Dict[str, int] = {}
        # TODO(pradeep): Assuming that possible actions are the same in all worlds. This may change
        # later.
        for production_rule in worlds[0].all_possible_actions():
            instance_action_ids[production_rule] = len(instance_action_ids)
            field = ProductionRuleField(production_rule, is_global_rule=True)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)
        worlds_field = ListField([MetadataField(world) for world in worlds])
        fields: Dict[str, Field] = {
            "sentence": sentence_field,
            "worlds": worlds_field,
            "actions": action_field
        }
        if identifier is not None:
            fields["identifier"] = MetadataField(identifier)
        # Depending on the type of supervision used for training the parser, we may want either
        # target action sequences or an agenda in our instance. We check if target sequences are
        # provided, and include them if they are. If not, we'll get an agenda for the sentence, and
        # include that in the instance.
        if target_sequences:
            action_sequence_fields: List[Field] = []
            for target_sequence in target_sequences:
                index_fields = ListField([
                    IndexField(instance_action_ids[action], action_field)
                    for action in target_sequence
                ])
                action_sequence_fields.append(index_fields)
                # TODO(pradeep): Define a max length for this field.
            fields["target_action_sequences"] = ListField(
                action_sequence_fields)
        elif self._output_agendas:
            # TODO(pradeep): Assuming every world gives the same agenda for a sentence. This is true
            # now, but may change later too.
            agenda = worlds[0].get_agenda_for_sentence(
                sentence, add_paths_to_agenda=False)
            assert agenda, "No agenda found for sentence: %s" % sentence
            # agenda_field contains indices into actions.
            agenda_field = ListField([
                IndexField(instance_action_ids[action], action_field)
                for action in agenda
            ])
            fields["agenda"] = agenda_field
        if labels:
            labels_field = ListField([
                LabelField(label, label_namespace='denotations')
                for label in labels
            ])
            fields["labels"] = labels_field

        return Instance(fields)
 def test_padding_lengths_are_computed_correctly(self):
     field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
     field.index(self.vocab)
     assert field.get_padding_lengths() == {}
Exemple #28
0
 def test_production_rule_field_can_print(self):
     field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
     print(field)
Exemple #29
0
    def __init__(
        self,
        image_feat_path: str,
        topk: int = -1,
        lazy: bool = False,
        reload_tsv: bool = False,
        cache_path: str = "",
        relations_path: str = "",
        attributes_path: str = "",
        objects_path: str = "",
        positive_threshold: float = 0.5,
        negative_threshold: float = 0.5,
        object_supervision: bool = False,
        require_some_positive: bool = False,
    ) -> None:
        super().__init__(lazy)
        self._tokenizer = PretrainedTransformerTokenizer("bert-base-uncased",
                                                         do_lowercase=True)
        self._word_tokenizer = WordTokenizer()
        self._token_indexers = {
            "tokens":
            PretrainedTransformerIndexer("bert-base-uncased",
                                         do_lowercase=True)
        }

        self._language = VisualReasoningGqaLanguage(None, None, None, None,
                                                    None)
        self._production_rules = self._language.all_possible_productions()
        self._action_map = {
            rule: i
            for i, rule in enumerate(self._production_rules)
        }
        production_rule_fields = [
            ProductionRuleField(rule, is_global_rule=True)
            for rule in self._production_rules
        ]
        self._production_rule_field = ListField(production_rule_fields)
        self._image_feat_cache_dir = os.path.join(
            "cache",
            image_feat_path.split("/")[-1])
        if len(cache_path) > 0:
            self._image_feat_cache_dir = os.path.join(
                cache_path, "cache",
                image_feat_path.split("/")[-1])
        self.img_data = None
        if reload_tsv:
            self.img_data = load_obj_tsv(
                image_feat_path,
                topk,
                save_cache=False,
                cache_path=self._image_feat_cache_dir,
            )
            self.img_data = {img["img_id"]: img for img in self.img_data}
        self.object_data = None
        self.attribute_data = None
        if len(objects_path) > 0:
            self.object_data = json.load(open(objects_path))
            self.object_data = {
                img["image_id"]: img
                for img in self.object_data
            }
        if len(attributes_path) > 0:
            self.attribute_data = json.load(open(attributes_path))
            self.attribute_data = {
                img["image_id"]: img
                for img in self.attribute_data
            }
        if len(relations_path) > 0:
            self.relation_data = json.load(open(relations_path))
            self.relation_data = {
                img["image_id"]: img
                for img in self.relation_data
            }
        self.positive_threshold = positive_threshold
        self.negative_threshold = negative_threshold
        self.object_supervision = object_supervision
        self.require_some_positive = require_some_positive
Exemple #30
0
 def test_index_converts_field_correctly(self):
     field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
     field.index(self.vocab)
     assert field._rule_id == self.s_rule_index
    def test_batch_tensors_does_not_modify_list(self):
        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict1 = field.as_tensor(padding_lengths)

        field = ProductionRuleField('NP -> test', is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict2 = field.as_tensor(padding_lengths)
        tensor_list = [tensor_dict1, tensor_dict2]
        assert field.batch_tensors(tensor_list) == tensor_list
Exemple #32
0
    def test_as_tensor_produces_correct_output(self):
        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 4
        assert tensor_tuple[0] == 'S -> [NP, VP]'
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.s_rule_index])

        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=False)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 4
        assert tensor_tuple[0] == 'S -> [NP, VP]'
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None
Exemple #33
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            table_lines: List[List[str]],
            target_values: List[str] = None,
            offline_search_output: List[str] = None) -> Instance:
        """
        Reads text inputs and makes an instance. We pass the ``table_lines`` to ``TableQuestionContext``, and that
        method accepts this field either as lines from CoreNLP processed tagged files that come with the dataset,
        or simply in a tsv format where each line corresponds to a row and the cells are tab-separated.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            The table content optionally preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines``
            for the expected format.
        target_values : ``List[str]``, optional
            Target values for the denotations the logical forms should execute to. Not required for testing.
        offline_search_output : ``List[str]``, optional
            List of logical forms, produced by offline search. Not required during test.
        """
        # pylint: disable=arguments-differ
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        metadata: Dict[str, Any] = {
            "question_tokens": [x.text for x in tokenized_question]
        }
        table_context = TableQuestionContext.read_from_lines(
            table_lines, tokenized_question)
        target_values_field = MetadataField(target_values)
        world = WikiTablesLanguage(table_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(
            table_context.get_table_knowledge_graph(),
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_productions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_instance_specific_entity(
                rule_right_side)
            field = ProductionRuleField(production_rule,
                                        is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'metadata': MetadataField(metadata),
            'table': table_field,
            'world': world_field,
            'actions': action_field,
            'target_values': target_values_field
        }

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    action_sequence = world.logical_form_to_action_sequence(
                        logical_form)
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except ParsingError as error:
                    logger.debug(
                        f'Parsing error: {error.message}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except KeyError as error:
                    logger.debug(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                if len(action_sequence_fields
                       ) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda(conservative=True):
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)
 def test_index_converts_field_correctly(self):
     field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
     field.index(self.vocab)
     assert field._rule_id == self.s_rule_index
    def text_to_instance(
        self,
        sentence: str,
        identifier: str,
        image_ids: List[str],
        logical_form: str = None,
        attention_mode: int = None,
        box_annotation: Dict = None,
        denotation: str = None,
    ) -> Instance:
        tokenized_sentence = self._tokenizer.tokenize(sentence)
        sentence_field = TextField(tokenized_sentence, self._token_indexers)

        world = VisualReasoningNlvr2Language(None, None, None, None, None,
                                             None)

        production_rule_fields: List[Field] = []
        instance_action_ids: Dict[str, int] = {}
        for production_rule in world.all_possible_productions():
            instance_action_ids[production_rule] = len(instance_action_ids)
            field = ProductionRuleField(production_rule, is_global_rule=True)
            production_rule_fields.append(field)

        action_field = ListField(production_rule_fields)

        boxes2 = []
        feats2 = []
        max_num_boxes = 0
        for key in image_ids:
            if self.img_data is not None:
                img_info = self.img_data[key]
            else:
                split_name = "train"
                if "dev" in key:
                    split_name = "valid"
                img_info = pickle.load(
                    open(
                        os.path.join(self._image_feat_cache_dir,
                                     split_name + "_obj36.tsv", key),
                        "rb",
                    ))
            boxes = img_info["boxes"].copy()
            feats = img_info["features"].copy()
            assert len(boxes) == len(feats)

            # Normalize the boxes (to 0 ~ 1)
            img_h, img_w = img_info["img_h"], img_info["img_w"]
            boxes[..., (0, 2)] /= img_w
            boxes[..., (1, 3)] /= img_h
            np.testing.assert_array_less(boxes, 1 + 1e-5)
            np.testing.assert_array_less(-boxes, 0 + 1e-5)

            if boxes.shape[0] > self._max_boxes:
                boxes = boxes[:self._max_boxes, :]
                feats = feats[:self._max_boxes, :]
            max_num_boxes = max(max_num_boxes, boxes.shape[0])
            boxes2.append(boxes)
            feats2.append(feats)
        boxes3 = [
            np.zeros((max_num_boxes, img_boxes.shape[-1]))
            for img_boxes in boxes2
        ]
        feats3 = [
            np.zeros((max_num_boxes, img_feats.shape[-1]))
            for img_feats in feats2
        ]
        for i in range(len(boxes2)):
            boxes3[i][:boxes2[i].shape[0], :] = boxes2[i]
            feats3[i][:feats2[i].shape[0], :] = feats2[i]
        boxes2 = boxes3
        feats2 = feats3
        feats = np.stack(feats2)
        boxes = np.stack(boxes2)
        metadata: Dict[str, Any] = {
            "utterance": sentence,
            "tokenized_utterance": tokenized_sentence,
            "identifier": identifier,
        }

        fields: Dict[str, Field] = {
            "sentence": sentence_field,
            "actions": action_field,
            "metadata": MetadataField(metadata),
            "image_id": MetadataField(identifier[:-2]),
            "visual_feat": ArrayField(feats),
            "pos": ArrayField(boxes),
        }
        if denotation is not None:
            fields["denotation"] = LabelField(denotation, skip_indexing=True)

        if logical_form:
            lisp_exp = annotation_to_lisp_exp(logical_form)
            target_sequence = world.logical_form_to_action_sequence(lisp_exp)
            index_field = [
                IndexField(instance_action_ids[action], action_field)
                for action in target_sequence
            ]
            fields["target_action_sequence"] = ListField(index_field)

            module_attention = annotation_to_module_attention(logical_form)
            target_attention = target_sequence_to_target_attn(
                target_sequence, module_attention)
            gold_question_attentions = self._assign_attention_to_tokens(
                target_attention, sentence, attention_mode)
            attn_index_field = [
                ListField(
                    [IndexField(att, sentence_field) for att in target_att])
                for target_att in gold_question_attentions
            ]
            fields["gold_question_attentions"] = ListField(attn_index_field)
            if box_annotation is None and len(self.box_annotations) > 0:
                fields["gold_box_annotations"] = MetadataField([])
            elif box_annotation is not None:
                modules = logical_form.split("\n")
                children = [[] for _ in modules]
                for j, module in enumerate(modules):
                    num_periods = len(module) - len(module.strip("."))
                    for k in range(j + 1, len(modules)):
                        num_periods_k = len(modules[k]) - len(
                            modules[k].strip("."))
                        if num_periods_k <= num_periods:
                            break
                        if num_periods_k == num_periods + 1:
                            children[j].append(k)
                for j in range(len(modules) - 1, -1, -1):
                    if modules[j].strip(".") == "in_left_image":
                        box_annotation[j] = {}
                        box_annotation[j]["module"] = modules[j].strip(".")
                        box_annotation[j][0] = box_annotation[j + 1][0]
                        box_annotation[j][1] = []
                        """for k in children[j]:
                            box_annotation[k][0] = box_annotation[k][0]
                            box_annotation[k][1] = []"""
                    elif modules[j].strip(".") == "in_right_image":
                        box_annotation[j] = {}
                        box_annotation[j]["module"] = modules[j].strip(".")
                        box_annotation[j][1] = box_annotation[j + 1][1]
                        box_annotation[j][0] = []
                    elif modules[j].strip(".") in {
                            "in_one_image", "in_other_image"
                    }:
                        box_annotation[j] = {}
                        box_annotation[j]["module"] = modules[j].strip(".")
                        box_annotation[j][0] = box_annotation[j + 1][0]
                        box_annotation[j][1] = box_annotation[j + 1][1]
                        """for k in children[j]:
                            box_annotation[k][0] = []
                            box_annotation[k][1] = box_annotation[k][1]"""
                keys = sorted(list(box_annotation.keys()))
                # print(identifier, keys)
                # print(box_annotation)
                # print(target_sequence)
                module_boxes = [(
                    mod,
                    box_annotation[mod]["module"],
                    [box_annotation[mod][0], box_annotation[mod][1]],
                ) for mod in keys]
                gold_boxes, gold_counts = target_sequence_to_target_boxes(
                    target_sequence, module_boxes, children)
                # print(identifier, target_sequence, module_boxes, gold_boxes)
                fields["gold_box_annotations"] = MetadataField(gold_boxes)
            metadata["gold"] = world.action_sequence_to_logical_form(
                target_sequence)
            fields["valid_target_sequence"] = ArrayField(
                np.array(1, dtype=np.int32))
        else:
            fields["target_action_sequence"] = ListField(
                [IndexField(0, action_field)])
            fields["gold_question_attentions"] = ListField(
                [ListField([IndexField(0, sentence_field)])])
            fields["valid_target_sequence"] = ArrayField(
                np.array(0, dtype=np.int32))
            if len(self.box_annotations) > 0:
                fields["gold_box_annotations"] = MetadataField([])
        return Instance(fields)