Esempio n. 1
0
 def test_get_final_encoder_states(self):
     encoder_outputs = torch.Tensor([[[1, 2, 3, 4],
                                      [5, 6, 7, 8],
                                      [9, 10, 11, 12]],
                                     [[13, 14, 15, 16],
                                      [17, 18, 19, 20],
                                      [21, 22, 23, 24]]])
     mask = torch.Tensor([[1, 1, 1], [1, 1, 0]])
     final_states = util.get_final_encoder_states(encoder_outputs, mask, bidirectional=False)
     assert_almost_equal(final_states.data.numpy(), [[9, 10, 11, 12], [17, 18, 19, 20]])
     final_states = util.get_final_encoder_states(encoder_outputs, mask, bidirectional=True)
     assert_almost_equal(final_states.data.numpy(), [[9, 10, 3, 4], [17, 18, 15, 16]])
    def _get_initial_rnn_state(self, sentence: Dict[str, torch.LongTensor]):
        embedded_input = self._sentence_embedder(sentence)
        # (batch_size, sentence_length)
        sentence_mask = util.get_text_field_mask(sentence).float()

        batch_size = embedded_input.size(0)

        # (batch_size, sentence_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_input, sentence_mask))

        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             sentence_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())
        attended_sentence, _ = self._decoder_step.attend_on_question(final_encoder_output,
                                                                     encoder_outputs, sentence_mask)
        encoder_outputs_list = [encoder_outputs[i] for i in range(batch_size)]
        sentence_mask_list = [sentence_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 attended_sentence[i],
                                                 encoder_outputs_list,
                                                 sentence_mask_list))
        return initial_rnn_state
Esempio n. 3
0
    def _get_initial_state(self,
                           encoder_outputs: torch.Tensor,
                           mask: torch.Tensor,
                           actions: List[List[ProductionRule]]) -> GrammarBasedState:

        batch_size = encoder_outputs.size(0)
        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())
        initial_score = encoder_outputs.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_utterance,
                                                 encoder_output_list,
                                                 utterance_mask_list))

        initial_grammar_state = [self._create_grammar_state(actions[i]) for i in range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          debug_info=None)
        return initial_state
Esempio n. 4
0
    def _get_initial_state(self, utterance: Dict[str, torch.LongTensor],
                           worlds: List[AtisWorld],
                           actions: List[List[ProductionRuleArray]],
                           linking_scores: torch.Tensor) -> GrammarBasedState:
        embedded_utterance = self._utterance_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size = embedded_utterance.size(0)
        num_entities = max([len(world.entities) for world in worlds])

        # entity_types: tensor with shape (batch_size, num_entities)
        entity_types, _ = self._get_type_vector(worlds, num_entities,
                                                embedded_utterance)

        # (batch_size, num_utterance_tokens, embedding_dim)
        encoder_input = embedded_utterance

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, utterance_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            if self._decoder_num_layers > 1:
                initial_rnn_state.append(
                    RnnStatelet(
                        final_encoder_output[i].repeat(
                            self._decoder_num_layers, 1),
                        memory_cell[i].repeat(self._decoder_num_layers,
                                              1), self._first_action_embedding,
                        self._first_attended_utterance, encoder_output_list,
                        utterance_mask_list))
            else:
                initial_rnn_state.append(
                    RnnStatelet(final_encoder_output[i], memory_cell[i],
                                self._first_action_embedding,
                                self._first_attended_utterance,
                                encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(worlds[i], actions[i],
                                       linking_scores[i], entity_types[i])
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            debug_info=None)
        return initial_state
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            spans: torch.LongTensor,
            metadata: List[Dict[str, Any]],
            pos_tags: Dict[str, torch.LongTensor] = None,
            span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.
        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError(
                "Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)
        encoder_final_state = get_final_encoder_states(encoded_text, mask)

        span_representations = self.span_extractor(encoded_text, spans, mask,
                                                   span_mask)

        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)

        logits = self.tag_projection_layer(span_representations)
        class_probabilities = masked_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
            "encoder_final_state": encoder_final_state,
            "encoded_text": encoded_text,
            "class_probabilities": class_probabilities,
            "spans": spans,
            "tokens": [meta["tokens"] for meta in metadata],
            "pos_tags": [meta.get("pos_tags") for meta in metadata],
            "num_spans": num_spans
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels,
                                                      span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees
               ) and self._evalb_score is not None and not self.training:
            gold_pos_tags: List[List[str]] = [
                list(zip(*tree.pos()))[1] for tree in batch_gold_trees
            ]
            predicted_trees = self.construct_trees(
                class_probabilities.cpu().data,
                spans.cpu().data, num_spans.data, output_dict["tokens"],
                gold_pos_tags)
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict
Esempio n. 6
0
    def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        table: Dict[str, torch.LongTensor],
        world: List[QuarelWorld],
        actions: List[List[ProductionRule]],
        entity_bits: torch.Tensor = None,
        denotation_target: torch.Tensor = None,
        target_action_sequences: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        entity_bits : ``torch.Tensor``, optional (default=None)
            Tensor encoding bits for the world entities.
        denotation_target : ``torch.Tensor``, optional (default=None)
            If model's field ``denotation_only`` is True, this is the tensor target denotation.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        metadata : List[Dict[str, Any]], optional (default=None).
            A dictionary of metadata for each batch element which has keys:
                question_tokens : ``List[str]``, optional.
                    The original string tokens in the question.
                world_extractions : ``nltk.Tree``, optional.
                    Extracted worlds from the question.
                answer_index : ``List[str]``, optional.
                    Index of the correct answer.
        """

        table_text = table["text"]

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (
                    embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (
                    embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(
                    embedded_table.view(batch_size,
                                        num_entities * num_entity_tokens,
                                        self._embedding_dim),
                    torch.transpose(embedded_question, 1, 2),
                )

                question_entity_similarity = question_entity_similarity.view(
                    batch_size, num_entities, num_entity_tokens,
                    num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(
                    question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (
                    embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (
                    embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(
                    -1, num_entities * num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size,
                                          num_entities * num_entity_tokens,
                                          self._embedding_dim)
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(
                    batch_size,
                    num_question_tokens * num_entities * num_entity_tokens,
                    self._embedding_dim,
                )
                question_entity_similarity = self._entity_similarity_layer(
                    product)
                question_entity_similarity = question_entity_similarity.view(
                    batch_size, num_entities, num_entity_tokens,
                    num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(
                    question_entity_similarity, 2)
                linking_scores = question_entity_similarity_max_score

            # (batch_size, num_entities, num_question_tokens, num_features)
            linking_features = table["linking"]

            if self._linking_params is not None:
                feature_scores = self._linking_params(
                    linking_features).squeeze(3)
                linking_scores = linking_scores + feature_scores

            # (batch_size, num_question_tokens, num_entities)
            linking_probabilities = self._get_linking_probabilities(
                world, linking_scores.transpose(1, 2), question_mask,
                entity_type_dict)
            encoder_input = embedded_question
        else:
            if entity_bits is not None and not self._entity_bits_output:
                encoder_input = torch.cat([embedded_question, entity_bits], 2)
            else:
                encoder_input = embedded_question

            # Fake linking_scores added for downstream code to not object
            linking_scores = question_mask.clone().fill_(0).unsqueeze(1)
            linking_probabilities = None

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        if self._entity_bits_output and entity_bits is not None:
            encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        # For predicting a categorical denotation directly
        if self._denotation_only:
            denotation_logits = self._denotation_classifier(
                final_encoder_output)
            loss = torch.nn.functional.cross_entropy(
                denotation_logits, denotation_target.view(-1))
            self._denotation_accuracy_cat(denotation_logits, denotation_target)
            return {"loss": loss}

        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder_output_dim)

        _, num_entities, num_question_tokens = linking_scores.size()

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []

        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(
                    final_encoder_output[i],
                    memory_cell[i],
                    self._first_action_embedding,
                    self._first_attended_question,
                    encoder_output_list,
                    question_mask_list,
                ))

        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i], linking_scores[i],
                                       entity_types[i])
            for i in range(batch_size)
        ]

        initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=None,
            debug_info=None,
        )

        if self.training:
            outputs = self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
            return outputs

        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs = {"action_mapping": action_mapping}
            if target_action_sequences is not None:
                outputs["loss"] = self._decoder_trainer.decode(
                    initial_state, self._decoder_step,
                    (target_action_sequences, target_mask))["loss"]

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False)
            outputs["best_action_sequence"] = []
            outputs["debug_info"] = []
            outputs["entities"] = []
            if self._linking_params is not None:
                outputs["linking_scores"] = linking_scores
                outputs["feature_scores"] = feature_scores
                outputs["linking_features"] = linking_features
            if self._use_entities:
                outputs["linking_probabilities"] = linking_probabilities
            if entity_bits is not None:
                outputs["entity_bits"] = entity_bits
            # outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs["logical_form"] = []
            outputs["denotation_acc"] = []
            outputs["score"] = []
            outputs["parse_acc"] = []
            outputs["answer_index"] = []
            if metadata is not None:
                outputs["question_tokens"] = []
                outputs["world_extractions"] = []
            for i in range(batch_size):
                if metadata is not None:
                    outputs["question_tokens"].append(metadata[i].get(
                        "question_tokens", []))
                if metadata is not None:
                    outputs["world_extractions"].append(metadata[i].get(
                        "world_extractions", {}))
                outputs["entities"].append(world[i].table_graph.entities)
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][
                        0].action_history[0]
                    sequence_in_targets = 0
                    if target_action_sequences is not None:
                        targets = target_action_sequences[i].data
                        sequence_in_targets = self._action_history_match(
                            best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in best_action_indices
                    ]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(
                            action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = "Error producing logical form"
                    denotation_accuracy = 0.0
                    predicted_answer_index = world[i].execute(logical_form)
                    if metadata is not None and "answer_index" in metadata[i]:
                        answer_index = metadata[i]["answer_index"]
                        denotation_accuracy = self._denotation_match(
                            predicted_answer_index, answer_index)
                        self._denotation_accuracy(denotation_accuracy)
                    score = math.exp(
                        best_final_states[i][0].score[0].data.cpu().item())
                    outputs["answer_index"].append(predicted_answer_index)
                    outputs["score"].append(score)
                    outputs["parse_acc"].append(sequence_in_targets)
                    outputs["best_action_sequence"].append(action_strings)
                    outputs["logical_form"].append(logical_form)
                    outputs["denotation_acc"].append(denotation_accuracy)
                    outputs["debug_info"].append(
                        best_final_states[i][0].debug_info[0])  # type: ignore
                else:
                    outputs["parse_acc"].append(0)
                    outputs["logical_form"].append("")
                    outputs["denotation_acc"].append(0)
                    outputs["score"].append(0)
                    outputs["answer_index"].append(-1)
                    outputs["best_action_sequence"].append([])
                    outputs["debug_info"].append([])
                    self._has_logical_form(0.0)
            return outputs
Esempio n. 7
0
File: model.py Progetto: vivi0204/-
    def forward(self,  # type: ignore
                context_tokens: Dict[str, torch.LongTensor],
                tokens: Dict[str, torch.LongTensor],
                tags: torch.LongTensor = None,
                intents: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None,
                # pylint: disable=unused-argument
                **kwargs) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------

        Returns
        -------
        """
        if self.context_for_intent or self.context_for_tag or \
            self.attention_for_intent or self.attention_for_tag:
            embedded_context_input = self.text_field_embedder(context_tokens)

            if self.dropout:
                embedded_context_input = self.dropout(embedded_context_input)

            context_mask = util.get_text_field_mask(context_tokens)
            encoded_context = self.encoder(embedded_context_input, context_mask)

            if self.dropout:
                encoded_context = self.dropout(encoded_context)

            encoded_context_summary = util.get_final_encoder_states(
                encoded_context,
                context_mask,
                self.encoder.is_bidirectional())

        embedded_text_input = self.text_field_embedder(tokens)
        mask = util.get_text_field_mask(tokens)

        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        encoded_text = self.encoder(embedded_text_input, mask)

        if self.dropout:
            encoded_text = self.dropout(encoded_text)

        intent_encoded_text = self.intent_encoder(encoded_text, mask) if self.intent_encoder else encoded_text

        if self.dropout and self.intent_encoder:
            intent_encoded_text = self.dropout(intent_encoded_text)

        is_bidirectional = self.intent_encoder.is_bidirectional() if self.intent_encoder else self.encoder.is_bidirectional()
        if self._feedforward is not None:
            encoded_summary = self._feedforward(util.get_final_encoder_states(
                intent_encoded_text,
                mask,
                is_bidirectional))
        else:
            encoded_summary = util.get_final_encoder_states(
                intent_encoded_text,
                mask,
                is_bidirectional)
        
        tag_encoded_text = self.tag_encoder(encoded_text, mask) if self.tag_encoder else encoded_text

        if self.dropout and self.tag_encoder:
            tag_encoded_text = self.dropout(tag_encoded_text)

        if self.attention_for_intent or self.attention_for_tag:
            attention_weights = self.attention(encoded_summary, encoded_context, context_mask.float())
            attended_context = util.weighted_sum(encoded_context, attention_weights)

        if self.context_for_intent:
            encoded_summary = torch.cat([encoded_summary, encoded_context_summary], dim=-1)
        
        if self.attention_for_intent:
            encoded_summary = torch.cat([encoded_summary, attended_context], dim=-1)

        if self.context_for_tag:
            tag_encoded_text = torch.cat([tag_encoded_text, 
                encoded_context_summary.unsqueeze(dim=1).expand(
                    encoded_context_summary.size(0),
                    tag_encoded_text.size(1),
                    encoded_context_summary.size(1))], dim=-1)

        if self.attention_for_tag:
            tag_encoded_text = torch.cat([tag_encoded_text, 
                attended_context.unsqueeze(dim=1).expand(
                    attended_context.size(0),
                    tag_encoded_text.size(1),
                    attended_context.size(1))], dim=-1)

        intent_logits = self.intent_projection_layer(encoded_summary)
        intent_probs = torch.sigmoid(intent_logits)
        predicted_intents = (intent_probs > 0.5).long()

        sequence_logits = self.tag_projection_layer(tag_encoded_text)
        if self.crf is not None:
            best_paths = self.crf.viterbi_tags(sequence_logits, mask)
            # Just get the tags and ignore the score.
            predicted_tags = [x for x, y in best_paths]
        else:
            predicted_tags = self.get_predicted_tags(sequence_logits)

        output = {"sequence_logits": sequence_logits, "mask": mask, "tags": predicted_tags,
        "intent_logits": intent_logits, "intent_probs": intent_probs, "intents": predicted_intents}

        if tags is not None:
            if self.crf is not None:
                # Add negative log-likelihood as loss
                log_likelihood = self.crf(sequence_logits, tags, mask)
                output["loss"] = -log_likelihood

                # Represent viterbi tags as "class probabilities" that we can
                # feed into the metrics
                class_probabilities = sequence_logits * 0.
                for i, instance_tags in enumerate(predicted_tags):
                    for j, tag_id in enumerate(instance_tags):
                        class_probabilities[i, j, tag_id] = 1
            else:
                loss = sequence_cross_entropy_with_logits(sequence_logits, tags, mask)
                class_probabilities = sequence_logits
                output["loss"] = loss

            if self.calculate_span_f1:
                self._f1_metric(class_probabilities, tags, mask.float())
        
        if metadata is not None:
            output["words"] = [x["words"] for x in metadata]

        if tags is not None and metadata:
            self.decode(output)
            self._dai_f1_metric(output["dialog_act"], [x["dialog_act"] for x in metadata])
            rewards = self.get_rewards(output["dialog_act"], [x["dialog_act"] for x in metadata]) if self.rl else None

        if intents is not None:
            output["loss"] += torch.mean(self.intent_loss(intent_logits, intents.float()))
            self._intent_f1_metric(predicted_intents, intents)

        return output
Esempio n. 8
0
    def _get_initial_state(self,
                           utterance: Dict[str, torch.LongTensor],
                           worlds: List[AtisWorld],
                           actions: List[List[ProductionRule]],
                           linking_scores: torch.Tensor) -> GrammarBasedState:
        embedded_utterance = self._utterance_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size = embedded_utterance.size(0)
        num_entities = max([len(world.entities) for world in worlds])

        # entity_types: tensor with shape (batch_size, num_entities)
        entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance)

        # (batch_size, num_utterance_tokens, embedding_dim)
        encoder_input = embedded_utterance

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             utterance_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            if self._decoder_num_layers > 1:
                initial_rnn_state.append(RnnStatelet(final_encoder_output[i].repeat(self._decoder_num_layers, 1),
                                                     memory_cell[i].repeat(self._decoder_num_layers, 1),
                                                     self._first_action_embedding,
                                                     self._first_attended_utterance,
                                                     encoder_output_list,
                                                     utterance_mask_list))
            else:
                initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                     memory_cell[i],
                                                     self._first_action_embedding,
                                                     self._first_attended_utterance,
                                                     encoder_output_list,
                                                     utterance_mask_list))


        initial_grammar_state = [self._create_grammar_state(worlds[i],
                                                            actions[i],
                                                            linking_scores[i],
                                                            entity_types[i])
                                 for i in range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          debug_info=None)
        return initial_state
Esempio n. 9
0
    def _get_initial_state(
            self, utterance: Dict[str, torch.LongTensor],
            worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor],
            actions: List[List[ProductionRule]]) -> GrammarBasedState:
        schema_text = schema['text']
        embedded_schema = self._question_embedder(schema_text,
                                                  num_wrapping_dims=1)
        schema_mask = util.get_text_field_mask(schema_text,
                                               num_wrapping_dims=1).float()

        embedded_utterance = self._question_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size()
        num_entities = max([
            len(world.db_context.knowledge_graph.entities) for world in worlds
        ])
        num_question_tokens = utterance['tokens'].size(1)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            worlds, num_entities, embedded_schema.device)

        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_schema.view(batch_size, num_entities * num_entity_tokens,
                                 self._embedding_dim),
            torch.transpose(embedded_utterance, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)
        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = schema['linking']

        linking_scores = question_entity_similarity_max_score

        feature_scores = self._linking_params(linking_features).squeeze(3)

        linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            worlds, linking_scores.transpose(1, 2), utterance_mask,
            entity_type_dict)

        # (batch_size, num_entities, num_neighbors) or None
        neighbor_indices = self._get_neighbor_indices(worlds, num_entities,
                                                      linking_scores.device)

        if self._use_neighbor_similarity_for_linking and neighbor_indices is not None:
            # (batch_size, num_entities, embedding_dim)
            encoded_table = self._entity_encoder(embedded_schema, schema_mask)

            # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
            # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
            # be added for the mask since that method expects 0 for padding.
            # (batch_size, num_entities, num_neighbors, embedding_dim)
            embedded_neighbors = util.batched_index_select(
                encoded_table, torch.abs(neighbor_indices))

            neighbor_mask = util.get_text_field_mask(
                {
                    'ignored': neighbor_indices + 1
                }, num_wrapping_dims=1).float()

            # Encoder initialized to easily obtain a masked average.
            neighbor_encoder = TimeDistributed(
                BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
            # (batch_size, num_entities, embedding_dim)
            embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                                  neighbor_mask)
            projected_neighbor_embeddings = self._neighbor_params(
                embedded_neighbors.float())

            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings +
                                           projected_neighbor_embeddings)
        else:
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings)

        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_utterance], 2)

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, utterance_mask))

        max_entities_relevance = linking_probabilities.max(dim=1)[0]
        entities_relevance = max_entities_relevance.unsqueeze(-1).detach()

        graph_initial_embedding = entity_type_embeddings * entities_relevance

        encoder_output_dim = self._encoder.get_output_dim()
        if self._gnn:
            entities_graph_encoding = self._get_schema_graph_encoding(
                worlds, graph_initial_embedding)
            graph_link_embedding = util.weighted_sum(entities_graph_encoding,
                                                     linking_probabilities)
            encoder_outputs = torch.cat(
                (encoder_outputs, graph_link_embedding), dim=-1)
            encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim(
            )
        else:
            entities_graph_encoding = None

        if self._self_attend:
            # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding)
            entities_ff = self._ent2ent_ff(entities_graph_encoding)
            linked_actions_linking_scores = torch.bmm(
                entities_ff, entities_ff.transpose(1, 2))
        else:
            linked_actions_linking_scores = [None] * batch_size

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(
                worlds[i], actions[i], linking_scores[i],
                linked_actions_linking_scores[i], entity_types[i],
                entities_graph_encoding[i]
                if entities_graph_encoding is not None else None)
            for i in range(batch_size)
        ]

        initial_sql_state = [
            SqlState(actions[i], self.parse_sql_on_decoding)
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            sql_state=initial_sql_state,
            possible_actions=actions,
            action_entity_mapping=[
                w.get_action_entity_mapping() for w in worlds
            ])

        return initial_state
Esempio n. 10
0
    def _get_initial_state(
            self, utterance: Dict[str, torch.LongTensor],
            worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]]) -> GrammarBasedState:

        utterance_mask = util.get_text_field_mask(utterance).float()
        embedded_utterance = self.question_embedder(utterance)
        batch_size, _, _ = embedded_utterance.size()
        encoder_outputs = self._dropout(
            self._question_encoder(embedded_utterance, utterance_mask))

        schema_text = schema['text']
        input_mm_schema = self._input_mm_embedder(schema_text,
                                                  num_wrapping_dims=1)
        output_mm_schema = self._output_mm_embedder(schema_text,
                                                    num_wrapping_dims=1)
        batch_size, num_entities, num_entity_tokens, _ = input_mm_schema.size()
        schema_mask = util.get_text_field_mask(schema_text,
                                               num_wrapping_dims=1).float()

        # TODO
        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            worlds, num_entities, input_mm_schema.device)
        # (batch_size, num_entities, embedding_dim)
        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)

        # (batch_size, num_entities, embedding_dim)
        # An entity memory-representation is concatenated with two parts:
        # 1. Entity tokens embedding
        # 2. Entity type embedding
        K = torch.cat([
            self._input_mm_encoder(input_mm_schema, schema_mask),
            entity_type_embeddings
        ],
                      dim=2)
        V = torch.cat([
            self._output_mm_encoder(output_mm_schema, schema_mask),
            entity_type_embeddings
        ],
                      dim=2)
        encoder_output_dim = self._question_encoder.get_output_dim()

        # Encodes utterance in the context of the schema, which is stored in external memory
        encoder_outputs_with_context, attn_weights = self._mm_attn(
            encoder_outputs, K, V)
        attn_weights = attn_weights.transpose(1, 2)
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs_with_context, utterance_mask,
            self._question_encoder.is_bidirectional())

        max_entities_relevance = attn_weights.max(dim=2)[0]
        entities_relevance = max_entities_relevance.unsqueeze(-1).detach()
        if self._self_attend:
            entities_ff = self._ent2ent_ff(entity_type_embeddings *
                                           entities_relevance)
            linked_actions_linking_scores = torch.bmm(
                entities_ff, entities_ff.transpose(1, 2))
        else:
            linked_actions_linking_scores = [None] * batch_size

        memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        # RnnStatelet is using to keep track of the internal state of a decoder RNN:
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(worlds[i], valid_actions[i],
                                       attn_weights[i],
                                       linked_actions_linking_scores[i],
                                       entity_types[i])
            for i in range(batch_size)
        ]

        initial_sql_state = [
            SqlState(valid_actions[i], self.parse_sql_on_decoding)
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            sql_state=initial_sql_state,
            possible_actions=valid_actions,
            action_entity_mapping=[
                w.get_action_entity_mapping() for w in worlds
            ])

        return initial_state
Esempio n. 11
0
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                premise_tags,
                hypothesis: Dict[str, torch.LongTensor],
                hypothesis_tags,
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
               ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
                [encoded_premise, attended_hypothesis,
                 encoded_premise - attended_hypothesis,
                 encoded_premise * attended_hypothesis],
                dim=-1
        )
        hypothesis_enhanced = torch.cat(
                [encoded_hypothesis, attended_premise,
                 encoded_hypothesis - attended_premise,
                 encoded_hypothesis * attended_premise],
                dim=-1
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(
                v_ai, premise_mask.unsqueeze(-1), -1e7
        ).max(dim=1)
        v_b_max, _ = replace_masked_values(
                v_bi, hypothesis_mask.unsqueeze(-1), -1e7
        ).max(dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
                premise_mask, 1, keepdim=True
        )
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
                hypothesis_mask, 1, keepdim=True
        )

        # running the parser
        encoded_p_parse, p_parse_mask = self._parser(premise, premise_tags)
        p_parse_encoder_final_state = get_final_encoder_states(encoded_p_parse, p_parse_mask)
        encoded_h_parse, h_parse_mask = self._parser(hypothesis, hypothesis_tags)
        h_parse_encoder_final_state = get_final_encoder_states(encoded_h_parse, h_parse_mask)

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg,
                           v_a_max,
                           v_b_avg,
                           v_b_max,
                           p_parse_encoder_final_state,
                           h_parse_encoder_final_state], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Esempio n. 12
0
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            premise_tags: torch.LongTensor,
            hypothesis: Dict[str, torch.LongTensor],
            hypothesis_tags: torch.LongTensor,
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        premise_tags : torch.LongTensor
            The POS tags of the premise.
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``.
        hypothesis_tags: torch.LongTensor
            The POS tags of the hypothesis.
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._attention(projected_premise,
                                            projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat(
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat(
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
            hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(
            -1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        # running the parser
        encoded_p_parse, p_parse_mask = self._parser(premise, premise_tags)
        p_parse_encoder_final_state = get_final_encoder_states(
            encoded_p_parse, p_parse_mask)
        encoded_h_parse, h_parse_mask = self._parser(hypothesis,
                                                     hypothesis_tags)
        h_parse_encoder_final_state = get_final_encoder_states(
            encoded_h_parse, h_parse_mask)

        compared_premise = torch.cat(
            [compared_premise, p_parse_encoder_final_state], dim=-1)
        compared_hypothesis = torch.cat(
            [compared_hypothesis, h_parse_encoder_final_state], dim=-1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {'logits': label_logits, 'label_probs': label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict['loss'] = loss

        if metadata is not None:
            output_dict['premise_tokens'] = [
                x['premise_tokens'] for x in metadata
            ]
            output_dict['hypothesis_tokens'] = [
                x['hypothesis_tokens'] for x in metadata
            ]

        return output_dict
Esempio n. 13
0
 def forward(self, inputs: torch.Tensor, mask: torch.Tensor):
     return get_final_encoder_states(self.seq2seq(inputs, None), mask)
Esempio n. 14
0
File: model.py Progetto: zqwerty/NLU
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            tags: torch.LongTensor = None,
            intents: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None,
            # pylint: disable=unused-argument
            **kwargs) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : ``torch.LongTensor``, optional (default = ``None``)
            A torch tensor representing the sequence of integer gold class labels of shape
            ``(batch_size, num_tokens)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containg the original words in the sentence to be tagged under a 'words' key.

        Returns
        -------
        An output dictionary consisting of:

        logits : ``torch.FloatTensor``
            The logits that are the output of the ``tag_projection_layer``
        mask : ``torch.LongTensor``
            The text field mask for the input tokens
        tags : ``List[List[int]]``
            The predicted tags using the Viterbi algorithm.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised. Only computed if gold label ``tags`` are provided.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        mask = util.get_text_field_mask(tokens)

        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        encoded_text = self.encoder(embedded_text_input, mask)

        if self.dropout:
            encoded_text = self.dropout(encoded_text)

        intent_encoded_text = self.intent_encoder(
            encoded_text, mask) if self.intent_encoder else encoded_text
        if self.dropout and self.intent_encoder:
            intent_encoded_text = self.dropout(intent_encoded_text)

        is_bidirectional = self.intent_encoder.is_bidirectional(
        ) if self.intent_encoder else self.encoder.is_bidirectional()
        if self._feedforward is not None:
            encoded_summary = self._feedforward(
                util.get_final_encoder_states(intent_encoded_text, mask,
                                              is_bidirectional))
        else:
            encoded_summary = util.get_final_encoder_states(
                intent_encoded_text, mask, is_bidirectional)

        sequence_logits = self.tag_projection_layer(encoded_text)
        if self.crf is not None:
            best_paths = self.crf.viterbi_tags(sequence_logits, mask)
            # Just get the tags and ignore the score.
            predicted_tags = [x for x, y in best_paths]
        else:
            predicted_tags = self.get_predicted_tags(sequence_logits)

        intent_logits = self.intent_projection_layer(encoded_summary)
        predicted_intents = (torch.sigmoid(intent_logits) > 0.5).long()

        output = {
            "sequence_logits": sequence_logits,
            "mask": mask,
            "tags": predicted_tags,
            "intent_logits": intent_logits,
            "intents": predicted_intents
        }

        if tags is not None:
            if self.crf is not None:
                # Add negative log-likelihood as loss
                log_likelihood = self.crf(sequence_logits, tags, mask)
                output["loss"] = -log_likelihood

                # Represent viterbi tags as "class probabilities" that we can
                # feed into the metrics
                class_probabilities = sequence_logits * 0.
                for i, instance_tags in enumerate(predicted_tags):
                    for j, tag_id in enumerate(instance_tags):
                        class_probabilities[i, j, tag_id] = 1
            else:
                loss = sequence_cross_entropy_with_logits(
                    sequence_logits, tags, mask)
                class_probabilities = sequence_logits
                output["loss"] = loss

            # self.metrics['tag_acc'](class_probabilities, tags, mask.float())
            if self.calculate_span_f1:
                self._f1_metric(class_probabilities, tags, mask.float())

        if intents is not None:
            output["loss"] += self.intent_loss(intent_logits, intents.float())
            # bloss = self.intent_loss2(intent_logits, intents.float())

            # self.metrics['int_acc'](predicted_intents, intents)
            self._intent_f1_metric(predicted_intents, intents)

            # print(list([self.vocab.get_token_from_index(intent[0], namespace=self.intent_label_namespace)
            # for intent in instance_intents.nonzero().tolist()] for instance_intents in predicted_intents))
            # print(list([self.vocab.get_token_from_index(intent[0], namespace=self.intent_label_namespace)
            # for intent in instance_intents.nonzero().tolist()] for instance_intents in intents))

        if metadata is not None:
            output["words"] = [x["words"] for x in metadata]

        if tags is not None and metadata:
            self.decode(output)
            # print(output)
            # print(metadata)
            self._dai_f1_metric(output["dialog_act"],
                                [x["dialog_act"] for x in metadata])

        return output
Esempio n. 15
0
    def forward(self,
                tokens: TextFieldTensors,
                targets: TextFieldTensors,
                target_sentiments: torch.LongTensor = None,
                target_sequences: Optional[torch.LongTensor] = None,
                metadata: torch.LongTensor = None,
                position_weights: Optional[torch.LongTensor] = None,
                position_embeddings: Optional[Dict[str,
                                                   torch.LongTensor]] = None,
                **kwargs) -> Dict[str, torch.Tensor]:
        '''
        The text and targets are Dictionaries as they are text fields they can 
        be represented many different ways e.g. just words or words and chars 
        etc therefore the dictionary represents these different ways e.g. 
        {'words': words_tensor_ids, 'chars': char_tensor_ids}
        '''
        # Get masks for the targets before they get manipulated
        targets_mask = util.get_text_field_mask(targets, num_wrapping_dims=1)
        # This is required if the input is of shape greater than 3 dim e.g.
        # character input where it is
        # (batch size, number targets, token length, char length)
        label_mask = (targets_mask.sum(dim=-1) >= 1).type(torch.int64)
        batch_size, number_targets = label_mask.shape
        batch_size_num_targets = batch_size * number_targets

        # Embed and encode text as a sequence
        embedded_context = self.context_field_embedder(tokens)
        embedded_context = self._variational_dropout(embedded_context)
        context_mask = util.get_text_field_mask(tokens)
        # Need to repeat the so it is of shape:
        # (Batch Size * Number Targets, Sequence Length, Dim) Currently:
        # (Batch Size, Sequence Length, Dim)
        batch_size, context_sequence_length, context_embed_dim = embedded_context.shape
        reshaped_embedding_context = embedded_context.unsqueeze(1).repeat(
            1, number_targets, 1, 1)
        reshaped_embedding_context = reshaped_embedding_context.view(
            batch_size_num_targets, context_sequence_length, context_embed_dim)
        # Embed and encode target as a sequence. If True here the target
        # embeddings come from the context.
        if self._use_target_sequences:
            _, _, target_sequence_length, target_index_length = target_sequences.shape
            target_index_len_err = (
                'The size of the context sequence '
                f'{context_sequence_length} is not the same'
                ' as the target index sequence '
                f'{target_index_length}. This is to get '
                'the contextualized target through the context')
            assert context_sequence_length == target_index_length, target_index_len_err
            seq_targets_mask = target_sequences.view(batch_size_num_targets,
                                                     target_sequence_length,
                                                     target_index_length)
            reshaped_embedding_targets = torch.matmul(
                seq_targets_mask.type(torch.float32),
                reshaped_embedding_context)
        else:
            temp_targets = elmo_input_reshape(targets, batch_size,
                                              number_targets,
                                              batch_size_num_targets)
            if self.target_field_embedder:
                embedded_targets = self.target_field_embedder(temp_targets)
            else:
                embedded_targets = self.context_field_embedder(temp_targets)
                embedded_targets = elmo_input_reverse(embedded_targets,
                                                      targets, batch_size,
                                                      number_targets,
                                                      batch_size_num_targets)

            # Size (batch size, num targets, target sequence length, embedding dim)
            embedded_targets = self._time_variational_dropout(embedded_targets)
            batch_size, number_targets, target_sequence_length, target_embed_dim = embedded_targets.shape
            reshaped_embedding_targets = embedded_targets.view(
                batch_size_num_targets, target_sequence_length,
                target_embed_dim)

        encoded_targets_mask = targets_mask.view(batch_size_num_targets,
                                                 target_sequence_length)
        # Shape (Batch Size * Number targets), encoded dim
        encoded_targets_seq = self.target_encoder(reshaped_embedding_targets,
                                                  encoded_targets_mask)
        encoded_targets_seq = self._naive_dropout(encoded_targets_seq)

        repeated_context_mask = context_mask.unsqueeze(1).repeat(
            1, number_targets, 1)
        repeated_context_mask = repeated_context_mask.view(
            batch_size_num_targets, context_sequence_length)
        # Need to concat the target embeddings to the context words
        repeated_encoded_targets = encoded_targets_seq.unsqueeze(1).repeat(
            1, context_sequence_length, 1)
        if self._AE:
            reshaped_embedding_context = torch.cat(
                (reshaped_embedding_context, repeated_encoded_targets), -1)
        # add position embeddings if required.
        reshaped_embedding_context = concat_position_embeddings(
            reshaped_embedding_context, position_embeddings,
            self.target_position_embedding)
        # Size (batch size * number targets, sequence length, embedding dim)
        reshaped_encoded_context_seq = self.context_encoder(
            reshaped_embedding_context, repeated_context_mask)
        reshaped_encoded_context_seq = self._variational_dropout(
            reshaped_encoded_context_seq)
        # Weighted position information encoded into the context sequence.
        if self.target_position_weight is not None:
            if position_weights is None:
                raise ValueError(
                    'This model requires `position_weights` to '
                    'better encode the target but none were given')
            position_output = self.target_position_weight(
                reshaped_encoded_context_seq, position_weights,
                repeated_context_mask)
            reshaped_encoded_context_seq, weighted_position_weights = position_output
        # Whether to concat the aspect embeddings on to the contextualised word
        # representations
        attention_encoded_context_seq = reshaped_encoded_context_seq
        if self._AttentionAE:
            attention_encoded_context_seq = torch.cat(
                (attention_encoded_context_seq, repeated_encoded_targets), -1)
        _, _, attention_encoded_dim = attention_encoded_context_seq.shape

        # Projection layer before the attention layer
        attention_encoded_context_seq = self.attention_project_layer(
            attention_encoded_context_seq)
        attention_encoded_context_seq = self._context_attention_activation_function(
            attention_encoded_context_seq)
        attention_encoded_context_seq = self._variational_dropout(
            attention_encoded_context_seq)

        # Attention over the context sequence
        attention_vector = self.attention_vector.unsqueeze(0).repeat(
            batch_size_num_targets, 1)
        attention_weights = self.context_attention_layer(
            attention_vector, attention_encoded_context_seq,
            repeated_context_mask)
        expanded_attention_weights = attention_weights.unsqueeze(-1)
        weighted_encoded_context_seq = reshaped_encoded_context_seq * expanded_attention_weights
        weighted_encoded_context_vec = weighted_encoded_context_seq.sum(dim=1)

        # Add the last hidden state of the context vector, with the attention vector
        context_final_states = util.get_final_encoder_states(
            reshaped_encoded_context_seq,
            repeated_context_mask,
            bidirectional=self.context_encoder_bidirectional)
        context_final_states = self.final_hidden_state_projection_layer(
            context_final_states)
        weighted_encoded_context_vec = self.final_attention_projection_layer(
            weighted_encoded_context_vec)
        feature_vector = context_final_states + weighted_encoded_context_vec
        feature_vector = self._naive_dropout(feature_vector)
        # Reshape the vector into (Batch Size, Number Targets, number labels)
        _, feature_dim = feature_vector.shape
        feature_target_seq = feature_vector.view(batch_size, number_targets,
                                                 feature_dim)

        if self.inter_target_encoding is not None:
            feature_target_seq = self.inter_target_encoding(
                feature_target_seq, label_mask)
            feature_target_seq = self._variational_dropout(feature_target_seq)

        if self.feedforward is not None:
            feature_target_seq = self.feedforward(feature_target_seq)

        logits = self.label_projection(feature_target_seq)
        masked_class_probabilities = util.masked_softmax(
            logits, label_mask.unsqueeze(-1))
        output_dict = {
            "class_probabilities": masked_class_probabilities,
            "targets_mask": label_mask
        }
        # Convert it to bool tensor.
        label_mask = label_mask == 1

        if target_sentiments is not None:
            # gets the loss per target instance due to the average=`token`
            if self.loss_weights is not None:
                loss = util.sequence_cross_entropy_with_logits(
                    logits,
                    target_sentiments,
                    label_mask,
                    average='token',
                    alpha=self.loss_weights)
            else:
                loss = util.sequence_cross_entropy_with_logits(
                    logits, target_sentiments, label_mask, average='token')
            for metrics in [self.metrics, self.f1_metrics]:
                for metric in metrics.values():
                    metric(logits, target_sentiments, label_mask)
            output_dict["loss"] = loss

        if metadata is not None:
            words = []
            texts = []
            targets = []
            target_words = []
            for batch_index, sample in enumerate(metadata):
                words.append(sample['text words'])
                texts.append(sample['text'])
                targets.append(sample['targets'])
                target_words.append(sample['target words'])

            output_dict["words"] = words
            output_dict["text"] = texts
            word_attention_weights = attention_weights.view(
                batch_size, number_targets, context_sequence_length)
            output_dict["word_attention"] = word_attention_weights
            output_dict["targets"] = targets
            output_dict["target words"] = target_words
            output_dict["context_mask"] = context_mask

        return output_dict
Esempio n. 16
0
    def get_BILOU_features(self, token_indices, sent_len, span_len):

        #print(token_indices)
        span_level_token_indices = {}        
        for ky,val in list(token_indices.items()):
            if ky == 'elmo':
                continue
            val = val.unsqueeze(1)
            span_level_token_indices[ky] = torch.cat([val[:, :, i:i + span_len + 1] for i in range(sent_len - 1 - span_len)], 1)

        '''
        print(span_level_token_indices)
        t = span_level_token_indices["tokens"][0].cpu().numpy().tolist()
        import json
        with open("./data/dict.json", "r", encoding="utf-8") as df:
            dic = json.load(df)
        a = [[dic[str(word)] for word in span] for span in t]
        print(a)
        '''
        ori_seq = [self.id2words(each.cpu().numpy().tolist()) for each in span_level_token_indices["tokens"]]
        att_logits = torch.Tensor([self.span_score(seq) for seq in ori_seq])


        spans_embedded = self.softdict_text_field_embedder(span_level_token_indices, num_wrapping_dims=1)
        spans_mask = util.get_text_field_mask(span_level_token_indices, num_wrapping_dims=1)
        
        '''
        for param in self.softdict_text_field_embedder.parameters():
            #np.save("embed.npy", param.detach().numpy())
            print(param.size()), exit(0)
        '''
        
        #print(spans_mask)
        #print(spans_mask.size())
        
        if util.get_device_of(spans_mask) >= 0:
            att_mask = torch.ge(torch.mean(spans_mask.float(), -1), (torch.ones(spans_mask.size(0), spans_mask.size(1)) - 2e-6).cuda(util.get_device_of(spans_mask)))
        else:
            att_mask = torch.ge(torch.mean(spans_mask.float(), -1), (torch.ones(spans_mask.size(0), spans_mask.size(1)) - 2e-6))
                
        dim_2_pad = self.ALLOWED_SPANLEN - spans_embedded.size(2)
        p2d = (0,0,0, dim_2_pad)
        # now shape (batch_size, num_span, max_span_width, dim)
        spans_embedded = F.pad(spans_embedded, p2d, "constant", 0.)
        spans_mask = F.pad(spans_mask, (0, dim_2_pad), "constant", 0.)
        #print("embed:")
        #print(spans_embedded)

        '''
        tt = {"tokens":torch.LongTensor([   50,  1138,    84,     7,   645,  1135,  7386,  1123,  4979,   952,
             2,   381,   173,   128,  8932,     9,    95,  1098, 16550,   524,
          3897,  5190,  8242,    22,  2112,  6912,  1408,   814,  9853,   128])}
        t = self.softdict_text_field_embedder(tt)
        print(t),exit(0)
        '''
        
        
        batch_size = spans_mask.size(0)
        num_spans = spans_mask.size(1)
        if util.get_device_of(spans_mask) >= 0:
            length_vec = torch.autograd.Variable(torch.LongTensor(range(self.ALLOWED_SPANLEN))).cuda(util.get_device_of(spans_mask))
        else:
            length_vec = torch.autograd.Variable(torch.LongTensor(range(self.ALLOWED_SPANLEN)))
        length_vec = self.length_embedder(length_vec).unsqueeze(0).unsqueeze(0).expand(batch_size, num_spans, -1,-1)
        
        spans_encoded = self.encoder(spans_embedded, spans_mask)        #BiLSTM
        
        
        #spans_encoded = torch.cat((spans_encoded, length_vec), 3).contiguous()
        #print(spans_encoded)

        spans_encoded = spans_encoded.reshape([batch_size * num_spans, self.ALLOWED_SPANLEN, -1])
        '''
            [batch_size * num_spans, self.ALLOWED_SPANLEN] shaped mask may occur whole zero
            like tensor([[1, 0, 0,  ..., 0, 0, 0],
                    [1, 0, 0,  ..., 0, 0, 0],
                    [1, 0, 0,  ..., 0, 0, 0],
                    ...,
                    [0, 0, 0,  ..., 0, 0, 0],
                    [0, 0, 0,  ..., 0, 0, 0],
                    [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0'), and use 'get_final_encoder_states()' will lead to
                    the error RuntimeError: cuda runtime error (59) : device-side assert triggered at /pytorch/aten/src/THC/THCReduceAll.cuh:327

            change to
            the tensor masked on span_sequence level is still assign as the unmasked tensor: [1 * 1:0*other] remain one 1
        
        '''
        spans_mask = spans_mask.reshape([batch_size * num_spans, self.ALLOWED_SPANLEN])
        if util.get_device_of(spans_mask) >= 0:
            tmp = torch.zeros(self.ALLOWED_SPANLEN, dtype=torch.int64).cuda(util.get_device_of(spans_mask))
        else:
            tmp = torch.zeros(self.ALLOWED_SPANLEN, dtype=torch.int64)
        tmp[0] = 1
        tmp = tmp.expand([batch_size * num_spans, self.ALLOWED_SPANLEN])
        
        new_spans_mask = spans_mask | tmp
        #print(new_spans_mask)
        last_state = get_final_encoder_states(spans_encoded, new_spans_mask)
        attention_coe, attention_out, attention_logits = self.attention(lstm_output=spans_encoded, final_state=last_state, mask_cuda=util.get_device_of(spans_mask))
        #print(attention_logits),exit(0)
        attention_logits = attention_logits.reshape([batch_size, num_spans, -1])[:,:,1]     # here 0 stand for true / 1
        #print(attention_logits)
        #print(attention_logits.size())
        attention_logits = attention_logits * att_mask.float()
        #print(attention_logits), exit(0)
        attention_out = attention_out.reshape([batch_size, num_spans, -1])
        
        #print(attention_coe.size())
        #attention_coe = attention_coe * spans_mask.float()
        attention_coe = attention_coe.reshape([batch_size, num_spans, -1])
        attention_coe = attention_coe.unsqueeze(-1)
        #print(attention_coe)
        #attention_coe = torch.gt(attention_coe, 0.1).float()
        attention_coe = attention_coe.expand([batch_size, num_spans, attention_coe.size(2), 1])
        #print(attention_coe.size()), exit(0)
        attention_coe = torch.cat([attention_coe, attention_coe.new_zeros(batch_size, 1, attention_coe.size(2), attention_coe.size(3))], dim=1)
        attention_out = torch.cat([attention_out, attention_out.new_zeros(batch_size, 1, attention_out.size(-1))], dim=1)
        attention_logits = torch.cat([attention_logits, attention_logits.new_zeros(batch_size, 1)], dim=1)

        #print(attention_logits.size(), att_logits.size()),exit(0)
        att_logits = torch.cat([att_logits, att_logits.new_zeros(batch_size, 1)], dim=1)
        return attention_coe[:,:,:span_len+1,:].detach(), att_logits
    def forward(
        self,  # type: ignore
        tokens: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata to persist

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()

        encoder_output = self._encoder(embedded_text, mask)

        encoded_repr = []
        for aggregation in self._aggregations:
            if aggregation == "meanpool":
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoder_output * broadcast_mask
                encoded_text = masked_mean(context_vectors,
                                           broadcast_mask,
                                           dim=1,
                                           keepdim=False)
            elif aggregation == 'maxpool':
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoder_output * broadcast_mask
                encoded_text = masked_max(context_vectors,
                                          broadcast_mask,
                                          dim=1)
            elif aggregation == 'final_state':
                is_bi = self._encoder.is_bidirectional()
                encoded_text = get_final_encoder_states(
                    encoder_output, mask, is_bi)
            encoded_repr.append(encoded_text)

        encoded_repr = torch.cat(encoded_repr, 1)

        if self.dropout:
            encoded_repr = self.dropout(encoded_repr)

        output_hidden = self._output_feedforward(encoded_repr)
        label_logits = self._classification_layer(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Esempio n. 18
0
    def bidaf_reprs(self, question, contexts):
        # Shape: (B, ques_len, D), (B, num_contexts, context_len, D)
        (embedded_question_tensor, embedded_passages_tensor,
         question_mask_tensor,
         passages_mask_tensor) = self.embed_ques_passages(question, contexts)

        batch_size = embedded_question_tensor.size()[0]
        num_contexts = embedded_passages_tensor.size()[1]

        embedded_questions = []
        questions_mask = []
        embedded_contexts = []
        contexts_mask = []

        for i in range(0, batch_size):
            embedded_questions.append(embedded_question_tensor[i])
            embedded_contexts.append(embedded_passages_tensor[i])
            questions_mask.append(question_mask_tensor[i])
            contexts_mask.append(passages_mask_tensor[i])

        # Shape: (B, ques_len, D)
        encoded_ques_tensor = self.encode_question(
            embedded_question=embedded_question_tensor,
            question_lstm_mask=question_mask_tensor)

        # Shape: (B, D)
        ques_encoded_final_state = allenutil.get_final_encoder_states(
            encoded_ques_tensor, question_mask_tensor,
            self.bidaf_encoder_bidirectional)

        # List of tensors: (question_len, D)
        encoded_questions = []
        # List of tensors: (num_contexts, context_len, D)
        encoded_contexts = []
        for i in range(0, batch_size):
            # Shape: (1, ques_len, D)
            # encoded_ques = self.encode_question(embedded_question=embedded_questions[i].unsqueeze(0),
            #                                     question_lstm_mask=questions_mask[i].unsqueeze(0))
            encoded_questions.append(encoded_ques_tensor[i])
            # Shape: (num_contexts, context_len, D)
            encoded_context = self.encode_context(
                embedded_passage=embedded_contexts[i],
                passage_lstm_mask=contexts_mask[i])
            encoded_contexts.append(encoded_context)

        modeled_contexts = []
        for i in range(0, batch_size):
            # Shape: (question_len, D)
            encoded_ques = encoded_questions[i]
            ques_mask = questions_mask[i]
            encoded_ques_ex = encoded_ques.unsqueeze(0).expand(
                num_contexts, *encoded_ques.size())
            ques_mask_ex = ques_mask.unsqueeze(0).expand(
                num_contexts, *ques_mask.size())

            output_dict = self.forward_bidaf(
                encoded_question=encoded_ques_ex,
                encoded_passage=encoded_contexts[i],
                question_lstm_mask=ques_mask_ex,
                passage_lstm_mask=contexts_mask[i])

            # Shape: (num_contexts, context_len, D)
            modeled_context = output_dict['modeled_passage']
            modeled_contexts.append(modeled_context)

        return (ques_encoded_final_state, encoded_ques_tensor,
                question_mask_tensor, embedded_questions, questions_mask,
                embedded_contexts, contexts_mask, encoded_questions,
                encoded_contexts, modeled_contexts)
Esempio n. 19
0
    def forward(
        self,  # type: ignore
        source_tokens: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the entire target sequence.

        Parameters
        ----------
        source_tokens : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be
           passed through a ``TextFieldEmbedder`` and then through an encoder.
        target_tokens : Dict[str, torch.LongTensor], optional (default = None)
           Output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the
           target tokens are also represented as a ``TextField``.
        """
        # (batch_size, input_sequence_length, encoder_output_dim)
        embedded_input = self._source_embedder(source_tokens)
        batch_size, _, _ = embedded_input.size()
        source_mask = util.get_text_field_mask(source_tokens)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        # (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, source_mask, self._encoder.is_bidirectional())
        if target_tokens:
            targets = target_tokens["tokens"]
            target_sequence_length = targets.size()[1]
            # The last input from the target is either padding or the end symbol. Either way, we
            # don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        decoder_context = encoder_outputs.new_zeros(batch_size,
                                                    self._decoder_output_dim)
        last_predictions = None
        step_logits = []
        step_probabilities = []
        step_predictions = []
        for timestep in range(num_decoding_steps):
            use_gold_targets = False
            # Use gold tokens at test time when provided and at a rate of 1 -
            # _scheduled_sampling_ratio during training.
            if self.training:
                if torch.rand(1).item() >= self._scheduled_sampling_ratio:
                    use_gold_targets = True
            elif target_tokens:
                use_gold_targets = True

            if use_gold_targets:
                input_choices = targets[:, timestep]
            else:
                if timestep == 0:
                    # For the first timestep, when we do not have targets, we input start symbols.
                    # (batch_size,)
                    input_choices = source_mask.new_full(
                        (batch_size, ), fill_value=self._start_index)
                else:
                    input_choices = last_predictions
            decoder_input = self._prepare_decode_step_input(
                input_choices, decoder_hidden, encoder_outputs, source_mask)
            decoder_hidden, decoder_context = self._decoder_cell(
                decoder_input, (decoder_hidden, decoder_context))
            # (batch_size, num_classes)
            output_projections = self._output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            step_probabilities.append(class_probabilities.unsqueeze(1))
            last_predictions = predicted_classes
            # (batch_size, 1)
            step_predictions.append(last_predictions.unsqueeze(1))
        # step_logits is a list containing tensors of shape (batch_size, 1, num_classes)
        # This is (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        class_probabilities = torch.cat(step_probabilities, 1)
        all_predictions = torch.cat(step_predictions, 1)
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities,
            "predictions": all_predictions
        }
        if target_tokens:
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss
            # TODO: Define metrics
            relevant_targets = targets[:, 1:].contiguous()
            # shape: (batch_size, num_decoding_steps)
            relevant_mask = target_mask[:, 1:].contiguous()

            self.__sequence_accuracy(all_predictions.unsqueeze(1),
                                     relevant_targets, relevant_mask)
        return output_dict
Esempio n. 20
0
    def forward(self,
                spans_tensor: torch.FloatTensor,
                spans_mask: torch.FloatTensor,
                question_tensor: torch.FloatTensor,
                question_mask: torch.FloatTensor,
                evd_chain_labels: torch.FloatTensor,
                self_att_layer: Seq2SeqEncoder,
                sent_encoder: Seq2SeqEncoder,
                get_all_beam: bool = False):

        print("spans_tensor", spans_tensor.shape)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()
        # shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        spans_tensor = spans_tensor.view(batch_size, num_spans,
                                         max_batch_span_width,
                                         spans_tensor.size(2))
        # shape: (batch_size, num_spans)
        max_pooled_span_mask = (torch.sum(spans_mask, dim=-1) >= 1).float()
        att_score = None

        # extract the final hidden states as the question vector
        # Shape: (batch_size, embedding_dim)
        question_emb = util.get_final_encoder_states(question_tensor,
                                                     question_mask, True)

        # decode the most likely evidence path
        # shape (all_predictions): (batch_size, K, num_decoding_steps)
        # shape (all_logprobs): (batch_size, K, num_decoding_steps)
        # shape (seq_logprobs): (batch_size, K)
        # shape (final_hidden): (batch_size, K, decoder_output_dim)
        all_predictions, all_logprobs, seq_logprobs, final_hidden = self.evd_decoder(
            spans_tensor,
            spans_mask,
            question_emb,
            aux_input=None,  #question_emb,#None
            transition_mask=None,
            labels=evd_chain_labels)
        if self._pass_label:
            all_predictions = evd_chain_labels.long().unsqueeze(1)
            all_logprobs = torch.zeros_like(all_predictions).float()
        #print("batch:", batch_size)
        #print("predict num:", torch.sum((all_predictions > 0).float(), dim=1))
        print("all prediction:", all_predictions)

        # The selection order of each sentence. Set to -1 if not being chosen
        # shape: (batch_size, K, num_spans)
        _, beam, num_steps = all_predictions.size()
        orders = spans_tensor.new_ones((batch_size, beam, 1 + num_spans)) * -1
        indices = util.get_range_vector(num_steps, util.get_device_of(spans_tensor)).\
                float().\
                unsqueeze(0).\
                unsqueeze(0).\
                expand(batch_size, beam, num_steps)
        orders.scatter_(2, all_predictions, indices)
        orders = orders[:, :, 1:]

        # For beamsearch, get the top one. For other helpers, just like squeeze
        if not get_all_beam:
            all_predictions = all_predictions[:, 0, :]
            all_logprobs = all_logprobs[:, 0, :]
            seq_logprobs = seq_logprobs[:, 0]
            final_hidden = final_hidden[:, 0, :]

        # build the gate. The dim is set to 1 + num_spans to account for the end embedding
        # shape: (batch_size, 1+num_spans) or (batch_size, K, 1+num_spans)
        if not get_all_beam:
            gate = spans_tensor.new_zeros((batch_size, 1 + num_spans))
        else:
            gate = spans_tensor.new_zeros((batch_size, beam, 1 + num_spans))
        gate.scatter_(-1, all_predictions, 1.)
        # remove the column for end embedding
        # shape: (batch_size, num_spans) or (batch_size, K, num_spans)
        gate = gate[..., 1:]
        #print("gate:", gate)
        #print("real num:", torch.sum(gate, dim=1))
        #print("seq probs:", torch.exp(seq_logprobs))

        # shape: (batch_size * num_spans, 1) or (batch_size * K * num_spans, 1)
        if not get_all_beam:
            gate = gate.reshape(batch_size * num_spans, 1)
        else:
            gate = gate.reshape(batch_size * beam * num_spans, 1)

        # The probability of each selected sentence being selected. If not selected, set to 0.
        # shape: (batch_size * num_spans, 1) or (batch_size * K * num_spans, 1)
        if not get_all_beam:
            gate_probs = spans_tensor.new_zeros((batch_size, 1 + num_spans))
        else:
            gate_probs = spans_tensor.new_zeros(
                (batch_size, beam, 1 + num_spans))
        gate_probs.scatter_(-1, all_predictions, all_logprobs.exp())
        gate_probs = gate_probs[..., 1:]
        if not get_all_beam:
            gate_probs = gate_probs.reshape(batch_size * num_spans, 1)
        else:
            gate_probs = gate_probs.reshape(batch_size * beam * num_spans, 1)

        return all_predictions, all_logprobs, seq_logprobs, gate, gate_probs, max_pooled_span_mask, att_score, orders
Esempio n. 21
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        text_encoder: Seq2SeqEncoder,
        classifier_feedforward: FeedForward,
        verbose_metrics: False,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
        loss: Optional[dict] = None,
    ) -> None:
        super(MultilabelTextClassifier, self).__init__(vocab, regularizer)

        self.log = logging.getLogger(__name__)
        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.log.warning(f'num_classes: {self.num_classes}')
        self.text_encoder = text_encoder
        self.classifier_feedforward = classifier_feedforward
        self.log.warning(
            f'output_dim: {self.classifier_feedforward.get_output_dim()}')
        self.prediction_layer = torch.nn.Linear(
            self.classifier_feedforward.get_output_dim(), self.num_classes)
        self.pool = lambda text, mask: util.get_final_encoder_states(
            text, mask, bidirectional=True)

        self.label_accuracy = CategoricalAccuracy()
        self.label_f1_metrics = OrderedDict()
        self.verbose_metrics = verbose_metrics
        for i in range(self.num_classes):
            label = vocab.get_token_from_index(index=i, namespace="labels")
            self.log.warning(f'label {i}: {label}')
            self.label_f1_metrics[label] = F1Measure(positive_label=i)
        self.micro_f1 = MultiLabelF1Measure()
        self.label_f1 = OrderedDict()
        for i in range(self.num_classes):
            label = vocab.get_token_from_index(index=i, namespace="labels")
            self.label_f1[label] = MultiLabelF1Measure()

        if loss is not None:
            alpha = loss.get('alpha')
            gamma = loss.get('gamma')
            weight = loss.get('weight')
            if alpha is not None:
                alpha = float(alpha)
            if gamma is not None:
                gamma = float(gamma)
            if weight is not None:
                weight = torch.tensor(ast.literal_eval(weight))
        if loss is None or loss.get('type') == 'CrossEntropyLoss':
            self.loss = torch.nn.CrossEntropyLoss()
        elif loss.get('type') == 'BinaryFocalLoss':
            self.loss = BinaryFocalLoss(alpha=alpha, gamma=gamma)
        elif loss.get('type') == 'FocalLoss':
            self.loss = FocalLoss(alpha=alpha, gamma=gamma)
        elif loss.get('type') == 'MultiLabelMarginLoss':
            self.loss = torch.nn.MultiLabelMarginLoss()
        elif loss.get('type') == 'MultiLabelSoftMarginLoss':
            self.loss = torch.nn.MultiLabelSoftMarginLoss(weight)
        else:
            raise ValueError(f'Unexpected loss "{loss}"')

        initializer(self)
Esempio n. 22
0
    def forward(self, tags, history, next_sym, source_tokens, his_symptoms,
                target_tokens, **args):
        bs = len(tags)
        # self.flatten_parameters()
        # 获取history的embedding
        embeddings = self._source_embedder(history)
        mask = get_text_field_mask(history,
                                   num_wrapping_dims=1)  # num_wrapping 增加维度
        sz = list(embeddings.size())
        embeddings = embeddings.view(sz[0] * sz[1], sz[2], sz[3])
        mask = mask.view(sz[0] * sz[1], sz[2])

        # 获取每一句的hidden  bs * sen_num * hidden
        utter_hidden = self._vecoder(embeddings, mask)
        utter_hidden = utter_hidden.view(sz[0], sz[1],
                                         -1)  # bs * sen_num * hidden

        dialog_mask = get_text_field_mask(history)
        dialog_hidden = self._sen_encoder(utter_hidden, dialog_mask)  # hred的形式
        # print("dialog_hidden: ",dialog_hidden.size())

        # 初始化每个节点
        symp_state = torch.zeros(
            bs, self.symp_size,
            self.outfeature).cuda()  # bs * symp_size * hidden
        symp_state += self.symp_state  # 每一个节点的初始化emb相同,这是个问题吗?

        # 句子与句子连边  如果不用cuda呢
        sym_mat = torch.zeros(bs, self.symp_size, self.symp_size)

        for i in range(bs):
            dic = {}
            for j in range(len(tags[i])):  # 这一个地方可以改一下,这里是和前面的所有有关系的边都连上了
                symp_state[i][self.topic + j] = utter_hidden[i][j]
                dic[j] = set(list(tags[i][j]))
                for k in range(j):
                    for aa in dic[j]:
                        if aa in dic[k] and aa != -1:
                            # sym_mat[i][self.topic+j][self.topic+k] += 1
                            sym_mat[i][self.topic + k][self.topic + j] += 1

        last_h = self.attn_one(symp_state, sym_mat)
        sym_mat = torch.zeros(bs, self.symp_size, self.symp_size)
        for i in range(bs):
            for j in range(len(tags[i])):
                for tt in tags[i][j]:
                    if tt != -1:
                        sym_mat[i][self.topic + j][tt] += 1
        #
        last_h = self.attn_two(last_h, sym_mat)
        #
        # topic和topic连边
        sym_mat = torch.zeros(bs, self.symp_size, self.symp_size)
        #加边
        # for symp_i in his_symptoms:
        #     for symp_j in his_symptoms:
        #         self.evovl_mat[symp_i][symp_j] = 1
        # temp_mat = (torch.nn.functional.relu(self.symp_mat) + self.evovl_mat).cpu()
        # with open('visulize_graph.txt', 'a') as fout:
        #     fout.write('evovl_mat is: \n')
        #     for i in self.evovl_mat.detach().cpu().numpy():
        #         fout.write(str(i) + '\n')
        #     fout.write('temp_mat is: \n')
        #     for i in temp_mat.detach().cpu().numpy():
        #         fout.write(str(i) + '\n')
        # print('[info] temp_mat is:{}'.format(temp_mat))
        sym_mat[:, :self.topic, :self.topic] += self.symp_mat

        last_h = self.attn_three(last_h, sym_mat)
        # last_h = self.attn_three(last_h, sym_mat)

        topic_pre = torch.sum(self.predict_layer * last_h,
                              dim=-1) + self.predict_bias
        topic_probs = torch.sigmoid(topic_pre)
        topics_weight = torch.ones_like(topic_probs) + 5 * next_sym.float()
        topic_loss = torch.nn.functional.binary_cross_entropy(
            topic_probs, next_sym.float(), weight=topics_weight)

        ans = (topic_probs > 0.5).long()

        # his_symptoms bs * sym_size?
        # his_mask = torch.where(his_symptoms > 0, torch.full_like(his_symptoms, 0), torch.full_like(his_symptoms,1)).long()

        # 隐藏句子节点
        # his_mask
        his_sentence_mask = torch.zeros(bs, self.sen_num).long()
        total_mask = torch.cat(
            (torch.ones(bs, self.topic).long(), his_sentence_mask), -1)

        if self.training and torch.rand(
                1).item() < self._scheduled_sampling_ratio:
            aa = next_sym.long()
        else:
            aa = ans

        # total_mask = torch.ones(bs, self.symp_size).cuda()
        # total_mask = total_mask.long() & his_mask.long()
        topic_embedding = aa.float().matmul(self.symp_state)
        topic_hidden = last_h

        # 计算topic的f1, acc, rec
        pre_total = torch.sum(ans).item()
        true_total = torch.sum(next_sym).item()
        pre_right = torch.sum((ans == next_sym).long() * next_sym).item()
        # print(pre_total,pre_right)
        self.topic_acc(pre_right, pre_total)
        self.topic_rec(pre_right, true_total)
        acc = self.topic_acc.get_metric(False)
        rec = self.topic_rec.get_metric(False)
        f1 = 0.
        if acc + rec > 0:
            f1 = acc * rec * 2 / (acc + rec)
        self.topic_f1(f1)

        # Encoding source_tokens
        embedded_input = self._source_embedder(source_tokens)
        source_mask = util.get_text_field_mask(source_tokens)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, source_mask, self._encoder.is_bidirectional())
        # if self.training:
        #     ff = next_sym.float().matmul(symp_state)
        # else:
        #     ff = topics_weight.matmul(symp_state)
        # print('[info]final_encoder_output is:{}, ff:{}'.format(final_encoder_output.size(), ff.size()))
        state = {
            "source_mask":
            source_mask,
            "encoder_outputs":
            encoder_outputs,  # bs * seq_len * dim
            "decoder_hidden":
            dialog_hidden,  # bs * dim hred的输出
            # "decoder_hidden": torch.cat((topic_embedding, dialog_hidden), -1),
            "decoder_context":
            encoder_outputs.new_zeros(bs, self._decoder_output_dim),
            "topic_embedding":
            topic_embedding
        }
        # state[''] = topic_embedding
        # 获取一次decoder
        output_dict = self._forward_loop(state, topic_hidden,
                                         total_mask.cuda(), target_tokens)
        best_predictions = output_dict["predictions"]

        # output something
        references, hypothesis = [], []
        for i in range(bs):
            cut_hypo = best_predictions[i][:]
            if self._end_index in list(best_predictions[i]):
                cut_hypo = best_predictions[i][:list(best_predictions[i]).
                                               index(self._end_index)]
            hypothesis.append([
                self.vocab.get_token_from_index(idx.item()) for idx in cut_hypo
            ])

        flag = 1
        for i in range(bs):
            cut_ref = target_tokens['tokens'][1:]
            if self._end_index in list(target_tokens['tokens'][i]):
                cut_ref = target_tokens['tokens'][i][
                    1:list(target_tokens['tokens'][i]).index(self._end_index)]
            references.append([
                self.vocab.get_token_from_index(idx.item()) for idx in cut_ref
            ])
            if random.random() <= 0.001 and flag == 1:  #not self.training and
                flag = 0
                for jj in range(i):
                    print('___hypo___', ''.join(hypothesis[jj]), end=' ## ')
                    print(''.join(references[jj]))
                    print("")

        self.bleu_aver(references, hypothesis)
        self.bleu1(references, hypothesis)
        self.bleu2(references, hypothesis)
        self.bleu4(references, hypothesis)
        self.kd_metric(references, hypothesis)
        self.dink1(hypothesis)
        self.dink2(hypothesis)
        if self.training:
            output_dict['loss'] = output_dict['loss'] + 8 * topic_loss
        else:
            output_dict['loss'] = topic_loss
        return output_dict
Esempio n. 23
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[QuarelWorld],
                actions: List[List[ProductionRule]],
                entity_bits: torch.Tensor = None,
                denotation_target: torch.Tensor = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """

        table_text = table['text']

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                           num_entities * num_entity_tokens,
                                                                           self._embedding_dim),
                                                       torch.transpose(embedded_question, 1, 2))

                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim)
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(batch_size,
                                       num_question_tokens*num_entities*num_entity_tokens,
                                       self._embedding_dim)
                question_entity_similarity = self._entity_similarity_layer(product)
                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)
                linking_scores = question_entity_similarity_max_score

            # (batch_size, num_entities, num_question_tokens, num_features)
            linking_features = table['linking']

            if self._linking_params is not None:
                feature_scores = self._linking_params(linking_features).squeeze(3)
                linking_scores = linking_scores + feature_scores

            # (batch_size, num_question_tokens, num_entities)
            linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                    question_mask, entity_type_dict)
            encoder_input = embedded_question
        else:
            if entity_bits is not None and not self._entity_bits_output:
                encoder_input = torch.cat([embedded_question, entity_bits], 2)
            else:
                encoder_input = embedded_question

            # Fake linking_scores added for downstream code to not object
            linking_scores = question_mask.clone().fill_(0).unsqueeze(1)
            linking_probabilities = None

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        if self._entity_bits_output and entity_bits is not None:
            encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        # For predicting a categorical denotation directly
        if self._denotation_only:
            denotation_logits = self._denotation_classifier(final_encoder_output)
            loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1))
            self._denotation_accuracy_cat(denotation_logits, denotation_target)
            return {"loss": loss}

        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim)

        _, num_entities, num_question_tokens = linking_scores.size()

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []

        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_question,
                                                 encoder_output_list,
                                                 question_mask_list))

        initial_grammar_state = [self._create_grammar_state(world[i], actions[i],
                                                            linking_scores[i], entity_types[i])
                                 for i in range(batch_size)]

        initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          extras=None,
                                          debug_info=None)

        if self.training:
            outputs = self._decoder_trainer.decode(initial_state,
                                                   self._decoder_step,
                                                   (target_action_sequences, target_mask))
            return outputs

        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            if self._linking_params is not None:
                outputs['linking_scores'] = linking_scores
                outputs['feature_scores'] = feature_scores
                outputs['linking_features'] = linking_features
            if self._use_entities:
                outputs['linking_probabilities'] = linking_probabilities
            if entity_bits is not None:
                outputs['entity_bits'] = entity_bits
            # outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            outputs['denotation_acc'] = []
            outputs['score'] = []
            outputs['parse_acc'] = []
            outputs['answer_index'] = []
            if metadata is not None:
                outputs['question_tokens'] = []
                outputs['world_extractions'] = []
            for i in range(batch_size):
                if metadata is not None:
                    outputs['question_tokens'].append(metadata[i].get('question_tokens', []))
                if metadata is not None:
                    outputs['world_extractions'].append(metadata[i].get('world_extractions', {}))
                outputs['entities'].append(world[i].table_graph.entities)
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    sequence_in_targets = 0
                    if target_action_sequences is not None:
                        targets = target_action_sequences[i].data
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    denotation_accuracy = 0.0
                    predicted_answer_index = world[i].execute(logical_form)
                    if metadata is not None and 'answer_index' in metadata[i]:
                        answer_index = metadata[i]['answer_index']
                        denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index)
                        self._denotation_accuracy(denotation_accuracy)
                    score = math.exp(best_final_states[i][0].score[0].data.cpu().item())
                    outputs['answer_index'].append(predicted_answer_index)
                    outputs['score'].append(score)
                    outputs['parse_acc'].append(sequence_in_targets)
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['denotation_acc'].append(denotation_accuracy)
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                else:
                    outputs['parse_acc'].append(0)
                    outputs['logical_form'].append('')
                    outputs['denotation_acc'].append(0)
                    outputs['score'].append(0)
                    outputs['answer_index'].append(-1)
                    outputs['best_action_sequence'].append([])
                    outputs['debug_info'].append([])
                    self._has_logical_form(0.0)
            return outputs
Esempio n. 24
0
    def _get_initial_state_and_scores(
            self,
            question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor],
            world: List[WikiTablesWorld],
            actions: List[List[ProductionRuleArray]],
            example_lisp_string: List[str] = None,
            add_world_to_initial_state: bool = False,
            checklist_states: List[ChecklistState] = None) -> Dict:
        """
        Does initial preparation and creates an intiial state for both the semantic parsers. Note
        that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to
        pass it.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text,
                                              num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities,
                                                      encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(
            encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask(
            {
                'ignored': neighbor_indices + 1
            }, num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(
            BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                              neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(
            embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(
            entity_type_embeddings + projected_neighbor_embeddings)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_table.view(batch_size, num_entities * num_entity_tokens,
                                self._embedding_dim),
            torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(
                question_entity_similarity_max_score,
                torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(
                question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(
                    -1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            world, linking_scores.transpose(1, 2), question_mask,
            entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())

        initial_score = embedded_question.data.new_zeros(batch_size)

        action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(
            actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(
            linking_scores, world, actions)
        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnState(final_encoder_output[i], memory_cell[i],
                         self._first_action_embedding,
                         self._first_attended_question, encoder_output_list,
                         question_mask_list))
        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i])
            for i in range(batch_size)
        ]
        initial_state_world = world if add_world_to_initial_state else None
        initial_state = WikiTablesDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            output_action_embeddings=output_action_embeddings,
            action_biases=action_biases,
            action_indices=action_indices,
            possible_actions=actions,
            flattened_linking_scores=flattened_linking_scores,
            actions_to_entities=actions_to_entities,
            entity_types=entity_type_dict,
            world=initial_state_world,
            example_lisp_string=example_lisp_string,
            checklist_state=checklist_states,
            debug_info=None)
        return {
            "initial_state": initial_state,
            "linking_scores": linking_scores,
            "feature_scores": feature_scores,
            "similarity_scores": question_entity_similarity_max_score
        }
Esempio n. 25
0
 def forward(self, inputs: torch.Tensor, mask: torch.Tensor):
     # https://github.com/allenai/allennlp/issues/2411
     return get_final_encoder_states(self._seq2seq(inputs, None), mask)
Esempio n. 26
0
 def forward(self, inputs: torch.Tensor, mask: torch.Tensor):
     out = self.stacked_self_att_enc(inputs, mask)
     return get_final_encoder_states(out, mask)
    def forward(
            self,  # type: ignore
            source_tokens: Dict[str, torch.LongTensor] = None,
            target_tokens: Dict[str, torch.LongTensor] = None,
            source_tokens_raw=None,
            target_tokens_raw=None,
            predict: bool = False) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        embedded_input = self._source_embedder(source_tokens)
        batch_size, _, _ = embedded_input.size()
        source_mask = get_text_field_mask(source_tokens)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        final_encoder_output = get_final_encoder_states(
            encoder_outputs, source_mask
        )  #encoder_outputs[:, -1]  # (batch_size, encoder_output_dim)
        if target_tokens:
            targets = target_tokens["tokens"]
            target_sequence_length = targets.size()[1]
            # The last input from the target is either padding or the end symbol. Either way, we
            # don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps
        decoder_hidden = self.decode_h0_projection_layer(final_encoder_output)
        decoder_context = self.decode_h0_projection_layer(final_encoder_output)
        last_predictions = None
        step_attensions = []
        step_probabilities = []
        step_predictions = []
        step_p_gen = []
        for timestep in range(num_decoding_steps):
            if self.training and all(
                    torch.rand(1) >= self._scheduled_sampling_ratio):
                input_choices = targets[:, timestep]
            else:
                if timestep == 0:
                    # For the first timestep, when we do not have targets, we input start symbols.
                    # (batch_size,)
                    input_choices = source_mask.new().resize_(
                        batch_size).fill_(self._start_index)
                else:
                    input_choices = last_predictions
            # input_indices : (batch_size,)  since we are processing these one timestep at a time.
            # (batch_size, target_embedding_dim)
            input_choices = {'tokens': input_choices}
            decoder_input = self._target_embedder(input_choices)
            #Dh_t(S_t),Dc_t
            decoder_hidden, decoder_context = self._decoder_cell(
                decoder_input, (decoder_hidden, decoder_context))

            #cat[S_t,H*_t(short memory)]
            P_attensions, decoder_output = self._decode_step_output(
                decoder_hidden, encoder_outputs, source_mask)
            # (batch_size, num_classes)
            # W[S_t,H*_t]+b
            output_attention = self._output_attention_layer(decoder_output)
            output_projections = self._output_projection_layer(
                output_attention)
            # P_vocab
            class_probabilities = F.softmax(output_projections, dim=-1)
            # generation probability
            #P_gen = F.sigmoid(self._pointer_gen_layer(torch.cat((decoder_output,decoder_input),-1)))
            #class_probabilities = P_gen*class_probabilities
            #step_p_gen.append(P_gen.unsqueeze(1))
            #print(f'P_gen:{P_gen.data.mean()}')
            # list of (batch_size, 1, num_classes)
            step_attensions.append(P_attensions.unsqueeze(1))
            _, predicted_classes = torch.max(class_probabilities, 1)
            step_probabilities.append(class_probabilities.unsqueeze(1))
            last_predictions = predicted_classes
            # (batch_size, 1)
            step_predictions.append(last_predictions.unsqueeze(1))
        # This is (batch_size, num_decoding_steps, num_classes)
        all_attensions = torch.cat(step_attensions, 1)
        #all_p_gens = torch.cat(step_p_gen,1)
        class_probabilities = torch.cat(step_probabilities, 1)
        all_predictions = torch.cat(step_predictions, 1)
        output_dict = {
            "all_attensions": all_attensions,
            #"all_p_gens": all_p_gens,
            "source_tokens": source_tokens_raw,
            "class_probabilities": class_probabilities,
            "predictions": all_predictions
        }
        #att_dists = self._att_dists(all_predictions,all_attensions,source_tokens_raw)
        #output_dict.update({"att_dists":att_dists})
        if target_tokens:
            target_mask = get_text_field_mask(target_tokens)
            gen_loss = self._get_loss(class_probabilities, targets,
                                      target_mask)
            import pdb
            pdb.set_trace()
            #copy_loss = self._get_copy_loss(all_p_gens,att_dists,target_tokens_raw)
            #copy_loss = self._get_copy_loss(att_dists,target_tokens_raw)
            #loss = gen_loss#+copy_loss
            print(f'gen_loss:{gen_loss.data.mean()}'
                  )  #,copy_loss:{copy_loss.data.mean()}')
            output_dict["loss"] = gen_loss
            for metric in self.metrics.values():
                evaluated_sentences = [
                    ''.join(i)
                    for i in self.decode(output_dict)["predicted_tokens"]
                ]
                reference_sentences = [
                    ''.join([j.text for j in i]) for i in target_tokens_raw
                ]
                #print(f'evaluated_sentences:{evaluated_sentences},reference_sentences:{reference_sentences}')
                metric(evaluated_sentences, reference_sentences)

        return output_dict
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesWorld],
                actions: List[List[ProductionRuleArray]],
                example_lisp_string: List[str] = None,
                target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``, optional (default=None)
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """

        table_text = table['text']

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word cosine similarity. Need to add a small value to
        # to the table norm since there are padding values which cause a divide by 0.
        embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
        embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = question_entity_similarity_max_score + feature_scores
        else:
            # The linking score is computed as a linear projection of two terms. The first is the maximum
            # similarity score over the entity's words and the question token. The second is the maximum
            # similarity over the words in the entity's neighbors and the question token.
            #   The second term, projected_question_neighbor_similarity, is useful when
            # a column needs to be selected. For example, the question token might have no similarity
            # with the column name, but is similar with the cells in the column.
            #   Note that projected_question_neighbor_similarity is intended to capture the same information
            # as the related_column feature.
            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = Variable(encoder_outputs.data.new(batch_size, self._encoder.get_output_dim()).fill_(0))

        initial_score = Variable(embedded_question.data.new(batch_size).fill_(0))

        action_embeddings, action_indices = self._embed_actions(actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores,
                                                                                     world,
                                                                                     actions)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnState(final_encoder_output[i],
                                              memory_cell[i],
                                              self._first_action_embedding,
                                              self._first_attended_question,
                                              encoder_output_list,
                                              question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i], actions[i])
                                 for i in range(batch_size)]
        initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)),
                                               action_history=[[] for _ in range(batch_size)],
                                               score=initial_score_list,
                                               rnn_state=initial_rnn_state,
                                               grammar_state=initial_grammar_state,
                                               action_embeddings=action_embeddings,
                                               action_indices=action_indices,
                                               possible_actions=actions,
                                               flattened_linking_scores=flattened_linking_scores,
                                               actions_to_entities=actions_to_entities,
                                               entity_types=entity_type_dict,
                                               debug_info=None)
        if self.training:
            return self._decoder_trainer.decode(initial_state,
                                                self._decoder_step,
                                                (target_action_sequences, target_mask))
        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            outputs['linking_scores'] = linking_scores
            if self._linking_params is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    if example_lisp_string:
                        self._denotation_accuracy(logical_form, example_lisp_string[i])
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                    outputs['entities'].append(world[i].table_graph.entities)
                else:
                    outputs['logical_form'].append('')
                    self._has_logical_form(0.0)
                    if example_lisp_string:
                        self._denotation_accuracy(None, example_lisp_string[i])
            return outputs
Esempio n. 29
0
 def forward(self, inputs: torch.Tensor, mask: torch.Tensor):
     output_seq = self.encoder(inputs, mask)
     output_vec = get_final_encoder_states(output_seq, mask)
     return output_vec
    def _get_initial_state_and_scores(self,
                                      question: Dict[str, torch.LongTensor],
                                      table: Dict[str, torch.LongTensor],
                                      world: List[WikiTablesWorld],
                                      actions: List[List[ProductionRuleArray]],
                                      example_lisp_string: List[str] = None,
                                      add_world_to_initial_state: bool = False,
                                      checklist_states: List[ChecklistState] = None) -> Dict:
        """
        Does initial preparation and creates an intiial state for both the semantic parsers. Note
        that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to
        pass it.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())

        initial_score = embedded_question.data.new_zeros(batch_size)

        action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores,
                                                                                     world,
                                                                                     actions)
        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnState(final_encoder_output[i],
                                              memory_cell[i],
                                              self._first_action_embedding,
                                              self._first_attended_question,
                                              encoder_output_list,
                                              question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i], actions[i])
                                 for i in range(batch_size)]
        initial_state_world = world if add_world_to_initial_state else None
        initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)),
                                               action_history=[[] for _ in range(batch_size)],
                                               score=initial_score_list,
                                               rnn_state=initial_rnn_state,
                                               grammar_state=initial_grammar_state,
                                               action_embeddings=action_embeddings,
                                               output_action_embeddings=output_action_embeddings,
                                               action_biases=action_biases,
                                               action_indices=action_indices,
                                               possible_actions=actions,
                                               flattened_linking_scores=flattened_linking_scores,
                                               actions_to_entities=actions_to_entities,
                                               entity_types=entity_type_dict,
                                               world=initial_state_world,
                                               example_lisp_string=example_lisp_string,
                                               checklist_state=checklist_states,
                                               debug_info=None)
        return {"initial_state": initial_state,
                "linking_scores": linking_scores,
                "feature_scores": feature_scores,
                "similarity_scores": question_entity_similarity_max_score}
Esempio n. 31
0
    def _get_initial_rnn_and_grammar_state(
            self, question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld],
            actions: List[List[ProductionRuleArray]],
            outputs: Dict[str,
                          Any]) -> Tuple[List[RnnState], List[GrammarState]]:
        """
        Encodes the question and table, computes a linking between the two, and constructs an
        initial RnnState and GrammarState for each batch instance to pass to the decoder.

        We take ``outputs`` as a parameter here and `modify` it, adding things that we want to
        visualize in a demo.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text,
                                              num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities,
                                                      encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(
            encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask(
            {
                'ignored': neighbor_indices + 1
            }, num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(
            BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                              neighbor_mask)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, encoded_table)

        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)
        projected_neighbor_embeddings = self._neighbor_params(
            embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.tanh(entity_type_embeddings +
                                       projected_neighbor_embeddings)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_table.view(batch_size, num_entities * num_entity_tokens,
                                self._embedding_dim),
            torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(
                question_entity_similarity_max_score,
                torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(
                question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(
                    -1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            world, linking_scores.transpose(1, 2), question_mask,
            entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnState(final_encoder_output[i], memory_cell[i],
                         self._first_action_embedding,
                         self._first_attended_question, encoder_output_list,
                         question_mask_list))
        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i], linking_scores[i],
                                       entity_types[i])
            for i in range(batch_size)
        ]
        if not self.training:
            # We add a few things to the outputs that will be returned from `forward` at evaluation
            # time, for visualization in a demo.
            outputs['linking_scores'] = linking_scores
            if feature_scores is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
        return initial_rnn_state, initial_grammar_state
    def _get_initial_rnn_and_grammar_state(self,
                                           question: Dict[str, torch.LongTensor],
                                           table: Dict[str, torch.LongTensor],
                                           world: List[WikiTablesWorld],
                                           actions: List[List[ProductionRule]],
                                           outputs: Dict[str, Any]) -> Tuple[List[RnnStatelet],
                                                                             List[LambdaGrammarStatelet]]:
        """
        Encodes the question and table, computes a linking between the two, and constructs an
        initial RnnStatelet and LambdaGrammarStatelet for each batch instance to pass to the
        decoder.

        We take ``outputs`` as a parameter here and `modify` it, adding things that we want to
        visualize in a demo.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._entity_type_encoder_embedding(entity_types)
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_question,
                                                 encoder_output_list,
                                                 question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i],
                                                            actions[i],
                                                            linking_scores[i],
                                                            entity_types[i])
                                 for i in range(batch_size)]
        if not self.training:
            # We add a few things to the outputs that will be returned from `forward` at evaluation
            # time, for visualization in a demo.
            outputs['linking_scores'] = linking_scores
            if feature_scores is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
        return initial_rnn_state, initial_grammar_state
def embed_and_encode_ques_contexts(text_field_embedder: TextFieldEmbedder,
                                   qencoder: Seq2SeqEncoder, batch_size: int,
                                   question: Dict[str, torch.LongTensor],
                                   contexts: Dict[str, torch.LongTensor]):
    """ Embed and Encode question and contexts

        Parameters:
        -----------
        text_field_embedder: ``TextFieldEmbedder``
        qencoder: ``Seq2SeqEncoder``
        question: Dict[str, torch.LongTensor]
            Output of a TextField. Should yield tensors of shape (B, ques_length, D)
        contexts: Dict[str, torch.LongTensor]
            Output of a TextField. Should yield tensors of shape (B, num_contexts, ques_length, D)

        Returns:
        ---------
        embedded_questions: List[(ques_length, D)]
            Batch-sized list of embedded questions from the text_field_embedder
        encoded_questions: List[(ques_length, D)]
            Batch-sized list of encoded questions from the qencoder
        questions_mask: List[(ques_length)]
            Batch-sized list of questions masks
        encoded_ques_tensor: Shape: (batch_size, ques_len, D)
            Output of the qencoder
        questions_mask_tensor: Shape: (batch_size, ques_length)
            Questions mask as a tensor
        ques_encoded_final_state: Shape: (batch_size, D)
            For each question, the final state of the qencoder
        embedded_contexts: List[(num_contexts, context_length, D)]
            Batch-sized list of embedded contexts for each instance from the text_field_embedder
        contexts_mask: List[(num_contexts, context_length)]
            Batch-sized list of contexts_mask for each context in the instance

        """
    # Shape: (B, question_length, D)
    embedded_questions_tensor = text_field_embedder(question)
    # Shape: (B, question_length)
    questions_mask_tensor = allenutil.get_text_field_mask(question).float()
    embedded_questions = [
        embedded_questions_tensor[i] for i in range(batch_size)
    ]
    questions_mask = [questions_mask_tensor[i] for i in range(batch_size)]

    # Shape: (B, ques_len, D)
    encoded_ques_tensor = qencoder(embedded_questions_tensor,
                                   questions_mask_tensor)
    # Shape: (B, D)
    ques_encoded_final_state = allenutil.get_final_encoder_states(
        encoded_ques_tensor, questions_mask_tensor,
        qencoder.is_bidirectional())
    encoded_questions = [encoded_ques_tensor[i] for i in range(batch_size)]

    # # contexts is a (B, num_contexts, context_length, *) tensors
    # (tokenindexer, indexed_tensor) = next(iter(contexts.items()))
    # num_contexts = indexed_tensor.size()[1]
    # # Making a separate batched token_indexer_dict for each context -- [{token_inderxer: (C, T, *)}]
    # contexts_indices_list: List[Dict[str, torch.LongTensor]] = [{} for _ in range(batch_size)]
    # for token_indexer_name, token_indices_tensor in contexts.items():
    #         print(f"{token_indexer_name}: {token_indices_tensor.size()}")
    #         for i in range(batch_size):
    #                 contexts_indices_list[i][token_indexer_name] = token_indices_tensor[i, ...]
    #
    # # Each tensor of shape (num_contexts, context_len, D)
    # embedded_contexts = []
    # contexts_mask = []
    # # Shape: (num_contexts, context_length, D)
    # for i in range(batch_size):
    #         embedded_contexts_i = text_field_embedder(contexts_indices_list[i])
    #         embedded_contexts.append(embedded_contexts_i)
    #         contexts_mask_i = allenutil.get_text_field_mask(contexts_indices_list[i]).float()
    #         contexts_mask.append(contexts_mask_i)

    embedded_contexts_tensor = text_field_embedder(contexts,
                                                   num_wrapping_dims=1)
    contexts_mask_tensor = allenutil.get_text_field_mask(
        contexts, num_wrapping_dims=1).float()

    embedded_contexts = [
        embedded_contexts_tensor[i] for i in range(batch_size)
    ]
    contexts_mask = [contexts_mask_tensor[i] for i in range(batch_size)]

    return (embedded_questions, encoded_questions, questions_mask,
            encoded_ques_tensor, questions_mask_tensor,
            ques_encoded_final_state, embedded_contexts, contexts_mask)