Пример #1
0
    def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance:
        question_tokens = self._read_tokens_from_json_list(
            json_obj['question_tokens'])
        question_field = TextField(question_tokens,
                                   self._question_token_indexers)
        table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines(
            json_obj['table_lines'], question_tokens)
        entity_tokens = [
            self._read_tokens_from_json_list(token_list)
            for token_list in json_obj['entity_texts']
        ]
        table_field = KnowledgeGraphField(
            table_knowledge_graph,
            question_tokens,
            tokenizer=None,
            token_indexers=self._table_token_indexers,
            entity_tokens=entity_tokens,
            linking_features=json_obj['linking_features'],
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        world = WikiTablesWorld(table_knowledge_graph)
        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        example_string_field = MetadataField(json_obj['example_lisp_string'])

        fields = {
            'question': question_field,
            'table': table_field,
            'world': world_field,
            'actions': action_field,
            'example_lisp_string': example_string_field
        }

        if 'target_action_sequences' in json_obj:
            action_map = {
                action.rule: i
                for i, action in enumerate(action_field.field_list)
            }  # type: ignore
            action_sequence_fields: List[Field] = []
            for sequence in json_obj['target_action_sequences']:
                index_fields: List[Field] = []
                for production_rule in sequence:
                    index_fields.append(
                        IndexField(action_map[production_rule], action_field))
                action_sequence_fields.append(ListField(index_fields))
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)

        return Instance(fields)
Пример #2
0
    def text_to_instance(self,  # type: ignore
                         logical_forms: List[str],
                         table_lines: List[List[str]],
                         question: str) -> Instance:
        # pylint: disable=arguments-differ
        tokenized_question = self._tokenizer.tokenize(question.lower())
        tokenized_question.insert(0, Token(START_SYMBOL))
        tokenized_question.append(Token(END_SYMBOL))
        question_field = TextField(tokenized_question, self._question_token_indexers)
        table_context = TableQuestionContext.read_from_lines(table_lines, tokenized_question)
        world = WikiTablesLanguage(table_context)

        action_sequences_list: List[List[str]] = []
        action_sequence_fields_list: List[TextField] = []
        for logical_form in logical_forms:
            try:
                action_sequence = world.logical_form_to_action_sequence(logical_form)
                action_sequence = reader_utils.make_bottom_up_action_sequence(action_sequence,
                                                                              world.is_nonterminal)
                action_sequence_field = TextField([Token(rule) for rule in  action_sequence],
                                                  self._rule_indexers)
                action_sequences_list.append(action_sequence)
                action_sequence_fields_list.append(action_sequence_field)
            except ParsingError as error:
                logger.debug(f'Parsing error: {error.message}, skipping logical form')
                logger.debug(f'Question was: {question}')
                logger.debug(f'Logical form was: {logical_form}')
                logger.debug(f'Table info was: {table_lines}')
            except:
                logger.error(logical_form)
                raise

        if not action_sequences_list:
            return None

        all_production_rule_fields: List[List[Field]] = []
        for action_sequence in action_sequences_list:
            all_production_rule_fields.append([])
            for production_rule in action_sequence:
                _, rule_right_side = production_rule.split(' -> ')
                is_global_rule = not world.is_instance_specific_entity(rule_right_side)
                field = ProductionRuleField(production_rule, is_global_rule=is_global_rule)
                all_production_rule_fields[-1].append(field)
        action_field = ListField([ListField(production_rule_fields) for production_rule_fields in
                                  all_production_rule_fields])

        fields = {'action_sequences': ListField(action_sequence_fields_list),
                  'target_tokens': question_field,
                  'world': MetadataField(world),
                  'actions': action_field}

        return Instance(fields)
Пример #3
0
    def text_to_instance(
        self,  # type: ignore
        question: str,
        logical_forms: List[str] = None,
        additional_metadata: Dict[str, Any] = None,
        world_extractions: Dict[str, Union[str, List[str]]] = None,
        entity_literals: Dict[str, Union[str, List[str]]] = None,
        tokenized_question: List[Token] = None,
        debug_counter: int = None,
        qr_spec_override: List[Dict[str, int]] = None,
        dynamic_entities_override: Dict[str, str] = None,
    ) -> Instance:

        tokenized_question = tokenized_question or self._tokenizer.tokenize(question.lower())
        additional_metadata = additional_metadata or dict()
        additional_metadata["question_tokens"] = [token.text for token in tokenized_question]
        if world_extractions is not None:
            additional_metadata["world_extractions"] = world_extractions
        question_field = TextField(tokenized_question, self._question_token_indexers)

        if qr_spec_override is not None or dynamic_entities_override is not None:
            # Dynamically specify theory and/or entities
            dynamic_entities = dynamic_entities_override or self._dynamic_entities
            neighbors: Dict[str, List[str]] = {key: [] for key in dynamic_entities.keys()}
            knowledge_graph = KnowledgeGraph(
                entities=set(dynamic_entities.keys()),
                neighbors=neighbors,
                entity_text=dynamic_entities,
            )
            world = QuarelWorld(knowledge_graph, self._lf_syntax, qr_coeff_sets=qr_spec_override)
        else:
            knowledge_graph = self._knowledge_graph
            world = self._world

        table_field = KnowledgeGraphField(
            knowledge_graph,
            tokenized_question,
            self._entity_token_indexers,
            tokenizer=self._tokenizer,
        )

        if self._tagger_only:
            fields: Dict[str, Field] = {"tokens": question_field}
            if entity_literals is not None:
                entity_tags = self._get_entity_tags(
                    self._all_entities, table_field, entity_literals, tokenized_question
                )
                if debug_counter > 0:
                    logger.info(f"raw entity tags = {entity_tags}")
                entity_tags_bio = self._convert_tags_bio(entity_tags)
                fields["tags"] = SequenceLabelField(entity_tags_bio, question_field)
                additional_metadata["tags_gold"] = entity_tags_bio
            additional_metadata["words"] = [x.text for x in tokenized_question]
            fields["metadata"] = MetadataField(additional_metadata)
            return Instance(fields)

        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(" -> ")
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            "question": question_field,
            "table": table_field,
            "world": world_field,
            "actions": action_field,
        }

        if self._denotation_only:
            denotation_field = LabelField(additional_metadata["answer_index"], skip_indexing=True)
            fields["denotation_target"] = denotation_field

        if self._entity_bits_mode is not None and world_extractions is not None:
            entity_bits = self._get_entity_tags(
                ["world1", "world2"], table_field, world_extractions, tokenized_question
            )
            if self._entity_bits_mode == "simple":
                entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple_collapsed":
                entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple3":
                entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag] for tag in entity_bits]

            entity_bits_field = ArrayField(np.array(entity_bits_v))
            fields["entity_bits"] = entity_bits_field

        if logical_forms:
            action_map = {
                action.rule: i for i, action in enumerate(action_field.field_list)  # type: ignore
            }
            action_sequence_fields: List[Field] = []
            for logical_form in logical_forms:
                expression = world.parse_logical_form(logical_form)
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(IndexField(action_map[production_rule], action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.info(f"Missing production rule: {error.args}, skipping logical form")
                    logger.info(f"Question was: {question}")
                    logger.info(f"Logical form was: {logical_form}")
                    continue
            fields["target_action_sequences"] = ListField(action_sequence_fields)
        fields["metadata"] = MetadataField(additional_metadata or {})
        return Instance(fields)
Пример #4
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}
        """KAIMARY"""
        # Contains
        # 1. db schema(Tables with corresponding columns)
        # 2. Tokenized utterance
        # 3. Knowledge graph(Table entities, column entities and "text" column type related token entities)
        # 4. Entity_tokens(Retrieved from entities_text from kg)
        db_context = SpiderDBContext(db_id,
                                     utterance,
                                     tokenizer=self._tokenizer,
                                     tables_file=self._tables_file,
                                     dataset_path=self._dataset_path)
        # https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html#knowledge-graph-field
        # *feature extractors*
        table_field = SpiderKnowledgeGraphField(
            db_context.knowledge_graph,
            db_context.tokenized_utterance,
            self._utterance_token_indexers,
            entity_tokens=db_context.entity_tokens,
            include_in_vocab=False,  # TODO: self._use_table_for_vocab,
            max_table_tokens=None)  # self._max_table_tokens)

        world = SpiderWorld(db_context, query=sql)
        fields["utterance"] = TextField(db_context.tokenized_utterance,
                                        self._utterance_token_indexers)

        action_sequence, all_actions = world.get_action_sequence_and_all_actions(
        )

        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, rhs = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(rhs),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(valid_actions_field.field_list)
        }

        for production_rule in action_sequence:
            index_fields.append(
                IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        fields["world"] = MetadataField(world)
        fields["schema"] = table_field
        return Instance(fields)
Пример #5
0
    def text_to_instance(self,
                         utterances: List[str],
                         db_id: str,
                         sql: List[List[str]] = None):
        fields: Dict[str, Field] = {}

        ctxts = [
            SpiderDBContext(db_id,
                            utterance,
                            tokenizer=self._tokenizer,
                            tables_file=self._tables_file,
                            dataset_path=self._dataset_path)
            for utterance in utterances
        ]

        super_utterance = ' '.join(utterances)
        hack_ctxt = SpiderDBContext(db_id,
                                    super_utterance,
                                    tokenizer=self._tokenizer,
                                    tables_file=self._tables_file,
                                    dataset_path=self._dataset_path)

        kg = SpiderKnowledgeGraphField(
            hack_ctxt.knowledge_graph,
            hack_ctxt.tokenized_utterance,
            self._utterance_token_indexers,
            entity_tokens=hack_ctxt.entity_tokens,
            include_in_vocab=False,  # TODO: self._use_table_for_vocab,
            max_table_tokens=None)
        '''
        kgs = [SpiderKnowledgeGraphField(db_context.knowledge_graph,
                                        db_context.tokenized_utterance,
                                        self._utterance_token_indexers,
                                        entity_tokens=db_context.entity_tokens,
                                        include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                                        max_table_tokens=None)  # self._max_table_tokens)
                        for db_context in ctxts]
        '''
        worlds = []

        for i in range(len(sql)):
            sqli = sql[i]
            db_context = ctxts[i]
            world = SpiderWorld(db_context, query=sqli)
            worlds.append(world)

        fields["utterances"] = ListField([
            TextField(db_context.tokenized_utterance,
                      self._utterance_token_indexers) for db_context in ctxts
        ])

        #action_sequence, all_actions = world.get_action_sequence_and_all_actions()

        action_tups = [
            world.get_action_sequence_and_all_actions() for world in worlds
        ]
        action_sequences = [tup[0] for tup in action_tups]
        all_actions = [tup[1] for tup in action_tups]

        for i in range(len(action_sequences)):
            action_sequence = action_sequences[i]

            if action_sequence is None and self._keep_if_unparsable:
                # print("Parse error")
                action_sequence = []
            elif action_sequence is None:
                return None

            action_sequences[i] = action_sequence

        all_valid_actions_fields = []
        all_action_sequence_fields = []

        for i in range(len(all_actions)):
            index_fields: List[Field] = []
            production_rule_fields: List[Field] = []

            all_actionsi = all_actions[i]
            for production_rule in all_actionsi:
                nonterminal, rhs = production_rule.split(' -> ')
                production_rule = ' '.join(production_rule.split(' '))
                field = ProductionRuleField(production_rule,
                                            world.is_global_rule(rhs),
                                            nonterminal=nonterminal)
                production_rule_fields.append(field)

            valid_actions_field = ListField(production_rule_fields)

            all_valid_actions_fields.append(valid_actions_field)

            action_map = {
                action.rule: i  # type: ignore
                for i, action in enumerate(valid_actions_field.field_list)
            }

            index_fields: List[Field] = []

            action_sequence = action_sequences[i]

            for production_rule in action_sequence:
                index_fields.append(
                    IndexField(action_map[production_rule],
                               valid_actions_field))
            if not action_sequence:
                index_fields = [IndexField(-1, valid_actions_field)]

            action_sequence_field = ListField(index_fields)

            all_action_sequence_fields.append(action_sequence_field)

        fields["valid_actions"] = ListField(all_valid_actions_fields)
        fields["action_sequences"] = ListField(all_action_sequence_fields)
        fields["worlds"] = ListField(
            [MetadataField(world) for world in worlds])
        fields["schema"] = kg
        '''
        fields['utterances'] = ListField[TextField]
        fields['valid_actions'] = ListField[ListField[ProductionRuleField]]
        fields['action_sequences'] = ListField[ListField[IndexField]]
        fields['worlds'] = ListField[MetadataField[SpiderWorld]]
        fields['schemas'] = ListField[SpiderKnowledgeGraphField]
        '''

        return Instance(fields)