示例#1
0
    def memorize(self, states: torch.tensor,
                 actions: torch.IntTensor,
                 next_states: torch.tensor,
                 rewards: torch.tensor):
        """
        Memorizes a batch of exploration transitions (quadruples s, a, ns, r).
        :param states: Successive states encountered. Should have shape (number_of_states, state_dim + 1) where
                           the last column values are either 1 if the correspond state is final or 0 otherwise.
        :param actions: Successive actions decided by the agent. Should be a tensor of shape
                       (number_of_states)
        :param next_states: (number_of_states, state_dim) shaped tensor indicating the next states.
        :param rewards: (number_of_states, )-sized 1D tensor containing the rewards for
                             the episode.
        """
        if len(states.size()) + len(actions.size()) + len(next_states.size()) != 5:
            raise ValueError("Wrong dimensions")
            return None

        # Make sure the tensors are on the right device
        states.to(self.device)
        next_states.to(self.device)
        actions.to(self.device)
        rewards.to(self.device)

        if self.need_init:
            self.state_mem = states
            self.action_mem = actions.type(torch.int64)
            self.nstate_mem = next_states
            self.reward_mem = rewards
            self.need_init = False
        else:
            self.state_mem = torch.cat((self.state_mem, states), dim=0)
            self.action_mem = torch.cat((self.action_mem, actions.type(torch.int64)))
            self.nstate_mem = torch.cat((self.nstate_mem, next_states), dim=0)
            nb_states_added = states.size()[0]
            self.reward_mem = torch.cat((self.reward_mem, rewards))
示例#2
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """

        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        question_mask_temp = question_mask
        question_mask_temp = question_mask_temp.unsqueeze_(1)

        #New Question SA encoding
        device0 = torch.device('cuda:0')
        device1 = torch.device('cuda:1')
        device2 = torch.device('cuda:2')
        device3 = torch.device('cuda:3')
        sa_question_sim = self._dropout(
            self._sa_matrix_attention(device1, embedded_question, None, True))
        sa_question_att = util.masked_softmax(sa_question_sim.to(device1),
                                              question_mask_temp.to(device1))
        sa_encoded_question = util.weighted_sum(encoded_question.to(device1),
                                                sa_question_att.to(device1))

        sa_encoded_question = sa_encoded_question.to(device0)

        sa_passage_sim = self._dropout(
            self._sa_matrix_attention(device2, embedded_passage, None, True))
        sa_passage_att = util.masked_softmax(
            sa_passage_sim.to(device1),
            passage_mask.clone().unsqueeze_(1).to(device1))
        sa_encoded_passage = util.weighted_sum(encoded_passage.to(device1),
                                               sa_passage_att.to(device1))
        sa_encoded_passage = sa_encoded_passage.to(device0)
        #sa_encoded_passage = encoded_passage

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        sa_passage_question_similarity = self._l_matrix_attention(
            device1, sa_encoded_passage, sa_encoded_question, False)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask_temp)
        sa_passage_question_attention = util.masked_softmax(
            sa_passage_question_similarity.to(device0), question_mask_temp)

        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)
        sa_passage_question_vectors = util.weighted_sum(
            sa_encoded_question.to(device0), sa_passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask, -1e7)
        sa_masked_similarity = util.replace_masked_values(
            sa_passage_question_similarity.to(device0), question_mask, -1e7)

        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        sa_question_passage_similarity = sa_masked_similarity.max(
            dim=-1)[0].squeeze(-1)

        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        sa_question_passage_attention = util.masked_softmax(
            sa_question_passage_similarity, passage_mask)

        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        sa_question_passage_vector = util.weighted_sum(
            sa_encoded_passage, sa_question_passage_attention)

        # Shape: (batch_size, passage_length, encoding_dim)
        #print("Shape:",question_passage_vector.size(),question_passage_vector.unsqueeze(1).size())
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)
        #print("Shape:",sa_question_passage_vector.size(),sa_question_passage_vector.unsqueeze(1).size())
        sa_tiled_question_passage_vector = sa_question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        #print("Shape of SA Encoded:",sa_encoded_question.size(),sa_encoded_passage.size())
        #print("Required Shape of Encoded Passage:",encoded_passage.size(),passage_question_vectors.size())

        #sa_passage_question_vectors = passage_question_vectors
        #sa_tiled_question_passage_vector = tiled_question_passage_vector

        # Shape: (batch_size, passage_length, encoding_dim * 4 + 4*sa_dim )
        final_merged_passage = torch.cat([
            encoded_passage, sa_encoded_passage, passage_question_vectors,
            sa_passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector,
            sa_encoded_passage * sa_passage_question_vectors,
            sa_encoded_passage * sa_tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim + 2*selfattention_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)

        # Shape: (batch_size, 1000)
        span_start_logits_pad = (0, 1000 - passage_length, 0, 0)
        span_start_logits_w_na = self._span_start_predictor_w_na(
            pad(span_start_logits, span_start_logits_pad)).squeeze(-1)

        # Shape: (batch_size, passage_lenght+1)
        span_start_logits_w_na = span_start_logits_w_na[:, :passage_length + 1]

        span_start_na_logits = span_start_logits_w_na[:, 0]
        span_start_logits = span_start_logits_w_na[:, 1:]

        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3 + 4*sadim)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim +4*sadim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)

        # Shape: (batch_size, passage_length+1)
        span_end_logits_pad = (0, 1000 - passage_length, 0, 0)
        span_end_logits_w_na = self._span_end_predictor_w_na(
            pad(span_end_logits, span_end_logits_pad)).squeeze(-1)

        # Shape: (batch_size, passage_lenght+1)
        span_end_logits_w_na = span_end_logits_w_na[:, :passage_length + 1]

        span_end_na_logits = span_end_logits_w_na[:, 0]
        span_end_logits = span_end_logits_w_na[:, 1:]

        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        na_gt = (span_start == -1).type(torch.cuda.LongTensor)
        na_inv = (1.0 - na_gt)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
            # "na_logits": na_logits,
            # "na_probs": na_probs
        }

        # Compute the loss for training.
        if span_start is not None:
            y_start = span_start + 1
            y_end = span_end + 1
            passage_mask_w_na = torch.cat([
                torch.ones([batch_size, 1]).type(torch.cuda.FloatTensor),
                passage_mask
            ], -1)
            loss = 0.0

            # calculate loss if there is answer
            # loss for start
            preds_start = util.masked_log_softmax(
                span_start_logits_w_na.type(torch.cuda.FloatTensor),
                passage_mask_w_na.type(torch.cuda.FloatTensor)).type(
                    torch.cuda.FloatTensor)
            y_start = y_start.squeeze(-1).type(torch.cuda.LongTensor)
            loss += nll_loss(preds_start, y_start)

            # accuracy for start
            acc_p_start = na_inv.type(
                torch.cuda.FloatTensor) * span_start_logits.type(
                    torch.cuda.FloatTensor)
            acc_y_start = na_inv.squeeze(-1).type(
                torch.cuda.FloatTensor) * span_start.squeeze(-1).type(
                    torch.cuda.FloatTensor)
            self._span_start_accuracy(acc_p_start, acc_y_start)

            # loss for end
            preds_end = util.masked_log_softmax(
                span_end_logits_w_na.type(torch.cuda.FloatTensor),
                passage_mask_w_na.type(torch.cuda.FloatTensor)).type(
                    torch.cuda.FloatTensor)
            y_end = y_end.squeeze(-1).type(torch.cuda.LongTensor)
            loss += nll_loss(preds_end, y_end)

            # accuracy for end
            acc_p_end = na_inv.type(
                torch.cuda.FloatTensor) * span_end_logits.type(
                    torch.cuda.FloatTensor)
            acc_y_end = na_inv.squeeze(-1).type(
                torch.cuda.FloatTensor) * span_end.squeeze(-1).type(
                    torch.cuda.FloatTensor)
            self._span_end_accuracy(acc_p_end, acc_y_end)

            # accuracy for span
            acc_p = na_inv.type(torch.cuda.FloatTensor) * best_span.type(
                torch.cuda.FloatTensor)
            acc_y = na_inv.type(torch.cuda.FloatTensor) * torch.cat([
                span_start.type(torch.cuda.FloatTensor),
                span_end.type(torch.cuda.FloatTensor)
            ], -1)
            self._span_accuracy(acc_p, acc_y)

            output_dict["loss"] = loss

            preds_start = util.masked_softmax(
                span_start_logits_w_na.type(torch.cuda.FloatTensor),
                passage_mask_w_na.type(torch.cuda.FloatTensor)).type(
                    torch.cuda.FloatTensor)
            preds_end = util.masked_softmax(
                span_end_logits_w_na.type(torch.cuda.FloatTensor),
                passage_mask_w_na.type(torch.cuda.FloatTensor)).type(
                    torch.cuda.FloatTensor)

            output_dict["na_logits"] = preds_start[:, 0] * preds_end[:, 0]
            output_dict["na_probs"] = torch.stack(
                [1.0 - output_dict["na_logits"], output_dict["na_logits"]], -1)

            # calculate loss for answer existance
            self._na_accuracy(
                output_dict["na_probs"].type(torch.cuda.FloatTensor),
                na_gt.squeeze(-1).type(torch.cuda.FloatTensor))

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict