def scatter_topk( src: Tensor, index: LongTensor, k: int, num_chunks=None, fill_value=None ) -> Tuple[Tensor, LongTensor, LongTensor]: """ Args: src: index: must be sorted in ascending order k: num_chunks: fill_value: Returns: A 1D tensor of shape [num_chunks * k] """ if src.ndimension() > 1: raise ValueError("Only implemented for 1D tensors") if num_chunks is None: num_chunks = index.max().item() + 1 if fill_value is None: fill_value = float("NaN") result_values = src.new_full((num_chunks * k,), fill_value=fill_value) result_indexes_whole = index.new_full((num_chunks * k,), fill_value=-1) result_indexes_within_chunk = index.new_full((num_chunks * k,), fill_value=-1) chunk_sizes = ( index.new_zeros(num_chunks) .scatter_add_(dim=0, index=index, src=torch.ones_like(index)) .tolist() ) start = 0 for chunk_idx, chunk_size in enumerate(chunk_sizes): chunk = src[start : start + chunk_size] values, indexes = torch.topk(chunk, k=min(k, chunk_size), dim=0) result_values[chunk_idx * k : chunk_idx * k + len(values)] = values result_indexes_within_chunk[ chunk_idx * k : chunk_idx * k + len(indexes) ] = indexes result_indexes_whole[chunk_idx * k : chunk_idx * k + len(indexes)] = ( indexes + start ) start += chunk_size return result_values, result_indexes_whole, result_indexes_within_chunk
def map_predictions(self, predictions: torch.LongTensor, source_token_ids: torch.LongTensor, meta_field: List[Dict]) -> torch.LongTensor: """ Map those copy indices to target idx :return: """ batch_size, max_length = predictions.size() mapped_predictions = predictions.new_full((batch_size,max_length), fill_value=self._pad_index) for i in range(batch_size): source_tokens_to_copy = meta_field[i]['source_tokens_to_copy'] for j in range(max_length): idx = predictions[i, j] if idx < self._num_classes: mapped_predictions[i, j] = idx else: # Copy source_idx = idx - self._num_classes if source_idx > len(source_tokens_to_copy): tid = self._pad_index else: token = source_tokens_to_copy[source_idx] # source_token_id = int(source_token_ids[i, source_idx]) # token = self.vocab.get_token_from_index(source_token_id, self._source_namespace) tid = self.vocab.get_token_index(token, self._target_namespace) mapped_predictions[i, j] = tid return mapped_predictions.long()
def span_to_position_ids(span: torch.LongTensor, max_length: int = None) -> torch.LongTensor: batch_size = span.size(0) max_length = max_length or get_span_max_length(span) position_ids = span.new_full((batch_size, max_length), fill_value=-1) for i, (start, end) in enumerate(span): positions = torch.arange(start, end + 1) position_ids[i, :len(positions)] = positions return position_ids
def _input_ids_to_outputs(self, input_ids: torch.LongTensor, step: int, cache: Cache) -> Tuple[torch.Tensor, Cache]: r"""The function is called in beam-search decoding. :attr:`inputs` should be of shape ``[batch_size]``. Returns: A tuple of logits and updated cache. Logits are of shape ``[batch_size, vocab_size]``. """ _batch_size = input_ids.size(0) times = input_ids.new_full((_batch_size, ), step) inputs = self.embedding(input_ids, times) return self._inputs_to_outputs(inputs, cache)
def greedy_predict( self, final_encoder_output: torch.LongTensor, target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear, ) -> torch.Tensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the predicted sequence. # Parameters final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell : ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer : ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._max_decoding_steps decoder_hidden = final_encoder_output batch_size = final_encoder_output.size()[0] predictions = [ final_encoder_output.new_full((batch_size, ), fill_value=self._start_index, dtype=torch.long) ] for _ in range(num_decoding_steps): input_choices = predictions[-1] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, 1) predictions.append(predicted_classes) all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1) # Drop start symbol and return. return all_predictions[:, 1:]
def greedy_predict(self, final_encoder_output: torch.LongTensor, target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear) -> torch.Tensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the predicted sequence. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._max_decoding_steps decoder_hidden = final_encoder_output batch_size = final_encoder_output.size()[0] predictions = [final_encoder_output.new_full( (batch_size,), fill_value=self._start_index, dtype=torch.long )] for _ in range(num_decoding_steps): input_choices = predictions[-1] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, 1) predictions.append(predicted_classes) all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1) # Drop start symbol and return. return all_predictions[:, 1:]
def _action_to_token(self, action_tokens: torch.LongTensor, draft_tokens: torch.LongTensor) -> torch.LongTensor: predicted_pointer = action_tokens.new_zeros((draft_tokens.size(0), 1)) draft_pointer = draft_tokens.new_ones((draft_tokens.size(0), 1)) predicted_tokens = action_tokens.new_full((action_tokens.size()), self.END) for act_step in action_tokens.t(): # KEEP, DELETE, COPY, ADD (other) keep_mask = act_step == self.KEEP drop_mask = act_step == self.DROP add_mask = ~(keep_mask | drop_mask) predicted_tokens.scatter_(1, predicted_pointer, draft_tokens.gather(1, draft_pointer)) predicted_tokens[add_mask] = predicted_tokens[add_mask].scatter( 1, predicted_pointer[add_mask], act_step[add_mask].unsqueeze(1)) draft_pointer[keep_mask | drop_mask] += 1 predicted_pointer[~drop_mask] += 1 return predicted_tokens
def scatter_topk_2d_flat( src: Tensor, index: LongTensor, k: int, dim_size=None, fill_value=None ) -> Tuple[Tensor, Tuple[LongTensor, LongTensor], Tuple[LongTensor, LongTensor]]: """Finds the top k values in a 2D array partitioned along the dimension 0. :: +-----------------------+ | X | | X | | X | | X | +-----------------------+ | | | Y | | Y | +-------+ | | |X X X X| | | top 4 +-------+ | | --------> |X X X X| | | +-------+ | Y | |Z Z Z Z| | | +-------+ | Y | | | +-----------------------+ | | | Z Z | | | | Z Z | | | +-----------------------+ Args: src: index: k: dim_size: fill_value: Returns: """ if src.ndimension() != 2: raise ValueError("Only implemented for 2D tensors") if dim_size is None: dim_size = index.max().item() + 1 if fill_value is None: fill_value = float("NaN") ncols = src.shape[1] result_values = src.new_full((dim_size, k), fill_value=fill_value) result_indexes_whole_0 = index.new_full((dim_size, k), fill_value=-1) result_indexes_whole_1 = index.new_full((dim_size, k), fill_value=-1) result_indexes_within_chunk_0 = index.new_full((dim_size, k), fill_value=-1) result_indexes_within_chunk_1 = index.new_full((dim_size, k), fill_value=-1) chunk_sizes = ( index.new_zeros(dim_size) .scatter_add_(dim=0, index=index, src=torch.ones_like(index)) .tolist() ) start_src = 0 for chunk_idx, chunk_size in enumerate(chunk_sizes): flat_chunk = src[start_src : start_src + chunk_size, :].flatten() flat_values, flat_indexes = torch.topk( flat_chunk, k=min(k, chunk_size * ncols), dim=0 ) result_values[chunk_idx, : len(flat_values)] = flat_values indexes_0 = flat_indexes / ncols indexes_1 = flat_indexes % ncols result_indexes_within_chunk_0[chunk_idx, : len(flat_indexes)] = indexes_0 result_indexes_within_chunk_1[chunk_idx, : len(flat_indexes)] = indexes_1 result_indexes_whole_0[chunk_idx, : len(flat_indexes)] = indexes_0 + start_src result_indexes_whole_1[chunk_idx, : len(flat_indexes)] = indexes_1 start_src += chunk_size return ( result_values, (result_indexes_whole_0, result_indexes_whole_1), (result_indexes_within_chunk_0, result_indexes_within_chunk_1), )
def beam_search( self, final_encoder_output: torch.LongTensor, width: int, num_decoding_steps: int, target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear ) -> Tuple[torch.Tensor, torch.Tensor]: """ Uses beam search to compute the highest probability sequences for the ``decoder_cell`` that fit within the given``width``. Returns the tuple consisting of the sequences themselves and their log probabilities. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. width : ``int``, required Size of the beam. num_decoding_steps : ``int``, required Maximum sequence length. target_embedder : ``Embedding``, required Used to embed the token predicted at the previous time step. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. Returns ------- predictions : ``torch.LongTensor`` Tensor of shape (batch_size, width, num_decoding_steps) with the predicted indices. log_probabilities : ``torch.FloatTensor`` Tensor of shape (batch_size, width) with the log probability of the corresponding prediction. """ batch_size = final_encoder_output.size()[0] # List of (batch_size, width) tensors. One for each time step. Does not # include the start symbols, which are implicit. predictions = [] # List of (batch_size, width) tensors. One for each time step. None for # the first. Stores the index n for the parent prediction, i.e. # predictions[t-1][i][n], that it came from. backpointers = [] # Calculate the first timestep. This is done outside the main loop # because we are going from a single decoder input (the output from the # encoder) to the top ``width`` decoder outputs. On the other hand, # within the main loop we are going from the ``width`` elements of the # beam to ``width``^2 candidates from which we will select the top # ``width`` elements for the next iteration. start_predictions = final_encoder_output.new_full( (batch_size, ), fill_value=self._start_index, dtype=torch.long) start_decoder_input = target_embedder(start_predictions) start_decoder_hidden = decoder_cell(start_decoder_input, final_encoder_output) start_output_projections = output_projection_layer( start_decoder_hidden) start_class_log_probabilities = F.log_softmax(start_output_projections, dim=-1) start_top_log_probabilities, start_predicted_classes = start_class_log_probabilities.topk( width) # Set starting values # The log probabilities for the last time step. (batch_size, width) last_log_probabilities = start_top_log_probabilities # [(batch_size, width)] predictions.append(start_predicted_classes) # Set the same hidden state for each element in beam. # (batch_size * width, _decoder_output_dim) decoder_hidden = start_decoder_hidden.\ unsqueeze(1).expand(batch_size, width, self._decoder_output_dim).\ reshape(batch_size * width, self._decoder_output_dim) # Log probability tensor that mandates that the end token is selected. num_classes = self.vocab.get_vocab_size(self._target_namespace) log_probs_after_end = start_class_log_probabilities.new_full( (batch_size * width, num_classes), float("-inf")) log_probs_after_end[:, self._end_index] = 0.0 for timestep in range(num_decoding_steps - 1): # (batch_size * width,) last_predictions = predictions[-1].reshape(batch_size * width) decoder_input = target_embedder(last_predictions) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size * width, num_classes) output_projections = output_projection_layer(decoder_hidden) # (batch_size * width, num_classes) class_log_probabilities = F.log_softmax(output_projections, dim=-1) # (batch_size * width, num_classes) last_predictions_expanded = last_predictions.unsqueeze(-1).expand( batch_size * width, num_classes) # Here we are finding any beams where we predicted the end token in # the previous timestep and replacing the distribution with a # one-hot distribution, forcing the beam to predict the end token # this timestep as well. cleaned_log_probabilities = torch.where( last_predictions_expanded == self._end_index, log_probs_after_end, class_log_probabilities) # Note: We could consider normalizing for length here, but the # original implementation does not do so. # (batch_size * width, width), (batch_size * width, width) top_log_probabilities, predicted_classes = cleaned_log_probabilities.topk( width) # Here we expand the last log probabilities to (batch_size * width, # width) so that we can add them to the current log probs for this # timestep. This lets us maintain the log probability of each # element on the beam. expanded_last_log_probabilities = last_log_probabilities.\ unsqueeze(2).\ expand(batch_size, width, width).\ reshape(batch_size * width, width) summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities reshaped_summed = summed_top_log_probabilities.reshape( batch_size, width * width) reshaped_predicted_classes = predicted_classes.reshape( batch_size, width * width) # Keep only the top ``width`` beam indices. restricted_beam_log_probs, restricted_beam_indices = reshaped_summed.topk( width) # Use the beam indices to extract the corresponding classes. restricted_predicted_classes = reshaped_predicted_classes.gather( 1, restricted_beam_indices) last_log_probabilities = restricted_beam_log_probs predictions.append(restricted_predicted_classes) # The beam indices come from a width * width dimension where the # indices with a common ancestor are grouped together. Hence # dividing by width gives the ancestor. (Note that this is integer # division as the tensor is a LongTensor.) backpointer = restricted_beam_indices / width backpointers.append(backpointer) # For the gather below. expanded_backpointer = backpointer.unsqueeze(2).expand( batch_size, width, self._decoder_output_dim) # Keep only the pieces of the hidden state corresponding to the # ancestors created this iteration. decoder_hidden = decoder_hidden.\ reshape(batch_size, width, self._decoder_output_dim).\ gather(1, expanded_backpointer).\ reshape(batch_size * width, self._decoder_output_dim) assert len(predictions) == num_decoding_steps,\ "len(predictions) not equal to num_decoding_steps" assert len(backpointers) == num_decoding_steps - 1,\ "len(backpointers) not equal to num_decoding_steps" # Reconstruct the sequences. reconstructed_predictions = [ predictions[num_decoding_steps - 1].unsqueeze(2) ] cur_backpointers = backpointers[num_decoding_steps - 2] for timestep in range(num_decoding_steps - 2, 0, -1): cur_preds = predictions[timestep].gather( 1, cur_backpointers).unsqueeze(2) reconstructed_predictions.append(cur_preds) cur_backpointers = backpointers[timestep - 1].gather( 1, cur_backpointers) final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2) reconstructed_predictions.append(final_preds) # We don't add the start tokens here. They are implicit. all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2) return (all_predictions, last_log_probabilities)