Пример #1
0
    def _prepare_decode_step_input(
        self,
        input_indices: torch.LongTensor,
        decoder_hidden_state: torch.LongTensor = None,
        encoder_outputs: torch.LongTensor = None,
        encoder_outputs_mask: torch.LongTensor = None,
    ) -> torch.LongTensor:
        """
        Given the input indices for the current timestep of the decoder, and all the encoder
        outputs, compute the input at the current timestep.  Note: This method is agnostic to
        whether the indices are gold indices or the predictions made by the decoder at the last
        timestep.

        If we're not using attention, the output of this method is just an embedding of the input
        indices.  If we are, the output will be a concatentation of the embedding and an attended
        average of the encoder inputs.

        Parameters
        ----------
        input_indices : torch.LongTensor
            Indices of either the gold inputs to the decoder or the predicted labels from the
            previous timestep.
        decoder_hidden_state : torch.LongTensor, optional (not needed if no attention)
            Output of from the decoder at the last time step. Needed only if using attention.
        encoder_outputs : torch.LongTensor, optional (not needed if no attention)
            Encoder outputs from all time steps. Needed only if using attention.
        encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention)
            Masks on encoder outputs. Needed only if using attention.
        """
        input_indices = input_indices.long()
        # input_indices : (batch_size,)  since we are processing these one timestep at a time.
        # (batch_size, target_embedding_dim)
        embedded_input = self._target_embedder(input_indices)

        if self._decoder_attention is not None:
            # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
            # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
            # complain.

            # important - need to use zero-masking instead of -inf for attention
            # I've checked that doing this doesn't significantly increase time
            # per batch, but should consider only doing once
            encoder_outputs.data.masked_fill_(
                1 - encoder_outputs_mask.byte().data, 0.0)

            encoder_outputs = 0.5 * encoder_outputs
            encoder_outputs_mask = encoder_outputs_mask.float()
            encoder_outputs_mask = encoder_outputs_mask[:, :, 0]
            # (batch_size, input_sequence_length)
            attention_input = torch.cat((decoder_hidden_state, embedded_input),
                                        1)
            input_weights = self._decoder_attention(attention_input,
                                                    encoder_outputs,
                                                    encoder_outputs_mask)
            # (batch_size, input_dim)
            attended_input = weighted_sum(encoder_outputs, input_weights)
            # (batch_size, input_dim + target_embedding_dim)
            return torch.cat((attended_input, embedded_input), -1)
        else:
            return embedded_input
Пример #2
0
def convert_span_to_sequence(sequence_tensor: torch.FloatTensor,
                             spans_tensor: torch.FloatTensor,
                             span_mask: torch.LongTensor):
    batch_size, num_spans, max_batch_span_width = span_mask.size()
    recovered_indices = []
    for slice_embs, m in zip(spans_tensor.view(batch_size, num_spans, max_batch_span_width, -1),
                             span_mask.byte()):
        # print('selected_shape:', torch.masked_select(slice_embs, m.unsqueeze(-1)).shape)
        # print(torch.masked_select(slice_embs, m.unsqueeze(-1)).view(-1, 200).shape)
        recovered_indices.append(torch.masked_select(slice_embs, m.unsqueeze(-1)))
    recovered_context_representation = torch.nn.utils.rnn.pad_sequence(recovered_indices, batch_first=True)
    recovered_context_representation = recovered_context_representation.view(batch_size, sequence_tensor.size(1), -1)
    return recovered_context_representation
Пример #3
0
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:

        mask_p = util.get_text_field_mask(premise)
        mask_h = util.get_text_field_mask(hypothesis)

        embedded_p = self.text_field_embedder(premise)
        encoded_p = self.encoder(embedded_p, mask_p)

        embedded_h = self.text_field_embedder(hypothesis)
        encoded_h = self.encoder(embedded_h, mask_h)

        fc_p, fc_h = self.feedforward(encoded_p, encoded_h)

        distance = F.pairwise_distance(fc_p, fc_h)
        prediction = distance < (self.margin / 2.0)
        output_dict = {'distance': distance, "prediction": prediction}

        if label is not None:
            """
            Contrastive loss function.
            Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
            """
            y = label.float()
            l1 = y * torch.pow(distance, 2) / 2.0
            l2 = (1 - y) * torch.pow(
                torch.clamp(self.margin - distance, min=0.0), 2) / 2.0
            loss = torch.mean(l1 + l2)

            self.accuracy(prediction, label.byte())

            output_dict["loss"] = loss

        return output_dict
Пример #4
0
    def forward(
        self,  # pylint: disable=arguments-differ
        embeddings: torch.FloatTensor,
        mask: torch.LongTensor,
        num_items_to_keep: Union[int, torch.LongTensor],
        class_scores: torch.FloatTensor = None,
        gold_labels: torch.long = None
    ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor,
               torch.FloatTensor]:
        """
        Extracts the top-k scoring items with respect to the scorer. We additionally return
        the indices of the top-k in their original order, not ordered by score, so that downstream
        components can rely on the original ordering (e.g., for knowing what spans are valid
        antecedents in a coreference resolution model). May use the same k for all sentences in
        minibatch, or different k for each.

        Parameters
        ----------
        embeddings : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for
            each item in the list that we want to prune.
        mask : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_items), denoting unpadded elements of
            ``embeddings``.
        num_items_to_keep : ``Union[int, torch.LongTensor]``, required.
            If a tensor of shape (batch_size), specifies the number of items to keep for each
            individual sentence in minibatch.
            If an int, keep the same number of items for all sentences.
        class_scores:
           Class scores to be used with entity beam.
        candidate_labels: If in debugging mode, use gold labels to get beam.

        Returns
        -------
        top_embeddings : ``torch.FloatTensor``
            The representations of the top-k scoring items.
            Has shape (batch_size, max_num_items_to_keep, embedding_size).
        top_mask : ``torch.LongTensor``
            The corresponding mask for ``top_embeddings``.
            Has shape (batch_size, max_num_items_to_keep).
        top_indices : ``torch.IntTensor``
            The indices of the top-k scoring items into the original ``embeddings``
            tensor. This is returned because it can be useful to retain pointers to
            the original items, if each item is being scored by multiple distinct
            scorers, for instance. Has shape (batch_size, max_num_items_to_keep).
        top_item_scores : ``torch.FloatTensor``
            The values of the top-k scoring items.
            Has shape (batch_size, max_num_items_to_keep, 1).
        num_items_kept
        """
        # If an int was given for number of items to keep, construct tensor by repeating the value.
        if isinstance(num_items_to_keep, int):
            batch_size = mask.size(0)
            # Put the tensor on same device as the mask.
            num_items_to_keep = num_items_to_keep * torch.ones(
                [batch_size], dtype=torch.long, device=mask.device)

        mask = mask.unsqueeze(-1)
        num_items = embeddings.size(1)

        # Shape: (batch_size, num_items, 1)
        # If entity beam is one, use the class scores. Else ignore them and use the scorer.
        if self._entity_beam:
            scores, _ = class_scores.max(dim=-1)
            scores = scores.unsqueeze(-1)
        # If gold beam is one, give a score of 0 wherever the gold label is non-zero (indicating a
        # non-null label), otherwise give a large negative number.
        elif self._gold_beam:
            scores = torch.where(
                gold_labels > 0,
                torch.zeros_like(gold_labels, dtype=torch.float),
                -1e20 * torch.ones_like(gold_labels, dtype=torch.float))
            scores = scores.unsqueeze(-1)
        else:
            scores = self._scorer(embeddings)

        # If we're only keeping items that score above a given threshold, change the number of kept
        # items here.
        if self._min_score_to_keep is not None:
            num_good_items = torch.sum(scores > self._min_score_to_keep,
                                       dim=1).squeeze()
            num_items_to_keep = torch.min(num_items_to_keep, num_good_items)
        # If gold beam is on, keep the gold items.
        if self._gold_beam:
            num_items_to_keep = torch.sum(gold_labels > 0, dim=1)

        # Always keep at least one item to avoid edge case with empty matrix.
        max_items_to_keep = max(num_items_to_keep.max().item(), 1)

        if scores.size(-1) != 1 or scores.dim() != 3:
            raise ValueError(
                f"The scorer passed to Pruner must produce a tensor of shape"
                f"(batch_size, num_items, 1), but found shape {scores.size()}")
        # Make sure that we don't select any masked items by setting their scores to be very
        # negative.  These are logits, typically, so -1e20 should be plenty negative.
        # NOTE(`mask` needs to be a byte tensor now.)
        scores = util.replace_masked_values(scores, mask.byte(), -1e20)

        # Shape: (batch_size, max_num_items_to_keep, 1)
        _, top_indices = scores.topk(max_items_to_keep, 1)

        # Mask based on number of items to keep for each sentence.
        # Shape: (batch_size, max_num_items_to_keep)
        top_indices_mask = util.get_mask_from_sequence_lengths(
            num_items_to_keep, max_items_to_keep)
        top_indices_mask = top_indices_mask.bool()

        # Shape: (batch_size, max_num_items_to_keep)
        top_indices = top_indices.squeeze(-1)

        # Fill all masked indices with largest "top" index for that sentence, so that all masked
        # indices will be sorted to the end.
        # Shape: (batch_size, 1)
        fill_value, _ = top_indices.max(dim=1)
        fill_value = fill_value.unsqueeze(-1)
        # Shape: (batch_size, max_num_items_to_keep)
        top_indices = torch.where(top_indices_mask, top_indices, fill_value)

        # Now we order the selected indices in increasing order with
        # respect to their indices (and hence, with respect to the
        # order they originally appeared in the ``embeddings`` tensor).
        top_indices, _ = torch.sort(top_indices, 1)

        # Shape: (batch_size * max_num_items_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select items for each element in the batch.
        flat_top_indices = util.flatten_and_batch_shift_indices(
            top_indices, num_items)

        # Shape: (batch_size, max_num_items_to_keep, embedding_size)
        top_embeddings = util.batched_index_select(embeddings, top_indices,
                                                   flat_top_indices)

        # Combine the masks on spans that are out-of-bounds, and the mask on spans that are outside
        # the top k for each sentence.
        # Shape: (batch_size, max_num_items_to_keep)
        sequence_mask = util.batched_index_select(mask, top_indices,
                                                  flat_top_indices)
        sequence_mask = sequence_mask.squeeze(-1).bool()
        top_mask = top_indices_mask & sequence_mask
        top_mask = top_mask.long()

        # Shape: (batch_size, max_num_items_to_keep, 1)
        top_scores = util.batched_index_select(scores, top_indices,
                                               flat_top_indices)

        return top_embeddings, top_mask, top_indices, top_scores, num_items_to_keep