Beispiel #1
0
def scatter_topk(
    src: Tensor, index: LongTensor, k: int, num_chunks=None, fill_value=None
) -> Tuple[Tensor, LongTensor, LongTensor]:
    """

    Args:
        src:
        index: must be sorted in ascending order
        k:
        num_chunks:
        fill_value:

    Returns: A 1D tensor of shape [num_chunks * k]

    """
    if src.ndimension() > 1:
        raise ValueError("Only implemented for 1D tensors")

    if num_chunks is None:
        num_chunks = index.max().item() + 1

    if fill_value is None:
        fill_value = float("NaN")

    result_values = src.new_full((num_chunks * k,), fill_value=fill_value)
    result_indexes_whole = index.new_full((num_chunks * k,), fill_value=-1)
    result_indexes_within_chunk = index.new_full((num_chunks * k,), fill_value=-1)

    chunk_sizes = (
        index.new_zeros(num_chunks)
        .scatter_add_(dim=0, index=index, src=torch.ones_like(index))
        .tolist()
    )

    start = 0
    for chunk_idx, chunk_size in enumerate(chunk_sizes):
        chunk = src[start : start + chunk_size]
        values, indexes = torch.topk(chunk, k=min(k, chunk_size), dim=0)

        result_values[chunk_idx * k : chunk_idx * k + len(values)] = values
        result_indexes_within_chunk[
            chunk_idx * k : chunk_idx * k + len(indexes)
        ] = indexes
        result_indexes_whole[chunk_idx * k : chunk_idx * k + len(indexes)] = (
            indexes + start
        )

        start += chunk_size

    return result_values, result_indexes_whole, result_indexes_within_chunk
 def map_predictions(self, predictions: torch.LongTensor,
                     source_token_ids: torch.LongTensor,
                     meta_field: List[Dict]) -> torch.LongTensor:
     """
     Map those copy indices to target idx
     :return:
     """
     batch_size, max_length = predictions.size()
     mapped_predictions = predictions.new_full((batch_size,max_length), fill_value=self._pad_index)
     for i in range(batch_size):
         source_tokens_to_copy = meta_field[i]['source_tokens_to_copy']
         for j in range(max_length):
             idx = predictions[i, j]
             if idx < self._num_classes:
                 mapped_predictions[i, j] = idx
             else:
                 # Copy
                 source_idx = idx - self._num_classes
                 if source_idx > len(source_tokens_to_copy):
                     tid = self._pad_index
                 else:
                     token = source_tokens_to_copy[source_idx]
                     # source_token_id = int(source_token_ids[i, source_idx])
                     # token = self.vocab.get_token_from_index(source_token_id, self._source_namespace)
                     tid = self.vocab.get_token_index(token, self._target_namespace)
                 mapped_predictions[i, j] = tid
     return mapped_predictions.long()
Beispiel #3
0
def span_to_position_ids(span: torch.LongTensor,
                         max_length: int = None) -> torch.LongTensor:
    batch_size = span.size(0)
    max_length = max_length or get_span_max_length(span)
    position_ids = span.new_full((batch_size, max_length), fill_value=-1)

    for i, (start, end) in enumerate(span):
        positions = torch.arange(start, end + 1)
        position_ids[i, :len(positions)] = positions
    return position_ids
    def _input_ids_to_outputs(self, input_ids: torch.LongTensor, step: int,
                              cache: Cache) -> Tuple[torch.Tensor, Cache]:
        r"""The function is called in beam-search decoding.

        :attr:`inputs` should be of shape ``[batch_size]``.

        Returns:
            A tuple of logits and updated cache. Logits are of shape
            ``[batch_size, vocab_size]``.
        """
        _batch_size = input_ids.size(0)
        times = input_ids.new_full((_batch_size, ), step)
        inputs = self.embedding(input_ids, times)
        return self._inputs_to_outputs(inputs, cache)
Beispiel #5
0
    def greedy_predict(
        self,
        final_encoder_output: torch.LongTensor,
        target_embedder: Embedding,
        decoder_cell: GRUCell,
        output_projection_layer: Linear,
    ) -> torch.Tensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the predicted sequence.

        # Parameters

        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell : ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer : ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        batch_size = final_encoder_output.size()[0]
        predictions = [
            final_encoder_output.new_full((batch_size, ),
                                          fill_value=self._start_index,
                                          dtype=torch.long)
        ]
        for _ in range(num_decoding_steps):
            input_choices = predictions[-1]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            predictions.append(predicted_classes)
        all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1)
        # Drop start symbol and return.
        return all_predictions[:, 1:]
Beispiel #6
0
    def greedy_predict(self,
                       final_encoder_output: torch.LongTensor,
                       target_embedder: Embedding,
                       decoder_cell: GRUCell,
                       output_projection_layer: Linear) -> torch.Tensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the predicted sequence.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        batch_size = final_encoder_output.size()[0]
        predictions = [final_encoder_output.new_full(
                (batch_size,), fill_value=self._start_index, dtype=torch.long
        )]
        for _ in range(num_decoding_steps):
            input_choices = predictions[-1]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            predictions.append(predicted_classes)
        all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1)
        # Drop start symbol and return.
        return all_predictions[:, 1:]
Beispiel #7
0
    def _action_to_token(self, action_tokens: torch.LongTensor,
                         draft_tokens: torch.LongTensor) -> torch.LongTensor:
        predicted_pointer = action_tokens.new_zeros((draft_tokens.size(0), 1))
        draft_pointer = draft_tokens.new_ones((draft_tokens.size(0), 1))

        predicted_tokens = action_tokens.new_full((action_tokens.size()),
                                                  self.END)

        for act_step in action_tokens.t():
            # KEEP, DELETE, COPY, ADD (other)
            keep_mask = act_step == self.KEEP
            drop_mask = act_step == self.DROP
            add_mask = ~(keep_mask | drop_mask)

            predicted_tokens.scatter_(1, predicted_pointer,
                                      draft_tokens.gather(1, draft_pointer))
            predicted_tokens[add_mask] = predicted_tokens[add_mask].scatter(
                1, predicted_pointer[add_mask],
                act_step[add_mask].unsqueeze(1))

            draft_pointer[keep_mask | drop_mask] += 1
            predicted_pointer[~drop_mask] += 1
        return predicted_tokens
Beispiel #8
0
def scatter_topk_2d_flat(
    src: Tensor, index: LongTensor, k: int, dim_size=None, fill_value=None
) -> Tuple[Tensor, Tuple[LongTensor, LongTensor], Tuple[LongTensor, LongTensor]]:
    """Finds the top k values in a 2D array partitioned along the dimension 0.

    ::

        +-----------------------+
        |          X            |
        |  X                    |
        |              X        |
        |     X                 |
        +-----------------------+
        |                       |
        |                 Y     |
        |       Y               |              +-------+
        |                       |              |X X X X|
        |                       |    top 4     +-------+
        |                       |  -------->   |X X X X|
        |                       |              +-------+
        |             Y         |              |Z Z Z Z|
        |                       |              +-------+
        |   Y                   |
        |                       |
        +-----------------------+
        |                       |
        |     Z       Z         |
        |                       |
        |        Z        Z     |
        |                       |
        +-----------------------+


    Args:
        src:
        index:
        k:
        dim_size:
        fill_value:

    Returns:

    """
    if src.ndimension() != 2:
        raise ValueError("Only implemented for 2D tensors")

    if dim_size is None:
        dim_size = index.max().item() + 1

    if fill_value is None:
        fill_value = float("NaN")

    ncols = src.shape[1]

    result_values = src.new_full((dim_size, k), fill_value=fill_value)
    result_indexes_whole_0 = index.new_full((dim_size, k), fill_value=-1)
    result_indexes_whole_1 = index.new_full((dim_size, k), fill_value=-1)
    result_indexes_within_chunk_0 = index.new_full((dim_size, k), fill_value=-1)
    result_indexes_within_chunk_1 = index.new_full((dim_size, k), fill_value=-1)

    chunk_sizes = (
        index.new_zeros(dim_size)
        .scatter_add_(dim=0, index=index, src=torch.ones_like(index))
        .tolist()
    )

    start_src = 0
    for chunk_idx, chunk_size in enumerate(chunk_sizes):
        flat_chunk = src[start_src : start_src + chunk_size, :].flatten()
        flat_values, flat_indexes = torch.topk(
            flat_chunk, k=min(k, chunk_size * ncols), dim=0
        )
        result_values[chunk_idx, : len(flat_values)] = flat_values

        indexes_0 = flat_indexes / ncols
        indexes_1 = flat_indexes % ncols
        result_indexes_within_chunk_0[chunk_idx, : len(flat_indexes)] = indexes_0
        result_indexes_within_chunk_1[chunk_idx, : len(flat_indexes)] = indexes_1

        result_indexes_whole_0[chunk_idx, : len(flat_indexes)] = indexes_0 + start_src
        result_indexes_whole_1[chunk_idx, : len(flat_indexes)] = indexes_1

        start_src += chunk_size

    return (
        result_values,
        (result_indexes_whole_0, result_indexes_whole_1),
        (result_indexes_within_chunk_0, result_indexes_within_chunk_1),
    )
    def beam_search(
            self, final_encoder_output: torch.LongTensor, width: int,
            num_decoding_steps: int, target_embedder: Embedding,
            decoder_cell: GRUCell, output_projection_layer: Linear
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Uses beam search to compute the highest probability sequences for the
        ``decoder_cell`` that fit within the given``width``.  Returns the tuple
        consisting of the sequences themselves and their log probabilities.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        width : ``int``, required
            Size of the beam.
        num_decoding_steps : ``int``, required
            Maximum sequence length.
        target_embedder : ``Embedding``, required
            Used to embed the token predicted at the previous time step.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.

        Returns
        -------
        predictions : ``torch.LongTensor``
            Tensor of shape (batch_size, width, num_decoding_steps) with the predicted indices.
        log_probabilities : ``torch.FloatTensor``
            Tensor of shape (batch_size, width) with the log probability of the
            corresponding prediction.
        """
        batch_size = final_encoder_output.size()[0]
        # List of (batch_size, width) tensors. One for each time step. Does not
        # include the start symbols, which are implicit.
        predictions = []
        # List of (batch_size, width) tensors. One for each time step. None for
        # the first.  Stores the index n for the parent prediction, i.e.
        # predictions[t-1][i][n], that it came from.
        backpointers = []

        # Calculate the first timestep. This is done outside the main loop
        # because we are going from a single decoder input (the output from the
        # encoder) to the top ``width`` decoder outputs. On the other hand,
        # within the main loop we are going from the ``width`` elements of the
        # beam to ``width``^2 candidates from which we will select the top
        # ``width`` elements for the next iteration.
        start_predictions = final_encoder_output.new_full(
            (batch_size, ), fill_value=self._start_index, dtype=torch.long)
        start_decoder_input = target_embedder(start_predictions)
        start_decoder_hidden = decoder_cell(start_decoder_input,
                                            final_encoder_output)
        start_output_projections = output_projection_layer(
            start_decoder_hidden)
        start_class_log_probabilities = F.log_softmax(start_output_projections,
                                                      dim=-1)
        start_top_log_probabilities, start_predicted_classes = start_class_log_probabilities.topk(
            width)

        # Set starting values
        # The log probabilities for the last time step. (batch_size, width)
        last_log_probabilities = start_top_log_probabilities
        # [(batch_size, width)]
        predictions.append(start_predicted_classes)
        # Set the same hidden state for each element in beam.
        # (batch_size * width, _decoder_output_dim)
        decoder_hidden = start_decoder_hidden.\
            unsqueeze(1).expand(batch_size, width, self._decoder_output_dim).\
            reshape(batch_size * width, self._decoder_output_dim)

        # Log probability tensor that mandates that the end token is selected.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        log_probs_after_end = start_class_log_probabilities.new_full(
            (batch_size * width, num_classes), float("-inf"))
        log_probs_after_end[:, self._end_index] = 0.0

        for timestep in range(num_decoding_steps - 1):
            # (batch_size * width,)
            last_predictions = predictions[-1].reshape(batch_size * width)
            decoder_input = target_embedder(last_predictions)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size * width, num_classes)
            output_projections = output_projection_layer(decoder_hidden)

            # (batch_size * width, num_classes)
            class_log_probabilities = F.log_softmax(output_projections, dim=-1)

            # (batch_size * width, num_classes)
            last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
                batch_size * width, num_classes)
            # Here we are finding any beams where we predicted the end token in
            # the previous timestep and replacing the distribution with a
            # one-hot distribution, forcing the beam to predict the end token
            # this timestep as well.
            cleaned_log_probabilities = torch.where(
                last_predictions_expanded == self._end_index,
                log_probs_after_end, class_log_probabilities)

            # Note: We could consider normalizing for length here, but the
            # original implementation does not do so.

            # (batch_size * width, width), (batch_size * width, width)
            top_log_probabilities, predicted_classes = cleaned_log_probabilities.topk(
                width)
            # Here we expand the last log probabilities to (batch_size * width,
            # width) so that we can add them to the current log probs for this
            # timestep. This lets us maintain the log probability of each
            # element on the beam.
            expanded_last_log_probabilities = last_log_probabilities.\
                unsqueeze(2).\
                expand(batch_size, width, width).\
                reshape(batch_size * width, width)
            summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities

            reshaped_summed = summed_top_log_probabilities.reshape(
                batch_size, width * width)
            reshaped_predicted_classes = predicted_classes.reshape(
                batch_size, width * width)
            # Keep only the top ``width`` beam indices.
            restricted_beam_log_probs, restricted_beam_indices = reshaped_summed.topk(
                width)
            # Use the beam indices to extract the corresponding classes.
            restricted_predicted_classes = reshaped_predicted_classes.gather(
                1, restricted_beam_indices)

            last_log_probabilities = restricted_beam_log_probs
            predictions.append(restricted_predicted_classes)
            # The beam indices come from a width * width dimension where the
            # indices with a common ancestor are grouped together. Hence
            # dividing by width gives the ancestor. (Note that this is integer
            # division as the tensor is a LongTensor.)
            backpointer = restricted_beam_indices / width
            backpointers.append(backpointer)
            # For the gather below.
            expanded_backpointer = backpointer.unsqueeze(2).expand(
                batch_size, width, self._decoder_output_dim)
            # Keep only the pieces of the hidden state corresponding to the
            # ancestors created this iteration.
            decoder_hidden = decoder_hidden.\
                    reshape(batch_size, width, self._decoder_output_dim).\
                    gather(1, expanded_backpointer).\
                    reshape(batch_size * width, self._decoder_output_dim)

        assert len(predictions) == num_decoding_steps,\
               "len(predictions) not equal to num_decoding_steps"
        assert len(backpointers) == num_decoding_steps - 1,\
               "len(backpointers) not equal to num_decoding_steps"

        # Reconstruct the sequences.
        reconstructed_predictions = [
            predictions[num_decoding_steps - 1].unsqueeze(2)
        ]
        cur_backpointers = backpointers[num_decoding_steps - 2]
        for timestep in range(num_decoding_steps - 2, 0, -1):
            cur_preds = predictions[timestep].gather(
                1, cur_backpointers).unsqueeze(2)
            reconstructed_predictions.append(cur_preds)
            cur_backpointers = backpointers[timestep - 1].gather(
                1, cur_backpointers)
        final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
        reconstructed_predictions.append(final_preds)
        # We don't add the start tokens here. They are implicit.

        all_predictions = torch.cat(list(reversed(reconstructed_predictions)),
                                    2)
        return (all_predictions, last_log_probabilities)