def forward( # type: ignore self, tokens: TextFieldTensors, target_ids: TextFieldTensors = None ) -> Dict[str, torch.Tensor]: # Shape: (batch_size, num_tokens, embedding_dim) embeddings = self._text_field_embedder(tokens) batch_size = embeddings.size(0) # Shape: (batch_size, num_tokens, encoding_dim) if self._contextualizer: mask = util.get_text_field_mask(embeddings) contextual_embeddings = self._contextualizer(embeddings, mask) final_embeddings = util.get_final_encoder_states(contextual_embeddings, mask) else: final_embeddings = embeddings[:, -1] target_logits = self._language_model_head(self._dropout(final_embeddings)) vocab_size = target_logits.size(-1) probs = torch.nn.functional.softmax(target_logits, dim=-1) k = min(vocab_size, 5) # min here largely because tests use small vocab top_probs, top_indices = probs.topk(k=k, dim=-1) output_dict = {"probabilities": top_probs, "top_indices": top_indices} output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(tokens) if target_ids is not None: targets = util.get_token_ids_from_text_field_tensors(target_ids).view(batch_size) target_logits = target_logits.view(batch_size, vocab_size) loss = torch.nn.functional.cross_entropy(target_logits, targets) self._perplexity(loss) output_dict["loss"] = loss return output_dict
def forward( self, text: TextFieldTensors, masked_text: Optional[TextFieldTensors] = None, masked_positions: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: # type: ignore if len(text) != 1: raise ValueError( "PretrainedTransformerBackbone is only compatible with using a single TokenIndexer" ) mask = util.get_text_field_mask(text) encoded_text = self.embedder(text) outputs = { "encoded_text": encoded_text, "encoded_text_mask": mask, "token_ids": util.get_token_ids_from_text_field_tensors(text), } if masked_text is not None and masked_positions is not None: masked_text_mask = util.get_text_field_mask(masked_text) encoded_masked_text = self.embedder(text) outputs["masked_positions"] = masked_positions, outputs["encoded_masked_text"] = encoded_masked_text outputs["encoded_masked_text_mask"] = masked_text_mask return outputs
def copy_reference_policy(self, timestep, last_predictions: torch.LongTensor, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor], ) -> torch.FloatTensor: targets = util.get_token_ids_from_text_field_tensors(target_tokens) seq_len = targets.size(1) batch_size = last_predictions.shape[0] if seq_len > timestep + 1: # + 1 because timestep is an index, indexed at 0. # As we might be overriding the next/predicted token/ # We have to use the value corresponding to {t+1}^{th} # timestep. target_at_timesteps = targets[:, timestep + 1] else: # We have overshot the seq_len, so just repeat the # last token which is either _end_token or _pad_token. target_at_timesteps = targets[:, -1] # TODO: Add support to allow other types of reference policies. # target_logits: (batch_size, num_classes). # This tensor has 0 at targets and (near) -inf at other places. target_logits = (target_at_timesteps.new_zeros((batch_size, self._num_classes)) + 1e-45) \ .scatter_(dim=1, index=target_at_timesteps.unsqueeze(1), value=1.0).log() return target_logits, state
def _evaluate(self, tokens, eval_mask): transformer_input = self._adapt_for_transformer(tokens) logits, *_ = self.model(**transformer_input) token_ids = util.get_token_ids_from_text_field_tensors(tokens) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) token_log_likelihood = log_probs[:,:-1].gather(-1, token_ids[:,1:].unsqueeze(-1)).squeeze(-1) suffix_log_likelihood = (eval_mask[:,1:] * token_log_likelihood).sum(-1) return token_log_likelihood, suffix_log_likelihood
def forward( # type: ignore self, text: TextFieldTensors) -> TaskOutput: mask = get_text_field_mask(text) contextual_embeddings = self.backbone.forward(text, mask) token_ids = get_token_ids_from_text_field_tensors(text) assert isinstance(contextual_embeddings, torch.Tensor) # Use token_ids to compute targets # targets are next token ids with respect to first token in the seq # e.g. token_ids [[1, 3, 5, 7],..[]], forward_targets=[[3,5,7],..] forward_targets = torch.zeros_like(token_ids) forward_targets[:, 0:-1] = token_ids[:, 1:] if self.bidirectional: backward_targets = torch.zeros_like(token_ids) backward_targets[:, 1:] = token_ids[:, 0:-1] else: backward_targets = None # add dropout contextual_embeddings_with_dropout = self._dropout( contextual_embeddings) # compute softmax loss try: forward_loss, backward_loss = self._compute_loss( contextual_embeddings_with_dropout, forward_targets, backward_targets) except IndexError: raise IndexError( "Word token out of vocabulary boundaries, please check your vocab is correctly set" " or created before starting training.") num_targets = torch.sum((forward_targets > 0).long()) if num_targets > 0: if self.bidirectional: average_loss = (0.5 * (forward_loss + backward_loss) / num_targets.float()) else: average_loss = forward_loss / num_targets.float() else: average_loss = torch.tensor(0.0).to(forward_targets.device) for metric in self.metrics.values(): metric(average_loss) return TaskOutput( logits=None, probs=None, loss=average_loss, **{ "lm_embeddings": contextual_embeddings, "mask": mask }, )
def forward( # type: ignore self, tokens: TextFieldTensors, labels: torch.IntTensor = None) -> Dict[str, torch.Tensor]: """ # Parameters tokens : `TextFieldTensors` From a `TextField` labels : `torch.IntTensor`, optional (default = `None`) From a `MultiLabelField` # Returns An output dictionary consisting of: - `logits` (`torch.FloatTensor`) : A tensor of shape `(batch_size, num_labels)` representing unnormalized log probabilities of the label. - `probs` (`torch.FloatTensor`) : A tensor of shape `(batch_size, num_labels)` representing probabilities of the label. - `loss` : (`torch.FloatTensor`, optional) : A scalar loss to be optimised. """ embedded_text = self._text_field_embedder(tokens) mask = get_text_field_mask(tokens) if self._seq2seq_encoder: embedded_text = self._seq2seq_encoder(embedded_text, mask=mask) embedded_text = self._seq2vec_encoder(embedded_text, mask=mask) if self._dropout: embedded_text = self._dropout(embedded_text) if self._feedforward is not None: embedded_text = self._feedforward(embedded_text) logits = self._classification_layer(embedded_text) probs = torch.sigmoid(logits) output_dict = {"logits": logits, "probs": probs} output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors( tokens) if labels is not None: loss = self._loss(logits, labels.float().view(-1, self._num_labels)) output_dict["loss"] = loss # TODO (John): This shouldn't be necessary as __call__ of the metrics detaches these # tensors anyways? cloned_logits, cloned_labels = logits.clone(), labels.clone() self._micro_f1(cloned_logits, cloned_labels) self._macro_f1(cloned_logits, cloned_labels) return output_dict
def rollin_policy(self, timestep: int, last_predictions: torch.LongTensor, target_tokens: Dict[str, torch.Tensor] = None, rollin_mode = None) -> torch.LongTensor: """ Roll-in policy to use. This takes in targets, timestep and last_predictions, and decide which to use for taking next step i.e., generating next token. What to do is decided by rolling mode. Options are - teacher_forcing, - learned, - mixed, By default the mode is mixed with scheduled_sampling_ratio=0.0. This defaults to teacher_forcing. You can also explicitly run with teacher_forcing mode. Arguments: timestep {int} -- Current timestep decides which target token to use. In case of teacher_forcing this is usually {t-1}^{th} timestep for predicting t^{th} token. last_predictions {torch.LongTensor} -- {t-1}^th token predicted by the model. Keyword Arguments: targets {torch.LongTensor} -- Targets value if it is available. This will be available in training mode but not in inference mode. (default: {None}) rollin_mode {str} -- Rollin mode. Options are teacher_forcing, learned, scheduled-sampling (default: {'teacher_forcing'}) Returns: torch.LongTensor -- The method returns input token for predicting next token. """ rollin_mode = rollin_mode or self._rollin_mode # For first timestep, you are passing start token, so don't do anything smart. if (timestep == 0 or # If no targets, no way to do teacher_forcing, so use your own predictions. target_tokens is None or rollin_mode == 'learned'): # shape: (batch_size,) return last_predictions targets = util.get_token_ids_from_text_field_tensors(target_tokens) if rollin_mode == 'teacher_forcing': # shape: (batch_size,) input_choices = targets[:, timestep] elif rollin_mode == 'mixed': if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - self._scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] else: raise ConfigurationError(f"invalid configuration for rollin policy: {rollin_mode}") return input_choices
def forward( self, sentence1: TextFieldTensors, sentence2: TextFieldTensors, label: torch.IntTensor, ) -> Dict[str, torch.Tensor]: embedded_sentence1 = self._text_field_embedder(sentence1) embedded_sentence2 = self._text_field_embedder(sentence2) sentence1_mask = get_text_field_mask(sentence1) sentence2_mask = get_text_field_mask(sentence2) if self._seq2seq_encoder: embedded_sentence1 = self._seq2seq_encoder(embedded_sentence1, mask=sentence1_mask) embedded_sentence2 = self._seq2seq_encoder(embedded_sentence2, mask=sentence2_mask) embedded_sentence1 = self._seq2vec_encoder(embedded_sentence1, mask=sentence1_mask) embedded_sentence2 = self._seq2vec_encoder(embedded_sentence2, mask=sentence2_mask) pair_vec = self._pair_vec_to_vec(embedded_sentence1, embedded_sentence2) if self._dropout: pair_vec = self._dropout(pair_vec) logits = self._classification_layer(pair_vec) probs = torch.softmax(logits, dim=-1) output_dict = { "logits": logits, "probs": probs, "sentence1_token_ids": get_token_ids_from_text_field_tensors(sentence1), "sentence2_token_ids": get_token_ids_from_text_field_tensors(sentence2), } if label is not None: loss = self._loss(logits, label.long().view(-1)) output_dict["loss"] = loss self._accuracy(logits, label) return output_dict
def forward(self, text: TextFieldTensors) -> Dict[str, Any]: # type: ignore mask = get_text_field_mask(text) contextual_embeddings = self.backbone.forward(text, mask) token_ids = get_token_ids_from_text_field_tensors(text) assert isinstance(contextual_embeddings, torch.Tensor) # Use token_ids to compute targets # targets are next token ids with respect to first token in the seq # e.g. token_ids [[1, 3, 5, 7],..[]], forward_targets=[[3,5,7],..] forward_targets = torch.zeros_like(token_ids) forward_targets[:, 0:-1] = token_ids[:, 1:] if self.bidirectional: backward_targets = torch.zeros_like(token_ids) backward_targets[:, 1:] = token_ids[:, 0:-1] else: backward_targets = None # add dropout contextual_embeddings_with_dropout = self._dropout( contextual_embeddings) # compute softmax loss try: forward_loss, backward_loss = self._compute_loss( contextual_embeddings_with_dropout, forward_targets, backward_targets) except IndexError: raise IndexError( "Word token out of vocabulary boundaries, please check your vocab is correctly set" " or created before starting training.") num_targets = torch.sum((forward_targets > 0).long()) if num_targets > 0: if self.bidirectional: average_loss = (0.5 * (forward_loss + backward_loss) / num_targets.float()) else: average_loss = forward_loss / num_targets.float() else: average_loss = torch.tensor(0.0) for metric in self._metrics.get_dict(is_train=self.training).values(): # Perplexity needs the value to be on the cpu metric(average_loss.to("cpu")) return dict( loss=average_loss, lm_embeddings=contextual_embeddings, mask=mask, )
def forward( # type: ignore self, tokens: TextFieldTensors, label: torch.IntTensor = None, metadata: MetadataField = None, ) -> Dict[str, torch.Tensor]: """ # Parameters tokens : `TextFieldTensors` From a `TextField` label : `torch.IntTensor`, optional (default = `None`) From a `LabelField` # Returns An output dictionary consisting of: - `logits` (`torch.FloatTensor`) : A tensor of shape `(batch_size, num_labels)` representing unnormalized log probabilities of the label. - `probs` (`torch.FloatTensor`) : A tensor of shape `(batch_size, num_labels)` representing probabilities of the label. - `loss` : (`torch.FloatTensor`, optional) : A scalar loss to be optimised. """ embedded_text = self._text_field_embedder(tokens) mask = get_text_field_mask(tokens) if self._seq2seq_encoder: embedded_text = self._seq2seq_encoder(embedded_text, mask=mask) embedded_text = self._seq2vec_encoder(embedded_text, mask=mask) if self._dropout: embedded_text = self._dropout(embedded_text) if self._feedforward is not None: embedded_text = self._feedforward(embedded_text) logits = self._classification_layer(embedded_text) probs = torch.nn.functional.softmax(logits, dim=-1) output_dict = {"logits": logits, "probs": probs} output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors( tokens) if label is not None: loss = self._loss(logits, label.long().view(-1)) output_dict["loss"] = loss self._accuracy(logits, label) return output_dict
def forward(self, text: TextFieldTensors) -> Dict[str, torch.Tensor]: # type: ignore bert_output = self._embed(text) outputs = { "encoded_text": bert_output['orig_embeddings'], "encoded_text_mask": bert_output['orig_mask'], "wordpiece_encoded_text": bert_output['wordpiece_embeddings'], "wordpiece_encoded_text_mask": bert_output['wordpiece_mask'], "token_ids": util.get_token_ids_from_text_field_tensors(text), } self._extend_with_masked_text(outputs, text) return outputs
def _compute_rollin_loss_batch(self, rollin_output_dict: Dict[str, torch.Tensor], state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.Tensor]) -> torch.FloatTensor: logits = rollin_output_dict['logits'] targets = util.get_token_ids_from_text_field_tensors(target_tokens) # shape: (batch_size, num_decoding_steps) best_logits = logits[:, 0, :, :].squeeze(1) target_masks = util.get_text_field_mask(target_tokens) # Compute loss. loss_batch = self._get_cross_entropy_loss(best_logits, targets, target_masks) return loss_batch
def compute_sentence_probs(self, sequences_dict: Dict[str, torch.LongTensor], ) -> torch.FloatTensor: """ Given a batch of tokens, compute the per-token log probability of sequences given the trained model. Arguments: sequences_dict {Dict[str, torch.LongTensor]} -- The sequences that needs to be scored. Returns: seq_probs {torch.FloatTensor} -- Probabilities of the sequence. seq_lens {torch.LongTensor} -- Length of the non padded sequence. per_step_seq_probs {torch.LongTensor} -- Probability of per prediction in a sequence """ state = {} sequences = util.get_token_ids_from_text_field_tensors(sequences_dict) batch_size = sequences.size(0) seq_len = sequences.size(1) start_predictions = self._get_start_predictions(state, sequences_dict, batch_size) # We are now computing probability considering given the sequence, # So, we will use rollin_mode=teacher_forcing as we want to select # token from the sequences for which we need to compute the probability. rollin_output_dict = self.rollin(state={}, start_predictions=start_predictions, rollin_steps=seq_len - 1, target_tokens=sequences_dict, rollin_mode='teacher_forcing', ) step_log_probs = F.log_softmax(rollin_output_dict['logits'].squeeze(1), dim=-1) per_step_seq_probs = torch.gather(step_log_probs, 2, sequences[:,1:].unsqueeze(2)) \ .squeeze(2) sequence_mask = util.get_text_field_mask(sequences_dict) per_step_seq_probs_summed = torch.sum(per_step_seq_probs * sequence_mask[:, 1:], dim=-1) non_batch_dims = tuple(range(1, len(sequence_mask.shape))) # shape : (batch_size,) sequence_mask_sum = sequence_mask[:, 1:].sum(dim=non_batch_dims) # (seq_probs, seq_lens, per_step_seq_probs) return torch.exp(per_step_seq_probs_summed/sequence_mask_sum), \ sequence_mask_sum, \ torch.exp(per_step_seq_probs)
def forward( self, text: TextFieldTensors) -> Dict[str, torch.Tensor]: # type: ignore if len(text) != 1: raise ValueError( "PretrainedTransformerBackbone is only compatible with using a single TokenIndexer" ) text_inputs = next(iter(text.values())) mask = util.get_text_field_mask(text) encoded_text = self._embedder(**text_inputs) outputs = {"encoded_text": encoded_text, "encoded_text_mask": mask} if self._output_token_strings: outputs["token_ids"] = util.get_token_ids_from_text_field_tensors( text) return outputs
def forward( # type: ignore self, text: TextFieldTensors) -> TaskOutput: mask = get_text_field_mask(text) contextual_embeddings = self.backbone.forward(text, mask) # NOTE: @dvsrepo, Allennlp 1.0 includes a second features level that I'm not sure of understand. # Anyway, they proved a function to realize the target here (the function docstring clarifies the # real spaghetti inside indexer code references, :-) token_ids = get_token_ids_from_text_field_tensors(text) assert isinstance(contextual_embeddings, torch.Tensor) # Use token_ids to compute targets # targets are next token ids with respect to first token in the seq # e.g. token_ids [[1, 3, 5, 7],..[]], forward_targets=[[3,5,7],..] forward_targets = torch.zeros_like(token_ids) forward_targets[:, 0:-1] = token_ids[:, 1:] # add dropout contextual_embeddings_with_dropout = self._dropout( contextual_embeddings) # compute softmax loss try: forward_loss = self._compute_loss( contextual_embeddings_with_dropout, forward_targets) except IndexError: raise IndexError( "Word token out of vocabulary boundaries, please check your vocab is correctly set" " or created before starting training.") num_targets = torch.sum((forward_targets > 0).long()) if num_targets > 0: average_loss = forward_loss / num_targets.float() else: average_loss = torch.tensor(0.0).to(forward_targets.device) for metric in self.metrics.values(): metric(average_loss) return TaskOutput(logits=None, probs=None, loss=average_loss, **{ "lm_embeddings": contextual_embeddings, "mask": mask })
def forward( self, encoder_out: Dict[str, torch.LongTensor], target_tokens: TextFieldTensors = None, ) -> Dict[str, torch.Tensor]: state = encoder_out decoder_init_state = self._decoder_net.init_decoder_state(state) state.update(decoder_init_state) if target_tokens: state_forward_loss = (state if self.training else {k: v.clone() for k, v in state.items()}) output_dict = self._forward_loss(state_forward_loss, target_tokens) else: output_dict = {} if not self.training: predictions = self._forward_beam_search(state) output_dict.update(predictions) if target_tokens: targets = util.get_token_ids_from_text_field_tensors( target_tokens) if self._tensor_based_metric is not None: # shape: (batch_size, beam_size, max_sequence_length) top_k_predictions = output_dict["predictions"] # shape: (batch_size, max_predicted_sequence_length) best_predictions = top_k_predictions[:, 0, :] self._tensor_based_metric( # type: ignore best_predictions, targets) if self._token_based_metric is not None: output_dict = self.post_process(output_dict) predicted_tokens = output_dict["predicted_tokens"] self._token_based_metric( # type: ignore predicted_tokens, self.indices_to_tokens(targets[:, 1:]), ) return output_dict
def forward( self, transactions: TextFieldTensors, label: Optional[torch.Tensor] = None, amounts: Optional[TextFieldTensors] = None, **kwargs, ) -> Dict[str, torch.Tensor]: emb_out = self.get_transaction_embeddings(transactions) output_dict = self.forward_on_transaction_embeddings( transaction_embeddings=emb_out["transaction_embeddings"], mask=emb_out["mask"], label=label, amounts=amounts, ) output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(transactions) return output_dict
def _get_start_predictions(self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None, generation_batch_size:int = None) -> torch.LongTensor: if self._seq2seq_mode: source_mask = state["source_mask"] batch_size = source_mask.size()[0] elif target_tokens: targets = util.get_token_ids_from_text_field_tensors(target_tokens) batch_size = targets.size(0) else: batch_size = generation_batch_size # Initialize target predictions with the start index. # shape: (batch_size,) return torch.zeros((batch_size,), dtype=torch.long, device=self.current_device) \ .fill_(self._start_index)
def forward( # type: ignore self, tokens: TextFieldTensors, mask_positions: torch.BoolTensor, target_ids: TextFieldTensors = None, ) -> Dict[str, torch.Tensor]: """ # Parameters tokens : `TextFieldTensors` The output of `TextField.as_tensor()` for a batch of sentences. mask_positions : `torch.LongTensor` The positions in `tokens` that correspond to [MASK] tokens that we should try to fill in. Shape should be (batch_size, num_masks). target_ids : `TextFieldTensors` This is a list of token ids that correspond to the mask positions we're trying to fill. It is the output of a `TextField`, purely for convenience, so we can handle wordpiece tokenizers and such without having to do crazy things in the dataset reader. We assume that there is exactly one entry in the dictionary, and that it has a shape identical to `mask_positions` - one target token per mask position. """ targets = None if target_ids is not None: targets = util.get_token_ids_from_text_field_tensors(target_ids) mask_positions = mask_positions.squeeze(-1) batch_size, num_masks = mask_positions.size() if targets is not None and targets.size() != mask_positions.size(): raise ValueError( f"Number of targets ({targets.size()}) and number of masks " f"({mask_positions.size()}) are not equal") # Shape: (batch_size, num_tokens, embedding_dim) embeddings = self._text_field_embedder(tokens) # Shape: (batch_size, num_tokens, encoding_dim) if self._contextualizer: mask = util.get_text_field_mask(embeddings) contextual_embeddings = self._contextualizer(embeddings, mask) else: contextual_embeddings = embeddings # Does advanced indexing to get the embeddings of just the mask positions, which is what # we're trying to predict. batch_index = torch.arange(0, batch_size).long().unsqueeze(1) mask_embeddings = contextual_embeddings[batch_index, mask_positions] target_logits = self._language_model_head( self._dropout(mask_embeddings)) vocab_size = target_logits.size(-1) probs = torch.nn.functional.softmax(target_logits, dim=-1) k = min(vocab_size, 5) # min here largely because tests use small vocab top_probs, top_indices = probs.topk(k=k, dim=-1) output_dict = {"probabilities": top_probs, "top_indices": top_indices} output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors( tokens) if targets is not None: target_logits = target_logits.view(batch_size * num_masks, vocab_size) targets = targets.view(batch_size * num_masks) loss = torch.nn.functional.cross_entropy(target_logits, targets) self._perplexity(loss) output_dict["loss"] = loss return output_dict
def rollout(self, state: Dict[str, torch.Tensor], start_predictions: torch.LongTensor, rollout_steps: int, beam_size: int = None, per_node_beam_size: int = None, target_tokens: Dict[str, torch.LongTensor] = None, sampled: bool = True, truncate_at_end_all: bool = True, # shape (prediction_prefixes): (batch_size, prefix_length) prediction_prefixes: torch.LongTensor = None, target_prefixes: torch.LongTensor = None, rollout_mixing_func: RolloutMixingProbFuncType = None, reference_policy_type:str = "copy", rollout_mode: str = None, ): state['rollout_params'] = {} if reference_policy_type == 'oracle': reference_policy = partial(self.oracle_reference_policy, token_to_idx=self._vocab._token_to_index['target_tokens'], idx_to_token=self._vocab._index_to_token['target_tokens'], ) num_steps_to_take = rollout_steps state['rollout_params']['rollout_prefixes'] = prediction_prefixes else: reference_policy = partial(self.copy_reference_policy, target_tokens=target_tokens) num_steps_to_take = rollout_steps rollout_policy = partial(self.rollout_policy, rollout_mode=rollout_mode, rollout_mixing_func=rollout_mixing_func, reference_policy=reference_policy, ) rolling_policy=partial(self.take_step, rollout_policy=rollout_policy) # shape (step_predictions): (batch_size, beam_size, num_decoding_steps) # shape (log_probabilities): (batch_size, beam_size) # shape (logits): (batch_size, beam_size, num_decoding_steps, num_classes) step_predictions, log_probabilities, logits = \ self._beam_search.search(start_predictions, state, rolling_policy, max_steps=num_steps_to_take, beam_size=beam_size, per_node_beam_size=per_node_beam_size, sampled=sampled, truncate_at_end_all=truncate_at_end_all) logits = torch.cat(logits, dim=2) # Concatenate the start tokens to the predictions.They are not # added to the predictions by default. batch_size, beam_size, _ = step_predictions.shape start_prediction_length = start_predictions.size(0) step_predictions = torch.cat([start_predictions.unsqueeze(1) \ .expand(batch_size, beam_size) \ .reshape(batch_size, beam_size, 1), step_predictions], dim=-1) # There might be some predictions which might have been made by # rollin policy. If passed, concatenate them here. if prediction_prefixes is not None: prefixes_length = prediction_prefixes.size(1) step_predictions = torch.cat([prediction_prefixes.unsqueeze(1)\ .expand(batch_size, beam_size, prefixes_length), step_predictions], dim=-1) step_prediction_masks = self._get_mask(step_predictions \ .reshape(batch_size * beam_size, -1)) \ .reshape(batch_size, beam_size, -1) output_dict = { "predictions": step_predictions, "prediction_masks": step_prediction_masks, "logits": logits, "class_log_probabilities": log_probabilities, } step_targets = None step_target_masks = None if target_tokens is not None: step_targets = util.get_token_ids_from_text_field_tensors(target_tokens) if target_prefixes is not None: prefixes_length = target_prefixes.size(1) step_targets = torch.cat([target_prefixes, step_targets], dim=-1) step_target_masks = util.get_text_field_mask({'tokens': {'tokens': step_targets}}) output_dict.update({ "targets": step_targets, "target_masks": step_target_masks, }) return output_dict
def forward( # type: ignore self, question_with_context: Dict[str, Dict[str, torch.LongTensor]], context_span: torch.IntTensor, answer_span: Optional[torch.IntTensor] = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Parameters ---------- question_with_context : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this text field contains the context followed by the question. It further assumes that the tokens have type ids set such that any token that can be part of the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and [SEP]) has type id 1. context_span : ``torch.IntTensor`` From a ``SpanField``. This marks the span of word pieces in ``question`` from which answers can come. answer_span : ``torch.IntTensor``, optional From a ``SpanField``. This is the thing we are trying to predict - the span of text that marks the answer. If given, we compute a loss that gets included in the output directory. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question id, and the original texts of context, question, tokenized version of both, and a list of possible answers. The length of the ``metadata`` list should be the batch size, and each dictionary should have the keys ``id``, ``question``, ``context``, ``question_tokens``, ``context_tokens``, and ``answers``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. best_span_scores : torch.FloatTensor The score for each of the best spans. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._text_field_embedder(question_with_context) logits = self._linear_layer(embedded_question) span_start_logits, span_end_logits = logits.split(1, dim=-1) span_start_logits = span_start_logits.squeeze(-1) span_end_logits = span_end_logits.squeeze(-1) possible_answer_mask = torch.zeros_like( get_token_ids_from_text_field_tensors(question_with_context)) for i, (start, end) in enumerate(context_span): possible_answer_mask[i, start:end + 1] = 1 span_start_logits = util.replace_masked_values(span_start_logits, possible_answer_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, possible_answer_mask, -1e32) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_spans = get_best_span(span_start_logits, span_end_logits) best_span_scores = torch.gather( span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) + torch.gather( span_end_logits, 1, best_spans[:, 1].unsqueeze(1)) best_span_scores = best_span_scores.squeeze(1) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_spans, "best_span_scores": best_span_scores, } # Compute the loss for training. if answer_span is not None: span_start = answer_span[:, 0] span_end = answer_span[:, 1] span_mask = span_start != -1 self._span_accuracy(best_spans, answer_span, span_mask.unsqueeze(-1).expand_as(best_spans)) start_loss = cross_entropy(span_start_logits, span_start, ignore_index=-1) if torch.any(start_loss > 1e9): logger.critical("Start loss too high (%r)", start_loss) logger.critical("span_start_logits: %r", span_start_logits) logger.critical("span_start: %r", span_start) assert False end_loss = cross_entropy(span_end_logits, span_end, ignore_index=-1) if torch.any(end_loss > 1e9): logger.critical("End loss too high (%r)", end_loss) logger.critical("span_end_logits: %r", span_end_logits) logger.critical("span_end: %r", span_end) assert False loss = (start_loss + end_loss) / 2 self._span_start_accuracy(span_start_logits, span_start, span_mask) self._span_end_accuracy(span_end_logits, span_end, span_mask) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: best_spans = best_spans.detach().cpu().numpy() output_dict["best_span_str"] = [] context_tokens = [] for metadata_entry, best_span in zip(metadata, best_spans): context_tokens_for_question = metadata_entry["context_tokens"] context_tokens.append(context_tokens_for_question) best_span -= 1 + len(metadata_entry["question_tokens"]) + 2 assert np.all(best_span >= 0) predicted_start, predicted_end = tuple(best_span) while (predicted_start >= 0 and context_tokens_for_question[predicted_start].idx is None): predicted_start -= 1 if predicted_start < 0: logger.warning( f"Could not map the token '{context_tokens_for_question[best_span[0]].text}' at index " f"'{best_span[0]}' to an offset in the original text.") character_start = 0 else: character_start = context_tokens_for_question[ predicted_start].idx while (predicted_end < len(context_tokens_for_question) and context_tokens_for_question[predicted_end].idx is None): predicted_end += 1 if predicted_end >= len(context_tokens_for_question): logger.warning( f"Could not map the token '{context_tokens_for_question[best_span[1]].text}' at index " f"'{best_span[1]}' to an offset in the original text.") character_end = len(metadata_entry["context"]) else: end_token = context_tokens_for_question[predicted_end] character_end = end_token.idx + len( sanitize_wordpiece(end_token.text)) best_span_string = metadata_entry["context"][ character_start:character_end] output_dict["best_span_str"].append(best_span_string) answers = metadata_entry.get("answers") if len(answers) > 0: self._per_instance_metrics(best_span_string, answers) output_dict["context_tokens"] = context_tokens return output_dict
def forward(self, # type: ignore encoder_out: Dict[str, torch.LongTensor] = {}, target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Make foward pass with decoder logic for producing the entire target sequence. Parameters ---------- target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) Output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. source_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. Returns ------- Dict[str, torch.Tensor] """ output_dict: Dict[str, torch.Tensor] = {} state: Dict[str, torch.Tensor] = {} decoder_init_state: Dict[str, torch.Tensor] = {} state.update(copy.copy(encoder_out)) # In Seq2Seq setting, we will encode the source sequence, # and init the state object with encoder output and decoder # cell will use these encoder outputs for attention/initing # the decoder states. if self._seq2seq_mode: decoder_init_state = \ self._decoder_net.init_decoder_state(state) state.update(decoder_init_state) # Initialize target predictions with the start index. # shape: (batch_size,) start_predictions: torch.LongTensor = \ self._get_start_predictions(state, target_tokens, self._generation_batch_size) # In case we have target_tokens, roll-in and roll-out # only till those many steps, otherwise we roll-out for # `self._max_decoding_steps`. if target_tokens: # shape: (batch_size, max_target_sequence_length) targets: torch.LongTensor = \ util.get_token_ids_from_text_field_tensors(target_tokens) _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps: int = target_sequence_length - 1 else: num_decoding_steps: int = self._max_decoding_steps if target_tokens: decoder_output_dict, rollin_dict, rollout_dict_iter = \ self._forward_loop( state=state, start_predictions=start_predictions, num_decoding_steps=num_decoding_steps, target_tokens=target_tokens) output_dict.update(decoder_output_dict) predictions = decoder_output_dict['predictions'] predicted_tokens = self._decode_tokens(predictions, vocab_namespace=self._target_namespace, truncate=True) output_dict["decoded_predictions"] = predicted_tokens decoded_targets = self._decode_tokens(targets, vocab_namespace=self._target_namespace, truncate=True) output_dict["decoded_targets"] = decoded_targets output_dict.update(self._loss_criterion( rollin_output_dict=rollin_dict, rollout_output_dict_iter=rollout_dict_iter, state=state, target_tokens=target_tokens)) mle_loss_output = self._mle_loss( rollin_output_dict=rollin_dict, rollout_output_dict_iter=rollout_dict_iter, state=state, target_tokens=target_tokens) mle_loss = mle_loss_output['loss'] self._perplexity(mle_loss) if not self.training: # While validating or testing we need to roll out the learned policy and the output # of this rollout is used to compute the secondary metrics # like BLEU. state: Dict[str, torch.Tensor] = {} state.update(copy.copy(encoder_out)) state.update(decoder_init_state) rollout_output_dict = self.rollout(state, start_predictions, rollout_steps=num_decoding_steps, rollout_mode='learned', sampled=self._sample_rollouts, beam_size=self._eval_beam_size, # TODO #6 (Kushal): Add a reason why truncate_at_end_all is False here. truncate_at_end_all=False) output_dict.update(rollout_output_dict) predictions = decoder_output_dict['predictions'] predicted_tokens = self._decode_tokens(predictions, vocab_namespace=self._target_namespace, truncate=True) output_dict["decoded_predictions"] = predicted_tokens decoded_predictions = [predictions[0] \ for predictions in output_dict["decoded_predictions"]] # shape (predictions): (batch_size, beam_size, num_decoding_steps) predictions = rollout_output_dict['predictions'] # shape (best_predictions): (batch_size, num_decoding_steps) best_predictions = predictions[:, 0, :] if target_tokens: targets = util.get_token_ids_from_text_field_tensors(target_tokens) target_mask = util.get_text_field_mask(target_tokens) decoded_targets = self._decode_tokens(targets, vocab_namespace=self._target_namespace, truncate=True) # TODO #3 (Kushal): Maybe abstract out these losses and use loss_metric like AllenNLP uses. if self._bleu and target_tokens: self._bleu(best_predictions, targets) if self._hamming and target_tokens: self._hamming(best_predictions, targets, target_mask) if self._tensor_based_metric is not None: self._tensor_based_metric( # type: ignore predictions=best_predictions, gold_targets=targets, ) if self._tensor_based_metric_mask is not None: self._tensor_based_metric_mask( # type: ignore predictions=best_predictions, gold_targets=targets, mask=~target_mask, ) if self._token_based_metric is not None: self._token_based_metric( # type: ignore predictions=decoded_predictions, gold_targets=decoded_targets, ) return output_dict
def rollin_parallel(self, state: Dict[str, torch.Tensor], start_predictions: torch.LongTensor, rollin_steps: int, target_tokens: Dict[str, torch.LongTensor] = None, beam_size: int = 1, per_node_beam_size: int = None, sampled: bool = False, truncate_at_end_all: bool = False, rollin_mode: str = None, ): assert self._decoder_net.decodes_parallel, \ "Rollin Parallel is only applicable for transformer style decoders" + \ "that decode whole sequence in parallel." assert not rollin_mode or rollin_mode == "learned", \ "Parallel Decoding only works when following " + \ "teacher forcing rollin policy (rollin_mode='learned')." assert self._scheduled_sampling_ratio == 0, \ "For learned rollin mode, scheduled sampling ratio should always be 0." self.training_iteration += 1 # shape: (batch_size, max_input_sequence_length, encoder_output_dim) encoder_outputs = state["encoder_outputs"] # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] # shape: (batch_size, max_target_sequence_length) targets = util.get_token_ids_from_text_field_tensors(target_tokens) # Prepare embeddings for targets. They will be used as gold embeddings during decoder training # shape: (batch_size, max_target_sequence_length, embedding_dim) target_embedding = self.target_embedder(targets) # shape: (batch_size, max_target_batch_sequence_length) target_mask = util.get_text_field_mask(target_tokens) _, decoder_output = self._decoder_net( previous_state=state, previous_steps_predictions=target_embedding[:, :-1, :], encoder_outputs=encoder_outputs, source_mask=source_mask, previous_steps_mask=target_mask[:, :-1], ) # shape: (group_size, max_target_sequence_length, num_classes) logits = self._output_projection_layer(decoder_output) # Unsqueeze logit to add beam size dimension. logits = logits.unsqueeze(dim=1) log_probabilities, step_predictions = torch.max(logits, dim=-1) return { "predictions": step_predictions, "logits": logits, "class_log_probabilities": log_probabilities, }
def forward( # type: ignore self, tokens: TextFieldTensors, verb_indicator: torch.Tensor, frame_indicator: torch.Tensor, metadata: List[Any], tags: torch.LongTensor = None, frame_tags: torch.LongTensor = None, ): """ # Parameters tokens : `TextFieldTensors`, required The output of `TextField.as_array()`, which should typically be passed directly to a `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which indexes wordpieces from the BERT vocabulary. verb_indicator: `torch.LongTensor`, required. An integer `SequenceFeatureField` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. frame_indicator: torch.LongTensor, required. An integer ``SequenceFeatureField`` representation of the position of the frame in the sentence. This should have shape (batch_size, num_tokens). Similar to verb_indicator, but handles bert wordpiece tokenizer by cosnidering a frame only the first subtoken. tags : `torch.LongTensor`, optional (default = `None`) A torch tensor representing the sequence of integer gold class labels of shape `(batch_size, num_tokens)` frame_tags : torch.LongTensor, optional (default = None) A torch tensor representing the gold frames of shape ``(batch_size, num_tokens)`` metadata : `List[Dict[str, Any]]`, optional, (default = `None`) metadata containg the original words in the sentence, the verb to compute the frame for, and start offsets for converting wordpieces back to a sequence of words, under 'words', 'verb' and 'offsets' keys, respectively. # Returns An output dictionary consisting of: logits : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing unnormalised log probabilities of the tag classes. class_probabilities : `torch.FloatTensor` A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing a distribution of the tag classes per word. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ mask = get_text_field_mask(tokens) input_ids = util.get_token_ids_from_text_field_tensors(tokens) bert_embeddings, _ = self.transformer( input_ids=input_ids, token_type_ids=verb_indicator, attention_mask=mask, return_dict=False, ) # extract embeddings embedded_text_input = self.embedding_dropout(bert_embeddings) frame_embeddings = embedded_text_input[frame_indicator == 1] # get sizes batch_size, sequence_length, _ = embedded_text_input.size() # outputs logits = self.tag_projection_layer(embedded_text_input) frame_logits = self.frame_projection_layer(frame_embeddings) reshaped_log_probs = logits.view(-1, self.num_classes) class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view( [batch_size, sequence_length, self.num_classes]) frame_probabilities = F.softmax(frame_logits, dim=-1) # We need to retain the mask in the output dictionary # so that we can crop the sequences to remove padding # when we do viterbi inference in self.make_output_human_readable. output_dict = { "logits": logits, "frame_logits": frame_logits, "class_probabilities": class_probabilities, "frame_probabilities": frame_probabilities, "mask": mask, } # We add in the offsets here so we can compute the un-wordpieced tags. words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"]) for x in metadata]) lemmas = [l for x in metadata for l in x["lemmas"]] output_dict["words"] = list(words) output_dict["lemma"] = list(lemmas) output_dict["verb"] = list(verbs) output_dict["wordpiece_offsets"] = list(offsets) if tags is not None: # compute role loss role_loss = sequence_cross_entropy_with_logits( logits, tags, mask, label_smoothing=self._label_smoothing) # compute frame loss frame_tags_filtered = frame_tags[frame_indicator == 1] frame_loss = self.frame_criterion(frame_logits, frame_tags_filtered) if not self.ignore_span_metric and self.span_metric is not None and not self.training: batch_verb_indices = [ example_metadata["verb_index"] for example_metadata in metadata ] batch_sentences = [ example_metadata["words"] for example_metadata in metadata ] # Get the BIO tags from make_output_human_readable() batch_bio_predicted_tags = self.make_output_human_readable( output_dict).pop("tags") from allennlp_models.structured_prediction.models.srl import ( convert_bio_tags_to_conll_format, ) batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [ example_metadata["gold_tags"] for example_metadata in metadata ] batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] self.span_metric( batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags, ) self.f1_frame_metric(frame_logits, frame_tags_filtered) output_dict["frame_loss"] = frame_loss output_dict["role_loss"] = role_loss output_dict["loss"] = (role_loss + frame_loss) / 2 return output_dict
def forward( # type: ignore self, tokens: TextFieldTensors = None, label: torch.IntTensor = None, **metadata) -> Dict[str, torch.Tensor]: """ # Parameters tokens : `TextFieldTensors` From a `TextField` label : `torch.IntTensor`, optional (default = `None`) From a `LabelField` # Returns An output dictionary consisting of: - `logits` (`torch.FloatTensor`) : A tensor of shape `(batch_size, num_labels)` representing unnormalized log probabilities of the label. - `probs` (`torch.FloatTensor`) : A tensor of shape `(batch_size, num_labels)` representing probabilities of the label. - `loss` : (`torch.FloatTensor`, optional) : A scalar loss to be optimised. """ if tokens is None: tokens = metadata.pop("sentence") token_embeddings = self._text_field_embedder(tokens) mask = get_text_field_mask(tokens) text_embeddings = self._seq2vec_encoder(token_embeddings, mask=mask) if self._dropout: text_embeddings = self._dropout(text_embeddings) if self._feedforward is not None: text_embeddings = self._feedforward(text_embeddings) logits = self._classification_layer(text_embeddings) output_dict = {"logits": logits} if self._num_labels > 1: probs = torch.nn.functional.softmax(logits, dim=-1) output_dict["probs"] = probs for key in ["idx", "pair_id"]: output_dict[key] = metadata.get(key, [None] * len(logits)) output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors( tokens) if label is not None: if self._num_labels > 1: loss = self._loss(logits, label.long().view(-1)) output_dict["loss"] = loss assert self._accuracy is not None self._accuracy(logits, label) # Shape: (batch_size,) predictions = logits.argmax(axis=-1) # Shape: (batch_size,) references = label else: # Shape: (batch_size,) predictions = logits.squeeze(-1) # Shape: (batch_size,) references = label loss = self._loss(logits.squeeze(-1), label) output_dict["loss"] = loss for metric in self._metrics: metric(predictions, references) return output_dict
def forward( # type: ignore self, tokens: TextFieldTensors, target_ids: TextFieldTensors = None ) -> Dict[str, torch.Tensor]: """ Run a forward pass of the model, returning an output tensor dictionary with the following fields: - `"probabilities"`: a tensor of shape `(batch_size, n_best)` representing the probabilities of the predicted tokens, where `n_best` is either `self._n_best` or `beam_size` if using beam search. - `"top_indices"`: a tensor of shape `(batch_size, n_best, num_predicted_tokens)` containing the IDs of the predicted tokens, where `num_predicted_tokens` is just 1 unless using beam search, in which case it depends on the parameters of the beam search. - `"token_ids"`: a tensor of shape `(batch_size, num_input_tokens)` containing the IDs of the input tokens. - `"loss"` (optional): the loss of the batch, only given if `target_ids` is not `None`. """ output_dict = { "token_ids": util.get_token_ids_from_text_field_tensors(tokens), } # Shape: (batch_size, vocab_size) target_logits = self._next_token_scores(tokens) # Compute loss. if target_ids is not None: batch_size, vocab_size = target_logits.size() tmp = util.get_token_ids_from_text_field_tensors(target_ids) # In some scenarios, target_ids might be a topk list of token ids (e.g. sorted by probabilities). # Therefore, we need to make sure only one token per batch # Assume: first token in each batch is the most desirable one (e.g. highest probability) tmp = tmp[:, 0] if len(tmp.shape) == 2 else tmp assert len(tmp.shape) <= 2 targets = tmp.view(batch_size) loss = torch.nn.functional.cross_entropy(target_logits, targets) self._perplexity(loss) output_dict["loss"] = loss if self._beam_search_generator is not None: # Dummy start predictions. # Shape: (batch_size,) start_predictions = torch.zeros( target_logits.size()[0], device=target_logits.device, dtype=torch.int ) state = self._beam_search_generator.get_step_state(tokens) # Put this in here to avoid having to re-compute on the first step of beam search. state["start_target_logits"] = target_logits # Shape (top_indices): (batch_size, beam_size, num_predicted_tokens) # Shape (top_log_probs): (batch_size, beam_size) top_indices, top_log_probs = self._beam_search_generator.search( start_predictions, state, self._beam_search_step ) # Shape: (batch_size, beam_size) top_probs = top_log_probs.exp() else: # Shape: (batch_size, vocab_size) probs = torch.nn.functional.softmax(target_logits, dim=-1) # Shape (both): (batch_size, n_best) # min here largely because tests use small vocab top_probs, top_indices = probs.topk(k=min(target_logits.size(-1), self._n_best), dim=-1) # Shape: (batch_size, n_best, 1) top_indices = top_indices.unsqueeze(-1) output_dict["top_indices"] = top_indices output_dict["probabilities"] = top_probs return output_dict
def forward( # type: ignore self, question_with_context: Dict[str, Dict[str, torch.LongTensor]], context_span: torch.IntTensor, cls_index: torch.LongTensor = None, answer_span: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ # Parameters question_with_context : `Dict[str, torch.LongTensor]` From a `TextField`. The model assumes that this text field contains the context followed by the question. It further assumes that the tokens have type ids set such that any token that can be part of the answer (i.e., tokens from the context) has type id 0, and any other token (including `[CLS]` and `[SEP]`) has type id 1. context_span : `torch.IntTensor` From a `SpanField`. This marks the span of word pieces in `question` from which answers can come. cls_index : `torch.LongTensor`, optional A tensor of shape `(batch_size,)` that provides the index of the `[CLS]` token in the `question_with_context` for each instance. This is needed because the `[CLS]` token is used to indicate that the question is impossible. If this is `None`, it's assumed that the `[CLS]` token is at index 0 for each instance in the batch. answer_span : `torch.IntTensor`, optional From a `SpanField`. This is the thing we are trying to predict - the span of text that marks the answer. If given, we compute a loss that gets included in the output directory. metadata : `List[Dict[str, Any]]`, optional If present, this should contain the question id, and the original texts of context, question, tokenized version of both, and a list of possible answers. The length of the `metadata` list should be the batch size, and each dictionary should have the keys `id`, `question`, `context`, `question_tokens`, `context_tokens`, and `answers`. # Returns `Dict[str, torch.Tensor]` : An output dictionary with the following fields: - span_start_logits (`torch.FloatTensor`) : A tensor of shape `(batch_size, passage_length)` representing unnormalized log probabilities of the span start position. - span_end_logits (`torch.FloatTensor`) : A tensor of shape `(batch_size, passage_length)` representing unnormalized log probabilities of the span end position (inclusive). - best_span_scores (`torch.FloatTensor`) : The score for each of the best spans. - loss (`torch.FloatTensor`, optional) : A scalar loss to be optimised, evaluated against `answer_span`. - best_span (`torch.IntTensor`, optional) : Provided when not in train mode and sufficient metadata given for the instance. The result of a constrained inference over `span_start_logits` and `span_end_logits` to find the most probable span. Shape is `(batch_size, 2)` and each offset is a token index, unless the best span for an instance was predicted to be the `[CLS]` token, in which case the span will be (-1, -1). - best_span_str (`List[str]`, optional) : Provided when not in train mode and sufficient metadata given for the instance. This is the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._text_field_embedder(question_with_context) # shape: (batch_size, sequence_length, 2) logits = self._linear_layer(embedded_question) # shape: (batch_size, sequence_length, 1) span_start_logits, span_end_logits = logits.split(1, dim=-1) # shape: (batch_size, sequence_length) span_start_logits = span_start_logits.squeeze(-1) # shape: (batch_size, sequence_length) span_end_logits = span_end_logits.squeeze(-1) # Create a mask for `question_with_context` to mask out tokens that are not part # of the context. # shape: (batch_size, sequence_length) possible_answer_mask = torch.zeros_like( get_token_ids_from_text_field_tensors(question_with_context), dtype=torch.bool) for i, (start, end) in enumerate(context_span): possible_answer_mask[i, start:end + 1] = True # Also unmask the [CLS] token since that token is used to indicate that # the question is impossible. possible_answer_mask[ i, 0 if cls_index is None else cls_index[i]] = True # Replace the masked values with a very negative constant since we're in log-space. # shape: (batch_size, sequence_length) span_start_logits = replace_masked_values_with_big_negative_number( span_start_logits, possible_answer_mask) # shape: (batch_size, sequence_length) span_end_logits = replace_masked_values_with_big_negative_number( span_end_logits, possible_answer_mask) # Now calculate the best span. # shape: (batch_size, 2) best_spans = get_best_span(span_start_logits, span_end_logits) # Sum the span start score with the span end score to get an overall score for the span. # shape: (batch_size,) best_span_scores = torch.gather( span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) + torch.gather( span_end_logits, 1, best_spans[:, 1].unsqueeze(1)) best_span_scores = best_span_scores.squeeze(1) output_dict = { "span_start_logits": span_start_logits, "span_end_logits": span_end_logits, "best_span_scores": best_span_scores, } # Compute the loss. if answer_span is not None: output_dict["loss"] = self._evaluate_span(best_spans, span_start_logits, span_end_logits, answer_span) # Gather the string of the best span and compute the EM and F1 against the gold span, # if given. if not self.training and metadata is not None: ( output_dict["best_span_str"], output_dict["best_span"], ) = self._collect_best_span_strings(best_spans, context_span, metadata, cls_index) return output_dict
def forward( # type: ignore self, tokens: TextFieldTensors, # batch * words options: TextFieldTensors, # batch * num_options * words labels: torch.IntTensor = None # batch * num_options ) -> Dict[str, torch.Tensor]: embedded_text = self._text_field_embedder(tokens) mask = get_text_field_mask(tokens).long() embedded_options = self._text_field_embedder( options, num_wrapping_dims=1) # options_mask.dim() - 2 options_mask = get_text_field_mask(options).long() if self._dropout: embedded_text = self._dropout(embedded_text) embedded_options = self._dropout(embedded_options) """ This isn't exactly a 'hack', but it's definitely not the most efficient way to do it. Our matcher expects a single (query, document) pair, but we have (query, [d_0, ..., d_n]). To get around this, we expand the query embeddings to create these pairs, and then flatten both into the 3D tensor [batch*num_options, words, dim] expected by the matcher. The expansion does this: [ (q_0, [d_{0,0}, ..., d_{0,n}]), (q_1, [d_{1,0}, ..., d_{1,n}]) ] => [ [ (q_0, d_{0,0}), ..., (q_0, d_{0,n}) ], [ (q_1, d_{1,0}), ..., (q_1, d_{1,n}) ] ] Which we then flatten along the batch dimension. It would likely be more efficient to rewrite the matrix multiplications in the relevance matchers, but this is a more general solution. """ embedded_text = embedded_text.unsqueeze(1).expand( -1, embedded_options.size(1), -1, -1) # [batch, num_options, words, dim] mask = mask.unsqueeze(1).expand(-1, embedded_options.size(1), -1) scores = self._relevance_matcher(embedded_text, embedded_options, mask, options_mask).squeeze(-1) probs = torch.sigmoid(scores) output_dict = {"logits": scores, "probs": probs} output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors( tokens) if labels is not None: label_mask = (labels != -1) self._mrr(probs, labels, label_mask) self._ndcg(probs, labels, label_mask) probs = probs.view(-1) labels = labels.view(-1) label_mask = label_mask.view(-1) self._auc(probs, labels.ge(0.5).long(), label_mask) loss = self._loss(probs, labels) output_dict["loss"] = loss.masked_fill(~label_mask, 0).sum() / label_mask.sum() return output_dict
def attack_from_json( self, inputs: JsonDict, input_field_to_attack: str = "tokens", grad_input_field: str = "grad_input_1", ignore_tokens: List[str] = None, target: JsonDict = None, ) -> JsonDict: """ Replaces one token at a time from the input until the model's prediction changes. ``input_field_to_attack`` is for example ``tokens``, it says what the input field is called. ``grad_input_field`` is for example ``grad_input_1``, which is a key into a grads dictionary. The method computes the gradient w.r.t. the tokens, finds the token with the maximum gradient (by L2 norm), and replaces it with another token based on the first-order Taylor approximation of the loss. This process is iteratively repeated until the prediction changes. Once a token is replaced, it is not flipped again. # Parameters inputs : ``JsonDict`` The model inputs, the same as what is passed to a ``Predictor``. input_field_to_attack : ``str``, optional (default='tokens') The field that has the tokens that we're going to be flipping. This must be a ``TextField``. grad_input_field : ``str``, optional (default='grad_input_1') If there is more than one field that gets embedded in your model (e.g., a question and a passage, or a premise and a hypothesis), this tells us the key to use to get the correct gradients. This selects from the output of :func:`Predictor.get_gradients`. ignore_tokens : ``List[str]``, optional (default=DEFAULT_IGNORE_TOKENS) These tokens will not be flipped. The default list includes some simple punctuation, OOV and padding tokens, and common control tokens for BERT, etc. target : ``JsonDict``, optional (default=None) If given, this will be a `targeted` hotflip attack, where instead of just trying to change a model's prediction from what it current is predicting, we try to change it to a `specific` target value. This is a ``JsonDict`` because it needs to specify the field name and target value. For example, for a masked LM, this would be something like ``{"words": ["she"]}``, because ``"words"`` is the field name, there is one mask token (hence the list of length one), and we want to change the prediction from whatever it was to ``"she"``. """ if self.embedding_matrix is None: self.initialize() ignore_tokens = DEFAULT_IGNORE_TOKENS if ignore_tokens is None else ignore_tokens # If `target` is `None`, we move away from the current prediction, otherwise we move # _towards_ the target. sign = -1 if target is None else 1 instance = self.predictor._json_to_instance(inputs) if target is None: output_dict = self.predictor._model.forward_on_instance(instance) else: output_dict = target # This now holds the predictions that we want to change (either away from or towards, # depending on whether `target` was passed). We'll use this in the loop below to check for # when we've met our stopping criterion. original_instances = self.predictor.predictions_to_labeled_instances( instance, output_dict) # This is just for ease of access in the UI, so we know the original tokens. It's not used # in the logic below. original_text_field: TextField = original_instances[0][ # type: ignore input_field_to_attack] original_tokens = deepcopy(original_text_field.tokens) final_tokens = [] # `original_instances` is a list because there might be several different predictions that # we're trying to attack (e.g., all of the NER tags for an input sentence). We attack them # one at a time. for instance in original_instances: # Gets a list of the fields that we want to check to see if they change. fields_to_compare = utils.get_fields_to_compare( inputs, instance, input_field_to_attack) # We'll be modifying the tokens in this text field below, and grabbing the modified # list after the `while` loop. text_field: TextField = instance[ input_field_to_attack] # type: ignore # Because we can save computation by getting grads and outputs at the same time, we do # them together at the end of the loop, even though we use grads at the beginning and # outputs at the end. This is our initial gradient for the beginning of the loop. The # output can be ignored here. grads, outputs = self.predictor.get_gradients([instance]) # Ignore any token that is in the ignore_tokens list by setting the token to already # flipped. flipped: List[int] = [] for index, token in enumerate(text_field.tokens): if token.text in ignore_tokens: flipped.append(index) if "clusters" in outputs: # Coref unfortunately needs a special case here. We don't want to flip words in # the same predicted coref cluster, but we can't really specify a list of tokens, # because, e.g., "he" could show up in several different clusters. # TODO(mattg): perhaps there's a way to get `predictions_to_labeled_instances` to # return the set of tokens that shouldn't be changed for each instance? E.g., you # could imagine setting a field on the `Token` object, that we could then read # here... for cluster in outputs["clusters"]: for mention in cluster: for index in range(mention[0], mention[1] + 1): flipped.append(index) while True: # Compute L2 norm of all grads. grad = grads[grad_input_field][0] grads_magnitude = [g.dot(g) for g in grad] # only flip a token once for index in flipped: grads_magnitude[index] = -1 # We flip the token with highest gradient norm. index_of_token_to_flip = numpy.argmax(grads_magnitude) if grads_magnitude[index_of_token_to_flip] == -1: # If we've already flipped all of the tokens, we give up. break flipped.append(index_of_token_to_flip) text_field_tensors = text_field.as_tensor( text_field.get_padding_lengths()) input_tokens = util.get_token_ids_from_text_field_tensors( text_field_tensors) original_id_of_token_to_flip = input_tokens[ index_of_token_to_flip] # Get new token using taylor approximation. new_id = self._first_order_taylor( grad[index_of_token_to_flip], original_id_of_token_to_flip, sign) # Flip token. We need to tell the instance to re-index itself, so the text field # will actually update. new_token = Token(self.vocab._index_to_token[self.namespace] [new_id]) # type: ignore text_field.tokens[index_of_token_to_flip] = new_token instance.indexed = False # Get model predictions on instance, and then label the instances grads, outputs = self.predictor.get_gradients( [instance]) # predictions for key, output in outputs.items(): if isinstance(output, torch.Tensor): outputs[key] = output.detach().cpu().numpy().squeeze() elif isinstance(output, list): outputs[key] = output[0] # TODO(mattg): taking the first result here seems brittle, if we're in a case where # there are multiple predictions. labeled_instance = self.predictor.predictions_to_labeled_instances( instance, outputs)[0] # If we've met our stopping criterion, we stop. has_changed = utils.instance_has_changed( labeled_instance, fields_to_compare) if target is None and has_changed: # With no target, we just want to change the prediction. break if target is not None and not has_changed: # With a given target, we want to *match* the target, which we check by # `not has_changed`. break final_tokens.append(text_field.tokens) return sanitize({ "final": final_tokens, "original": original_tokens, "outputs": outputs })
def forward( # type: ignore self, tokens: TextFieldTensors, verb_indicator: torch.Tensor, metadata: List[Any], tags: torch.LongTensor = None, ): """ # Parameters tokens : TextFieldTensors, required The output of `TextField.as_array()`, which should typically be passed directly to a `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which indexes wordpieces from the BERT vocabulary. verb_indicator: torch.LongTensor, required. An integer `SequenceFeatureField` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels of shape `(batch_size, num_tokens)` metadata : `List[Dict[str, Any]]`, optional, (default = None) metadata containg the original words in the sentence, the verb to compute the frame for, and start offsets for converting wordpieces back to a sequence of words, under 'words', 'verb' and 'offsets' keys, respectively. # Returns An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing unnormalised log probabilities of the tag classes. class_probabilities : torch.FloatTensor A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing a distribution of the tag classes per word. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ mask = get_text_field_mask(tokens) bert_embeddings, _ = self.bert_model( input_ids=util.get_token_ids_from_text_field_tensors(tokens), token_type_ids=verb_indicator, attention_mask=mask, ) embedded_text_input = self.embedding_dropout(bert_embeddings) batch_size, sequence_length, _ = embedded_text_input.size() logits = self.tag_projection_layer(embedded_text_input) reshaped_log_probs = logits.view(-1, self.num_classes) class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view( [batch_size, sequence_length, self.num_classes]) output_dict = { "logits": logits, "class_probabilities": class_probabilities } # We need to retain the mask in the output dictionary # so that we can crop the sequences to remove padding # when we do viterbi inference in self.make_output_human_readable. output_dict["mask"] = mask # We add in the offsets here so we can compute the un-wordpieced tags. words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"]) for x in metadata]) output_dict["words"] = list(words) output_dict["verb"] = list(verbs) output_dict["wordpiece_offsets"] = list(offsets) if tags is not None: loss = sequence_cross_entropy_with_logits( logits, tags, mask, label_smoothing=self._label_smoothing) if not self.ignore_span_metric and self.span_metric is not None and not self.training: batch_verb_indices = [ example_metadata["verb_index"] for example_metadata in metadata ] batch_sentences = [ example_metadata["words"] for example_metadata in metadata ] # Get the BIO tags from make_output_human_readable() # TODO (nfliu): This is kind of a hack, consider splitting out part # of make_output_human_readable() to a separate function. batch_bio_predicted_tags = self.make_output_human_readable( output_dict).pop("tags") batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [ example_metadata["gold_tags"] for example_metadata in metadata ] batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] self.span_metric( batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags, ) output_dict["loss"] = loss return output_dict