Exemple #1
0
 def test_get_explanation_provides_non_empty_explanation_for_typical_inputs(
         self):
     logical_form = '(infer (a:sugar higher world1) (a:diabetes higher world2) (a:diabetes higher world1))'
     entities = {'a:sugar': 'sugar', 'a:diabetes': 'diabetes'}
     world_extractions = {'world1': 'bill', 'world2': 'sue'}
     answer_index = 0
     knowledge_graph = KnowledgeGraph(entities.keys(),
                                      {key: []
                                       for key in entities}, entities)
     world = QuarelWorld(knowledge_graph, "quarel_v1_attr_entities")
     explanation = get_explanation(logical_form, world_extractions,
                                   answer_index, world)
     assert len(explanation) == 4
 def test_get_explanation_provides_non_empty_explanation_for_typical_inputs(
         self):
     logical_form = (
         "(infer (a:sugar higher world1) (a:diabetes higher world2) (a:diabetes higher world1))"
     )
     entities = {"a:sugar": "sugar", "a:diabetes": "diabetes"}
     world_extractions = {"world1": "bill", "world2": "sue"}
     answer_index = 0
     knowledge_graph = KnowledgeGraph(entities.keys(),
                                      {key: []
                                       for key in entities}, entities)
     world = QuarelWorld(knowledge_graph, "quarel_v1_attr_entities")
     explanation = get_explanation(logical_form, world_extractions,
                                   answer_index, world)
     assert len(explanation) == 4
Exemple #3
0
    def __init__(
            self,
            lazy: bool = False,
            sample: int = -1,
            lf_syntax: str = None,
            replace_world_entities: bool = False,
            align_world_extractions: bool = False,
            gold_world_extractions: bool = False,
            tagger_only: bool = False,
            denotation_only: bool = False,
            world_extraction_model: Optional[str] = None,
            skip_attributes_regex: Optional[str] = None,
            entity_bits_mode: Optional[str] = None,
            entity_types: Optional[List[str]] = None,
            lexical_cues: List[str] = None,
            tokenizer: Tokenizer = None,
            question_token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=lazy)
        self._tokenizer = tokenizer or WordTokenizer()
        self._question_token_indexers = question_token_indexers or {
            "tokens": SingleIdTokenIndexer()
        }
        self._entity_token_indexers = self._question_token_indexers
        self._sample = sample
        self._replace_world_entities = replace_world_entities
        self._lf_syntax = lf_syntax
        self._entity_bits_mode = entity_bits_mode
        self._align_world_extractions = align_world_extractions
        self._gold_world_extractions = gold_world_extractions
        self._entity_types = entity_types
        self._tagger_only = tagger_only
        self._denotation_only = denotation_only
        self._skip_attributes_regex = None
        if skip_attributes_regex is not None:
            self._skip_attributes_regex = re.compile(skip_attributes_regex)
        self._lexical_cues = lexical_cues

        # Recording of entities in categories relevant for tagging
        all_entities = {}
        all_entities["world"] = ["world1", "world2"]
        # TODO: Clarify this into an appropriate parameter
        self._collapse_tags = ["world"]

        self._all_entities = None
        if entity_types is not None:
            if self._entity_bits_mode == "collapsed":
                self._all_entities = entity_types
            else:
                self._all_entities = [
                    e for t in entity_types for e in all_entities[t]
                ]

        logger.info(f"all_entities = {self._all_entities}")

        # Base world, depending on LF syntax only
        self._knowledge_graph = KnowledgeGraph(
            entities={"placeholder"},
            neighbors={},
            entity_text={"placeholder": "placeholder"})
        self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax)

        # Decide dynamic entities, if any
        self._dynamic_entities: Dict[str, str] = dict()
        self._use_attr_entities = False
        if "_attr_entities" in lf_syntax:
            self._use_attr_entities = True
            qr_coeff_sets = self._world.qr_coeff_sets
            for qset in qr_coeff_sets:
                for attribute in qset:
                    if (self._skip_attributes_regex is not None
                            and self._skip_attributes_regex.search(attribute)):
                        continue
                    # Get text associated with each entity, both from entity identifier and
                    # associated lexical cues, if any
                    entity_strings = [
                        words_from_entity_string(attribute).lower()
                    ]
                    if self._lexical_cues is not None:
                        for key in self._lexical_cues:
                            if attribute in LEXICAL_CUES[key]:
                                entity_strings += LEXICAL_CUES[key][attribute]
                    self._dynamic_entities["a:" + attribute] = " ".join(
                        entity_strings)

        # Update world to include dynamic entities
        if self._use_attr_entities:
            logger.info(f"dynamic_entities = {self._dynamic_entities}")
            neighbors: Dict[str, List[str]] = {
                key: []
                for key in self._dynamic_entities
            }
            self._knowledge_graph = KnowledgeGraph(
                entities=set(self._dynamic_entities.keys()),
                neighbors=neighbors,
                entity_text=self._dynamic_entities)
            self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax)

        self._stemmer = PorterStemmer().stemmer

        self._world_tagger_extractor = None
        self._extract_worlds = False
        if world_extraction_model is not None:
            logger.info("Loading world tagger model...")
            self._extract_worlds = True
            self._world_tagger_extractor = WorldTaggerExtractor(
                world_extraction_model)
            logger.info("Done loading world tagger model!")

        # Convenience regex for recognizing attributes
        self._attr_regex = re.compile(r"""\((\w+) (high|low|higher|lower)""")
Exemple #4
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            logical_forms: List[str] = None,
            additional_metadata: Dict[str, Any] = None,
            world_extractions: Dict[str, Union[str, List[str]]] = None,
            entity_literals: Dict[str, Union[str, List[str]]] = None,
            tokenized_question: List[Token] = None,
            debug_counter: int = None,
            qr_spec_override: List[Dict[str, int]] = None,
            dynamic_entities_override: Dict[str, str] = None) -> Instance:

        # pylint: disable=arguments-differ
        tokenized_question = tokenized_question or self._tokenizer.tokenize(
            question.lower())
        additional_metadata = additional_metadata or dict()
        additional_metadata['question_tokens'] = [
            token.text for token in tokenized_question
        ]
        if world_extractions is not None:
            additional_metadata['world_extractions'] = world_extractions
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)

        if qr_spec_override is not None or dynamic_entities_override is not None:
            # Dynamically specify theory and/or entities
            dynamic_entities = dynamic_entities_override or self._dynamic_entities
            neighbors: Dict[str, List[str]] = {
                key: []
                for key in dynamic_entities.keys()
            }
            knowledge_graph = KnowledgeGraph(entities=set(
                dynamic_entities.keys()),
                                             neighbors=neighbors,
                                             entity_text=dynamic_entities)
            world = QuarelWorld(knowledge_graph,
                                self._lf_syntax,
                                qr_coeff_sets=qr_spec_override)
        else:
            knowledge_graph = self._knowledge_graph
            world = self._world

        table_field = KnowledgeGraphField(knowledge_graph,
                                          tokenized_question,
                                          self._entity_token_indexers,
                                          tokenizer=self._tokenizer)

        if self._tagger_only:
            fields: Dict[str, Field] = {'tokens': question_field}
            if entity_literals is not None:
                entity_tags = self._get_entity_tags(self._all_entities,
                                                    table_field,
                                                    entity_literals,
                                                    tokenized_question)
                if debug_counter > 0:
                    logger.info(f'raw entity tags = {entity_tags}')
                entity_tags_bio = self._convert_tags_bio(entity_tags)
                fields['tags'] = SequenceLabelField(entity_tags_bio,
                                                    question_field)
                additional_metadata['tags_gold'] = entity_tags_bio
            additional_metadata['words'] = [x.text for x in tokenized_question]
            fields['metadata'] = MetadataField(additional_metadata)
            return Instance(fields)

        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'table': table_field,
            'world': world_field,
            'actions': action_field
        }

        if self._denotation_only:
            denotation_field = LabelField(additional_metadata['answer_index'],
                                          skip_indexing=True)
            fields['denotation_target'] = denotation_field

        if self._entity_bits_mode is not None and world_extractions is not None:
            entity_bits = self._get_entity_tags(['world1', 'world2'],
                                                table_field, world_extractions,
                                                tokenized_question)
            if self._entity_bits_mode == "simple":
                entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag]
                                 for tag in entity_bits]
            elif self._entity_bits_mode == "simple_collapsed":
                entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple3":
                entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag]
                                 for tag in entity_bits]

            entity_bits_field = ArrayField(np.array(entity_bits_v))
            fields['entity_bits'] = entity_bits_field

        if logical_forms:
            action_map = {
                action.rule: i
                for i, action in enumerate(action_field.field_list)
            }  # type: ignore
            action_sequence_fields: List[Field] = []
            for logical_form in logical_forms:
                expression = world.parse_logical_form(logical_form)
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.info(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.info(f'Question was: {question}')
                    logger.info(f'Logical form was: {logical_form}')
                    continue
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        fields['metadata'] = MetadataField(additional_metadata or {})
        return Instance(fields)
    def _create_grammar_state(self,
                              world: QuarelWorld,
                              possible_actions: List[ProductionRule],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``QuarelWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [action_map[action_string] for action_string in action_strings]
            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append((production_rule_array[2], action_index))
                else:
                    linked_actions.append((production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(global_action_tensor)
                global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(global_action_tensor)
            translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                       global_output_embeddings,
                                                       list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor)
                translated_valid_actions[key]['linked'] = (entity_linking_scores,
                                                           entity_type_embeddings,
                                                           list(linked_action_ids))

        return GrammarStatelet([START_SYMBOL],
                               translated_valid_actions,
                               type_declaration.is_nonterminal)
Exemple #6
0
def get_explanation(logical_form: str, world_extractions: JsonDict,
                    answer_index: int, world: QuarelWorld) -> List[JsonDict]:
    """
    Create explanation (as a list of header/content entries) for an answer
    """
    output = []
    nl_world = {}
    if world_extractions["world1"] != "N/A" and world_extractions[
            "world1"] != ["N/A"]:
        nl_world["world1"] = nl_world_string(world_extractions["world1"])
        nl_world["world2"] = nl_world_string(world_extractions["world2"])
        output.append({
            "header":
            "Identified two worlds",
            "content": [
                f"""world1 = {nl_world['world1']}""",
                f"""world2 = {nl_world['world2']}""",
            ],
        })
    else:
        nl_world["world1"] = "world1"
        nl_world["world2"] = "world2"
    parse = util.lisp_to_nested_expression(logical_form)
    if parse[0] != "infer":
        return None
    setup = parse[1]
    output.append({
        "header": "The question is stating",
        "content": nl_arg(setup, nl_world)
    })
    answers = parse[2:]
    output.append({
        "header":
        "The answer options are stating",
        "content": [
            "A: " + " and ".join(nl_arg(answers[0], nl_world)),
            "B: " + " and ".join(nl_arg(answers[1], nl_world)),
        ],
    })
    setup_core = setup
    if setup[0] == "and":
        setup_core = setup[1]
    s_attr = setup_core[0]
    s_dir = world.qr_size[setup_core[1]]
    s_world = nl_world[setup_core[2]]
    a_attr = answers[answer_index][0]
    qr_dir = world._get_qr_coeff(strip_entity_type(s_attr),
                                 strip_entity_type(a_attr))
    a_dir = s_dir * qr_dir
    a_world = nl_world[answers[answer_index][2]]

    content = [
        f"When {nl_attr(s_attr)} is {nl_dir(s_dir)} " +
        f"then {nl_attr(a_attr)} is {nl_dir(a_dir)} (for {s_world})"
    ]
    if a_world != s_world:
        content.append(
            f"""Therefore {nl_attr(a_attr)} is {nl_dir(-a_dir)} for {a_world}"""
        )
    content.append(f"Therefore {chr(65+answer_index)} is the correct answer")

    output.append({"header": "Theory used", "content": content})

    return output