def _get_mask_for_eval(self, mask: torch.LongTensor, pos_tags: torch.LongTensor) -> torch.LongTensor: """ Dependency evaluation excludes words are punctuation. Here, we create a new mask to exclude word indices which have a "punctuation-like" part of speech tag. Parameters ---------- mask : ``torch.LongTensor``, required. The original mask. pos_tags : ``torch.LongTensor``, required. The pos tags for the sequence. Returns ------- A new mask, where any indices equal to labels we should be ignoring are masked. """ new_mask = mask.detach() for label in self._pos_to_ignore: label_mask = pos_tags.eq(label).long() new_mask = new_mask * (1 - label_mask) return new_mask
def flattened_index_select(target: torch.Tensor, indices: torch.LongTensor) -> torch.Tensor: """ The given ``indices`` of size ``(set_size, subset_size)`` specifies subsets of the ``target`` that each of the set_size rows should select. The `target` has size ``(batch_size, sequence_length, embedding_size)``, and the resulting selected tensor has size ``(batch_size, set_size, subset_size, embedding_size)``. Parameters ---------- target : ``torch.Tensor``, required. A Tensor of shape (batch_size, sequence_length, embedding_size). indices : ``torch.LongTensor``, required. A LongTensor of shape (set_size, subset_size). All indices must be < sequence_length as this tensor is an index into the sequence_length dimension of the target. Returns ------- selected : ``torch.Tensor``, required. A Tensor of shape (batch_size, set_size, subset_size, embedding_size). """ if indices.dim() != 2: raise ConfigurationError("Indices passed to flattened_index_select had shape {} but " "only 2 dimensional inputs are supported.".format(indices.size())) # Shape: (batch_size, set_size * subset_size, embedding_size) flattened_selected = target.index_select(1, indices.view(-1)) # Shape: (batch_size, set_size, subset_size, embedding_size) selected = flattened_selected.view(target.size(0), indices.size(0), indices.size(1), -1) return selected
def __call__(self, # type: ignore predictions: torch.LongTensor, gold_targets: torch.LongTensor) -> None: """ Update precision counts. Parameters ---------- predictions : ``torch.LongTensor``, required Batched predicted tokens of shape `(batch_size, max_sequence_length)`. references : ``torch.LongTensor``, required Batched reference (gold) translations with shape `(batch_size, max_gold_sequence_length)`. Returns ------- None """ predictions, gold_targets = self.unwrap_to_tensors(predictions, gold_targets) for ngram_size, _ in enumerate(self._ngram_weights, start=1): precision_matches, precision_totals = self._get_modified_precision_counts( predictions, gold_targets, ngram_size) self._precision_matches[ngram_size] += precision_matches self._precision_totals[ngram_size] += precision_totals if not self._exclude_indices: self._prediction_lengths += predictions.size(0) * predictions.size(1) self._reference_lengths += gold_targets.size(0) * gold_targets.size(1) else: valid_predictions_mask = self._get_valid_tokens_mask(predictions) self._prediction_lengths += valid_predictions_mask.sum().item() valid_gold_targets_mask = self._get_valid_tokens_mask(gold_targets) self._reference_lengths += valid_gold_targets_mask.sum().item()
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> None: # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)] if span_indices_mask is not None: # It's not strictly necessary to multiply the span indices by the mask here, # but it's possible that the span representation was padded with something other # than 0 (such as -1, which would be an invalid index), so we do so anyway to # be safe. span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask if not self._use_exclusive_start_indices: start_embeddings = util.batched_index_select(sequence_tensor, span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) else: # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1) exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any(): raise ValueError(f"Adjusted span indices must lie inside the the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}.") start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) # We're using sentinels, so we need to replace all the elements which were # outside the dimensions of the sequence_tensor with the start sentinel. float_start_sentinel_mask = start_sentinel_mask.float() start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1) if span_indices_mask is not None: return combined_tensors * span_indices_mask.unsqueeze(-1).float() return combined_tensors
def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(1): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:, :len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor, targets: torch.LongTensor, weights: torch.FloatTensor, batch_average: bool = True) -> torch.FloatTensor: """ Computes the cross entropy loss of a sequence, weighted with respect to some user provided weights. Note that the weighting here is not the same as in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting classes; here we are weighting the loss contribution from particular elements in the sequence. This allows loss computations for models which use padding. Parameters ---------- logits : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes) which contains the unnormalized probability for each class. targets : ``torch.LongTensor``, required. A ``torch.LongTensor`` of size (batch, sequence_length) which contains the index of the true class for each corresponding step. weights : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch, sequence_length) batch_average : bool, optional, (default = True). A bool indicating whether the loss should be averaged across the batch, or returned as a vector of losses per batch element. Returns ------- A torch.FloatTensor representing the cross entropy loss. If ``batch_average == True``, the returned loss is a scalar. If ``batch_average == False``, the returned loss is a vector of shape (batch_size,). """ # shape : (batch * sequence_length, num_classes) logits_flat = logits.view(-1, logits.size(-1)) # shape : (batch * sequence_length, num_classes) log_probs_flat = torch.nn.functional.log_softmax(logits_flat) # shape : (batch * max_len, 1) targets_flat = targets.view(-1, 1).long() # Contribution to the negative log likelihood only comes from the exact indices # of the targets, as the target distributions are one-hot. Here we use torch.gather # to extract the indices of the num_classes dimension which contribute to the loss. # shape : (batch * sequence_length, 1) negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size()) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood * weights.float() # shape : (batch_size,) per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13) if batch_average: num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13) return per_batch_loss.sum() / num_non_empty_sequences return per_batch_loss
def _get_modified_precision_counts(self, predicted_tokens: torch.LongTensor, reference_tokens: torch.LongTensor, ngram_size: int) -> Tuple[int, int]: """ Compare the predicted tokens to the reference (gold) tokens at the desired ngram size and calculate the numerator and denominator for a modified form of precision. The numerator is the number of ngrams in the predicted sentences that match with an ngram in the corresponding reference sentence, clipped by the total count of that ngram in the reference sentence. The denominator is just the total count of predicted ngrams. """ clipped_matches = 0 total_predicted = 0 for batch_num in range(predicted_tokens.size(0)): predicted_row = predicted_tokens[batch_num, :] reference_row = reference_tokens[batch_num, :] predicted_ngram_counts = self._ngrams(predicted_row, ngram_size) reference_ngram_counts = self._ngrams(reference_row, ngram_size) for ngram, count in predicted_ngram_counts.items(): clipped_matches += min(count, reference_ngram_counts[ngram]) total_predicted += count return clipped_matches, total_predicted
def _get_checklist_info(agenda: torch.LongTensor, all_actions: List[ProductionRule], terminal_productions: Set[str], max_num_terminals: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Takes an agenda, a list of all actions, a set of terminal productions in the corresponding world, and a length to pad the checklist vectors to, and returns a target checklist against which the checklist at each state will be compared to compute a loss, indices of ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions are relevant for checklist loss computation. Parameters ---------- ``agenda`` : ``torch.LongTensor`` Agenda of one instance of size ``(agenda_size, 1)``. ``all_actions`` : ``List[ProductionRule]`` All actions for one instance. ``terminal_productions`` : ``Set[str]`` String representations of terminal productions in the corresponding world. ``max_num_terminals`` : ``int`` Length to which the checklist vectors will be padded till. This is the max number of terminal productions in all the worlds in the batch. """ terminal_indices = [] target_checklist_list = [] agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()]) # We want to return checklist target and terminal actions that are column vectors to make # computing softmax over the difference between checklist and target easier. for index, action in enumerate(all_actions): # Each action is a ProductionRule, a tuple where the first item is the production # rule string. if action[0] in terminal_productions: terminal_indices.append([index]) if index in agenda_indices_set: target_checklist_list.append([1]) else: target_checklist_list.append([0]) while len(target_checklist_list) < max_num_terminals: target_checklist_list.append([0]) terminal_indices.append([-1]) # (max_num_terminals, 1) terminal_actions = agenda.new_tensor(terminal_indices) # (max_num_terminals, 1) target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float) checklist_mask = (target_checklist != 0).float() return target_checklist, terminal_actions, checklist_mask
def forward(self, # pylint: disable=arguments-differ inputs: torch.Tensor, mask: torch.LongTensor) -> torch.Tensor: """ Parameters ---------- inputs : ``torch.Tensor``, required. A Tensor of shape ``(batch_size, sequence_length, hidden_size)``. mask : ``torch.LongTensor``, required. A binary mask of shape ``(batch_size, sequence_length)`` representing the non-padded elements in each sequence in the batch. Returns ------- A ``torch.Tensor`` of shape (num_layers, batch_size, sequence_length, hidden_size), where the num_layers dimension represents the LSTM output from that layer. """ batch_size, total_sequence_length = mask.size() stacked_sequence_output, final_states, restoration_indices = \ self.sort_and_run_forward(self._lstm_forward, inputs, mask) num_layers, num_valid, returned_timesteps, encoder_dim = stacked_sequence_output.size() # Add back invalid rows which were removed in the call to sort_and_run_forward. if num_valid < batch_size: zeros = stacked_sequence_output.data.new(num_layers, batch_size - num_valid, returned_timesteps, encoder_dim).fill_(0) zeros = Variable(zeros) stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 1) # The states also need to have invalid rows added back. new_states = [] for state in final_states: state_dim = state.size(-1) zeros = state.data.new(num_layers, batch_size - num_valid, state_dim).fill_(0) zeros = Variable(zeros) new_states.append(torch.cat([state, zeros], 1)) final_states = new_states # It's possible to need to pass sequences which are padded to longer than the # max length of the sequence to a Seq2StackEncoder. However, packing and unpacking # the sequences mean that the returned tensor won't include these dimensions, because # the RNN did not need to process them. We add them back on in the form of zeros here. sequence_length_difference = total_sequence_length - returned_timesteps if sequence_length_difference > 0: zeros = stacked_sequence_output.data.new(num_layers, batch_size, sequence_length_difference, stacked_sequence_output[0].size(-1)).fill_(0) zeros = Variable(zeros) stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 2) self._update_states(final_states, restoration_indices) # Restore the original indices and return the sequence. # Has shape (num_layers, batch_size, sequence_length, hidden_size) return stacked_sequence_output.index_select(1, restoration_indices)
def _get_checklist_info(self, agenda: torch.LongTensor, all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Takes an agenda and a list of all actions and returns a target checklist against which the checklist at each state will be compared to compute a loss, indices of ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions are relevant for checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``, ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to ``False``, indices of all terminals that are not in the agenda will be masked. Parameters ---------- ``agenda`` : ``torch.LongTensor`` Agenda of one instance of size ``(agenda_size, 1)``. ``all_actions`` : ``List[ProductionRuleArray]`` All actions for one instance. """ terminal_indices = [] target_checklist_list = [] agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()]) for index, action in enumerate(all_actions): # Each action is a ProductionRuleArray, a tuple where the first item is the production # rule string. if action[0] in self._terminal_productions: terminal_indices.append([index]) if index in agenda_indices_set: target_checklist_list.append([1]) else: target_checklist_list.append([0]) # We want to return checklist target and terminal actions that are column vectors to make # computing softmax over the difference between checklist and target easier. # (num_terminals, 1) terminal_actions = agenda.new_tensor(terminal_indices) # (num_terminals, 1) target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float) if self._penalize_non_agenda_actions: # All terminal actions are relevant checklist_mask = torch.ones_like(target_checklist) else: checklist_mask = (target_checklist != 0).float() return target_checklist, terminal_actions, checklist_mask
def batched_index_select(target: torch.Tensor, indices: torch.LongTensor, flattened_indices: Optional[torch.LongTensor] = None) -> torch.Tensor: """ The given ``indices`` of size ``(batch_size, d_1, ..., d_n)`` indexes into the sequence dimension (dimension 2) of the target, which has size ``(batch_size, sequence_length, embedding_size)``. This function returns selected values in the target with respect to the provided indices, which have size ``(batch_size, d_1, ..., d_n, embedding_size)``. This can use the optionally precomputed :func:`~flattened_indices` with size ``(batch_size * d_1 * ... * d_n)`` if given. An example use case of this function is looking up the start and end indices of spans in a sequence tensor. This is used in the :class:`~allennlp.models.coreference_resolution.CoreferenceResolver`. Model to select contextual word representations corresponding to the start and end indices of mentions. The key reason this can't be done with basic torch functions is that we want to be able to use look-up tensors with an arbitrary number of dimensions (for example, in the coref model, we don't know a-priori how many spans we are looking up). Parameters ---------- target : ``torch.Tensor``, required. A 3 dimensional tensor of shape (batch_size, sequence_length, embedding_size). This is the tensor to be indexed. indices : ``torch.LongTensor`` A tensor of shape (batch_size, ...), where each element is an index into the ``sequence_length`` dimension of the ``target`` tensor. flattened_indices : Optional[torch.Tensor], optional (default = None) An optional tensor representing the result of calling :func:~`flatten_and_batch_shift_indices` on ``indices``. This is helpful in the case that the indices can be flattened once and cached for many batch lookups. Returns ------- selected_targets : ``torch.Tensor`` A tensor with shape [indices.size(), target.size(-1)] representing the embedded indices extracted from the batch flattened target tensor. """ if flattened_indices is None: # Shape: (batch_size * d_1 * ... * d_n) flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1)) # Shape: (batch_size * sequence_length, embedding_size) flattened_target = target.view(-1, target.size(-1)) # Shape: (batch_size * d_1 * ... * d_n, embedding_size) flattened_selected = flattened_target.index_select(0, flattened_indices) selected_shape = list(indices.size()) + [target.size(-1)] # Shape: (batch_size, d_1, ..., d_n, embedding_size) selected_targets = flattened_selected.view(*selected_shape) return selected_targets
def greedy_predict(self, final_encoder_output: torch.LongTensor, target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear) -> torch.Tensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the predicted sequence. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._max_decoding_steps decoder_hidden = final_encoder_output batch_size = final_encoder_output.size()[0] predictions = [final_encoder_output.new_full( (batch_size,), fill_value=self._start_index, dtype=torch.long )] for _ in range(num_decoding_steps): input_choices = predictions[-1] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, 1) predictions.append(predicted_classes) all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1) # Drop start symbol and return. return all_predictions[:, 1:]
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids) if self._scalar_mix is not None: mix = self._scalar_mix(all_encoder_layers, input_mask) else: mix = all_encoder_layers[-1] if offsets is None: return mix else: batch_size = input_ids.size(0) range_vector = util.get_range_vector(batch_size, device=util.get_device_of(mix)).unsqueeze(1) return mix[range_vector, offsets]
def _ngrams(self, tensor: torch.LongTensor, ngram_size: int) -> Dict[Tuple[int, ...], int]: ngram_counts: Dict[Tuple[int, ...], int] = Counter() if ngram_size > tensor.size(-1): return ngram_counts for start_position in range(ngram_size): for tensor_slice in tensor[start_position:].split(ngram_size, dim=-1): if tensor_slice.size(-1) < ngram_size: break ngram = tuple(x.item() for x in tensor_slice) if any(x in self._exclude_indices for x in ngram): continue ngram_counts[ngram] += 1 return ngram_counts
def _prepare_decode_step_input(self, input_indices: torch.LongTensor, decoder_hidden_state: torch.LongTensor = None, encoder_outputs: torch.LongTensor = None, encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor: """ Given the input indices for the current timestep of the decoder, and all the encoder outputs, compute the input at the current timestep. Note: This method is agnostic to whether the indices are gold indices or the predictions made by the decoder at the last timestep. So, this can be used even if we're doing some kind of scheduled sampling. If we're not using attention, the output of this method is just an embedding of the input indices. If we are, the output will be a concatentation of the embedding and an attended average of the encoder inputs. Parameters ---------- input_indices : torch.LongTensor Indices of either the gold inputs to the decoder or the predicted labels from the previous timestep. decoder_hidden_state : torch.LongTensor, optional (not needed if no attention) Output of from the decoder at the last time step. Needed only if using attention. encoder_outputs : torch.LongTensor, optional (not needed if no attention) Encoder outputs from all time steps. Needed only if using attention. encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention) Masks on encoder outputs. Needed only if using attention. """ # input_indices : (batch_size,) since we are processing these one timestep at a time. # (batch_size, target_embedding_dim) embedded_input = self._target_embedder(input_indices) if self._attention_function: # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim) # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will # complain. encoder_outputs_mask = encoder_outputs_mask.float() # (batch_size, input_sequence_length) input_weights = self._decoder_attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask) # (batch_size, encoder_output_dim) attended_input = weighted_sum(encoder_outputs, input_weights) # (batch_size, encoder_output_dim + target_embedding_dim) return torch.cat((attended_input, embedded_input), -1) else: return embedded_input
def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], tags: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, # pylint: disable=unused-argument **kwargs) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : ``Dict[str, torch.LongTensor]``, required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. tags : ``torch.LongTensor``, optional (default = ``None``) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)``. metadata : ``List[Dict[str, Any]]``, optional, (default = None) metadata containg the original words in the sentence to be tagged under a 'words' key. Returns ------- An output dictionary consisting of: logits : ``torch.FloatTensor`` The logits that are the output of the ``tag_projection_layer`` mask : ``torch.LongTensor`` The text field mask for the input tokens tags : ``List[List[int]]`` The predicted tags using the Viterbi algorithm. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. Only computed if gold label ``tags`` are provided. """ embedded_text_input = self.text_field_embedder(tokens) mask = util.get_text_field_mask(tokens) if self.dropout: embedded_text_input = self.dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) if self.dropout: encoded_text = self.dropout(encoded_text) if self._feedforward is not None: encoded_text = self._feedforward(encoded_text) logits = torch.nn.functional.log_softmax( self.tag_projection_layer(encoded_text), dim=-1) predicted_tags = torch.argmax(logits, -1) # # Just get the tags and ignore the score. output = {"logits": logits, "mask": mask, "tags": predicted_tags} if tags is not None: # Add negative log-likelihood as loss active_loss = mask.view(-1) == 1 active_logits = logits.view(-1, self.num_tags)[active_loss] active_tags = tags.view(-1)[active_loss] loss = self._loss(active_logits, active_tags) output["loss"] = loss # Represent viterbi tags as "class probabilities" that we can # feed into the metrics class_probabilities = logits * 0. for i, instance_tags in enumerate(predicted_tags): for j, tag_id in enumerate(instance_tags): class_probabilities[i, j, tag_id] = 1 for key, metric in self.metrics.items(): if 'accuracy' in key: metric(logits.view(-1, self.num_tags), tags.view(-1), mask.view(-1) == 1) if self.calculate_span_f1: tags_ = [[ self.vocab.get_token_from_index( tag.data.item(), namespace=self.label_namespace) for tag in instance_tags ] for instance_tags in tags] predicted_tags_ = [[ self.vocab.get_token_from_index( tag.data.item(), namespace=self.label_namespace) for tag in instance_tags ] for instance_tags in predicted_tags] prec, recall, f1 = evaluate( [x for y in tags_ for x in y], [x for y in predicted_tags_ for x in y]) self.metrics['precision-measure-overall'](prec) self.metrics['recall-measure-overall'](recall) self.metrics['f1-measure-overall'](f1) if metadata is not None: output["words"] = [x["words"] for x in metadata] return output
def forward(self, sequence_tensor: torch.FloatTensor, span_starts: torch.LongTensor, span_ends: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> None: sequence_tensor = sequence_tensor.unsqueeze(0) span_starts = span_starts.unsqueeze(0) span_ends = span_ends.unsqueeze(0) if sequence_mask is not None: sequence_mask = sequence_mask.unsqueeze(0) if span_indices_mask is not None: span_indices_mask = span_indices_mask.unsqueeze(0) if span_indices_mask is not None: # It's not strictly necessary to multiply the span indices by the mask here, # but it's possible that the span representation was padded with something other # than 0 (such as -1, which would be an invalid index), so we do so anyway to # be safe. span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask if not self._use_exclusive_start_indices: start_embeddings = util.batched_index_select( sequence_tensor, span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) else: # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = ( exclusive_span_starts == -1).long().unsqueeze(-1) exclusive_span_starts = exclusive_span_starts * ( 1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any(): raise ValueError( "Adjusted span indices must lie inside the the sequence tensor, " "but found: exclusive_span_starts: {exclusive_span_starts}." ) start_embeddings = util.batched_index_select( sequence_tensor, exclusive_span_starts) end_embeddings = util.batched_index_select(sequence_tensor, span_ends) # We're using sentinels, so we need to replace all the elements which were # outside the dimensions of the sequence_tensor with the start sentinel. float_start_sentinel_mask = start_sentinel_mask.float() start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel combined_tensors = util.combine_tensors( self._combination, [start_embeddings, end_embeddings]) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([combined_tensors, span_width_embeddings], -1).squeeze(0) if span_indices_mask is not None: return (combined_tensors * span_indices_mask.unsqueeze(-1).float()).squeeze(0) return combined_tensors.squeeeze(0)
def train_model(model, loss_fn, lr=0.001, batch_size=64, n_epochs=5): min_loss = float('inf') for i, (train_idx, val_idx) in enumerate(idx_splits): train_ds = TorchtextSubset(kfold_train_tabular_dataset, train_idx) val_ds = TorchtextSubset(kfold_train_tabular_dataset, val_idx) train_loader, val_loader = torchtext.data.BucketIterator.splits( [train_ds, val_ds], batch_sizes=[batch_size, batch_size], device=device, sort_key=lambda x: len(x.comment_text), sort_within_batch=True, repeat=False) print('Fold::::::::::', i) param_lrs = [{ 'params': param, 'lr': lr } for param in model.parameters()] optimizer = torch.optim.Adam(param_lrs, lr=lr) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.6**epoch) step = 0 for epoch in range(n_epochs): train_loader.init_epoch() start_time = time.time() epoch_loss, epoch_acc = 0, 0 for i, data in enumerate(train_loader): step += 1 optimizer.zero_grad() data_len_comment_text = LongTensor( list(map(len, data.comment_text))) question = data.comment_text question = question.to(device) my_target = data.our_target.to(device).float() x_batch = question y_batch = my_target y_pred = model(x_batch, data_len_comment_text).squeeze(1) loss = loss_fn(y_pred, y_batch) acc = binary_accuracy(y_pred, y_batch) loss.backward() optimizer.step() scheduler.step() epoch_loss += loss.item() epoch_acc += acc.item() if step % len(train_loader) == 0: val_acc, val_loss = get_val_score(model, val_loader, loss_fn) print("current val_loss", val_loss, "last min_loss", min_loss) if val_loss < min_loss: save(m=model, info={ 'epoch': epoch, 'val_loss': val_loss }) min_loss = val_loss print('val_acc', val_acc, 'val_loss', val_loss, 'train_acc', epoch_acc / len(train_loader), 'train_loss', epoch_loss / len(train_loader)) elapsed_time = time.time() - start_time print( 'Epoch {}/{} \t loss={:.4f} \t accouracy={} \t time={:.2f}s '. format(epoch + 1, n_epochs, epoch_loss / len(train_loader), epoch_acc / len(train_loader), elapsed_time)) return model
def _get_valid_tokens_mask(self, tensor: torch.LongTensor) -> torch.ByteTensor: valid_tokens_mask = torch.ones(tensor.size(), dtype=torch.uint8) for index in self._exclude_indices: valid_tokens_mask = valid_tokens_mask & (tensor != index) return valid_tokens_mask
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], mask_positions: torch.LongTensor, target_ids: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: """ Parameters ---------- tokens : ``Dict[str, torch.LongTensor]`` The output of ``TextField.as_tensor()`` for a batch of sentences. mask_positions : ``torch.LongTensor`` The positions in ``tokens`` that correspond to [MASK] tokens that we should try to fill in. Shape should be (batch_size, num_masks). target_ids : ``Dict[str, torch.LongTensor]`` This is a list of token ids that correspond to the mask positions we're trying to fill. It is the output of a ``TextField``, purely for convenience, so we can handle wordpiece tokenizers and such without having to do crazy things in the dataset reader. We assume that there is exactly one entry in the dictionary, and that it has a shape identical to ``mask_positions`` - one target token per mask position. """ # pylint: disable=arguments-differ targets = None if target_ids is not None: # A bit of a hack to get the right targets out of the TextField output... if len(target_ids) != 1: targets = target_ids['bert'] else: targets = list(target_ids.values())[0] mask_positions = mask_positions.squeeze(-1) batch_size, num_masks = mask_positions.size() if targets is not None and targets.size() != mask_positions.size(): raise ValueError(f"Number of targets ({targets.size()}) and number of masks " f"({mask_positions.size()}) are not equal") # Shape: (batch_size, num_tokens, embedding_dim) embeddings = self._text_field_embedder(tokens) # Shape: (batch_size, num_tokens, encoding_dim) if self._contextualizer: mask = util.get_text_field_mask(embeddings) contextual_embeddings = self._contextualizer(embeddings, mask) else: contextual_embeddings = embeddings # Does advanced indexing to get the embeddings of just the mask positions, which is what # we're trying to predict. batch_index = torch.arange(0, batch_size).long().unsqueeze(1) mask_embeddings = contextual_embeddings[batch_index, mask_positions] target_logits = self._language_model_head(self._dropout(mask_embeddings)) vocab_size = target_logits.size(-1) probs = torch.nn.functional.softmax(target_logits, dim=-1) k = min(vocab_size, 5) # min here largely because tests use small vocab top_probs, top_indices = probs.topk(k=k, dim=-1) output_dict = {"probabilities": top_probs, "top_indices": top_indices} # Using the namespace here is a hack... output_dict["token_ids"] = tokens[self._target_namespace] if targets is not None: target_logits = target_logits.view(batch_size * num_masks, vocab_size) targets = targets.view(batch_size * num_masks) loss = torch.nn.functional.cross_entropy(target_logits, targets) self._perplexity(loss) output_dict['loss'] = loss return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, target_action_sequences: torch.LongTensor = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # pylint: disable=unused-argument """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRuleArray]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRuleArray`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]``, optional (default=None) The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. target_action_sequences : torch.Tensor, optional (default=None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select( encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask( { 'ignored': neighbor_indices + 1 }, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed( BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector( world, num_entities, encoded_table) entity_type_embeddings = self._type_params(entity_types.float()) projected_neighbor_embeddings = self._neighbor_params( embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.nn.functional.tanh( entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word cosine similarity. Need to add a small value to # to the table norm since there are padding values which cause a divide by 0. embedded_table = embedded_table / ( embedded_table.norm(dim=-1, keepdim=True) + 1e-13) embedded_question = embedded_question / ( embedded_question.norm(dim=-1, keepdim=True) + 1e-13) question_entity_similarity = torch.bmm( embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view( batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max( question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select( question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max( question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze( -1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities( world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout( self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states( encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = Variable( encoder_outputs.data.new(batch_size, self._encoder.get_output_dim()).fill_(0)) initial_score = Variable( embedded_question.data.new(batch_size).fill_(0)) action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions( actions) _, num_entities, num_question_tokens = linking_scores.size() flattened_linking_scores, actions_to_entities = self._map_entity_productions( linking_scores, world, actions) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append( RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [ self._create_grammar_state(world[i], actions[i]) for i in range(batch_size) ] initial_state = WikiTablesDecoderState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, output_action_embeddings=output_action_embeddings, action_biases=action_biases, action_indices=action_indices, possible_actions=actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_type_dict, debug_info=None) if self.training: return self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask)) else: action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs: Dict[str, Any] = {'action_mapping': action_mapping} if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search( num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] outputs['linking_scores'] = linking_scores if self._linking_params is not None: outputs['feature_scores'] = feature_scores outputs['similarity_scores'] = question_entity_similarity_max_score outputs['logical_form'] = [] for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][ 0].action_history[0] if target_action_sequences is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = target_action_sequences[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match( best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) action_strings = [ action_mapping[(i, action_index)] for action_index in best_action_indices ] try: self._has_logical_form(1.0) logical_form = world[i].get_logical_form( action_strings, add_var_function=False) except ParsingError: self._has_logical_form(0.0) logical_form = 'Error producing logical form' if example_lisp_string: self._denotation_accuracy(logical_form, example_lisp_string[i]) outputs['best_action_sequence'].append(action_strings) outputs['logical_form'].append(logical_form) outputs['debug_info'].append( best_final_states[i][0].debug_info[0]) # type: ignore outputs['entities'].append(world[i].table_graph.entities) else: outputs['logical_form'].append('') self._has_logical_form(0.0) if example_lisp_string: self._denotation_accuracy(None, example_lisp_string[i]) return outputs
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], verb_span: torch.LongTensor, entity_span: torch.LongTensor, state_change_type_labels: torch.LongTensor = None, state_change_tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. verb_span: torch.LongTensor, required. An integer ``SequenceLabelField`` representation of the position of the focus verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that pre-processing stage could not extract a verbal predicate. entity_span: torch.LongTensor, required. An integer ``SequenceLabelField`` representation of the position of the focus entity in the sentence. This should have shape (batch_size, num_tokens) state_change_type_labels: torch.LongTensor, optional (default = None) A torch tensor representing the state change type class labels of shape ``(batch_size, 1)??? state_change_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)`` In the first implementation we focus only on state_change_types. Returns ------- An output dictionary consisting of: type_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_state_change_types)`` representing a distribution of state change types per datapoint. tags_class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_state_change_types, num_tokens)`` representing a distribution of location tags per token in a sentence. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ # Layer 1 = Word + Character embedding layer embedded_sentence = self.text_field_embedder(tokens) mask = get_text_field_mask(tokens).float() # Layer 2 = Add positional bit to encode position of focus verb and entity embedded_sentence_verb_entity = \ torch.cat([embedded_sentence, verb_span.float().unsqueeze(-1), entity_span.float().unsqueeze(-1)], dim=-1) # Layer 3 = Contextual embedding layer using Bi-LSTM over the sentence contextual_embedding = self.seq2seq_encoder(embedded_sentence_verb_entity, mask) # Layer 4: Attention (Contextual embedding, BOW(verb span)) verb_weight_matrix = verb_span.float() / (verb_span.float().sum(-1).unsqueeze(-1) + 1e-13) verb_vector = weighted_sum(contextual_embedding * verb_span.float().unsqueeze(-1), verb_weight_matrix) entity_weight_matrix = entity_span.float() / (entity_span.float().sum(-1).unsqueeze(-1) + 1e-13) entity_vector = weighted_sum(contextual_embedding * entity_span.float().unsqueeze(-1), entity_weight_matrix) verb_entity_vector = torch.cat([verb_vector, entity_vector], 1) batch_size, sequence_length, binary_feature_dim = verb_span.float().unsqueeze(-1).size() # attention weights for type prediction attention_weights_types = self.attention_layer(verb_entity_vector, contextual_embedding) attention_output_vector = weighted_sum(contextual_embedding, attention_weights_types) # contextual embedding + positional vectors for tag prediction context_positional_tags = torch.cat([contextual_embedding, verb_span.float().unsqueeze(-1), entity_span.float().unsqueeze(-1)], dim=-1) # Layer 5 = Dense softmax layer to pick one state change type per datapoint, # and one tag per word in the sentence type_logits = self.aggregate_feedforward(attention_output_vector) type_probs = torch.nn.functional.softmax(type_logits, dim=-1) tags_logits = self.tag_projection_layer(context_positional_tags) reshaped_log_probs = tags_logits.view(-1, self.num_tags) tags_class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view([batch_size, sequence_length, self.num_tags]) # Create output dictionary for the trainer # Compute loss and epoch metrics output_dict = {'type_probs': type_probs} if state_change_type_labels is not None: state_change_type_labels_loss = self._loss(type_logits, state_change_type_labels.long().view(-1)) for type_label in self.type_labels_vocab.values(): metric = self.type_f1_metrics["type_" + type_label] metric(type_probs, state_change_type_labels.squeeze(-1)) self._type_accuracy(type_probs, state_change_type_labels.squeeze(-1)) if state_change_tags is not None: state_change_tags_loss = sequence_cross_entropy_with_logits(tags_logits, state_change_tags, mask) self.span_metric(tags_class_probabilities, state_change_tags, mask) output_dict["tags_class_probabilities"] = tags_class_probabilities output_dict['loss'] = (state_change_type_labels_loss + state_change_tags_loss) return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.LongTensor = None, span_end: torch.LongTensor = None, spans=None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ embedded_question = self._highway_layer( self._text_field_embedder(question)) # Shape: (batch_size, 4, passage_length, embedding_dim) embedded_passage = self._text_field_embedder(passage) (batch_size, q_length, embedding_dim) = embedded_question.size() passage_length = embedded_passage.size(2) # reshape: (batch_size*4, -1, embedding_dim) embedded_passage = embedded_passage.view(-1, passage_length, embedding_dim) embedded_passage = self._highway_layer(embedded_passage) embedded_question = embedded_question.unsqueeze(0).expand( 4, -1, -1, -1).contiguous().view(-1, q_length, embedding_dim) question_mask = util.get_text_field_mask(question).float() question_mask = question_mask.unsqueeze(0).expand( 4, -1, -1).contiguous().view(-1, q_length) passage_mask = util.get_text_field_mask(passage, 1).float() passage_mask = passage_mask.view(-1, passage_length) question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) cuda_device = encoded_question.get_device() # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) question_attended_passage = relu( self._linear_layer(final_merged_passage)) # TODO: attach residual self-attention layer # Shape: (batch_size, passage_length, encoding_dim) residual_passage = self._dropout( self._residual_encoder(self._dropout(question_attended_passage), passage_lstm_mask)) mask = passage_mask.resize(batch_size, passage_length, 1) * passage_mask.resize( batch_size, 1, passage_length) self_mask = Variable( torch.eye(passage_length, passage_length).cuda(cuda_device)).resize( 1, passage_length, passage_length) mask = mask * (1 - self_mask) # Shape: (batch_size, passage_length, passage_length) x_similarity = torch.matmul(residual_passage, self._w_x).unsqueeze(2) y_similarity = torch.matmul(residual_passage, self._w_y).unsqueeze(1) dot_similarity = torch.bmm(residual_passage * self._w_xy, residual_passage.transpose(1, 2)) passage_self_similarity = dot_similarity + x_similarity + y_similarity #for i in range(passage_length): # passage_self_similarity[:, i, i] = float('-Inf') # Shape: (batch_size, passage_length, passage_length) passage_self_attention = util.last_dim_softmax(passage_self_similarity, mask) # Shape: (batch_size, passage_length, encoding_dim) passage_vectors = util.weighted_sum(residual_passage, passage_self_attention) # Shape: (batch_size, passage_length, encoding_dim * 3) merged_passage = torch.cat([ residual_passage, passage_vectors, residual_passage * passage_vectors ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) self_attended_passage = relu( self._residual_linear_layer(merged_passage)) # Shape: (batch_size, passage_length, encoding_dim) mixed_passage = question_attended_passage + self_attended_passage # Shape: (batch_size, passage_length, encoding_dim) encoded_span_start = self._dropout( self._span_start_encoder(mixed_passage, passage_lstm_mask)) span_start_logits = self._span_start_predictor( encoded_span_start).squeeze(-1) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, passage_length, encoding_dim * 2) concatenated_passage = torch.cat([mixed_passage, encoded_span_start], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(concatenated_passage, passage_lstm_mask)) span_end_logits = self._span_end_predictor(encoded_span_end).squeeze( -1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) # Shape: (batch_size, encoding_dim) v_1 = util.weighted_sum(encoded_span_start, span_start_probs) v_2 = util.weighted_sum(encoded_span_end, span_end_probs) no_span_logits = self._no_answer_predictor( self_attended_passage).squeeze(-1) no_span_probs = util.masked_softmax(no_span_logits, passage_mask) v_3 = util.weighted_sum(self_attended_passage, no_span_probs) # Shape: (batch_size, 1) z_score = self._feed_forward(torch.cat([v_1, v_2, v_3], dim=-1)) # compute no-answer score span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # create target tensor including no-answer label span_target = Variable(torch.ones(batch_size).long()).cuda(cuda_device) for b in range(batch_size): span_target[b].data[0] = span_start[ b, 0].data[0] * passage_length + span_end[b, 0].data[0] span_target[span_target < 0] = passage_length**2 # Shape: (batch_size, passage_length, passage_length) span_start_logits_tiled = span_start_logits.unsqueeze(1).expand( batch_size, passage_length, passage_length) span_end_logits_tiled = span_end_logits.unsqueeze(-1).expand( batch_size, passage_length, passage_length) span_logits = (span_start_logits_tiled + span_end_logits_tiled).view( batch_size, -1) answer_mask = torch.bmm(passage_mask.unsqueeze(-1), passage_mask.unsqueeze(1)).view( batch_size, -1) no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device) combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1) all_logits = torch.cat([span_logits, z_score], dim=-1) loss = nll_loss(util.masked_log_softmax(all_logits, combined_mask), span_target) output_dict["loss"] = loss # Shape(batch_size, max_answers, num_span) # max_answers = spans.size(1) # span_logits = torch.bmm(span_start_logits.unsqueeze(-1), span_end_logits.unsqueeze(1)).view(batch_size, -1) # answer_mask = torch.bmm(passage_mask.unsqueeze(-1), passage_mask.unsqueeze(1)).view(batch_size, -1) # no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device) # combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1) # # Shape: (batch_size, passage_length**2 + 1) # all_logits = torch.cat([span_logits, z_score], dim=-1) # # Shape: (batch_size, max_answers) # spans_combined = spans[:, :, 0] * passage_length + spans[:, :, 1] # spans_combined[spans_combined < 0] = passage_length*passage_length # # all_modified_logits = [] # for b in range(batch_size): # idxs = Variable(torch.LongTensor(range(passage_length**2 + 1))).cuda(cuda_device) # for i in range(max_answers): # idxs[spans_combined[b, i].data[0]].data = idxs[spans_combined[b, 0].data[0]].data # idxs[passage_length**2].data[0] = passage_length**2 # modified_logits = Variable(torch.zeros(all_logits.size(-1))).cuda(cuda_device) # modified_logits.index_add_(0, idxs, all_logits[b]) # all_modified_logits.append(modified_logits) # all_modified_logits = torch.stack(all_modified_logits, dim=0) # loss = nll_loss(util.masked_log_softmax(all_modified_logits, combined_mask), spans_combined[:, 0]) # output_dict["loss"] = loss if span_start is not None: self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].data.cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def compute_class_AP(model, dl, n_classes, show_progress, iou_thresh=0.1, detect_thresh=0.5, num_keep=100): tps, clas, p_scores = [], [], [] classes, n_gts = LongTensor( range(n_classes)), torch.zeros(n_classes).long() model.learn.model.eval() with torch.no_grad(): for input, target in progress_bar(dl, display=show_progress): # input - 4(batch-size),3,256,256 # target - 2(regression,classification), 4(batch-size), 3/4/2(max no of detections in the batch), 4/1(bbox,class) output = model.learn.pred_batch(batch=(input, target)) for i in range(target[0].size(0)): # range batch-size #output[0] - classpreds, output[1] - bbox preds op = model._data.y.analyze_pred((output[0][i], output[1][i]), thresh=detect_thresh, nms_overlap=iou_thresh, ssd=model, ret_scores=True, device=model._device) #op - bbox preds, class preds, scores # Unpad the targets tgt_bbox, tgt_clas = _get_y(target[0][i], target[1][i]) try: bbox_pred, preds, scores = op if len(bbox_pred) != 0 and len(tgt_bbox) != 0: bbox_pred = bbox_pred.to(model._device) preds = preds.to(model._device) tgt_bbox = tgt_bbox.to(model._device) # Convert the bbox coordinates to center-height-width(cthw) before calculating Intersection Over Union ious = IoU_values(tlbr2cthw(bbox_pred), tlbr2cthw(tgt_bbox)) max_iou, matches = ious.max(1) detected = [] for i in range(len(preds)): if max_iou[i] >= iou_thresh and matches[ i] not in detected and tgt_clas[ matches[i]] == preds[i]: detected.append(matches[i]) tps.append(1) else: tps.append(0) clas.append(preds.cpu()) p_scores.append(scores.cpu()) except: pass n_gts += ((tgt_clas.cpu()[:, None] - 1) == classes[None, :]).sum(0) # If no true positives are found return an average precision score of 0. if len(tps) == 0: return [0. for cls in range(1, n_classes + 1)] tps, p_scores, clas = torch.tensor(tps), torch.cat(p_scores, 0), torch.cat(clas, 0) fps = 1 - tps idx = p_scores.argsort(descending=True) tps, fps, clas = tps[idx], fps[idx], clas[idx] aps = [] for cls in range(1, n_classes + 1): tps_cls, fps_cls = tps[clas == cls].float().cumsum(0), fps[ clas == cls].float().cumsum(0) if tps_cls.numel() != 0 and tps_cls[-1] != 0: precision = tps_cls / (tps_cls + fps_cls + 1e-8) recall = tps_cls / (n_gts[cls - 1] + 1e-8) aps.append(compute_ap(precision, recall)) else: aps.append(0.) return aps
def get_long_tensor(np_tensor): if torch.cuda.is_available(): return LongTensor(from_numpy(np_tensor)).cuda() else: return LongTensor(from_numpy(np_tensor))
def encode_class(idxs, n_classes): target = idxs.new_zeros(len(idxs), n_classes).float() mask = idxs != 0 i1s = LongTensor(list(range(len(idxs)))) target[i1s[mask], idxs[mask] - 1] = 1 return target
def _dynamic_rnn_loop(cell: RNNCellBase[State], inputs: torch.Tensor, initial_state: State, sequence_length: torch.LongTensor) \ -> Tuple[torch.Tensor, State]: r"""Internal implementation of Dynamic RNN. Args: cell: An instance of RNNCell. inputs: A ``Tensor`` of shape ``[time, batch_size, input_size]``, or a nested tuple of such elements. initial_state: A ``Tensor`` of shape ``[batch_size, state_size]``, or if ``cell.state_size`` is a tuple, then this should be a tuple of tensors having shapes ``[batch_size, s]`` for ``s`` in ``cell.state_size``. sequence_length: (optional) An ``int32`` ``Tensor`` of shape ``[batch_size]``. Returns: Tuple ``(final_outputs, final_state)``. final_outputs: A ``Tensor`` of shape ``[time, batch_size, cell.output_size]``. If ``cell.output_size`` is a (possibly nested) tuple of ints or ``TensorShape`` objects, then this returns a (possibly nested) tuple of Tensors matching the corresponding shapes. final_state: A ``Tensor``, or possibly nested tuple of Tensors, matching in length and shapes to ``initial_state``. """ state = initial_state time_steps = inputs.shape[0] all_outputs = [] all_state: MaybeTuple[List[torch.Tensor]] if isinstance(state, tuple): all_state = ([], []) else: all_state = [] for i in range(time_steps): output, state = cell(inputs[i], state) all_outputs.append(output) if isinstance(state, tuple): all_state[0].append(state[0]) all_state[1].append(state[1]) else: all_state.append(state) # type: ignore # TODO: Do not compute everything regardless of sequence_length final_outputs = torch.stack(all_outputs, dim=0) final_outputs = mask_sequences(final_outputs, sequence_length=sequence_length, time_major=True) final_state: MaybeTuple[List[torch.Tensor]] if isinstance(state, tuple): final_state = ([], []) else: final_state = [] for batch_idx, time_idx in enumerate(sequence_length.tolist()): if time_idx > 0: if isinstance(state, tuple): final_state[0].append(all_state[0][time_idx - 1][batch_idx]) final_state[1].append(all_state[1][time_idx - 1][batch_idx]) else: final_state.append( # type: ignore all_state[time_idx - 1][batch_idx]) else: if isinstance(initial_state, tuple): final_state[0].append(initial_state[0][batch_idx]) final_state[1].append(initial_state[1][batch_idx]) else: final_state.append(initial_state[batch_idx]) # type: ignore if isinstance(state, tuple): final_state = (torch.stack(final_state[0], dim=0), torch.stack(final_state[1], dim=0)) else: final_state = torch.stack(final_state, dim=0) # type: ignore return final_outputs, final_state
def _get_linking_probabilities( self, worlds: List[WikiTablesVariableFreeWorld], linking_scores: torch.FloatTensor, question_mask: torch.LongTensor, entity_type_dict: Dict[int, int]) -> torch.FloatTensor: """ Produces the probability of an entity given a question word and type. The logic below separates the entities by type since the softmax normalization term sums over entities of a single type. Parameters ---------- worlds : ``List[WikiTablesWorld]`` linking_scores : ``torch.FloatTensor`` Has shape (batch_size, num_question_tokens, num_entities). question_mask: ``torch.LongTensor`` Has shape (batch_size, num_question_tokens). entity_type_dict : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. Returns ------- batch_probabilities : ``torch.FloatTensor`` Has shape ``(batch_size, num_question_tokens, num_entities)``. Contains all the probabilities for an entity given a question word. """ _, num_question_tokens, num_entities = linking_scores.size() batch_probabilities = [] for batch_index, world in enumerate(worlds): all_probabilities = [] num_entities_in_instance = 0 # NOTE: The way that we're doing this here relies on the fact that entities are # implicitly sorted by their types when we sort them by name, and that numbers come # before "date_column:", followed by "number_column:", "string:", and "string_column:". # This is not a great assumption, and could easily break later, but it should work for now. for type_index in range(self._num_entity_types): # This index of 0 is for the null entity for each type, representing the case where a # word doesn't link to any entity. entity_indices = [0] entities = world.table_graph.entities for entity_index, _ in enumerate(entities): if entity_type_dict[batch_index * num_entities + entity_index] == type_index: entity_indices.append(entity_index) if len(entity_indices) == 1: # No entities of this type; move along... continue # We're subtracting one here because of the null entity we added above. num_entities_in_instance += len(entity_indices) - 1 # We separate the scores by type, since normalization is done per type. There's an # extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_, # so we get back something of shape (num_question_tokens,) for each index we're # selecting. All of the selected indices together then make a tensor of shape # (num_question_tokens, num_entities_per_type + 1). indices = linking_scores.new_tensor(entity_indices, dtype=torch.long) entity_scores = linking_scores[batch_index].index_select( 1, indices) # We used index 0 for the null entity, so this will actually have some values in it. # But we want the null entity's score to be 0, so we set that here. entity_scores[:, 0] = 0 # No need for a mask here, as this is done per batch instance, with no padding. type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1) all_probabilities.append(type_probabilities[:, 1:]) # We need to add padding here if we don't have the right number of entities. if num_entities_in_instance != num_entities: zeros = linking_scores.new_zeros( num_question_tokens, num_entities - num_entities_in_instance) all_probabilities.append(zeros) # (num_question_tokens, num_entities) probabilities = torch.cat(all_probabilities, dim=1) batch_probabilities.append(probabilities) batch_probabilities = torch.stack(batch_probabilities, dim=0) return batch_probabilities * question_mask.unsqueeze(-1).float()
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRule]], example_lisp_string: List[str] = None, target_action_sequences: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ In this method we encode the table entities, link them to words in the question, then encode the question. Then we set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- question : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the question ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. table : ``Dict[str, torch.LongTensor]`` The output of ``KnowledgeGraphField.as_array()`` applied on the table ``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to get embeddings for each entity. world : ``List[WikiTablesWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. example_lisp_string : ``List[str]``, optional (default = None) The example (lisp-formatted) string corresponding to the given input. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE when evaluating denotation accuracy; it is otherwise unused. target_action_sequences : torch.Tensor, optional (default = None) A list of possibly valid action sequences, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, num_action_sequences, sequence_length)``. metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenized question within a 'question_tokens' key. """ outputs: Dict[str, Any] = {} rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question, table, world, actions, outputs) batch_size = len(rnn_state) initial_score = rnn_state[0].hidden_state.new_zeros(batch_size) initial_score_list = [initial_score[i] for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), # type: ignore action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=rnn_state, grammar_state=grammar_state, possible_actions=actions, extras=example_lisp_string, debug_info=None) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None if self.training: return self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) else: if target_action_sequences is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = best_final_states[i][0].action_history[0] if target_action_sequences is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = target_action_sequences[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._action_sequence_accuracy(sequence_in_targets) self._compute_validation_outputs(actions, best_final_states, world, example_lisp_string, metadata, outputs) return outputs
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # Both of shape (batch_size, sequence_length, embedding_size / 2) forward_sequence, backward_sequence = sequence_tensor.split(int(self._input_dim / 2), dim=-1) forward_sequence = forward_sequence.contiguous() backward_sequence = backward_sequence.contiguous() # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)] if span_indices_mask is not None: span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP ``SpanField`` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1) # We want `exclusive` span ends for the backward direction # (so that the `start` of the span in that direction is exlusive), so # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive. exclusive_span_ends = span_ends + 1 if sequence_mask is not None: # shape (batch_size) sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask) else: # shape (batch_size), filled with the sequence length size of the sequence_tensor. sequence_lengths = util.ones_like(sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1) # shape (batch_size, num_spans, 1) end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1) # As we added 1 to the span_ends to make them exclusive, which might have caused indices # equal to the sequence_length to become out of bounds, we multiply by the inverse of the # end_sentinel mask to erase these indices (as we will replace them anyway in the block below). # The same argument follows for the exclusive span start indices. exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1)) exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any() or (exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any(): raise ValueError(f"Adjusted span indices must lie inside the length of the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}, " f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths " f"{sequence_lengths}.") # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2) forward_start_embeddings = util.batched_index_select(forward_sequence, exclusive_span_starts) # Forward Direction: end indices are inclusive, so we can just use span_ends. # Shape (batch_size, num_spans, input_size / 2) forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends) # Backward Direction: The backward start embeddings use the `forward` end # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_start_embeddings = util.batched_index_select(backward_sequence, exclusive_span_ends) # Backward Direction: The backward end embeddings use the `forward` start # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts) if self._use_sentinels: # If we're using sentinels, we need to replace all the elements which were # outside the dimensions of the sequence_tensor with either the start sentinel, # or the end sentinel. float_end_sentinel_mask = end_sentinel_mask.float() float_start_sentinel_mask = start_sentinel_mask.float() forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \ + float_start_sentinel_mask * self._start_sentinel backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \ + float_end_sentinel_mask * self._end_sentinel # Now we combine the forward and backward spans in the manner specified by the # respective combinations and concatenate these representations. # Shape (batch_size, num_spans, forward_combination_dim) forward_spans = util.combine_tensors(self._forward_combination, [forward_start_embeddings, forward_end_embeddings]) # Shape (batch_size, num_spans, backward_combination_dim) backward_spans = util.combine_tensors(self._backward_combination, [backward_start_embeddings, backward_end_embeddings]) # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim) span_embeddings = torch.cat([forward_spans, backward_spans], -1) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values(span_ends - span_starts, num_total_buckets=self._num_width_embeddings) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([span_embeddings, span_width_embeddings], -1) if span_indices_mask is not None: return span_embeddings * span_indices_mask.float().unsqueeze(-1) return span_embeddings
def construct_trees(self, predictions: torch.FloatTensor, all_spans: torch.LongTensor, num_spans: torch.LongTensor, sentences: List[List[str]], pos_tags: List[List[str]] = None) -> List[Tree]: """ Construct ``nltk.Tree``'s for each batch element by greedily nesting spans. The trees use exclusive end indices, which contrasts with how spans are represented in the rest of the model. Parameters ---------- predictions : ``torch.FloatTensor``, required. A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)`` representing a distribution over the label classes per span. all_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the span indices we scored. num_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size), representing the lengths of non-padded spans in ``enumerated_spans``. sentences : ``List[List[str]]``, required. A list of tokens in the sentence for each element in the batch. pos_tags : ``List[List[str]]``, optional (default = None). A list of POS tags for each word in the sentence for each element in the batch. Returns ------- A ``List[Tree]`` containing the decoded trees for each element in the batch. """ # Switch to using exclusive end spans. exclusive_end_spans = all_spans.clone() exclusive_end_spans[:, :, -1] += 1 no_label_id = self.vocab.get_token_index("NO-LABEL", "labels") trees: List[Tree] = [] for batch_index, (scored_spans, spans, sentence) in enumerate(zip(predictions, exclusive_end_spans, sentences)): selected_spans = [] for prediction, span in zip(scored_spans[:num_spans[batch_index]], spans[:num_spans[batch_index]]): start, end = span no_label_prob = prediction[no_label_id] label_prob, label_index = torch.max(prediction, -1) # Does the span have a label != NO-LABEL or is it the root node? # If so, include it in the spans that we consider. if int(label_index) != no_label_id or (start == 0 and end == len(sentence)): # TODO(Mark): Remove this once pylint sorts out named tuples. # https://github.com/PyCQA/pylint/issues/1418 selected_spans.append(SpanInformation(start=int(start), # pylint: disable=no-value-for-parameter end=int(end), label_prob=float(label_prob), no_label_prob=float(no_label_prob), label_index=int(label_index))) # The spans we've selected might overlap, which causes problems when we try # to construct the tree as they won't nest properly. consistent_spans = self.resolve_overlap_conflicts_greedily(selected_spans) spans_to_labels = {(span.start, span.end): self.vocab.get_token_from_index(span.label_index, "labels") for span in consistent_spans} sentence_pos = pos_tags[batch_index] if pos_tags is not None else None trees.append(self.construct_tree_from_spans(spans_to_labels, sentence, sentence_pos)) return trees
def forward(self, # type: ignore start_tokens: torch.LongTensor, memory: Optional[State] = None, cache_len: int = 512, max_decoding_length: Optional[int] = 500, recompute_memory: bool = True, print_steps: bool = False, helper_type: Optional[Union[str, Type[Helper]]] = None, **helper_kwargs) \ -> Tuple[Output, Optional[State]]: r"""Perform autoregressive decoding using XLNet. The algorithm is largely inspired by: https://github.com/rusiaaman/XLNet-gen. Args: start_tokens: A LongTensor of shape `[batch_size, prompt_len]`, representing the tokenized initial prompt. memory (optional): The initial memory. cache_len: Length of memory (number of tokens) to cache. max_decoding_length (int): Maximum number of tokens to decode. recompute_memory (bool): If `True`, the entire memory is recomputed for each token to generate. This leads to better performance because it enables every generated token to attend to each other, compared to reusing previous memory which is equivalent to using a causal attention mask. However, it is computationally more expensive. Defaults to `True`. print_steps (bool): If `True`, will print decoding progress. helper: Type (or name of the type) of any sub-class of :class:`~texar.modules.decoders.Helper`. helper_kwargs: The keyword arguments to pass to constructor of the specific helper type. :returns: A tuple of `(output, new_memory)`: - **`output`**: The sampled tokens as a list of integers. - **`new_memory`**: The memory of the sampled tokens. """ start_tokens = start_tokens.t() self._state_recompute_memory = recompute_memory self._state_cache_len = cache_len self._state_previous_inputs = list( self.word_embed(start_tokens).unbind(dim=0))[:-1] if helper_type is None: helper_type = SampleEmbeddingHelper if not recompute_memory and start_tokens.size(0) > 1: _, memory = self._forward(memory=memory, cache_len=cache_len, **self._create_input( self._state_previous_inputs, initial=True)) start_tokens = start_tokens[-1] helper_kwargs.update(start_tokens=start_tokens) if helper_kwargs.get("end_token") is None: raise ValueError("'end_token' must be specified.") helper = get_instance( helper_type, helper_kwargs, module_paths=['texar.modules.decoders.decoder_helpers']) step_hook = None if print_steps: step_hook = lambda step: print(f"\033[2K\rDecoding step: {step}", end='') output, new_memory, _ = self.dynamic_decode( helper, inputs=None, sequence_length=None, initial_state=memory, max_decoding_length=max_decoding_length, step_hook=step_hook) if print_steps: print("\033[2K\r", end='') return output, new_memory
def forward( self, src_tokens: torch.LongTensor, src_lengths: torch.LongTensor, return_encoder_out: bool = False, return_encoder_padding_mask: bool = False, ) -> EncoderOuts: """Encode a batch of sequences Arguments: src_tokens {torch.LongTensor} -- [batch_size, seq_len] src_lengths {torch.LongTensor} -- [batch_size] Keyword Arguments: return_encoder_out {bool} -- Return output tensors? (default: {False}) return_encoder_padding_mask {bool} -- Return encoder padding mask? (default: {False}) Returns: [type] -- [description] """ bsz, seqlen = src_tokens.size() x = self.embed_tokens(src_tokens) x = x.transpose(0, 1) # BTC -> TBC # Pack then apply LSTM packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=False, enforce_sorted=True) packed_outs, (final_hiddens, final_cells) = \ self.lstm.forward(packed_x) x, _ = nn.utils.rnn.pad_packed_sequence( packed_outs, padding_value=self.padding_value) assert list(x.size()) == [seqlen, bsz, self.output_units] # Set padded outputs to -inf so they are not selected by max-pooling padding_mask = src_tokens.eq(self.padding_idx).t() if padding_mask.any(): x = x.float().masked_fill_( mask=padding_mask.unsqueeze(-1), value=float('-inf'), ).type_as(x) # Build the sentence embedding by max-pooling over the encoder outputs sentemb = x.max(dim=0)[0] encoder_out = None if return_encoder_out: final_hiddens = self._combine_outs(final_hiddens) final_cells = self._combine_outs(final_cells) encoder_out = (x, final_hiddens, final_cells) encoder_padding_mask = None if return_encoder_padding_mask: encoder_padding_mask = src_tokens.eq(self.padding_idx).t() return EncoderOuts(sentemb=sentemb, encoder_out=encoder_out, encoder_padding_mask=encoder_padding_mask)
def sequence_cross_entropy_with_logits( logits: torch.FloatTensor, targets: torch.LongTensor, weights: torch.FloatTensor, average: str = "batch", label_smoothing: float = None, gamma: float = None, alpha: Union[float, List[float], torch.FloatTensor] = None, ) -> torch.FloatTensor: """ Computes the cross entropy loss of a sequence, weighted with respect to some user provided weights. Note that the weighting here is not the same as in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting classes; here we are weighting the loss contribution from particular elements in the sequence. This allows loss computations for models which use padding. Parameters ---------- logits : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes) which contains the unnormalized probability for each class. targets : ``torch.LongTensor``, required. A ``torch.LongTensor`` of size (batch, sequence_length) which contains the index of the true class for each corresponding step weights : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch, sequence_length) average: str, optional (default = "batch") If "batch", average the loss across the batches. If "token", average the loss across each item in the input. If ``None``, return a vector of losses per batch element. label_smoothing : ``float``, optional (default = None) Whether or not to apply label smoothing to the cross-entropy loss. For example, with a label smoothing value of 0.2, a 4 class classification target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was the correct label. gamma : ``float``, optional (default = None) Focal loss[*] focusing parameter ``gamma`` to reduces the relative loss for well-classified examples and put more focus on hard. The greater value ``gamma`` is, the more focus on hard examples. alpha : ``float`` or ``List[float]``, optional (default = None) Focal loss[*] weighting factor ``alpha`` to balance between classes. Can be used independently with ``gamma``. If a single ``float`` is provided, it is assumed binary case using ``alpha`` and ``1 - alpha`` for positive and negative respectively. If a list of ``float`` is provided, with the same length as the number of classes, the weights will match the classes. [*] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár, "Focal Loss for Dense Object Detection," 2017 IEEE International Conference on Computer Vision (ICCV), Venice, 2017, pp. 2999-3007. Returns ------- A torch.FloatTensor representing the cross entropy loss. If ``average=="batch"`` or ``average=="token"``, the returned loss is a scalar. If ``average is None``, the returned loss is a vector of shape (batch_size,). """ if average not in {None, "token", "batch"}: raise ValueError("Got average f{average}, expected one of " "None, 'token', or 'batch'") # make sure weights are float weights = weights.float() # sum all dim except batch non_batch_dims = tuple(range(1, len(weights.shape))) # shape : (batch_size,) weights_batch_sum = weights.sum(dim=non_batch_dims) # shape : (batch * sequence_length, num_classes) logits_flat = logits.view(-1, logits.size(-1)) # shape : (batch * sequence_length, num_classes) log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1) # shape : (batch * max_len, 1) targets_flat = targets.reshape(-1, 1).long() # focal loss coefficient if gamma: # shape : (batch * sequence_length, num_classes) probs_flat = log_probs_flat.exp() # shape : (batch * sequence_length,) probs_flat = torch.gather(probs_flat, dim=1, index=targets_flat) # shape : (batch * sequence_length,) focal_factor = (1.0 - probs_flat)**gamma # shape : (batch, sequence_length) focal_factor = focal_factor.view(*targets.size()) weights = weights * focal_factor if alpha is not None: # shape : () / (num_classes,) if isinstance(alpha, (float, int)): # shape : (2,) alpha_factor = torch.tensor( [1.0 - float(alpha), float(alpha)], dtype=weights.dtype, device=weights.device) elif isinstance(alpha, (list, numpy.ndarray, torch.Tensor)): # shape : (c,) alpha_factor = torch.tensor(alpha, dtype=weights.dtype, device=weights.device) if not alpha_factor.size(): # shape : (1,) alpha_factor = alpha_factor.view(1) # shape : (2,) alpha_factor = torch.cat([1 - alpha_factor, alpha_factor]) else: raise TypeError( ("alpha must be float, list of float, or torch.FloatTensor, " "{} provided.").format(type(alpha))) # shape : (batch, max_len) alpha_factor = torch.gather( alpha_factor, dim=0, index=targets_flat.view(-1)).view(*targets.size()) weights = weights * alpha_factor if label_smoothing is not None and label_smoothing > 0.0: num_classes = logits.size(-1) smoothing_value = label_smoothing / num_classes # Fill all the correct indices with 1 - smoothing value. one_hot_targets = torch.zeros_like(log_probs_flat).scatter_( -1, targets_flat, 1.0 - label_smoothing) smoothed_targets = one_hot_targets + smoothing_value negative_log_likelihood_flat = -log_probs_flat * smoothed_targets negative_log_likelihood_flat = negative_log_likelihood_flat.sum( -1, keepdim=True) else: # Contribution to the negative log likelihood only comes from the exact indices # of the targets, as the target distributions are one-hot. Here we use torch.gather # to extract the indices of the num_classes dimension which contribute to the loss. # shape : (batch * sequence_length, 1) negative_log_likelihood_flat = -torch.gather( log_probs_flat, dim=1, index=targets_flat) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood_flat.view( *targets.size()) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood * weights if average == "batch": # shape : (batch_size,) per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / ( weights_batch_sum + 1e-13) num_non_empty_sequences = (weights_batch_sum > 0).float().sum() + 1e-13 return per_batch_loss.sum() / num_non_empty_sequences elif average == "token": return negative_log_likelihood.sum() / (weights_batch_sum.sum() + 1e-13) else: # shape : (batch_size,) per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / ( weights_batch_sum + 1e-13) return per_batch_loss
def forward(self, # type: ignore words: Dict[str, torch.LongTensor], pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required. The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. head_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. loss : ``torch.FloatTensor``, optional The loss contribution from predicting the dependency tags for the gold arcs. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). head_types : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) if pos_tags is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) loss = arc_nll + tag_nll evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) else: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) loss = arc_nll + tag_nll output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, "words": [meta["words"] for meta in metadata], "pos": [meta["pos"] for meta in metadata] } return output_dict
def triple_tensor_to_set(tensor: torch.LongTensor) -> Set[Tuple[int, ...]]: """Convert a tensor of triples to a set of int-tuples.""" return set(map(tuple, tensor.tolist()))
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor, targets: torch.LongTensor, weights: torch.FloatTensor, batch_average: bool = True, label_smoothing: float = None) -> torch.FloatTensor: """ Computes the cross entropy loss of a sequence, weighted with respect to some user provided weights. Note that the weighting here is not the same as in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting classes; here we are weighting the loss contribution from particular elements in the sequence. This allows loss computations for models which use padding. Parameters ---------- logits : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes) which contains the unnormalized probability for each class. targets : ``torch.LongTensor``, required. A ``torch.LongTensor`` of size (batch, sequence_length) which contains the index of the true class for each corresponding step. weights : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch, sequence_length) batch_average : bool, optional, (default = True). A bool indicating whether the loss should be averaged across the batch, or returned as a vector of losses per batch element. label_smoothing : ``float``, optional (default = None) Whether or not to apply label smoothing to the cross-entropy loss. For example, with a label smoothing value of 0.2, a 4 class classifcation target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was the correct label. Returns ------- A torch.FloatTensor representing the cross entropy loss. If ``batch_average == True``, the returned loss is a scalar. If ``batch_average == False``, the returned loss is a vector of shape (batch_size,). """ # shape : (batch * sequence_length, num_classes) logits_flat = logits.view(-1, logits.size(-1)) # shape : (batch * sequence_length, num_classes) log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1) # shape : (batch * max_len, 1) targets_flat = targets.view(-1, 1).long() if label_smoothing is not None and label_smoothing > 0.0: num_classes = logits.size(-1) smoothing_value = label_smoothing / num_classes # Fill all the correct indices with 1 - smoothing value. one_hot_targets = zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing) smoothed_targets = one_hot_targets + smoothing_value negative_log_likelihood_flat = - log_probs_flat * smoothed_targets negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True) else: # Contribution to the negative log likelihood only comes from the exact indices # of the targets, as the target distributions are one-hot. Here we use torch.gather # to extract the indices of the num_classes dimension which contribute to the loss. # shape : (batch * sequence_length, 1) negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size()) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood * weights.float() # shape : (batch_size,) per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13) if batch_average: num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13) return per_batch_loss.sum() / num_non_empty_sequences return per_batch_loss
def transformer_sliding_window( transformer: PreTrainedModel, input_ids: torch.LongTensor, input_mask=None, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None, max_pieces=512, start_tokens: int = 1, end_tokens: int = 1, ret_cls=None, ) -> torch.Tensor: """ Args: transformer: input_ids: torch.LongTensor: input_mask: (Default value = None) offsets: torch.LongTensor: (Default value = None) token_type_ids: torch.LongTensor: (Default value = None) max_pieces: (Default value = 512) start_tokens: int: (Default value = 1) end_tokens: int: (Default value = 1) ret_cls: (Default value = None) Returns: """ # pylint: disable=arguments-differ batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1) initial_dims = list(input_ids.shape[:-1]) # The embedder may receive an input tensor that has a sequence length longer than can # be fit. In that case, we should expect the wordpiece indexer to create padded windows # of length `max_pieces` for us, and have them concatenated into one long sequence. # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..." # We can then split the sequence into sub-sequences of that length, and concatenate them # along the batch dimension so we effectively have one huge batch of partial sentences. # This can then be fed into BERT without any sentence length issues. Keep in mind # that the memory consumption can dramatically increase for large batches with extremely # long sentences. needs_split = full_seq_len > max_pieces if needs_split: input_ids = split_to_sliding_window(input_ids, max_pieces) # if token_type_ids is None: # token_type_ids = torch.zeros_like(input_ids) if input_mask is None: input_mask = (input_ids != 0).long() # input_ids may have extra dimensions, so we reshape down to 2-d # before calling the BERT model and then reshape back at the end. outputs = transformer( input_ids=util.combine_initial_dims_to_1d_or_2d(input_ids), # token_type_ids=util.combine_initial_dims_to_1d_or_2d(token_type_ids), attention_mask=util.combine_initial_dims_to_1d_or_2d( input_mask)).to_tuple() if len(outputs) == 3: all_encoder_layers = outputs.hidden_states all_encoder_layers = torch.stack(all_encoder_layers) elif len(outputs) == 2: all_encoder_layers, _ = outputs[:2] else: all_encoder_layers = outputs[0] if needs_split: if ret_cls is not None: cls_mask = input_ids[:, 0] == input_ids[0][0] cls_hidden = all_encoder_layers[:, 0, :] if ret_cls == 'max': cls_hidden[~cls_mask] = -1e20 else: cls_hidden[~cls_mask] = 0 cls_mask = cls_mask.view(-1, batch_size).transpose(0, 1) cls_hidden = cls_hidden.reshape(cls_mask.size(1), batch_size, -1).transpose(0, 1) if ret_cls == 'max': cls_hidden = cls_hidden.max(1)[0] elif ret_cls == 'raw': return cls_hidden, cls_mask else: cls_hidden = torch.sum(cls_hidden, dim=1) cls_hidden /= torch.sum(cls_mask, dim=1, keepdim=True) return cls_hidden else: recombined_embeddings, select_indices = restore_from_sliding_window( all_encoder_layers, batch_size, max_pieces, full_seq_len, start_tokens, end_tokens) initial_dims.append(len(select_indices)) else: recombined_embeddings = all_encoder_layers # Recombine the outputs of all layers # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim) # recombined = torch.cat(combined, dim=2) # input_mask = (recombined_embeddings != 0).long() # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) if offsets is None: # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) dims = initial_dims if needs_split else input_ids.size() layers = util.uncombine_initial_dims(recombined_embeddings, dims) else: # offsets is (batch_size, d1, ..., dn, orig_sequence_length) offsets2d = util.combine_initial_dims_to_1d_or_2d(offsets) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) range_vector = util.get_range_vector( offsets2d.size(0), device=util.get_device_of(recombined_embeddings)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) selected_embeddings = recombined_embeddings[:, range_vector, offsets2d] layers = util.uncombine_initial_dims(selected_embeddings, offsets.size()) return layers
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]], action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- tokens : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. valid_actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. target_action_sequence : torch.Tensor, optional (default=None) The action sequence for the correct action sequence, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the trailing dimension. sql_queries : List[List[str]], optional (default=None) A list of the SQL queries that are given during training or validation. """ embedded_utterance = self._utterance_embedder(tokens) mask = util.get_text_field_mask(tokens).float() batch_size = embedded_utterance.size(0) # (batch_size, num_tokens, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask)) initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions) if action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). action_sequence = action_sequence.squeeze(-1) target_mask = action_sequence != self._action_padding_index else: target_mask = None outputs: Dict[str, Any] = {} if action_sequence is not None: # target_action_sequence is of shape (batch_size, 1, target_sequence_length) # here after we unsqueeze it for the MML trainer. loss_output = self._decoder_trainer.decode(initial_state, self._transition_function, (action_sequence.unsqueeze(1), target_mask.unsqueeze(1))) outputs.update(loss_output) if not self.training: action_mapping = [] for batch_actions in valid_actions: batch_action_mapping = {} for action_index, action in enumerate(batch_actions): batch_action_mapping[action_index] = action[0] action_mapping.append(batch_action_mapping) outputs['action_mapping'] = action_mapping # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(self._max_decoding_steps, initial_state, self._transition_function, keep_final_unfinished_states=True) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['predicted_sql_query'] = [] outputs['sql_queries'] = [] for i in range(batch_size): # Decoding may not have terminated with any completed valid SQL queries, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i not in best_final_states: self._exact_match(0) self._denotation_accuracy(0) self._valid_sql_query(0) self._action_similarity(0) outputs['predicted_sql_query'].append('') continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [action_mapping[i][action_index] for action_index in best_action_indices] predicted_sql_query = action_sequence_to_sql(action_strings) if action_sequence is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = action_sequence[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._exact_match(sequence_in_targets) similarity = difflib.SequenceMatcher(None, best_action_indices, targets) self._action_similarity(similarity.ratio()) outputs['best_action_sequence'].append(action_strings) outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True)) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore return outputs
def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None, ) -> torch.FloatTensor: # Both of shape (batch_size, sequence_length, embedding_size / 2) forward_sequence, backward_sequence = sequence_tensor.split( int(self._input_dim / 2), dim=-1 ) forward_sequence = forward_sequence.contiguous() backward_sequence = backward_sequence.contiguous() # shape (batch_size, num_spans) span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)] if span_indices_mask is not None: span_starts = span_starts * span_indices_mask span_ends = span_ends * span_indices_mask # We want `exclusive` span starts, so we remove 1 from the forward span starts # as the AllenNLP `SpanField` is inclusive. # shape (batch_size, num_spans) exclusive_span_starts = span_starts - 1 # shape (batch_size, num_spans, 1) start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1) # We want `exclusive` span ends for the backward direction # (so that the `start` of the span in that direction is exlusive), so # we add 1 to the span ends as the AllenNLP `SpanField` is inclusive. exclusive_span_ends = span_ends + 1 if sequence_mask is not None: # shape (batch_size) sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask) else: # shape (batch_size), filled with the sequence length size of the sequence_tensor. sequence_lengths = torch.ones_like( sequence_tensor[:, 0, 0], dtype=torch.long ) * sequence_tensor.size(1) # shape (batch_size, num_spans, 1) end_sentinel_mask = ( (exclusive_span_ends >= sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1) ) # As we added 1 to the span_ends to make them exclusive, which might have caused indices # equal to the sequence_length to become out of bounds, we multiply by the inverse of the # end_sentinel mask to erase these indices (as we will replace them anyway in the block below). # The same argument follows for the exclusive span start indices. exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1)) exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1)) # We'll check the indices here at runtime, because it's difficult to debug # if this goes wrong and it's tricky to get right. if (exclusive_span_starts < 0).any() or ( exclusive_span_ends > sequence_lengths.unsqueeze(-1) ).any(): raise ValueError( f"Adjusted span indices must lie inside the length of the sequence tensor, " f"but found: exclusive_span_starts: {exclusive_span_starts}, " f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths " f"{sequence_lengths}." ) # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2) forward_start_embeddings = util.batched_index_select( forward_sequence, exclusive_span_starts ) # Forward Direction: end indices are inclusive, so we can just use span_ends. # Shape (batch_size, num_spans, input_size / 2) forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends) # Backward Direction: The backward start embeddings use the `forward` end # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_start_embeddings = util.batched_index_select( backward_sequence, exclusive_span_ends ) # Backward Direction: The backward end embeddings use the `forward` start # indices, because we are going backwards. # Shape (batch_size, num_spans, input_size / 2) backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts) if self._use_sentinels: # If we're using sentinels, we need to replace all the elements which were # outside the dimensions of the sequence_tensor with either the start sentinel, # or the end sentinel. float_end_sentinel_mask = end_sentinel_mask.float() float_start_sentinel_mask = start_sentinel_mask.float() forward_start_embeddings = ( forward_start_embeddings * (1 - float_start_sentinel_mask) + float_start_sentinel_mask * self._start_sentinel ) backward_start_embeddings = ( backward_start_embeddings * (1 - float_end_sentinel_mask) + float_end_sentinel_mask * self._end_sentinel ) # Now we combine the forward and backward spans in the manner specified by the # respective combinations and concatenate these representations. # Shape (batch_size, num_spans, forward_combination_dim) forward_spans = util.combine_tensors( self._forward_combination, [forward_start_embeddings, forward_end_embeddings] ) # Shape (batch_size, num_spans, backward_combination_dim) backward_spans = util.combine_tensors( self._backward_combination, [backward_start_embeddings, backward_end_embeddings] ) # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim) span_embeddings = torch.cat([forward_spans, backward_spans], -1) if self._span_width_embedding is not None: # Embed the span widths and concatenate to the rest of the representations. if self._bucket_widths: span_widths = util.bucket_values( span_ends - span_starts, num_total_buckets=self._num_width_embeddings ) else: span_widths = span_ends - span_starts span_width_embeddings = self._span_width_embedding(span_widths) return torch.cat([span_embeddings, span_width_embeddings], -1) if span_indices_mask is not None: return span_embeddings * span_indices_mask.float().unsqueeze(-1) return span_embeddings
def _joint_likelihood(self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.LongTensor) -> torch.Tensor: """ Computes the numerator term for the log-likelihood, which is just score(inputs, tags) """ batch_size, sequence_length, num_tags = logits.data.shape # Transpose batch size and sequence dimensions: logits = logits.transpose(0, 1).contiguous() mask = mask.float().transpose(0, 1).contiguous() tags = tags.transpose(0, 1).contiguous() # Start with the transition scores from start_tag to the first tag in each input if self.include_start_end_transitions: score = self.start_transitions.index_select(0, tags[0]) else: score = 0.0 # Broadcast the transition scores to one per batch element broadcast_transitions = self.transitions.view(1, num_tags, num_tags).expand(batch_size, num_tags, num_tags) # Add up the scores for the observed transitions and all the inputs but the last for i in range(sequence_length - 1): # Each is shape (batch_size,) current_tag, next_tag = tags[i], tags[i+1] # The scores for transitioning from current_tag to next_tag transition_score = ( broadcast_transitions # Choose the current_tag-th row for each input .gather(1, current_tag.view(batch_size, 1, 1).expand(batch_size, 1, num_tags)) # Squeeze down to (batch_size, num_tags) .squeeze(1) # Then choose the next_tag-th column for each of those .gather(1, next_tag.view(batch_size, 1)) # And squeeze down to (batch_size,) .squeeze(1) ) # The score for using current_tag emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1) # Include transition score if next element is unmasked, # input_score if this element is unmasked. score = score + transition_score * mask[i + 1] + emit_score * mask[i] # Transition from last state to "stop" state. To start with, we need to find the last tag # for each instance. last_tag_index = mask.sum(0).long() - 1 last_tags = tags.gather(0, last_tag_index.view(1, batch_size).expand(sequence_length, batch_size)) # Is (sequence_length, batch_size), but all the columns are the same, so take the first. last_tags = last_tags[0] # Compute score of transitioning to `stop_tag` from each "last tag". if self.include_start_end_transitions: last_transition_score = self.end_transitions.index_select(0, last_tags) else: last_transition_score = 0.0 # Add the last input if it's not masked. last_inputs = logits[-1] # (batch_size, num_tags) last_input_score = last_inputs.gather(1, last_tags.view(-1, 1)) # (batch_size, 1) last_input_score = last_input_score.squeeze() # (batch_size,) score = score + last_transition_score + last_input_score * mask[-1] return score
def forward( self, # type: ignore tokens: Dict[str, torch.LongTensor], spans: torch.LongTensor, metadata: List[Dict[str, Any]], pos_tags: Dict[str, torch.LongTensor] = None, span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. spans : ``torch.LongTensor``, required. A tensor of shape ``(batch_size, num_spans, 2)`` representing the inclusive start and end indices of all possible spans in the sentence. metadata : List[Dict[str, Any]], required. A dictionary of metadata for each batch element which has keys: tokens : ``List[str]``, required. The original string tokens in the sentence. gold_tree : ``nltk.Tree``, optional (default = None) Gold NLTK trees for use in evaluation. pos_tags : ``List[str]``, optional. The POS tags for the sentence. These can be used in the model as embedded features, but they are passed here in addition for use in constructing the tree. pos_tags : ``torch.LongTensor``, optional (default = None) The output of a ``SequenceLabelField`` containing POS tags. span_labels : ``torch.LongTensor``, optional (default = None) A torch tensor representing the integer gold class labels for all possible spans, of shape ``(batch_size, num_spans)``. Returns ------- An output dictionary consisting of: class_probabilities : ``torch.FloatTensor`` A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)`` representing a distribution over the label classes per span. spans : ``torch.LongTensor`` The original spans tensor. tokens : ``List[List[str]]``, required. A list of tokens in the sentence for each element in the batch. pos_tags : ``List[List[str]]``, required. A list of POS tags in the sentence for each element in the batch. num_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size), representing the lengths of non-padded spans in ``enumerated_spans``. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ embedded_text_input = self.text_field_embedder(tokens) if pos_tags is not None and self.pos_tag_embedding is not None: embedded_pos_tags = self.pos_tag_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) elif self.pos_tag_embedding is not None: raise ConfigurationError( "Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(tokens) # Looking at the span start index is enough to know if # this is padding or not. Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long() if span_mask.dim() == 1: # This happens if you use batch_size 1 and encounter # a length 1 sentence in PTB, which do exist. -.- span_mask = span_mask.unsqueeze(-1) if span_labels is not None and span_labels.dim() == 1: span_labels = span_labels.unsqueeze(-1) num_spans = get_lengths_from_binary_sequence_mask(span_mask) encoded_text = self.encoder(embedded_text_input, mask) span_representations = self.span_extractor(encoded_text, spans, mask, span_mask) if self.feedforward_layer is not None: span_representations = self.feedforward_layer(span_representations) logits = self.tag_projection_layer(span_representations) class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1)) output_dict = { "class_probabilities": class_probabilities, "spans": spans, "tokens": [meta["tokens"] for meta in metadata], "pos_tags": [meta.get("pos_tags") for meta in metadata], "num_spans": num_spans } if span_labels is not None: loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask) self.tag_accuracy(class_probabilities, span_labels, span_mask) output_dict["loss"] = loss # The evalb score is expensive to compute, so we only compute # it for the validation and test sets. batch_gold_trees = [meta.get("gold_tree") for meta in metadata] if all(batch_gold_trees ) and self._evalb_score is not None and not self.training: gold_pos_tags: List[List[str]] = [ list(zip(*tree.pos()))[1] for tree in batch_gold_trees ] predicted_trees = self.construct_trees( class_probabilities.cpu().data, spans.cpu().data, num_spans.data, output_dict["tokens"], gold_pos_tags) self._evalb_score(predicted_trees, batch_gold_trees) return output_dict
def forward( self, token_ids: torch.LongTensor, mask: torch.BoolTensor, type_ids: Optional[torch.LongTensor] = None, segment_concat_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: # type: ignore """ # Parameters token_ids: `torch.LongTensor` Shape: `[batch_size, num_wordpieces if max_length is None else num_segment_concat_wordpieces]`. num_segment_concat_wordpieces is num_wordpieces plus special tokens inserted in the middle, e.g. the length of: "[CLS] A B C [SEP] [CLS] D E F [SEP]" (see indexer logic). mask: `torch.BoolTensor` Shape: [batch_size, num_wordpieces]. type_ids: `Optional[torch.LongTensor]` Shape: `[batch_size, num_wordpieces if max_length is None else num_segment_concat_wordpieces]`. segment_concat_mask: `Optional[torch.BoolTensor]` Shape: `[batch_size, num_segment_concat_wordpieces]`. # Returns `torch.Tensor` Shape: `[batch_size, num_wordpieces, embedding_size]`. """ # Some of the huggingface transformers don't support type ids at all and crash when you supply # them. For others, you can supply a tensor of zeros, and if you don't, they act as if you did. # There is no practical difference to the caller, so here we pretend that one case is the same # as another case. if type_ids is not None: max_type_id = type_ids.max() if max_type_id == 0: type_ids = None else: if max_type_id >= self._number_of_token_type_embeddings(): raise ValueError("Found type ids too large for the chosen transformer model.") assert token_ids.shape == type_ids.shape fold_long_sequences = self._max_length is not None and token_ids.size(1) > self._max_length if fold_long_sequences: batch_size, num_segment_concat_wordpieces = token_ids.size() token_ids, segment_concat_mask, type_ids = self._fold_long_sequences( token_ids, segment_concat_mask, type_ids ) transformer_mask = segment_concat_mask if self._max_length is not None else mask # Shape: [batch_size, num_wordpieces, embedding_size], # or if self._max_length is not None: # [batch_size * num_segments, self._max_length, embedding_size] # We call this with kwargs because some of the huggingface models don't have the # token_type_ids parameter and fail even when it's given as None. # Also, as of transformers v2.5.1, they are taking FloatTensor masks. parameters = {"input_ids": token_ids, "attention_mask": transformer_mask.float()} if type_ids is not None: parameters["token_type_ids"] = type_ids transformer_output = self.transformer_model(**parameters) if self._scalar_mix is not None: # As far as I can tell, the hidden states will always be the last element # in the output tuple as long as the model is not also configured to return # attention scores. # See, for example, the return value description for BERT: # https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel.forward # These hidden states will also include the embedding layer, which we don't # include in the scalar mix. Hence the `[1:]` slicing. hidden_states = transformer_output[-1][1:] embeddings = self._scalar_mix(hidden_states) else: embeddings = transformer_output[0] if fold_long_sequences: embeddings = self._unfold_long_sequences( embeddings, segment_concat_mask, batch_size, num_segment_concat_wordpieces ) return embeddings
def construct_trees(self, predictions: torch.FloatTensor, all_spans: torch.LongTensor, num_spans: torch.LongTensor, sentences: List[List[str]], pos_tags: List[List[str]] = None) -> List[Tree]: """ Construct ``nltk.Tree``'s for each batch element by greedily nesting spans. The trees use exclusive end indices, which contrasts with how spans are represented in the rest of the model. Parameters ---------- predictions : ``torch.FloatTensor``, required. A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)`` representing a distribution over the label classes per span. all_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the span indices we scored. num_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size), representing the lengths of non-padded spans in ``enumerated_spans``. sentences : ``List[List[str]]``, required. A list of tokens in the sentence for each element in the batch. pos_tags : ``List[List[str]]``, optional (default = None). A list of POS tags for each word in the sentence for each element in the batch. Returns ------- A ``List[Tree]`` containing the decoded trees for each element in the batch. """ # Switch to using exclusive end spans. exclusive_end_spans = all_spans.clone() exclusive_end_spans[:, :, -1] += 1 no_label_id = self.vocab.get_token_index("NO-LABEL", "labels") trees: List[Tree] = [] for batch_index, (scored_spans, spans, sentence) in enumerate( zip(predictions, exclusive_end_spans, sentences)): selected_spans = [] for prediction, span in zip(scored_spans[:num_spans[batch_index]], spans[:num_spans[batch_index]]): start, end = span no_label_prob = prediction[no_label_id] label_prob, label_index = torch.max(prediction, -1) # Does the span have a label != NO-LABEL or is it the root node? # If so, include it in the spans that we consider. if int(label_index) != no_label_id or (start == 0 and end == len(sentence)): # TODO(Mark): Remove this once pylint sorts out named tuples. # https://github.com/PyCQA/pylint/issues/1418 selected_spans.append( SpanInformation( start=int(start), # pylint: disable=no-value-for-parameter end=int(end), label_prob=float(label_prob), no_label_prob=float(no_label_prob), label_index=int(label_index))) # The spans we've selected might overlap, which causes problems when we try # to construct the tree as they won't nest properly. consistent_spans = self.resolve_overlap_conflicts_greedily( selected_spans) spans_to_labels = { (span.start, span.end): self.vocab.get_token_from_index(span.label_index, "labels") for span in consistent_spans } sentence_pos = pos_tags[ batch_index] if pos_tags is not None else None trees.append( self.construct_tree_from_spans(spans_to_labels, sentence, sentence_pos)) return trees
optimizer=optimizer, ) # # validate model with validation set if True: loss_avg_val, accu_avg_val = trainer_epoch( model, dataloader_val, criterion, ) # # Test if params_LocalTest: loss_avg_tst, accu_avg_tst = trainer_epoch( model, dataloader_tst, criterion, idxCat=LongTensor(idxCatUnseen).cuda(), ) # Predict if False: arr_outputs_prd = trainer_epoch( model, dataloader_prd, criterion, predict=True, predict_DataAug=False, idxCat=LongTensor(idxCatUnannotd).cuda(), ) time_elapsed = time.time() - time_start EpochResult = r'TRN_lss_{:.3g}_accu_{:.3g}_VAL_lss_{:.3g}_accu_{:.3g}_TST_lss_{:.3g}_accu_{:.3g}'.format(
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.LongTensor = None, span_end: torch.LongTensor = None, spans=None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # Shape: (batch_size, num_passage=4, passage_length, embedding_dim) embedded_passage = self._text_field_embedder(passage) passage_mask = util.get_text_field_mask(passage, 1).float() # get some parameters cuda_device = embedded_passage.get_device() batch_size, num_passage, passage_length, embedding_dim = embedded_passage.size( ) # when training, select randomly 2 passages from 4 passages each epoch if self.training: num_passage = 2 probs = torch.Tensor([1, 1, 1]).unsqueeze(0).expand(batch_size, 3) indices = torch.multinomial(probs, 1) + 1 zeros_tensor = torch.zeros(batch_size).long() # Shape: (batch_size, 2) indices = Variable( torch.cat([zeros_tensor.unsqueeze(-1), indices], 1).cuda(cuda_device)) # Shape: (batch_size, num_passage, passage_length, embedding_dim) embedded_passage = torch.gather( embedded_passage, 1, indices.unsqueeze(-1).unsqueeze(-1).expand( batch_size, num_passage, passage_length, embedding_dim)) # Shape: (batch_size, num_passage, passage_length) passage_mask = torch.gather( passage_mask, 1, indices.unsqueeze(-1).expand(batch_size, num_passage, passage_length)) # Shape: (batch_size*num_passage, passage_length, embedding_dim) embedded_passage = embedded_passage.view(-1, passage_length, embedding_dim) embedded_passage = self._highway_layer(embedded_passage) # Shape: (batch_size*num_passage, passage_length) passage_mask = passage_mask.view(-1, passage_length) # Shape: (batch_size, question_length, embedding_dim) embedded_question = self._highway_layer( self._text_field_embedder(question)) question_length = embedded_question.size(1) # Shape: (batch_size*numpassage, question_length, embedding_dim) embedded_question = embedded_question.unsqueeze(1).expand( -1, num_passage, -1, -1).contiguous().view(-1, question_length, embedding_dim) # Shape: (batch_size, question_length) question_mask = util.get_text_field_mask(question).float() # Shape: (batch_size*num_passage, question_length) question_mask = question_mask.unsqueeze(1).expand( -1, num_passage, -1).contiguous().view(-1, question_length) question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size*num_passage, passage_length, question_length) #passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (B, P, 1) p_similarity = torch.matmul(encoded_passage, self._w_p).unsqueeze(2) # Shape: (B, 1, Q) q_similarity = torch.matmul(encoded_question, self._w_q).unsqueeze(1) # Shape: (B, P, Q) pq_similarity = torch.bmm(encoded_passage * self._w_pq, encoded_question.transpose(1, 2)) passage_question_similarity = pq_similarity + p_similarity + q_similarity # Shape: (batch_size*num_passage, passage_length, question_length) passage_question_attention = util.last_dim_softmax( passage_question_similarity, question_mask) # Shape: (batch_size*num_passage, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size*num_passage, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size*num_passage, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size*num_passage, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size*num_passage, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size * num_passage, passage_length, encoding_dim) # Shape: (batch_size*num_passage, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) # Shape: (batch_size*num_passage, passage_length, encoding_dim) question_attended_passage = relu( self._linear_layer(final_merged_passage)) # attach residual self-attention layer # Shape: (batch_size*num_passage, passage_length, encoding_dim) residual_passage = self._dropout( self._residual_encoder(self._dropout(question_attended_passage), passage_lstm_mask)) # create mask for self-attention mask = passage_mask.resize( batch_size * num_passage, passage_length, 1) * passage_mask.resize( batch_size * num_passage, 1, passage_length) self_mask = Variable( torch.eye(passage_length, passage_length).cuda(cuda_device)).resize( 1, passage_length, passage_length) mask = mask * (1 - self_mask) # Shape: (batch_size*num_passage, passage_length, passage_length) x_similarity = torch.matmul(residual_passage, self._w_x).unsqueeze(2) y_similarity = torch.matmul(residual_passage, self._w_y).unsqueeze(1) dot_similarity = torch.bmm(residual_passage * self._w_xy, residual_passage.transpose(1, 2)) passage_self_similarity = dot_similarity + x_similarity + y_similarity # Shape: (batch_size*num_passage, passage_length, passage_length) passage_self_attention = util.last_dim_softmax(passage_self_similarity, mask) # Shape: (batch_size*num_passage, passage_length, encoding_dim) passage_vectors = util.weighted_sum(residual_passage, passage_self_attention) # Shape: (batch_size*num_passage, passage_length, encoding_dim * 3) merged_passage = torch.cat([ residual_passage, passage_vectors, residual_passage * passage_vectors ], dim=-1) # Shape: (batch_size*num_passage, passage_length, encoding_dim) self_attended_passage = relu( self._residual_linear_layer(merged_passage)) # Shape: (batch_size*num_passage, passage_length, encoding_dim) mixed_passage = question_attended_passage + self_attended_passage # Shape: (batch_size*num_passage, passage_length, encoding_dim) encoded_span_start = self._dropout( self._span_start_encoder(mixed_passage, passage_lstm_mask)) span_start_logits = self._span_start_predictor( encoded_span_start).squeeze(-1) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size*num_passage, passage_length, encoding_dim * 2) concatenated_passage = torch.cat([mixed_passage, encoded_span_start], dim=-1) # Shape: (batch_size*num_passage, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(concatenated_passage, passage_lstm_mask)) span_end_logits = self._span_end_predictor(encoded_span_end).squeeze( -1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) # Shape: (batch_size*num_passage, passage_length) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, #"best_span": best_span, } if span_start is not None: if self.training: # merge logits of multiple passages in the same context # Shape: (batch_size, num_passage*passage_length) span_start_logits = span_start_logits.view( batch_size, num_passage, passage_length).view(batch_size, -1) span_end_logits = span_end_logits.view(batch_size, num_passage, passage_length).view( batch_size, -1) # Shape: (batch_size, num_passage*passage_length) passage_mask = passage_mask.view(batch_size, num_passage, passage_length).view( batch_size, -1) loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) #self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) #self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) #self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) else: # do not care about loss when validating loss = Variable(torch.Tensor([0]).cuda(cuda_device)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if not self.training and metadata is not None: # Shape: (batch_size*num_passage, 3) best_span = self.get_best_span(span_start_logits, span_end_logits) # Shape: (batch_size, num_passage, 3) best_span = best_span.view(batch_size, num_passage, 3) output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] all_passages = metadata[i]['all_passages'] passage_offsets = metadata[i]['passage_offsets'] _, max_id = torch.max(best_span[i, :, 2], dim=0) max_id = int(max_id) predicted_span = tuple(best_span[i, max_id].data.cpu().numpy()) start_offset = passage_offsets[max_id][int( predicted_span[0])][0] end_offset = passage_offsets[max_id][int(predicted_span[1])][1] best_span_string = all_passages[max_id][ start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward(self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.LongTensor = None, span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor: # both of shape (batch_size, num_spans, 1) span_starts, span_ends = span_indices.split(1, dim=-1) # shape (batch_size, num_spans, 1) # These span widths are off by 1, because the span ends are `inclusive`. span_widths = span_ends - span_starts # We need to know the maximum span width so we can # generate indices to extract the spans from the sequence tensor. # These indices will then get masked below, such that if the length # of a given span is smaller than the max, the rest of the values # are masked. max_batch_span_width = span_widths.max().item() + 1 # shape (batch_size, sequence_length, 1) global_attention_logits = self._global_attention(sequence_tensor) # Shape: (1, 1, max_batch_span_width) max_span_range_indices = util.get_range_vector(max_batch_span_width, util.get_device_of(sequence_tensor)).view(1, 1, -1) # Shape: (batch_size, num_spans, max_batch_span_width) # This is a broadcasted comparison - for each span we are considering, # we are creating a range vector of size max_span_width, but masking values # which are greater than the actual length of the span. # # We're using <= here (and for the mask below) because the span ends are # inclusive, so we want to include indices which are equal to span_widths rather # than using it as a non-inclusive upper bound. span_mask = (max_span_range_indices <= span_widths).float() raw_span_indices = span_ends - max_span_range_indices # We also don't want to include span indices which are less than zero, # which happens because some spans near the beginning of the sequence # have an end index < max_batch_span_width, so we add this to the mask here. span_mask = span_mask * (raw_span_indices >= 0).float() span_indices = torch.nn.functional.relu(raw_span_indices.float()).long() # Shape: (batch_size * num_spans * max_batch_span_width) flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1)) # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim) span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_logits = util.batched_index_select(global_attention_logits, span_indices, flat_span_indices).squeeze(-1) # Shape: (batch_size, num_spans, max_batch_span_width) span_attention_weights = util.masked_softmax(span_attention_logits, span_mask) # Do a weighted sum of the embedded spans with # respect to the normalised attention distributions. # Shape: (batch_size, num_spans, embedding_dim) attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights) if span_indices_mask is not None: # Above we were masking the widths of spans with respect to the max # span width in the batch. Here we are masking the spans which were # originally passed in as padding. return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float() return attended_text_embeddings
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], verb_indicator: torch.LongTensor, tags: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. verb_indicator: torch.LongTensor, required. An integer ``SequenceFeatureField`` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) metadata containg the original words in the sentence and the verb to compute the frame for, under 'words' and 'verb' keys, respectively. Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing unnormalised log probabilities of the tag classes. class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing a distribution of the tag classes per word. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_text_input = self.embedding_dropout(self.text_field_embedder(tokens)) mask = get_text_field_mask(tokens) embedded_verb_indicator = self.binary_feature_embedding(verb_indicator.long()) # Concatenate the verb feature onto the embedded text. This now # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim). embedded_text_with_verb_indicator = torch.cat([embedded_text_input, embedded_verb_indicator], -1) batch_size, sequence_length, _ = embedded_text_with_verb_indicator.size() encoded_text = self.encoder(embedded_text_with_verb_indicator, mask) logits = self.tag_projection_layer(encoded_text) reshaped_log_probs = logits.view(-1, self.num_classes) class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view([batch_size, sequence_length, self.num_classes]) output_dict = {"logits": logits, "class_probabilities": class_probabilities} # We need to retain the mask in the output dictionary # so that we can crop the sequences to remove padding # when we do viterbi inference in self.decode. output_dict["mask"] = mask if tags is not None: loss = sequence_cross_entropy_with_logits(logits, tags, mask, label_smoothing=self._label_smoothing) if not self.ignore_span_metric and self.span_metric is not None and not self.training: batch_verb_indices = [example_metadata["verb_index"] for example_metadata in metadata] batch_sentences = [example_metadata["words"] for example_metadata in metadata] # Get the BIO tags from decode() # TODO (nfliu): This is kind of a hack, consider splitting out part # of decode() to a separate function. batch_bio_predicted_tags = self.decode(output_dict).pop("tags") batch_conll_predicted_tags = [convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags] batch_bio_gold_tags = [example_metadata["gold_tags"] for example_metadata in metadata] batch_conll_gold_tags = [convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags] self.span_metric(batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags) output_dict["loss"] = loss words, verbs = zip(*[(x["words"], x["verb"]) for x in metadata]) if metadata is not None: output_dict["words"] = list(words) output_dict["verb"] = list(verbs) return output_dict
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], spans: torch.LongTensor, metadata: List[Dict[str, Any]], pos_tags: Dict[str, torch.LongTensor] = None, span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. spans : ``torch.LongTensor``, required. A tensor of shape ``(batch_size, num_spans, 2)`` representing the inclusive start and end indices of all possible spans in the sentence. metadata : List[Dict[str, Any]], required. A dictionary of metadata for each batch element which has keys: tokens : ``List[str]``, required. The original string tokens in the sentence. gold_tree : ``nltk.Tree``, optional (default = None) Gold NLTK trees for use in evaluation. pos_tags : ``List[str]``, optional. The POS tags for the sentence. These can be used in the model as embedded features, but they are passed here in addition for use in constructing the tree. pos_tags : ``torch.LongTensor``, optional (default = None) The output of a ``SequenceLabelField`` containing POS tags. span_labels : ``torch.LongTensor``, optional (default = None) A torch tensor representing the integer gold class labels for all possible spans, of shape ``(batch_size, num_spans)``. Returns ------- An output dictionary consisting of: class_probabilities : ``torch.FloatTensor`` A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)`` representing a distribution over the label classes per span. spans : ``torch.LongTensor`` The original spans tensor. tokens : ``List[List[str]]``, required. A list of tokens in the sentence for each element in the batch. pos_tags : ``List[List[str]]``, required. A list of POS tags in the sentence for each element in the batch. num_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size), representing the lengths of non-padded spans in ``enumerated_spans``. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ embedded_text_input = self.text_field_embedder(tokens) if pos_tags is not None and self.pos_tag_embedding is not None: embedded_pos_tags = self.pos_tag_embedding(pos_tags) embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1) elif self.pos_tag_embedding is not None: raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(tokens) # Looking at the span start index is enough to know if # this is padding or not. Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long() if span_mask.dim() == 1: # This happens if you use batch_size 1 and encounter # a length 1 sentence in PTB, which do exist. -.- span_mask = span_mask.unsqueeze(-1) if span_labels is not None and span_labels.dim() == 1: span_labels = span_labels.unsqueeze(-1) num_spans = get_lengths_from_binary_sequence_mask(span_mask) encoded_text = self.encoder(embedded_text_input, mask) span_representations = self.span_extractor(encoded_text, spans, mask, span_mask) if self.feedforward_layer is not None: span_representations = self.feedforward_layer(span_representations) logits = self.tag_projection_layer(span_representations) class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1)) output_dict = { "class_probabilities": class_probabilities, "spans": spans, "tokens": [meta["tokens"] for meta in metadata], "pos_tags": [meta.get("pos_tags") for meta in metadata], "num_spans": num_spans } if span_labels is not None: loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask) self.tag_accuracy(class_probabilities, span_labels, span_mask) output_dict["loss"] = loss # The evalb score is expensive to compute, so we only compute # it for the validation and test sets. batch_gold_trees = [meta.get("gold_tree") for meta in metadata] if all(batch_gold_trees) and self._evalb_score is not None and not self.training: gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1] for tree in batch_gold_trees] predicted_trees = self.construct_trees(class_probabilities.cpu().data, spans.cpu().data, num_spans.data, output_dict["tokens"], gold_pos_tags) self._evalb_score(predicted_trees, batch_gold_trees) return output_dict
def forward(self, input_ids: torch.LongTensor, offsets: torch.LongTensor = None, token_type_ids: torch.LongTensor = None) -> torch.Tensor: """ Parameters ---------- input_ids : ``torch.LongTensor`` The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. offsets : ``torch.LongTensor``, optional The BERT embeddings are one per wordpiece. However it's possible/likely you might want one per original token. In that case, ``offsets`` represents the indices of the desired wordpiece for each original token. Depending on how your token indexer is configured, this could be the position of the last wordpiece for each token, or it could be the position of the first wordpiece for each token. For example, if you had the sentence "Definitely not", and if the corresponding wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. If offsets are provided, the returned tensor will contain only the wordpiece embeddings at those positions, and (in particular) will contain one embedding per token. If offsets are not provided, the entire tensor of wordpiece embeddings will be returned. token_type_ids : ``torch.LongTensor``, optional If an input consists of two sentences (as in the BERT paper), tokens from the first sentence should have type 0 and tokens from the second sentence should have type 1. If you don't provide this (the default BertIndexer doesn't) then it's assumed to be all 0s. """ # pylint: disable=arguments-differ batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1) initial_dims = list(input_ids.shape[:-1]) # The embedder may receive an input tensor that has a sequence length longer than can # be fit. In that case, we should expect the wordpiece indexer to create padded windows # of length `self.max_pieces` for us, and have them concatenated into one long sequence. # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..." # We can then split the sequence into sub-sequences of that length, and concatenate them # along the batch dimension so we effectively have one huge batch of partial sentences. # This can then be fed into BERT without any sentence length issues. Keep in mind # that the memory consumption can dramatically increase for large batches with extremely # long sentences. needs_split = full_seq_len > self.max_pieces last_window_size = 0 if needs_split: # Split the flattened list by the window size, `max_pieces` split_input_ids = list(input_ids.split(self.max_pieces, dim=-1)) # We want all sequences to be the same length, so pad the last sequence last_window_size = split_input_ids[-1].size(-1) padding_amount = self.max_pieces - last_window_size split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0) # Now combine the sequences along the batch dimension input_ids = torch.cat(split_input_ids, dim=0) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) input_mask = (input_ids != 0).long() # input_ids may have extra dimensions, so we reshape down to 2-d # before calling the BERT model and then reshape back at the end. all_encoder_layers, _ = self.bert_model( input_ids=util.combine_initial_dims(input_ids), token_type_ids=util.combine_initial_dims(token_type_ids), attention_mask=util.combine_initial_dims(input_mask)) all_encoder_layers = torch.stack(all_encoder_layers) if needs_split: # First, unpack the output embeddings into one long sequence again unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1) unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2) # Next, select indices of the sequence such that it will result in embeddings representing the original # sentence. To capture maximal context, the indices will be the middle part of each embedded window # sub-sequence (plus any leftover start and final edge windows), e.g., # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]" # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start # and final windows with indices [0, 1] and [14, 15] respectively. # Find the stride as half the max pieces, ignoring the special start and end tokens # Calculate an offset to extract the centermost embeddings of each window stride = (self.max_pieces - self.start_tokens - self.end_tokens) // 2 stride_offset = stride // 2 + self.start_tokens first_window = list(range(stride_offset)) max_context_windows = [ i for i in range(full_seq_len) if stride_offset - 1 < i % self.max_pieces < stride_offset + stride ] final_window_start = full_seq_len - ( full_seq_len % self.max_pieces) + stride_offset + stride final_window = list(range(final_window_start, full_seq_len)) select_indices = first_window + max_context_windows + final_window initial_dims.append(len(select_indices)) recombined_embeddings = unpacked_embeddings[:, :, select_indices] else: recombined_embeddings = all_encoder_layers # Recombine the outputs of all layers # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim) # recombined = torch.cat(combined, dim=2) input_mask = (recombined_embeddings != 0).long() # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) if offsets is None: # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) dims = initial_dims if needs_split else input_ids.size() layers = util.uncombine_initial_dims(recombined_embeddings, dims) else: # offsets is (batch_size, d1, ..., dn, orig_sequence_length) offsets2d = util.combine_initial_dims(offsets) # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) range_vector = util.get_range_vector( offsets2d.size(0), device=util.get_device_of(recombined_embeddings)).unsqueeze(1) # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) selected_embeddings = recombined_embeddings[:, range_vector, offsets2d] layers = util.uncombine_initial_dims(selected_embeddings, offsets.size()) if self._scalar_mix is not None: return self._scalar_mix(layers, input_mask) elif self.combine_layers == "last": return layers[-1] else: return layers
def _get_linking_probabilities(self, worlds: List[WikiTablesWorld], linking_scores: torch.FloatTensor, question_mask: torch.LongTensor, entity_type_dict: Dict[int, int]) -> torch.FloatTensor: """ Produces the probability of an entity given a question word and type. The logic below separates the entities by type since the softmax normalization term sums over entities of a single type. Parameters ---------- worlds : ``List[WikiTablesWorld]`` linking_scores : ``torch.FloatTensor`` Has shape (batch_size, num_question_tokens, num_entities). question_mask: ``torch.LongTensor`` Has shape (batch_size, num_question_tokens). entity_type_dict : ``Dict[int, int]`` This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id. Returns ------- batch_probabilities : ``torch.FloatTensor`` Has shape ``(batch_size, num_question_tokens, num_entities)``. Contains all the probabilities for an entity given a question word. """ _, num_question_tokens, num_entities = linking_scores.size() batch_probabilities = [] for batch_index, world in enumerate(worlds): all_probabilities = [] num_entities_in_instance = 0 # NOTE: The way that we're doing this here relies on the fact that entities are # implicitly sorted by their types when we sort them by name, and that numbers come # before "fb:cell", and "fb:cell" comes before "fb:row". This is not a great # assumption, and could easily break later, but it should work for now. for type_index in range(self._num_entity_types): # This index of 0 is for the null entity for each type, representing the case where a # word doesn't link to any entity. entity_indices = [0] entities = world.table_graph.entities for entity_index, _ in enumerate(entities): if entity_type_dict[batch_index * num_entities + entity_index] == type_index: entity_indices.append(entity_index) if len(entity_indices) == 1: # No entities of this type; move along... continue # We're subtracting one here because of the null entity we added above. num_entities_in_instance += len(entity_indices) - 1 # We separate the scores by type, since normalization is done per type. There's an # extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_, # so we get back something of shape (num_question_tokens,) for each index we're # selecting. All of the selected indices together then make a tensor of shape # (num_question_tokens, num_entities_per_type + 1). indices = linking_scores.new_tensor(entity_indices, dtype=torch.long) entity_scores = linking_scores[batch_index].index_select(1, indices) # We used index 0 for the null entity, so this will actually have some values in it. # But we want the null entity's score to be 0, so we set that here. entity_scores[:, 0] = 0 # No need for a mask here, as this is done per batch instance, with no padding. type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1) all_probabilities.append(type_probabilities[:, 1:]) # We need to add padding here if we don't have the right number of entities. if num_entities_in_instance != num_entities: zeros = linking_scores.new_zeros(num_question_tokens, num_entities - num_entities_in_instance) all_probabilities.append(zeros) # (num_question_tokens, num_entities) probabilities = torch.cat(all_probabilities, dim=1) batch_probabilities.append(probabilities) batch_probabilities = torch.stack(batch_probabilities, dim=0) return batch_probabilities * question_mask.unsqueeze(-1).float()
def forward(self, # type: ignore utterance: Dict[str, torch.LongTensor], world: List[AtisWorld], actions: List[List[ProductionRule]], linking_scores: torch.Tensor, target_action_sequence: torch.LongTensor = None, sql_queries: List[List[str]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- utterance : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. world : ``List[AtisWorld]`` We use a ``MetadataField`` to get the ``World`` for each input instance. Because of how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``, actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. linking_scores: ``torch.Tensor`` A matrix of the linking the utterance tokens and the entities. This is a binary matrix that is deterministically generated where each entry indicates whether a token generated an entity. This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``. target_action_sequence : torch.Tensor, optional (default=None) The action sequence for the correct action sequence, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the trailing dimension. sql_queries : List[List[str]], optional (default=None) A list of the SQL queries that are given during training or validation. """ initial_state = self._get_initial_state(utterance, world, actions, linking_scores) batch_size = linking_scores.shape[0] if target_action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequence = target_action_sequence.squeeze(-1) target_mask = target_action_sequence != self._action_padding_index else: target_mask = None if self.training: # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for # the MML trainer. return self._decoder_trainer.decode(initial_state, self._transition_function, (target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1))) else: # TODO(kevin) Move some of this functionality to a separate method for computing validation outputs. action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs: Dict[str, Any] = {'action_mapping': action_mapping} outputs['linking_scores'] = linking_scores if target_action_sequence is not None: outputs['loss'] = self._decoder_trainer.decode(initial_state, self._transition_function, (target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1)))['loss'] num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(num_steps, initial_state, self._transition_function, keep_final_unfinished_states=False) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['entities'] = [] outputs['predicted_sql_query'] = [] outputs['sql_queries'] = [] outputs['utterance'] = [] outputs['tokenized_utterance'] = [] for i in range(batch_size): # Decoding may not have terminated with any completed valid SQL queries, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i not in best_final_states: self._exact_match(0) self._denotation_accuracy(0) self._valid_sql_query(0) self._action_similarity(0) outputs['predicted_sql_query'].append('') continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices] predicted_sql_query = action_sequence_to_sql(action_strings) if target_action_sequence is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = target_action_sequence[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._exact_match(sequence_in_targets) similarity = difflib.SequenceMatcher(None, best_action_indices, targets) self._action_similarity(similarity.ratio()) if sql_queries and sql_queries[i]: denotation_correct = self._executor.evaluate_sql_query(predicted_sql_query, sql_queries[i]) self._denotation_accuracy(denotation_correct) outputs['sql_queries'].append(sql_queries[i]) outputs['utterance'].append(world[i].utterances[-1]) outputs['tokenized_utterance'].append([token.text for token in world[i].tokenized_utterances[-1]]) outputs['entities'].append(world[i].entities) outputs['best_action_sequence'].append(action_strings) outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True)) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore return outputs
def forward(self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRule]], identifier: List[str] = None, target_action_sequences: torch.LongTensor = None, labels: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences, trained to maximize marginal likelihod over a set of approximate logical forms. """ batch_size = len(worlds) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float) for i in range(batch_size)] label_strings = self._get_label_strings(labels) if labels is not None else None # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=label_strings) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None outputs: Dict[str, torch.Tensor] = {} if identifier is not None: outputs["identifier"] = identifier if target_action_sequences is not None: outputs = self._decoder_trainer.decode(initial_state, self._decoder_step, (target_action_sequences, target_mask)) if not self.training: initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._decoder_beam_search.search(self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) best_action_sequences: Dict[int, List[List[int]]] = {} for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = [best_final_states[i][0].action_history[0]] best_action_sequences[i] = best_action_indices batch_action_strings = self._get_action_strings(actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if target_action_sequences is not None: self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings) else: if metadata is not None: outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata] outputs['debug_info'] = [] for i in range(batch_size): outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs['action_mapping'] = action_mapping return outputs
def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrWorld]], actions: List[List[ProductionRuleArray]], target_action_sequences: torch.LongTensor = None, labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing type constrained target sequences, trained to maximize marginal likelihod over a set of approximate logical forms. """ batch_size = len(worlds) action_embeddings, action_indices = self._embed_actions(actions) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [ util.new_variable_with_data( list(sentence.values())[0], torch.Tensor([0.0])) for i in range(batch_size) ] label_strings = self._get_label_strings( labels) if labels is not None else None # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [ self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size) ] worlds_list = [worlds[i] for i in range(batch_size)] initial_state = NlvrDecoderState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, action_indices=action_indices, possible_actions=actions, worlds=worlds_list, label_strings=label_strings) if target_action_sequences is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequences = target_action_sequences.squeeze(-1) target_mask = target_action_sequences != self._action_padding_index else: target_mask = None outputs: Dict[str, torch.Tensor] = {} if target_action_sequences is not None: outputs = self._decoder_trainer.decode( initial_state, self._decoder_step, (target_action_sequences, target_mask)) best_final_states = self._decoder_beam_search.search( self._max_decoding_steps, initial_state, self._decoder_step, keep_final_unfinished_states=False) best_action_sequences: Dict[int, List[List[int]]] = {} for i in range(batch_size): # Decoding may not have terminated with any completed logical forms, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i in best_final_states: best_action_indices = [ best_final_states[i][0].action_history[0] ] best_action_sequences[i] = best_action_indices batch_action_strings = self._get_action_strings( actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if target_action_sequences is not None: self._update_metrics(action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings) else: outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations return outputs
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], verb_indicator: torch.LongTensor, tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. verb_indicator: torch.LongTensor, required. An integer ``SequenceFeatureField`` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)`` Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing unnormalised log probabilities of the tag classes. class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing a distribution of the tag classes per word. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_text_input = self.embedding_dropout(self.text_field_embedder(tokens)) mask = get_text_field_mask(tokens) embedded_verb_indicator = self.binary_feature_embedding(verb_indicator.long()) # Concatenate the verb feature onto the embedded text. This now # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim). embedded_text_with_verb_indicator = torch.cat([embedded_text_input, embedded_verb_indicator], -1) batch_size, sequence_length, embedding_dim_with_binary_feature = embedded_text_with_verb_indicator.size() if self.stacked_encoder.get_input_dim() != embedding_dim_with_binary_feature: raise ConfigurationError("The SRL model uses an indicator feature, which makes " "the embedding dimension one larger than the value " "specified. Therefore, the 'input_dim' of the stacked_encoder " "must be equal to total_embedding_dim + 1.") encoded_text = self.stacked_encoder(embedded_text_with_verb_indicator, mask) logits = self.tag_projection_layer(encoded_text) reshaped_log_probs = logits.view(-1, self.num_classes) class_probabilities = F.softmax(reshaped_log_probs).view([batch_size, sequence_length, self.num_classes]) output_dict = {"logits": logits, "class_probabilities": class_probabilities} if tags is not None: loss = sequence_cross_entropy_with_logits(logits, tags, mask) self.span_metric(class_probabilities, tags, mask) output_dict["loss"] = loss return output_dict
def forward(self, # pylint: disable=arguments-differ inputs: torch.Tensor, mask: torch.LongTensor = None) -> torch.FloatTensor: """ Parameters ---------- inputs : ``torch.FloatTensor``, required. A tensor of shape (batch_size, timesteps, input_dim) mask : ``torch.FloatTensor``, optional (default = None). A tensor of shape (batch_size, timesteps). Returns ------- A tensor of shape (batch_size, timesteps, output_projection_dim), where output_projection_dim = input_dim by default. """ num_heads = self._num_heads batch_size, timesteps, _ = inputs.size() if mask is None: mask = inputs.new_ones(batch_size, timesteps) # Shape (batch_size, timesteps, 2 * attention_dim + values_dim) combined_projection = self._combined_projection(inputs) # split by attention dim - if values_dim > attention_dim, we will get more # than 3 elements returned. All of the rest are the values vector, so we # just concatenate them back together again below. queries, keys, *values = combined_projection.split(self._attention_dim, -1) queries = queries.contiguous() keys = keys.contiguous() values = torch.cat(values, -1).contiguous() # Shape (num_heads * batch_size, timesteps, values_dim / num_heads) values_per_head = values.view(batch_size, timesteps, num_heads, int(self._values_dim/num_heads)) values_per_head = values_per_head.transpose(1, 2).contiguous() values_per_head = values_per_head.view(batch_size * num_heads, timesteps, int(self._values_dim/num_heads)) # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads) queries_per_head = queries.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads)) queries_per_head = queries_per_head.transpose(1, 2).contiguous() queries_per_head = queries_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads)) # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads) keys_per_head = keys.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads)) keys_per_head = keys_per_head.transpose(1, 2).contiguous() keys_per_head = keys_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads)) # shape (num_heads * batch_size, timesteps, timesteps) scaled_similarities = torch.bmm(queries_per_head, keys_per_head.transpose(1, 2)) / self._scale # shape (num_heads * batch_size, timesteps, timesteps) # Normalise the distributions, using the same mask for all heads. attention = last_dim_softmax(scaled_similarities, mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps)) attention = self._attention_dropout(attention) # Take a weighted sum of the values with respect to the attention # distributions for each element in the num_heads * batch_size dimension. # shape (num_heads * batch_size, timesteps, values_dim/num_heads) outputs = weighted_sum(values_per_head, attention) # Reshape back to original shape (batch_size, timesteps, values_dim) # shape (batch_size, num_heads, timesteps, values_dim/num_heads) outputs = outputs.view(batch_size, num_heads, timesteps, int(self._values_dim / num_heads)) # shape (batch_size, timesteps, num_heads, values_dim/num_heads) outputs = outputs.transpose(1, 2).contiguous() # shape (batch_size, timesteps, values_dim) outputs = outputs.view(batch_size, timesteps, self._values_dim) # Project back to original input size. # shape (batch_size, timesteps, input_size) outputs = self._output_projection(outputs) return outputs
def forward(self, # pylint: disable=arguments-differ embeddings: torch.FloatTensor, mask: torch.LongTensor, num_items_to_keep: int) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor, torch.FloatTensor]: """ Extracts the top-k scoring items with respect to the scorer. We additionally return the indices of the top-k in their original order, not ordered by score, so that downstream components can rely on the original ordering (e.g., for knowing what spans are valid antecedents in a coreference resolution model). Parameters ---------- embeddings : ``torch.FloatTensor``, required. A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for each item in the list that we want to prune. mask : ``torch.LongTensor``, required. A tensor of shape (batch_size, num_items), denoting unpadded elements of ``embeddings``. num_items_to_keep : ``int``, required. The number of items to keep when pruning. Returns ------- top_embeddings : ``torch.FloatTensor`` The representations of the top-k scoring items. Has shape (batch_size, num_items_to_keep, embedding_size). top_mask : ``torch.LongTensor`` The corresponding mask for ``top_embeddings``. Has shape (batch_size, num_items_to_keep). top_indices : ``torch.IntTensor`` The indices of the top-k scoring items into the original ``embeddings`` tensor. This is returned because it can be useful to retain pointers to the original items, if each item is being scored by multiple distinct scorers, for instance. Has shape (batch_size, num_items_to_keep). top_item_scores : ``torch.FloatTensor`` The values of the top-k scoring items. Has shape (batch_size, num_items_to_keep, 1). """ mask = mask.unsqueeze(-1) num_items = embeddings.size(1) # Shape: (batch_size, num_items, 1) scores = self._scorer(embeddings) if scores.size(-1) != 1 or scores.dim() != 3: raise ValueError(f"The scorer passed to Pruner must produce a tensor of shape" f"(batch_size, num_items, 1), but found shape {scores.size()}") # Make sure that we don't select any masked items by setting their scores to be very # negative. These are logits, typically, so -1e20 should be plenty negative. scores = util.replace_masked_values(scores, mask, -1e20) # Shape: (batch_size, num_items_to_keep, 1) _, top_indices = scores.topk(num_items_to_keep, 1) # Now we order the selected indices in increasing order with # respect to their indices (and hence, with respect to the # order they originally appeared in the ``embeddings`` tensor). top_indices, _ = torch.sort(top_indices, 1) # Shape: (batch_size, num_items_to_keep) top_indices = top_indices.squeeze(-1) # Shape: (batch_size * num_items_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select items for each element in the batch. flat_top_indices = util.flatten_and_batch_shift_indices(top_indices, num_items) # Shape: (batch_size, num_items_to_keep, embedding_size) top_embeddings = util.batched_index_select(embeddings, top_indices, flat_top_indices) # Shape: (batch_size, num_items_to_keep) top_mask = util.batched_index_select(mask, top_indices, flat_top_indices) # Shape: (batch_size, num_items_to_keep, 1) top_scores = util.batched_index_select(scores, top_indices, flat_top_indices) return top_embeddings, top_mask.squeeze(-1), top_indices, top_scores
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor, targets: torch.LongTensor, weights: torch.FloatTensor, batch_average: bool = True, label_smoothing: float = None) -> torch.FloatTensor: """ Computes the cross entropy loss of a sequence, weighted with respect to some user provided weights. Note that the weighting here is not the same as in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting classes; here we are weighting the loss contribution from particular elements in the sequence. This allows loss computations for models which use padding. Parameters ---------- logits : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes) which contains the unnormalized probability for each class. targets : ``torch.LongTensor``, required. A ``torch.LongTensor`` of size (batch, sequence_length) which contains the index of the true class for each corresponding step. weights : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch, sequence_length) batch_average : bool, optional, (default = True). A bool indicating whether the loss should be averaged across the batch, or returned as a vector of losses per batch element. label_smoothing : ``float``, optional (default = None) Whether or not to apply label smoothing to the cross-entropy loss. For example, with a label smoothing value of 0.2, a 4 class classifcation target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was the correct label. Returns ------- A torch.FloatTensor representing the cross entropy loss. If ``batch_average == True``, the returned loss is a scalar. If ``batch_average == False``, the returned loss is a vector of shape (batch_size,). """ # shape : (batch * sequence_length, num_classes) logits_flat = logits.view(-1, logits.size(-1)) # shape : (batch * sequence_length, num_classes) log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1) # shape : (batch * max_len, 1) targets_flat = targets.view(-1, 1).long() if label_smoothing is not None and label_smoothing > 0.0: num_classes = logits.size(-1) smoothing_value = label_smoothing / num_classes # Fill all the correct indices with 1 - smoothing value. one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing) smoothed_targets = one_hot_targets + smoothing_value negative_log_likelihood_flat = - log_probs_flat * smoothed_targets negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True) else: # Contribution to the negative log likelihood only comes from the exact indices # of the targets, as the target distributions are one-hot. Here we use torch.gather # to extract the indices of the num_classes dimension which contribute to the loss. # shape : (batch * sequence_length, 1) negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size()) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood * weights.float() # shape : (batch_size,) per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13) if batch_average: num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13) return per_batch_loss.sum() / num_non_empty_sequences return per_batch_loss
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], verb_indicator: torch.LongTensor, tags: torch.LongTensor = None, training: bool = False, # added by ph to make function consistent with other model metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. verb_indicator: torch.LongTensor, required. An integer ``SequenceFeatureField`` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) metadata containing the original words in the sentence and the verb to compute the frame for, under 'words' and 'verb' keys, respectively. training : added by ph to make function consistent with other model - does nothing Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing unnormalised log probabilities of the tag classes. class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing a distribution of the tag classes per word. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ # added by ph tokens['tokens'] = tokens['tokens'].cuda() verb_indicator = verb_indicator.cuda() if tags is not None: tags = tags.cuda() embedded_text_input = self.embedding_dropout(self.text_field_embedder(tokens)) mask = get_text_field_mask(tokens) embedded_verb_indicator = self.binary_feature_embedding(verb_indicator.long()) # Concatenate the verb feature onto the embedded text. This now # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim). embedded_text_with_verb_indicator = torch.cat([embedded_text_input, embedded_verb_indicator], -1) batch_size, sequence_length, _ = embedded_text_with_verb_indicator.size() encoded_text = self.encoder(embedded_text_with_verb_indicator, mask) logits = self.tag_projection_layer(encoded_text) reshaped_log_probs = logits.view(-1, self.num_classes) class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view([batch_size, sequence_length, self.num_classes]) output_dict = {"logits": logits, "class_probabilities": class_probabilities, "mask": mask} # We need to retain the mask in the output dictionary # so that we can crop the sequences to remove padding # when we do viterbi inference in self.decode. if tags is not None: loss = sequence_cross_entropy_with_logits(logits, tags, mask, label_smoothing=self._label_smoothing) output_dict["loss"] = loss # added by ph output_dict['softmax_3d'] = class_probabilities.detach().cpu().numpy() return output_dict
def forward( self, # type: ignore utterance: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]], world: List[SpiderWorld], schema: Dict[str, torch.LongTensor], action_sequence: torch.LongTensor = None ) -> Dict[str, torch.Tensor]: """KAIMARY""" # utterance # utterance:{'tokens': tensor([[ 6, 8, 2, 149, 46, 3, 14, 2, 290, 149, 98, 4, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ...])} # valid_actions # formatted as: [ProductionRule(rule='arg_list -> [expr, ",", arg_list]', is_global_rule=True, rule_id=tensor([0]), nonterminal='arg_list'), ...] # schema # schema:{'text': {'tokens': tensor([[[ 519, 35, 0, 0], # [ 149, 35, 0, 0], # [ 519, 35, 0, 0], # ..., # [ 0, 0, 0, 0], # [ 0, 0, 0, 0], # [ 0, 0, 0, 0]], ...] # 'linking': tensor([[[[ 0.0000, 0.0000, 0.0000, ..., -4.2500, 0.0000, 0.0000], # [ 0.0000, 0.0000, 0.0000, ..., -5.6667, 0.0000, 0.0000], # [ 0.0000, 0.0000, 0.0000, ..., -6.0000, 0.0000, 0.0000], ...]]]} # action_sequences """action_sequence:tensor([[[137], [118], [119], [163], [ 69], [157], [ 51], [ 12], [ 70], [161], [138], [ -1], [ -1], [ -1]]""" batch_size = len(world) device = utterance['tokens'].device initial_state = self._get_initial_state(utterance, world, schema, valid_actions) if action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). action_sequence = action_sequence.squeeze(-1) action_mask = action_sequence != self._action_padding_index else: action_mask = None if self.training: decode_output = self._decoder_trainer.decode( initial_state, self._transition_function, (action_sequence.unsqueeze(1), action_mask.unsqueeze(1))) return {'loss': decode_output['loss']} else: loss = torch.tensor([0]).float().to(device) if action_sequence is not None and action_sequence.size(1) > 1: try: loss = self._decoder_trainer.decode( initial_state, self._transition_function, (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))['loss'] except ZeroDivisionError: # reached a dead-end during beam search pass outputs: Dict[str, Any] = {'loss': loss} num_steps = self._max_decoding_steps # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search( num_steps, initial_state, self._transition_function, keep_final_unfinished_states=False) self._compute_validation_outputs(valid_actions, best_final_states, world, action_sequence, outputs) return outputs