Ejemplo n.º 1
0
 def test_get_explanation_provides_non_empty_explanation_for_typical_inputs(
         self):
     logical_form = '(infer (a:sugar higher world1) (a:diabetes higher world2) (a:diabetes higher world1))'
     entities = {'a:sugar': 'sugar', 'a:diabetes': 'diabetes'}
     world_extractions = {'world1': 'bill', 'world2': 'sue'}
     answer_index = 0
     knowledge_graph = KnowledgeGraph(entities.keys(),
                                      {key: []
                                       for key in entities}, entities)
     world = QuarelWorld(knowledge_graph, "quarel_v1_attr_entities")
     explanation = get_explanation(logical_form, world_extractions,
                                   answer_index, world)
     assert len(explanation) == 4
    def get_table_knowledge_graph(self) -> KnowledgeGraph:
        if self._table_knowledge_graph is None:
            entities: Set[str] = set()
            neighbors: Dict[str, List[str]] = defaultdict(list)
            entity_text: Dict[str, str] = {}
            # Add all column names to entities. We'll define their neighbors to be empty lists for
            # now, and later add number and string entities as needed.
            number_columns = []
            date_columns = []
            for typed_column_name in self.column_names:
                if "number_column:" in typed_column_name or "num2_column" in typed_column_name:
                    number_columns.append(typed_column_name)

                if "date_column:" in typed_column_name:
                    date_columns.append(typed_column_name)

                # Add column names to entities, with no neighbors yet.
                entities.add(typed_column_name)
                neighbors[typed_column_name] = []
                entity_text[typed_column_name] = typed_column_name.split(
                    ":", 1)[-1].replace("_", " ")

            string_entities, numbers = self.get_entities_from_question()
            for entity, column_names in string_entities:
                entities.add(entity)
                for column_name in column_names:
                    neighbors[entity].append(column_name)
                    neighbors[column_name].append(entity)
                entity_text[entity] = entity.replace("string:",
                                                     "").replace("_", " ")
            # For all numbers (except -1), we add all number and date columns as their neighbors.
            for number, _ in numbers:
                entities.add(number)
                neighbors[number].extend(number_columns + date_columns)
                for column_name in number_columns + date_columns:
                    neighbors[column_name].append(number)
                entity_text[number] = number
            for entity, entity_neighbors in neighbors.items():
                neighbors[entity] = list(set(entity_neighbors))

            # Add "-1" as an entity only if we have date columns in the table because we will need
            # it as a wild-card in dates. The neighbors are the date columns.
            if "-1" not in neighbors and date_columns:
                entities.add("-1")
                neighbors["-1"] = date_columns
                entity_text["-1"] = "-1"
                for date_column in date_columns:
                    neighbors[date_column].append("-1")
            self._table_knowledge_graph = KnowledgeGraph(
                entities, dict(neighbors), entity_text)
        return self._table_knowledge_graph
Ejemplo n.º 3
0
 def test_get_explanation_provides_non_empty_explanation_for_typical_inputs(
         self):
     logical_form = (
         "(infer (a:sugar higher world1) (a:diabetes higher world2) (a:diabetes higher world1))"
     )
     entities = {"a:sugar": "sugar", "a:diabetes": "diabetes"}
     world_extractions = {"world1": "bill", "world2": "sue"}
     answer_index = 0
     knowledge_graph = KnowledgeGraph(entities.keys(),
                                      {key: []
                                       for key in entities}, entities)
     world = QuarelWorld(knowledge_graph, "quarel_v1_attr_entities")
     explanation = get_explanation(logical_form, world_extractions,
                                   answer_index, world)
     assert len(explanation) == 4
Ejemplo n.º 4
0
    def __init__(
            self,
            lazy: bool = False,
            sample: int = -1,
            lf_syntax: str = None,
            replace_world_entities: bool = False,
            align_world_extractions: bool = False,
            gold_world_extractions: bool = False,
            tagger_only: bool = False,
            denotation_only: bool = False,
            world_extraction_model: Optional[str] = None,
            skip_attributes_regex: Optional[str] = None,
            entity_bits_mode: Optional[str] = None,
            entity_types: Optional[List[str]] = None,
            lexical_cues: List[str] = None,
            tokenizer: Tokenizer = None,
            question_token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=lazy)
        self._tokenizer = tokenizer or WordTokenizer()
        self._question_token_indexers = question_token_indexers or {
            "tokens": SingleIdTokenIndexer()
        }
        self._entity_token_indexers = self._question_token_indexers
        self._sample = sample
        self._replace_world_entities = replace_world_entities
        self._lf_syntax = lf_syntax
        self._entity_bits_mode = entity_bits_mode
        self._align_world_extractions = align_world_extractions
        self._gold_world_extractions = gold_world_extractions
        self._entity_types = entity_types
        self._tagger_only = tagger_only
        self._denotation_only = denotation_only
        self._skip_attributes_regex = None
        if skip_attributes_regex is not None:
            self._skip_attributes_regex = re.compile(skip_attributes_regex)
        self._lexical_cues = lexical_cues

        # Recording of entities in categories relevant for tagging
        all_entities = {}
        all_entities["world"] = ["world1", "world2"]
        # TODO: Clarify this into an appropriate parameter
        self._collapse_tags = ["world"]

        self._all_entities = None
        if entity_types is not None:
            if self._entity_bits_mode == "collapsed":
                self._all_entities = entity_types
            else:
                self._all_entities = [
                    e for t in entity_types for e in all_entities[t]
                ]

        logger.info(f"all_entities = {self._all_entities}")

        # Base world, depending on LF syntax only
        self._knowledge_graph = KnowledgeGraph(
            entities={"placeholder"},
            neighbors={},
            entity_text={"placeholder": "placeholder"})
        self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax)

        # Decide dynamic entities, if any
        self._dynamic_entities: Dict[str, str] = dict()
        self._use_attr_entities = False
        if "_attr_entities" in lf_syntax:
            self._use_attr_entities = True
            qr_coeff_sets = self._world.qr_coeff_sets
            for qset in qr_coeff_sets:
                for attribute in qset:
                    if (self._skip_attributes_regex is not None
                            and self._skip_attributes_regex.search(attribute)):
                        continue
                    # Get text associated with each entity, both from entity identifier and
                    # associated lexical cues, if any
                    entity_strings = [
                        words_from_entity_string(attribute).lower()
                    ]
                    if self._lexical_cues is not None:
                        for key in self._lexical_cues:
                            if attribute in LEXICAL_CUES[key]:
                                entity_strings += LEXICAL_CUES[key][attribute]
                    self._dynamic_entities["a:" + attribute] = " ".join(
                        entity_strings)

        # Update world to include dynamic entities
        if self._use_attr_entities:
            logger.info(f"dynamic_entities = {self._dynamic_entities}")
            neighbors: Dict[str, List[str]] = {
                key: []
                for key in self._dynamic_entities
            }
            self._knowledge_graph = KnowledgeGraph(
                entities=set(self._dynamic_entities.keys()),
                neighbors=neighbors,
                entity_text=self._dynamic_entities)
            self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax)

        self._stemmer = PorterStemmer().stemmer

        self._world_tagger_extractor = None
        self._extract_worlds = False
        if world_extraction_model is not None:
            logger.info("Loading world tagger model...")
            self._extract_worlds = True
            self._world_tagger_extractor = WorldTaggerExtractor(
                world_extraction_model)
            logger.info("Done loading world tagger model!")

        # Convenience regex for recognizing attributes
        self._attr_regex = re.compile(r"""\((\w+) (high|low|higher|lower)""")
Ejemplo n.º 5
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            logical_forms: List[str] = None,
            additional_metadata: Dict[str, Any] = None,
            world_extractions: Dict[str, Union[str, List[str]]] = None,
            entity_literals: Dict[str, Union[str, List[str]]] = None,
            tokenized_question: List[Token] = None,
            debug_counter: int = None,
            qr_spec_override: List[Dict[str, int]] = None,
            dynamic_entities_override: Dict[str, str] = None) -> Instance:

        # pylint: disable=arguments-differ
        tokenized_question = tokenized_question or self._tokenizer.tokenize(
            question.lower())
        additional_metadata = additional_metadata or dict()
        additional_metadata['question_tokens'] = [
            token.text for token in tokenized_question
        ]
        if world_extractions is not None:
            additional_metadata['world_extractions'] = world_extractions
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)

        if qr_spec_override is not None or dynamic_entities_override is not None:
            # Dynamically specify theory and/or entities
            dynamic_entities = dynamic_entities_override or self._dynamic_entities
            neighbors: Dict[str, List[str]] = {
                key: []
                for key in dynamic_entities.keys()
            }
            knowledge_graph = KnowledgeGraph(entities=set(
                dynamic_entities.keys()),
                                             neighbors=neighbors,
                                             entity_text=dynamic_entities)
            world = QuarelWorld(knowledge_graph,
                                self._lf_syntax,
                                qr_coeff_sets=qr_spec_override)
        else:
            knowledge_graph = self._knowledge_graph
            world = self._world

        table_field = KnowledgeGraphField(knowledge_graph,
                                          tokenized_question,
                                          self._entity_token_indexers,
                                          tokenizer=self._tokenizer)

        if self._tagger_only:
            fields: Dict[str, Field] = {'tokens': question_field}
            if entity_literals is not None:
                entity_tags = self._get_entity_tags(self._all_entities,
                                                    table_field,
                                                    entity_literals,
                                                    tokenized_question)
                if debug_counter > 0:
                    logger.info(f'raw entity tags = {entity_tags}')
                entity_tags_bio = self._convert_tags_bio(entity_tags)
                fields['tags'] = SequenceLabelField(entity_tags_bio,
                                                    question_field)
                additional_metadata['tags_gold'] = entity_tags_bio
            additional_metadata['words'] = [x.text for x in tokenized_question]
            fields['metadata'] = MetadataField(additional_metadata)
            return Instance(fields)

        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'table': table_field,
            'world': world_field,
            'actions': action_field
        }

        if self._denotation_only:
            denotation_field = LabelField(additional_metadata['answer_index'],
                                          skip_indexing=True)
            fields['denotation_target'] = denotation_field

        if self._entity_bits_mode is not None and world_extractions is not None:
            entity_bits = self._get_entity_tags(['world1', 'world2'],
                                                table_field, world_extractions,
                                                tokenized_question)
            if self._entity_bits_mode == "simple":
                entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag]
                                 for tag in entity_bits]
            elif self._entity_bits_mode == "simple_collapsed":
                entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple3":
                entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag]
                                 for tag in entity_bits]

            entity_bits_field = ArrayField(np.array(entity_bits_v))
            fields['entity_bits'] = entity_bits_field

        if logical_forms:
            action_map = {
                action.rule: i
                for i, action in enumerate(action_field.field_list)
            }  # type: ignore
            action_sequence_fields: List[Field] = []
            for logical_form in logical_forms:
                expression = world.parse_logical_form(logical_form)
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.info(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.info(f'Question was: {question}')
                    logger.info(f'Logical form was: {logical_form}')
                    continue
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        fields['metadata'] = MetadataField(additional_metadata or {})
        return Instance(fields)
 def empty_field(self) -> "KnowledgeGraphField":
     return KnowledgeGraphField(KnowledgeGraph(set(), {}), [],
                                self._token_indexers)