def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. We only transform the action string sequences into logical
        forms here.
        """
        best_action_strings = output_dict["best_action_strings"]
        # Instantiating an empty world for getting logical forms.
        world = NlvrWorld([])
        logical_forms = []
        for instance_action_sequences in best_action_strings:
            instance_logical_forms = []
            for action_strings in instance_action_sequences:
                if action_strings:
                    instance_logical_forms.append(
                        world.get_logical_form(action_strings))
                else:
                    instance_logical_forms.append('')
            logical_forms.append(instance_logical_forms)

        action_mapping = output_dict['action_mapping']
        best_actions = output_dict['best_action_strings']
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(
                    predicted_actions[0], debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions,
                                               probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)],
                                        probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['question_attention'] = action_debug_info.get(
                    'question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        output_dict["logical_form"] = logical_forms
        return output_dict
    def _create_grammar_state(
            self, world: NlvrWorld,
            possible_actions: List[ProductionRuleArray]) -> GrammarStatelet:
        valid_actions = world.get_valid_actions()
        action_mapping = {}
        for i, action in enumerate(possible_actions):
            action_mapping[action[0]] = i
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.
            action_indices = [
                action_mapping[action_string]
                for action_string in action_strings
            ]
            # All actions in NLVR are global actions.
            global_actions = [(possible_actions[index][2], index)
                              for index in action_indices]

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(
                global_action_tensor)
            translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                       global_input_embeddings,
                                                       list(global_action_ids))
        return GrammarStatelet([START_SYMBOL], {}, translated_valid_actions,
                               {}, type_declaration.is_nonterminal)
예제 #3
0
    def _create_grammar_state(self,
                              world: NlvrWorld,
                              possible_actions: List[ProductionRule]) -> GrammarStatelet:
        valid_actions = world.get_valid_actions()
        action_mapping = {}
        for i, action in enumerate(possible_actions):
            action_mapping[action[0]] = i
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.
            action_indices = [action_mapping[action_string] for action_string in action_strings]
            # All actions in NLVR are global actions.
            global_actions = [(possible_actions[index][2], index) for index in action_indices]

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(global_action_tensor)
            translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                       global_input_embeddings,
                                                       list(global_action_ids))
        return GrammarStatelet([START_SYMBOL],
                               translated_valid_actions,
                               type_declaration.is_nonterminal)
    def __init__(self,
                 vocab: Vocabulary,
                 sentence_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 beam_size: int,
                 max_decoding_steps: int,
                 max_num_finished_states: int = None,
                 dropout: float = 0.0,
                 normalize_beam_score_by_length: bool = False,
                 checklist_cost_weight: float = 0.6,
                 dynamic_cost_weight: Dict[str, Union[int, float]] = None,
                 penalize_non_agenda_actions: bool = False,
                 initial_mml_model_file: str = None) -> None:
        super(NlvrCoverageSemanticParser, self).__init__(vocab=vocab,
                                                         sentence_embedder=sentence_embedder,
                                                         action_embedding_dim=action_embedding_dim,
                                                         encoder=encoder,
                                                         dropout=dropout)
        self._agenda_coverage = Average()
        self._decoder_trainer: DecoderTrainer[Callable[[CoverageState], torch.Tensor]] = \
                ExpectedRiskMinimization(beam_size=beam_size,
                                         normalize_by_length=normalize_beam_score_by_length,
                                         max_decoding_steps=max_decoding_steps,
                                         max_num_finished_states=max_num_finished_states)

        # Instantiating an empty NlvrWorld just to get the number of terminals.
        self._terminal_productions = set(NlvrWorld([]).terminal_productions.values())
        self._decoder_step = CoverageTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                        action_embedding_dim=action_embedding_dim,
                                                        input_attention=attention,
                                                        num_start_types=1,
                                                        activation=Activation.by_name('tanh')(),
                                                        predict_start_type_separately=False,
                                                        add_action_bias=False,
                                                        dropout=dropout)
        self._checklist_cost_weight = checklist_cost_weight
        self._dynamic_cost_wait_epochs = None
        self._dynamic_cost_rate = None
        if dynamic_cost_weight:
            self._dynamic_cost_wait_epochs = dynamic_cost_weight["wait_num_epochs"]
            self._dynamic_cost_rate = dynamic_cost_weight["rate"]
        self._penalize_non_agenda_actions = penalize_non_agenda_actions
        self._last_epoch_in_forward: int = None
        # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
        # copied a trained ERM model from a different machine and the original MML model that was
        # used to initialize it does not exist on the current machine. This may not be the best
        # solution for the problem.
        if initial_mml_model_file is not None:
            if os.path.isfile(initial_mml_model_file):
                archive = load_archive(initial_mml_model_file)
                self._initialize_weights_from_archive(archive)
            else:
                # A model file is passed, but it does not exist. This is expected to happen when
                # you're using a trained ERM model to decode. But it may also happen if the path to
                # the file is really just incorrect. So throwing a warning.
                logger.warning("MML model file for initializing weights is passed, but does not exist."
                               " This is fine if you're just decoding.")
예제 #5
0
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. We only transform the action string sequences into logical
        forms here.
        """
        best_action_strings = output_dict["best_action_strings"]
        # Instantiating an empty world for getting logical forms.
        world = NlvrWorld([])
        logical_forms = []
        for instance_action_sequences in best_action_strings:
            instance_logical_forms = []
            for action_strings in instance_action_sequences:
                if action_strings:
                    instance_logical_forms.append(world.get_logical_form(action_strings))
                else:
                    instance_logical_forms.append('')
            logical_forms.append(instance_logical_forms)

        action_mapping = output_dict['action_mapping']
        best_actions = output_dict['best_action_strings']
        debug_infos = output_dict['debug_info']
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(predicted_actions[0], debug_info):
                action_info = {}
                action_info['predicted_action'] = predicted_action
                considered_actions = action_debug_info['considered_actions']
                probabilities = action_debug_info['probabilities']
                actions = []
                for action, probability in zip(considered_actions, probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)], probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info['considered_actions'] = considered_actions
                action_info['action_probabilities'] = probabilities
                action_info['question_attention'] = action_debug_info.get('question_attention', [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        output_dict["logical_form"] = logical_forms
        return output_dict
def make_data(input_file: str, output_file: str, archived_model_file: str,
              max_num_decoded_sequences: int) -> None:
    reader = NlvrDatasetReader(output_agendas=True)
    model = load_archive(archived_model_file).model
    if not isinstance(model, NlvrCoverageSemanticParser):
        model_type = type(model)
        raise RuntimeError(
            f"Expected an archived NlvrCoverageSemanticParser, but found {model_type} instead"
        )
    # Tweaking the decoder trainer to coerce the it to generate a k-best list. Setting k to 100
    # here, so that we can filter out the inconsistent ones later.
    model._decoder_trainer._max_num_decoded_sequences = 100
    num_outputs = 0
    num_sentences = 0
    with open(output_file, "w") as outfile:
        for line in open(input_file):
            num_sentences += 1
            input_data = json.loads(line)
            sentence = input_data["sentence"]
            structured_representations = input_data["worlds"]
            labels = input_data["labels"]
            instance = reader.text_to_instance(sentence,
                                               structured_representations)
            outputs = model.forward_on_instance(instance)
            action_strings = outputs["best_action_strings"]
            logical_forms = outputs["logical_form"]
            correct_sequences = []
            # Checking for consistency
            worlds = [
                NlvrWorld(structure)
                for structure in structured_representations
            ]
            for sequence, logical_form in zip(action_strings, logical_forms):
                denotations = [world.execute(logical_form) for world in worlds]
                denotations_are_correct = [
                    label.lower() == str(denotation).lower()
                    for label, denotation in zip(labels, denotations)
                ]
                if all(denotations_are_correct):
                    correct_sequences.append(sequence)
            correct_sequences = correct_sequences[:max_num_decoded_sequences]
            if not correct_sequences:
                continue
            output_data = {
                "id": input_data["identifier"],
                "sentence": sentence,
                "correct_sequences": correct_sequences,
                "worlds": structured_representations,
                "labels": input_data["labels"],
            }
            json.dump(output_data, outfile)
            outfile.write("\n")
            num_outputs += 1
        outfile.close()
    sys.stderr.write(
        f"{num_outputs} out of {num_sentences} sentences have outputs.")
예제 #7
0
 def decode(self, output_dict                         )                           :
     u"""
     This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
     time, to finalize predictions. We only transform the action string sequences into logical
     forms here.
     """
     best_action_strings = output_dict[u"best_action_strings"]
     # Instantiating an empty world for getting logical forms.
     world = NlvrWorld([])
     logical_forms = []
     for instance_action_sequences in best_action_strings:
         instance_logical_forms = []
         for action_strings in instance_action_sequences:
             if action_strings:
                 instance_logical_forms.append(world.get_logical_form(action_strings))
             else:
                 instance_logical_forms.append(u'')
         logical_forms.append(instance_logical_forms)
     output_dict[u"logical_form"] = logical_forms
     return output_dict
예제 #8
0
 def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
     """
     This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
     time, to finalize predictions. We only transform the action string sequences into logical
     forms here.
     """
     best_action_strings = output_dict["best_action_strings"]
     # Instantiating an empty world for getting logical forms.
     world = NlvrWorld([])
     logical_forms = []
     for instance_action_sequences in best_action_strings:
         instance_logical_forms = []
         for action_strings in instance_action_sequences:
             if action_strings:
                 instance_logical_forms.append(world.get_logical_form(action_strings))
             else:
                 instance_logical_forms.append('')
         logical_forms.append(instance_logical_forms)
     output_dict["logical_form"] = logical_forms
     return output_dict
예제 #9
0
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions.  This is (confusingly) a separate notion from the "decoder"
        in "encoder/decoder", where that decoder logic lives in ``WikiTablesDecoderStep``.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        best_action_strings = output_dict["best_action_strings"]
        # Instantiating an empty world for getting logical forms.
        world = NlvrWorld([])
        logical_forms = []
        for action_strings in best_action_strings:
            if action_strings:
                logical_forms.append(world.get_logical_form(action_strings))
            else:
                logical_forms.append('')
        output_dict["logical_form"] = logical_forms
        return output_dict
예제 #10
0
    def setUp(self):
        super(TestWorld, self).setUp()
        self.world_without_recursion = FakeWorldWithoutRecursion()
        self.world_with_recursion = FakeWorldWithRecursion()

        test_filename = self.FIXTURES_ROOT / u"data" / u"nlvr" / u"sample_ungrouped_data.jsonl"
        data = [json.loads(line)[u"structured_rep"] for line in open(test_filename).readlines()]
        self.nlvr_world = NlvrWorld(data[0])

        question_tokens = [Token(x) for x in [u'what', u'was', u'the', u'last', u'year', u'2004', u'?']]
        table_file = self.FIXTURES_ROOT / u'data' / u'wikitables' / u'sample_table.tsv'
        table_kg = TableQuestionKnowledgeGraph.read_from_file(table_file, question_tokens)
        self.wikitables_world = WikiTablesWorld(table_kg)
예제 #11
0
 def _create_grammar_state(
         world: NlvrWorld,
         possible_actions: List[ProductionRuleArray]) -> GrammarState:
     valid_actions = world.get_valid_actions()
     action_mapping = {}
     for i, action in enumerate(possible_actions):
         action_mapping[action[0]] = i
     translated_valid_actions = {}
     for key, action_strings in valid_actions.items():
         translated_valid_actions[key] = [
             action_mapping[action_string]
             for action_string in action_strings
         ]
     return GrammarState([START_SYMBOL], {}, translated_valid_actions,
                         action_mapping, type_declaration.is_nonterminal)
예제 #12
0
 def _create_grammar_state(world: NlvrWorld,
                           possible_actions: List[ProductionRuleArray]) -> GrammarState:
     valid_actions = world.get_valid_actions()
     action_mapping = {}
     for i, action in enumerate(possible_actions):
         action_mapping[action[0]] = i
     translated_valid_actions = {}
     for key, action_strings in valid_actions.items():
         translated_valid_actions[key] = [action_mapping[action_string]
                                          for action_string in action_strings]
     return GrammarState([START_SYMBOL],
                         {},
                         translated_valid_actions,
                         action_mapping,
                         type_declaration.is_nonterminal)
예제 #13
0
    def setUp(self):
        super().setUp()
        self.world_without_recursion = FakeWorldWithoutRecursion()
        self.world_with_recursion = FakeWorldWithRecursion()

        test_filename = "tests/fixtures/data/nlvr/sample_ungrouped_data.jsonl"
        data = [
            json.loads(line)["structured_rep"]
            for line in open(test_filename).readlines()
        ]
        self.nlvr_world = NlvrWorld(data[0])

        question_tokens = [
            Token(x)
            for x in ['what', 'was', 'the', 'last', 'year', '2004', '?']
        ]
        table_file = 'tests/fixtures/data/wikitables/sample_table.tsv'
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
            table_file, question_tokens)
        self.wikitables_world = WikiTablesWorld(table_kg)
예제 #14
0
    def text_to_instance(
            self,  # type: ignore
            sentence: str,
            structured_representations: List[List[List[JsonDict]]],
            labels: List[str] = None,
            target_sequences: List[List[str]] = None,
            identifier: str = None) -> Instance:
        """
        Parameters
        ----------
        sentence : ``str``
            The query sentence.
        structured_representations : ``List[List[List[JsonDict]]]``
            A list of Json representations of all the worlds. See expected format in this class' docstring.
        labels : ``List[str]`` (optional)
            List of string representations of the labels (true or false) corresponding to the
            ``structured_representations``. Not required while testing.
        target_sequences : ``List[List[str]]`` (optional)
            List of target action sequences for each element which lead to the correct denotation in
            worlds corresponding to the structured representations.
        identifier : ``str`` (optional)
            The identifier from the dataset if available.
        """
        # pylint: disable=arguments-differ
        worlds = [NlvrWorld(data) for data in structured_representations]
        tokenized_sentence = self._tokenizer.tokenize(sentence)
        sentence_field = TextField(tokenized_sentence,
                                   self._sentence_token_indexers)
        production_rule_fields: List[Field] = []
        instance_action_ids: Dict[str, int] = {}
        # TODO(pradeep): Assuming that possible actions are the same in all worlds. This may change
        # later.
        for production_rule in worlds[0].all_possible_actions():
            instance_action_ids[production_rule] = len(instance_action_ids)
            field = ProductionRuleField(production_rule, is_global_rule=True)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)
        worlds_field = ListField([MetadataField(world) for world in worlds])
        fields: Dict[str, Field] = {
            "sentence": sentence_field,
            "worlds": worlds_field,
            "actions": action_field
        }
        if identifier is not None:
            fields["identifier"] = MetadataField(identifier)
        # Depending on the type of supervision used for training the parser, we may want either
        # target action sequences or an agenda in our instance. We check if target sequences are
        # provided, and include them if they are. If not, we'll get an agenda for the sentence, and
        # include that in the instance.
        if target_sequences:
            action_sequence_fields: List[Field] = []
            for target_sequence in target_sequences:
                index_fields = ListField([
                    IndexField(instance_action_ids[action], action_field)
                    for action in target_sequence
                ])
                action_sequence_fields.append(index_fields)
                # TODO(pradeep): Define a max length for this field.
            fields["target_action_sequences"] = ListField(
                action_sequence_fields)
        elif self._output_agendas:
            # TODO(pradeep): Assuming every world gives the same agenda for a sentence. This is true
            # now, but may change later too.
            agenda = worlds[0].get_agenda_for_sentence(
                sentence, add_paths_to_agenda=False)
            assert agenda, "No agenda found for sentence: %s" % sentence
            # agenda_field contains indices into actions.
            agenda_field = ListField([
                IndexField(instance_action_ids[action], action_field)
                for action in agenda
            ])
            fields["agenda"] = agenda_field
        if labels:
            labels_field = ListField([
                LabelField(label, label_namespace='denotations')
                for label in labels
            ])
            fields["labels"] = labels_field

        return Instance(fields)
예제 #15
0
def process_data(input_file: str,
                 output_file: str,
                 max_path_length: int,
                 max_num_logical_forms: int,
                 ignore_agenda: bool,
                 write_sequences: bool) -> None:
    """
    Reads an NLVR dataset and returns a JSON representation containing sentences, labels, correct and
    incorrect logical forms. The output will contain at most `max_num_logical_forms` logical forms
    each in both correct and incorrect lists. The output format is:
        ``[{"id": str, "label": str, "sentence": str, "correct": List[str], "incorrect": List[str]}]``
    """
    processed_data: JsonDict = []
    # We can instantiate the ``ActionSpaceWalker`` with any world because the action space is the
    # same for all the ``NlvrWorlds``. It is just the execution that differs.
    serialized_walker_path = f"serialized_action_space_walker_pl={max_path_length}.pkl"
    if os.path.isfile(serialized_walker_path):
        print("Reading walker from serialized file", file=sys.stderr)
        walker = pickle.load(open(serialized_walker_path, "rb"))
    else:
        walker = ActionSpaceWalker(NlvrWorld({}), max_path_length=max_path_length)
        pickle.dump(walker, open(serialized_walker_path, "wb"))
    for line in open(input_file):
        instance_id, sentence, structured_reps, label_strings = read_json_line(line)
        worlds = [NlvrWorld(structured_rep) for structured_rep in structured_reps]
        labels = [label_string == "true" for label_string in label_strings]
        correct_logical_forms = []
        incorrect_logical_forms = []
        if ignore_agenda:
            # Get 1000 shortest logical forms.
            logical_forms = walker.get_all_logical_forms(max_num_logical_forms=1000)
        else:
            # TODO (pradeep): Assuming all worlds give the same agenda.
            sentence_agenda = worlds[0].get_agenda_for_sentence(sentence, add_paths_to_agenda=False)
            logical_forms = walker.get_logical_forms_with_agenda(sentence_agenda,
                                                                 max_num_logical_forms * 10)
        for logical_form in logical_forms:
            if all([world.execute(logical_form) == label for world, label in zip(worlds, labels)]):
                if len(correct_logical_forms) <= max_num_logical_forms:
                    correct_logical_forms.append(logical_form)
            else:
                if len(incorrect_logical_forms) <= max_num_logical_forms:
                    incorrect_logical_forms.append(logical_form)
            if len(correct_logical_forms) >= max_num_logical_forms \
               and len(incorrect_logical_forms) >= max_num_logical_forms:
                break
        if write_sequences:
            parsed_correct_forms = [worlds[0].parse_logical_form(logical_form) for logical_form in
                                    correct_logical_forms]
            correct_sequences = [worlds[0].get_action_sequence(parsed_form) for parsed_form in
                                 parsed_correct_forms]
            parsed_incorrect_forms = [worlds[0].parse_logical_form(logical_form) for logical_form in
                                      incorrect_logical_forms]
            incorrect_sequences = [worlds[0].get_action_sequence(parsed_form) for parsed_form in
                                   parsed_incorrect_forms]
            processed_data.append({"id": instance_id,
                                   "sentence": sentence,
                                   "correct_sequences": correct_sequences,
                                   "incorrect_sequences": incorrect_sequences,
                                   "worlds": structured_reps,
                                   "labels": label_strings})
        else:
            processed_data.append({"id": instance_id,
                                   "sentence": sentence,
                                   "correct_logical_forms": correct_logical_forms,
                                   "incorrect_logical_forms": incorrect_logical_forms,
                                   "worlds": structured_reps,
                                   "labels": label_strings})
    with open(output_file, "w") as outfile:
        for instance_processed_data in processed_data:
            json.dump(instance_processed_data, outfile)
            outfile.write('\n')
        outfile.close()