def test_get_explanation_provides_non_empty_explanation_for_typical_inputs( self): logical_form = '(infer (a:sugar higher world1) (a:diabetes higher world2) (a:diabetes higher world1))' entities = {'a:sugar': 'sugar', 'a:diabetes': 'diabetes'} world_extractions = {'world1': 'bill', 'world2': 'sue'} answer_index = 0 knowledge_graph = KnowledgeGraph(entities.keys(), {key: [] for key in entities}, entities) world = QuarelWorld(knowledge_graph, "quarel_v1_attr_entities") explanation = get_explanation(logical_form, world_extractions, answer_index, world) assert len(explanation) == 4
def test_get_explanation_provides_non_empty_explanation_for_typical_inputs( self): logical_form = ( "(infer (a:sugar higher world1) (a:diabetes higher world2) (a:diabetes higher world1))" ) entities = {"a:sugar": "sugar", "a:diabetes": "diabetes"} world_extractions = {"world1": "bill", "world2": "sue"} answer_index = 0 knowledge_graph = KnowledgeGraph(entities.keys(), {key: [] for key in entities}, entities) world = QuarelWorld(knowledge_graph, "quarel_v1_attr_entities") explanation = get_explanation(logical_form, world_extractions, answer_index, world) assert len(explanation) == 4
def __init__( self, lazy: bool = False, sample: int = -1, lf_syntax: str = None, replace_world_entities: bool = False, align_world_extractions: bool = False, gold_world_extractions: bool = False, tagger_only: bool = False, denotation_only: bool = False, world_extraction_model: Optional[str] = None, skip_attributes_regex: Optional[str] = None, entity_bits_mode: Optional[str] = None, entity_types: Optional[List[str]] = None, lexical_cues: List[str] = None, tokenizer: Tokenizer = None, question_token_indexers: Dict[str, TokenIndexer] = None) -> None: super().__init__(lazy=lazy) self._tokenizer = tokenizer or WordTokenizer() self._question_token_indexers = question_token_indexers or { "tokens": SingleIdTokenIndexer() } self._entity_token_indexers = self._question_token_indexers self._sample = sample self._replace_world_entities = replace_world_entities self._lf_syntax = lf_syntax self._entity_bits_mode = entity_bits_mode self._align_world_extractions = align_world_extractions self._gold_world_extractions = gold_world_extractions self._entity_types = entity_types self._tagger_only = tagger_only self._denotation_only = denotation_only self._skip_attributes_regex = None if skip_attributes_regex is not None: self._skip_attributes_regex = re.compile(skip_attributes_regex) self._lexical_cues = lexical_cues # Recording of entities in categories relevant for tagging all_entities = {} all_entities["world"] = ["world1", "world2"] # TODO: Clarify this into an appropriate parameter self._collapse_tags = ["world"] self._all_entities = None if entity_types is not None: if self._entity_bits_mode == "collapsed": self._all_entities = entity_types else: self._all_entities = [ e for t in entity_types for e in all_entities[t] ] logger.info(f"all_entities = {self._all_entities}") # Base world, depending on LF syntax only self._knowledge_graph = KnowledgeGraph( entities={"placeholder"}, neighbors={}, entity_text={"placeholder": "placeholder"}) self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax) # Decide dynamic entities, if any self._dynamic_entities: Dict[str, str] = dict() self._use_attr_entities = False if "_attr_entities" in lf_syntax: self._use_attr_entities = True qr_coeff_sets = self._world.qr_coeff_sets for qset in qr_coeff_sets: for attribute in qset: if (self._skip_attributes_regex is not None and self._skip_attributes_regex.search(attribute)): continue # Get text associated with each entity, both from entity identifier and # associated lexical cues, if any entity_strings = [ words_from_entity_string(attribute).lower() ] if self._lexical_cues is not None: for key in self._lexical_cues: if attribute in LEXICAL_CUES[key]: entity_strings += LEXICAL_CUES[key][attribute] self._dynamic_entities["a:" + attribute] = " ".join( entity_strings) # Update world to include dynamic entities if self._use_attr_entities: logger.info(f"dynamic_entities = {self._dynamic_entities}") neighbors: Dict[str, List[str]] = { key: [] for key in self._dynamic_entities } self._knowledge_graph = KnowledgeGraph( entities=set(self._dynamic_entities.keys()), neighbors=neighbors, entity_text=self._dynamic_entities) self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax) self._stemmer = PorterStemmer().stemmer self._world_tagger_extractor = None self._extract_worlds = False if world_extraction_model is not None: logger.info("Loading world tagger model...") self._extract_worlds = True self._world_tagger_extractor = WorldTaggerExtractor( world_extraction_model) logger.info("Done loading world tagger model!") # Convenience regex for recognizing attributes self._attr_regex = re.compile(r"""\((\w+) (high|low|higher|lower)""")
def text_to_instance( self, # type: ignore question: str, logical_forms: List[str] = None, additional_metadata: Dict[str, Any] = None, world_extractions: Dict[str, Union[str, List[str]]] = None, entity_literals: Dict[str, Union[str, List[str]]] = None, tokenized_question: List[Token] = None, debug_counter: int = None, qr_spec_override: List[Dict[str, int]] = None, dynamic_entities_override: Dict[str, str] = None) -> Instance: # pylint: disable=arguments-differ tokenized_question = tokenized_question or self._tokenizer.tokenize( question.lower()) additional_metadata = additional_metadata or dict() additional_metadata['question_tokens'] = [ token.text for token in tokenized_question ] if world_extractions is not None: additional_metadata['world_extractions'] = world_extractions question_field = TextField(tokenized_question, self._question_token_indexers) if qr_spec_override is not None or dynamic_entities_override is not None: # Dynamically specify theory and/or entities dynamic_entities = dynamic_entities_override or self._dynamic_entities neighbors: Dict[str, List[str]] = { key: [] for key in dynamic_entities.keys() } knowledge_graph = KnowledgeGraph(entities=set( dynamic_entities.keys()), neighbors=neighbors, entity_text=dynamic_entities) world = QuarelWorld(knowledge_graph, self._lf_syntax, qr_coeff_sets=qr_spec_override) else: knowledge_graph = self._knowledge_graph world = self._world table_field = KnowledgeGraphField(knowledge_graph, tokenized_question, self._entity_token_indexers, tokenizer=self._tokenizer) if self._tagger_only: fields: Dict[str, Field] = {'tokens': question_field} if entity_literals is not None: entity_tags = self._get_entity_tags(self._all_entities, table_field, entity_literals, tokenized_question) if debug_counter > 0: logger.info(f'raw entity tags = {entity_tags}') entity_tags_bio = self._convert_tags_bio(entity_tags) fields['tags'] = SequenceLabelField(entity_tags_bio, question_field) additional_metadata['tags_gold'] = entity_tags_bio additional_metadata['words'] = [x.text for x in tokenized_question] fields['metadata'] = MetadataField(additional_metadata) return Instance(fields) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { 'question': question_field, 'table': table_field, 'world': world_field, 'actions': action_field } if self._denotation_only: denotation_field = LabelField(additional_metadata['answer_index'], skip_indexing=True) fields['denotation_target'] = denotation_field if self._entity_bits_mode is not None and world_extractions is not None: entity_bits = self._get_entity_tags(['world1', 'world2'], table_field, world_extractions, tokenized_question) if self._entity_bits_mode == "simple": entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag] for tag in entity_bits] elif self._entity_bits_mode == "simple_collapsed": entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits] elif self._entity_bits_mode == "simple3": entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag] for tag in entity_bits] entity_bits_field = ArrayField(np.array(entity_bits_v)) fields['entity_bits'] = entity_bits_field if logical_forms: action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore action_sequence_fields: List[Field] = [] for logical_form in logical_forms: expression = world.parse_logical_form(logical_form) action_sequence = world.get_action_sequence(expression) try: index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except KeyError as error: logger.info( f'Missing production rule: {error.args}, skipping logical form' ) logger.info(f'Question was: {question}') logger.info(f'Logical form was: {logical_form}') continue fields['target_action_sequences'] = ListField( action_sequence_fields) fields['metadata'] = MetadataField(additional_metadata or {}) return Instance(fields)
def _create_grammar_state(self, world: QuarelWorld, possible_actions: List[ProductionRule], linking_scores: torch.Tensor, entity_types: torch.Tensor) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- world : ``QuarelWorld`` From the input to ``forward`` for a single batch instance. possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. linking_scores : ``torch.Tensor`` Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch dimension). entity_types : ``torch.Tensor`` Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension). """ action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index entity_map = {} for entity_index, entity in enumerate(world.table_graph.entities): entity_map[entity] = entity_index valid_actions = world.get_valid_actions() translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [action_map[action_string] for action_string in action_strings] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append((production_rule_array[2], action_index)) else: linked_actions.append((production_rule_array[0], action_index)) # Then we get the embedded representations of the global actions. global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0) global_input_embeddings = self._action_embedder(global_action_tensor) if self._add_action_bias: global_action_biases = self._action_biases(global_action_tensor) global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1) global_output_embeddings = self._output_action_embedder(global_action_tensor) translated_valid_actions[key]['global'] = (global_input_embeddings, global_output_embeddings, list(global_action_ids)) # Then the representations of the linked actions. if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = [rule.split(' -> ')[1] for rule in linked_rules] entity_ids = [entity_map[entity] for entity in entities] # (num_linked_actions, num_question_tokens) entity_linking_scores = linking_scores[entity_ids] # (num_linked_actions,) entity_type_tensor = entity_types[entity_ids] # (num_linked_actions, entity_type_embedding_dim) entity_type_embeddings = self._entity_type_decoder_embedding(entity_type_tensor) translated_valid_actions[key]['linked'] = (entity_linking_scores, entity_type_embeddings, list(linked_action_ids)) return GrammarStatelet([START_SYMBOL], translated_valid_actions, type_declaration.is_nonterminal)
def get_explanation(logical_form: str, world_extractions: JsonDict, answer_index: int, world: QuarelWorld) -> List[JsonDict]: """ Create explanation (as a list of header/content entries) for an answer """ output = [] nl_world = {} if world_extractions["world1"] != "N/A" and world_extractions[ "world1"] != ["N/A"]: nl_world["world1"] = nl_world_string(world_extractions["world1"]) nl_world["world2"] = nl_world_string(world_extractions["world2"]) output.append({ "header": "Identified two worlds", "content": [ f"""world1 = {nl_world['world1']}""", f"""world2 = {nl_world['world2']}""", ], }) else: nl_world["world1"] = "world1" nl_world["world2"] = "world2" parse = util.lisp_to_nested_expression(logical_form) if parse[0] != "infer": return None setup = parse[1] output.append({ "header": "The question is stating", "content": nl_arg(setup, nl_world) }) answers = parse[2:] output.append({ "header": "The answer options are stating", "content": [ "A: " + " and ".join(nl_arg(answers[0], nl_world)), "B: " + " and ".join(nl_arg(answers[1], nl_world)), ], }) setup_core = setup if setup[0] == "and": setup_core = setup[1] s_attr = setup_core[0] s_dir = world.qr_size[setup_core[1]] s_world = nl_world[setup_core[2]] a_attr = answers[answer_index][0] qr_dir = world._get_qr_coeff(strip_entity_type(s_attr), strip_entity_type(a_attr)) a_dir = s_dir * qr_dir a_world = nl_world[answers[answer_index][2]] content = [ f"When {nl_attr(s_attr)} is {nl_dir(s_dir)} " + f"then {nl_attr(a_attr)} is {nl_dir(a_dir)} (for {s_world})" ] if a_world != s_world: content.append( f"""Therefore {nl_attr(a_attr)} is {nl_dir(-a_dir)} for {a_world}""" ) content.append(f"Therefore {chr(65+answer_index)} is the correct answer") output.append({"header": "Theory used", "content": content}) return output