def sequence_cross_entropy_with_logits( logits: torch.FloatTensor, targets: torch.LongTensor, mask: torch.BoolTensor, label_smoothing: bool, reduce: str = "mean") -> torch.FloatTensor: """ label_smoothing : ``float``, optional (default = 0.0) It should be smaller than 1. """ # shape : (batch * sequence_length, num_classes) logits_flat = logits.view(-1, logits.size(-1)) # shape : (batch * sequence_length, num_classes) log_probs_flat = F.log_softmax(logits_flat, dim=-1) # shape : (batch * max_len, 1) targets_flat = targets.view(-1, 1).long() if label_smoothing > 0.0: num_classes = logits.size(-1) smoothing_value = label_smoothing / float(num_classes) # Fill all the correct indices with 1 - smoothing value. one_hot_targets = torch.zeros_like(log_probs_flat).scatter_( -1, targets_flat, 1.0 - label_smoothing) smoothed_targets = one_hot_targets + smoothing_value negative_log_likelihood_flat = -log_probs_flat * smoothed_targets negative_log_likelihood_flat = negative_log_likelihood_flat.sum( -1, keepdim=True) else: # shape : (batch * sequence_length, 1) negative_log_likelihood_flat = -torch.gather( log_probs_flat, dim=1, index=targets_flat) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood_flat.view( -1, logits.shape[1]) mask = mask.float() # shape : (batch, sequence_length) loss = negative_log_likelihood * mask if reduce == "mean": loss = loss.sum() / (mask.sum() + 1e-13) elif reduce == "batch": # shape : (batch,) loss = loss.sum(1) / (mask.sum(1) + 1e-13) elif reduce == "batch-sequence": # we favor longer sequences, so we don't divide with the total sequence length here # shape : (batch,) loss = loss.sum(1) return loss
def _joint_likelihood(self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: """ Computes the numerator term for the log-likelihood, which is just score(inputs, tags) """ batch_size, sequence_length, _ = logits.data.shape # Transpose batch size and sequence dimensions: logits = logits.transpose(0, 1).contiguous() mask = mask.transpose(0, 1).contiguous() tags = tags.transpose(0, 1).contiguous() # Start with the transition scores from start_tag to the first tag in each input if self.include_start_end_transitions: score = self.start_transitions.index_select(0, tags[0]) else: score = 0.0 # Add up the scores for the observed transitions and all the inputs but the last for i in range(sequence_length - 1): # Each is shape (batch_size,) current_tag, next_tag = tags[i], tags[i + 1] # The scores for transitioning from current_tag to next_tag transition_score = self.transitions[current_tag.view(-1), next_tag.view(-1)] # The score for using current_tag emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1) # Include transition score if next element is unmasked, # input_score if this element is unmasked. score = score + transition_score * mask[i + 1] + emit_score * mask[i] # Transition from last state to "stop" state. To start with, we need to find the last tag # for each instance. last_tag_index = mask.sum(0).long() - 1 last_tags = tags.gather(0, last_tag_index.view(1, batch_size)).squeeze(0) # Compute score of transitioning to `stop_tag` from each "last tag". if self.include_start_end_transitions: last_transition_score = self.end_transitions.index_select( 0, last_tags) else: last_transition_score = 0.0 # Add the last input if it's not masked. last_inputs = logits[-1] # (batch_size, num_tags) last_input_score = last_inputs.gather(1, last_tags.view( -1, 1)) # (batch_size, 1) last_input_score = last_input_score.squeeze() # (batch_size,) score = score + last_transition_score + last_input_score * mask[-1] return score
def forward(self, sequence: Tensor, mask: BoolTensor) -> DynamicRnnOutput: """ rnn 执行。特别注意: 所有的都是 batch first :param sequence: sequence 序列, shape: (B, seq_len, input_size) :param mask: 对 sequence 的 mask, shape: (B, seq_len) :return: 解码后的结果,具体参考 DynamicOutput 说明 """ assert sequence.dim() == 3, \ f"sequence shape: {sequence.dim()} 与 (B, seq_len, input_size) 不匹配" assert sequence.size(-1) == self.rnn.input_size, \ f"sequence.size(-1): {sequence.size(-1)} 与 rnn input_size: {self.rnn.input_size} 不相等" batch_size = sequence.size(0) sequence_length = sequence.size(1) sequence_lengths = mask.sum(dim=-1) pack = pack_padded_sequence(sequence, lengths=sequence_lengths, batch_first=True, enforce_sorted=False) packed_sequence_encoding, last_state = self.rnn(pack) encoding, pad_sequence_length = pad_packed_sequence( packed_sequence_encoding, batch_first=True, padding_value=0.0, total_length=sequence_length) if self.rnn_type == DynamicRnn.LSTM or self.rnn_type == DynamicRnn.GRU: h_n, c_n = last_state else: h_n = last_state c_n = None # h_n shape: (num_layers * num_directions, batch, hidden_size) # 因为是按照 batch first 来处理的,所以需要进行转换 # 转换之后的 h_n shape: (batch, num_layers, hidden_size * num_directions), c_n 同样的处理 h_n = torch.transpose(h_n, 0, 1).contiguous().view(batch_size, self.num_layers, -1) last_layer_h_n = h_n[:, -1, :].contiguous().view(batch_size, -1) last_layer_c_n = None if c_n is not None: c_n = torch.transpose(c_n, 0, 1).contiguous().view(batch_size, self.num_layers, -1) last_layer_c_n = c_n[:, -1, :].contiguous().view(batch_size, -1) return DynamicRnnOutput(last_layer_h_n=last_layer_h_n, last_layer_c_n=last_layer_c_n, h_n=h_n, c_n=c_n, sequence_encoding=encoding)
def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, sample_weights: torch.Tensor) -> torch.Tensor: BATCH_SIZE, _, CLASSES = pred.size() valid_positions = mask.sum() pred = pred.reshape(-1, CLASSES) true = true.reshape(-1) mask = mask.reshape(-1) loss = utils.masked_cross_entropy(pred, true, mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) return loss.sum() / valid_positions
def get_final_encoder_states( encoder_outputs: torch.Tensor, mask: torch.BoolTensor, bidirectional: bool = False ) -> torch.Tensor: last_word_indices = mask.sum(1) - 1 batch_size, _, encoder_output_dim = encoder_outputs.size() expanded_indices = last_word_indices.view(-1, 1, 1).expand(batch_size, 1, encoder_output_dim) final_encoder_output = encoder_outputs.gather(1, expanded_indices) final_encoder_output = final_encoder_output.squeeze(1) # (batch_size, encoder_output_dim) if bidirectional: final_forward_output = final_encoder_output[:, : (encoder_output_dim // 2)] final_backward_output = encoder_outputs[:, 0, (encoder_output_dim // 2) :] final_encoder_output = torch.cat([final_forward_output, final_backward_output], dim=-1) return final_encoder_output
def get_lengths_from_binary_sequence_mask(mask: torch.BoolTensor) -> torch.LongTensor: """ Compute sequence lengths for each batch element in a tensor using a binary mask. # Parameters mask : `torch.BoolTensor`, required. A 2D binary mask of shape (batch_size, sequence_length) to calculate the per-batch sequence lengths from. # Returns `torch.LongTensor` A torch.LongTensor of shape (batch_size,) representing the lengths of the sequences in the batch. """ return mask.sum(-1)
def remove_sentence_boundaries( tensor: torch.Tensor, mask: torch.BoolTensor ) -> Tuple[torch.Tensor, torch.Tensor]: sequence_lengths = mask.sum(dim=1).detach().cpu().numpy() tensor_shape = list(tensor.data.shape) new_shape = list(tensor_shape) new_shape[1] = tensor_shape[1] - 2 tensor_without_boundary_tokens = tensor.new_zeros(*new_shape) new_mask = tensor.new_zeros((new_shape[0], new_shape[1]), dtype=torch.bool) for i, j in enumerate(sequence_lengths): if j > 2: tensor_without_boundary_tokens[i, : (j - 2), :] = tensor[i, 1 : (j - 1), :] new_mask[i, : (j - 2)] = True return tensor_without_boundary_tokens, new_mask
def viterbi_decode(self, h: FloatTensor, mask: BoolTensor) -> List[List[int]]: """ decode labels using viterbi algorithm :param h: hidden matrix (batch_size, seq_len, num_labels) :param mask: mask tensor of each sequence in mini batch (batch_size, batch_size) :return: labels of each sequence in mini batch """ batch_size, seq_len, _ = h.size() # prepare the sequence lengths in each sequence seq_lens = mask.sum(dim=1) # In mini batch, prepare the score # from the start sequence to the first label score = [self.start_trans.data + h[:, 0]] path = [] for t in range(1, seq_len): # extract the score of previous sequence # (batch_size, num_labels, 1) previous_score = score[t - 1].view(batch_size, -1, 1) # extract the score of hidden matrix of sequence # (batch_size, 1, num_labels) h_t = h[:, t].view(batch_size, 1, -1) # extract the score in transition # from label of t-1 sequence to label of sequence of t # self.trans_matrix has the score of the transition # from sequence A to sequence B # (batch_size, num_labels, num_labels) score_t = previous_score + self.trans_matrix + h_t # keep the maximum value # and point where maximum value of each sequence # (batch_size, num_labels) best_score, best_path = score_t.max(1) score.append(best_score) path.append(best_path) # predict labels of mini batch best_paths = [ self._viterbi_compute_best_path(i, seq_lens, score, path) for i in range(batch_size) ] return best_paths
def __call__( # type: ignore self, predictions: Dict[str, torch.Tensor], gold_labels: Dict[str, torch.Tensor], mask: torch.BoolTensor): self.upos_score(predictions["upostag"], gold_labels["upostag"], mask) self.xpos_score(predictions["xpostag"], gold_labels["xpostag"], mask) self.semrel_score(predictions["semrel"], gold_labels["semrel"], mask) self.feats_score(predictions["feats"], gold_labels["feats"], mask) self.lemma_score(predictions["lemma"], gold_labels["lemma"], mask) self.attachment_scores(predictions["head"], predictions["deprel"], gold_labels["head"], gold_labels["deprel"], mask) total = mask.sum() correct_indices = (self.upos_score.correct_indices * self.xpos_score.correct_indices * self.semrel_score.correct_indices * self.feats_score.correct_indices * self.lemma_score.correct_indices * self.attachment_scores.correct_indices) total, correct_indices = self.detach_tensors(total, correct_indices) self.em_score = (correct_indices.float().sum() / total).item()
def _compute_numerator_log_likelihood(self, h: FloatTensor, y: LongTensor, mask: BoolTensor) -> FloatTensor: """ compute the numerator term for the log-likelihood :param h: hidden matrix (batch_size, seq_len, num_labels) :param y: answer labels of each sequence in mini batch (batch_size, seq_len) :param mask: mask tensor of each sequence in mini batch (batch_size, seq_len) :return: The score of numerator term for the log-likelihood """ batch_size, seq_len, _ = h.size() h_unsqueezed = h.unsqueeze(-1) trans = self.trans_matrix.unsqueeze(-1) arange_b = torch.arange(batch_size) # extract first vector of sequences in mini batch calc_range = seq_len - 1 score = self.start_trans[y[:, 0]] + sum([ self._calc_trans_score_for_num_llh(h_unsqueezed, y, trans, mask, t, arange_b) for t in range(calc_range) ]) # extract end label number of each sequence in mini batch # (batch_size) last_mask_index = mask.sum(1) - 1 last_labels = y[arange_b, last_mask_index] each_last_score = h[arange_b, -1, last_labels] * mask[:, -1] # Add the score of the sequences of the maximum length in mini batch # Add the scores from the last tag of each sequence to EOS score += each_last_score + self.end_trans[last_labels] return score
def compute_log_probability( logits: torch.FloatTensor, targets: torch.LongTensor, mask: torch.BoolTensor = None, debug_fxn: Callable[[object, str], None] = null_log, ) -> Tuple[torch.FloatTensor, torch.LongTensor]: """ Compute sum of log probs from model logits Arguments: logits (torch.FloatTensor): Model output logits (B x T x V) targets (torch.LongTensor): Target tokens (B x T) mask (torch.BoolTensor): Mask revealing only the utterance tokens (B x T) debug_fxn (callable): Logging function Returns: torch.FloatTensor: Target log probabilities (B x T) torch.LongTensor: Number of utterance tokens (1) """ # Get log probability from logits via log softmax log_probs = torch.log_softmax(logits, dim=-1) debug_fxn(log_probs, 'log_probs') debug_fxn(targets, 'targets') # Extract target token probability - (B x T) target_log_probs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) debug_fxn(target_log_probs, 'target_log_probs') # Mask to utterance tokens if mask is not None: target_log_probs = target_log_probs.masked_select(mask) debug_fxn(target_log_probs, 'target_log_probs (masked)') n_tokens = mask.sum() else: n_tokens = target_log_probs.numel() debug_fxn(n_tokens, 'n_tokens') return target_log_probs, n_tokens
def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: assert len(inputs.shape) == 3 assert len(mask.shape) == 2 assert inputs.shape[:-1] == mask.shape _, seq_len, _ = inputs.shape sequence_lengths = mask.sum(-1) packed_inputs = nn.utils.rnn.pack_padded_sequence(inputs, sequence_lengths, batch_first=True, enforce_sorted=False) packed_outputs, (h_n, c_n) = super(LstmWrapper, self).forward(packed_inputs) output, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True, total_length=seq_len) return output, (h_n, c_n)
def _loss( self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: BATCH_SIZE, N, M = pred.size() assert N == M SENTENCE_LENGTH = N valid_positions = mask.sum() result = [] # Ignore first pred dimension as it is ROOT token prediction for i in range(SENTENCE_LENGTH - 1): pred_i = pred[:, i + 1, :].reshape(BATCH_SIZE, SENTENCE_LENGTH) true_i = true[:, i].reshape(-1) mask_i = mask[:, i] cross_entropy_loss = utils.masked_cross_entropy( pred_i, true_i, mask_i) result.append(cross_entropy_loss) cycle_loss = self._cycle_loss(pred) loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1) return loss.sum() / valid_positions + cycle_loss.mean( ), cycle_loss.mean()
def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, sample_weights: torch.Tensor) -> torch.Tensor: assert pred.size() == true.size() BATCH_SIZE, _, MORPHOLOGICAL_FEATURES = pred.size() valid_positions = mask.sum() pred = pred.reshape(-1, MORPHOLOGICAL_FEATURES) true = true.reshape(-1, MORPHOLOGICAL_FEATURES) mask = mask.reshape(-1) loss = None loss_func = utils.masked_cross_entropy for cat, cat_indices in self.slices.items(): if cat not in ["__PAD__", "_"]: if loss is None: loss = loss_func(pred[:, cat_indices], true[:, cat_indices].argmax(dim=1), mask) else: loss += loss_func(pred[:, cat_indices], true[:, cat_indices].argmax(dim=1), mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) return loss.sum() / valid_positions
def add_sentence_boundary_token_ids( tensor: torch.Tensor, mask: torch.BoolTensor, sentence_begin_token: Any, sentence_end_token: Any ) -> Tuple[torch.Tensor, torch.BoolTensor]: sequence_lengths = mask.sum(dim=1).detach().cpu().numpy() tensor_shape = list(tensor.data.shape) new_shape = list(tensor_shape) new_shape[1] = tensor_shape[1] + 2 tensor_with_boundary_tokens = tensor.new_zeros(*new_shape) if len(tensor_shape) == 2: tensor_with_boundary_tokens[:, 1:-1] = tensor tensor_with_boundary_tokens[:, 0] = sentence_begin_token for i, j in enumerate(sequence_lengths): tensor_with_boundary_tokens[i, j + 1] = sentence_end_token new_mask = tensor_with_boundary_tokens != 0 elif len(tensor_shape) == 3: tensor_with_boundary_tokens[:, 1:-1, :] = tensor for i, j in enumerate(sequence_lengths): tensor_with_boundary_tokens[i, 0, :] = sentence_begin_token tensor_with_boundary_tokens[i, j + 1, :] = sentence_end_token new_mask = (tensor_with_boundary_tokens > 0).sum(dim=-1) > 0 else: raise ValueError("add_sentence_boundary_token_ids only accepts 2D and 3D input") return tensor_with_boundary_tokens, new_mask
def _unfold_long_sequences( self, embeddings: torch.FloatTensor, mask: torch.BoolTensor, batch_size: int, num_segment_concat_wordpieces: int, ) -> torch.FloatTensor: """ We take 2D segments of a long sequence and flatten them out to get the whole sequence representation while remove unnecessary special tokens. [ [ [CLS]_emb A_emb B_emb C_emb [SEP]_emb ], [ [CLS]_emb D_emb E_emb [SEP]_emb [PAD]_emb ] ] -> [ [CLS]_emb A_emb B_emb C_emb D_emb E_emb [SEP]_emb ] We truncate the start and end tokens for all segments, recombine the segments, and manually add back the start and end tokens. # Parameters embeddings: `torch.FloatTensor` Shape: [batch_size * num_segments, self._max_length, embedding_size]. mask: `torch.BoolTensor` Shape: [batch_size * num_segments, self._max_length]. The mask for the concatenated segments of wordpieces. The same as `segment_concat_mask` in `forward()`. batch_size: `int` num_segment_concat_wordpieces: `int` The length of the original "[ [CLS] A B C [SEP] [CLS] D E F [SEP] ]", i.e. the original `token_ids.size(1)`. # Returns: embeddings: `torch.FloatTensor` Shape: [batch_size, self._num_wordpieces, embedding_size]. """ def lengths_to_mask(lengths, max_len, device): return torch.arange(max_len, device=device).expand( lengths.size(0), max_len) < lengths.unsqueeze(1) device = embeddings.device num_segments = int(embeddings.size(0) / batch_size) embedding_size = embeddings.size(2) # We want to remove all segment-level special tokens but maintain sequence-level ones num_wordpieces = num_segment_concat_wordpieces - ( num_segments - 1) * self._num_added_tokens embeddings = embeddings.reshape(batch_size, num_segments * self._max_length, embedding_size) mask = mask.reshape(batch_size, num_segments * self._max_length) # We assume that all 1s in the mask precede all 0s, and add an assert for that. # Open an issue on GitHub if this breaks for you. # Shape: (batch_size,) seq_lengths = mask.sum(-1) if not (lengths_to_mask(seq_lengths, mask.size(1), device) == mask).all(): raise ValueError( "Long sequence splitting only supports masks with all 1s preceding all 0s." ) # Shape: (batch_size, self._num_added_end_tokens); this is a broadcast op end_token_indices = ( seq_lengths.unsqueeze(-1) - torch.arange(self._num_added_end_tokens, device=device) - 1) # Shape: (batch_size, self._num_added_start_tokens, embedding_size) start_token_embeddings = embeddings[:, :self. _num_added_start_tokens, :] # Shape: (batch_size, self._num_added_end_tokens, embedding_size) end_token_embeddings = batched_index_select(embeddings, end_token_indices) embeddings = embeddings.reshape(batch_size, num_segments, self._max_length, embedding_size) embeddings = embeddings[:, :, self._num_added_start_tokens:-self. _num_added_end_tokens, :] # truncate segment-level start/end tokens embeddings = embeddings.reshape(batch_size, -1, embedding_size) # flatten # Now try to put end token embeddings back which is a little tricky. # The number of segment each sequence spans, excluding padding. Mimicking ceiling operation. # Shape: (batch_size,) num_effective_segments = (seq_lengths + self._max_length - 1) / self._max_length # The number of indices that end tokens should shift back. num_removed_non_end_tokens = ( num_effective_segments * self._num_added_tokens - self._num_added_end_tokens) # Shape: (batch_size, self._num_added_end_tokens) end_token_indices -= num_removed_non_end_tokens.unsqueeze(-1) assert (end_token_indices >= self._num_added_start_tokens).all() # Add space for end embeddings embeddings = torch.cat( [embeddings, torch.zeros_like(end_token_embeddings)], 1) # Add end token embeddings back embeddings.scatter_( 1, end_token_indices.unsqueeze(-1).expand_as(end_token_embeddings), end_token_embeddings) # Now put back start tokens. We can do this before putting back end tokens, but then # we need to change `num_removed_non_end_tokens` a little. embeddings = torch.cat([start_token_embeddings, embeddings], 1) # Truncate to original length embeddings = embeddings[:, :num_wordpieces, :] return embeddings
def _construct_loss( self, head_tag: torch.Tensor, child_tag: torch.Tensor, score_arc: torch.Tensor, head_indices: torch.Tensor, head_tags: torch.Tensor, mask: torch.BoolTensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes the arc and tag loss for a sequence given gold head indices and tags. # Parameters head_tag : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, tag_dim), which will be used to generate predictions for the dependency tags for the given arcs. child_tag : `torch.Tensor`, required A tensor of shape (batch_size, sequence_length, tag_dim), which will be used to generate predictions for the dependency tags for the given arcs. score_arc : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a distribution over attachments of a given word to all other words. head_indices : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length). The indices of the heads for every word. head_tags : `torch.Tensor`, required. A tensor of shape (batch_size, sequence_length). The dependency labels of the heads for every word. mask : `torch.BoolTensor`, required. A mask of shape (batch_size, sequence_length), denoting unpadded elements in the sequence. # Returns arc_nll : `torch.Tensor`, required. The negative log likelihood from the arc loss. tag_nll : `torch.Tensor`, required. The negative log likelihood from the arc tag loss. """ batch_size, sequence_length, _ = score_arc.size() # shape (batch_size, 1) range_vector = torch.arange(batch_size, device=score_arc.device).unsqueeze(1) # shape (batch_size, sequence_length, sequence_length) normalised_arc_logits = ( masked_log_softmax(score_arc, mask) * mask.unsqueeze(2) * mask.unsqueeze(1) ) # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag, child_tag, head_indices) normalised_head_tag_logits = masked_log_softmax( head_tag_logits, mask.unsqueeze(-1) ) * mask.unsqueeze(-1) # index matrix with shape (batch, sequence_length) timestep_index = torch.arange(sequence_length, device=score_arc.device) child_index = ( timestep_index.view(1, sequence_length) .expand(batch_size, sequence_length) .long() ) # shape (batch_size, sequence_length) arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] # We don't care about predictions for the symbolic ROOT token's head, # so we remove it from the loss. arc_loss = arc_loss[:, 1:] tag_loss = tag_loss[:, 1:] # The number of valid positions is equal to the number of unmasked elements minus # 1 per sequence in the batch, to account for the symbolic HEAD token. valid_positions = mask.sum() - batch_size arc_nll = -arc_loss.sum() / valid_positions.float() tag_nll = -tag_loss.sum() / valid_positions.float() return arc_nll, tag_nll
def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, parent_mask: torch.BoolTensor, parent_start_mask: torch.BoolTensor, parent_end_mask: torch.BoolTensor, child_mask: torch.BoolTensor = None, parent_idxs: torch.LongTensor = None, parent_tags: torch.LongTensor = None, parent_starts: torch.BoolTensor = None, parent_ends: torch.BoolTensor = None, child_idxs: torch.BoolTensor = None, child_starts: torch.BoolTensor = None, child_ends: torch.BoolTensor = None, ): """ todo implement docstring Args: token_ids: [batch_size, num_word_pieces] type_ids: [batch_size, num_word_pieces] offsets: [batch_size, num_words, 2] wordpiece_mask: [batch_size, num_word_pieces] pos_tags: [batch_size, num_words] word_mask: [batch_size, num_words] parent_mask: [batch_size, num_words] parent_start_mask: [batch_size, num_words] parent_end_mask: [batch_size, num_words] child_mask: [batch_size, num_words] parent_idxs: [batch_size] parent_tags: [batch_size] parent_starts: [batch_size] parent_ends: [batch_size] child_idxs: [batch_size, num_words] child_starts: [batch_size, num_words] child_ends: [batch_size, num_words] Returns: parent_probs: [batch_size, num_words] parent_tag_probs: [batch_size, num_words, num_tags] parent_start_probs: [batch_size, num_words] parent_end_probs: [batch_size, num_words] child_probs: [batch_size, num_words] child_start_probs: [batch_size, num_words] child_end_probs: [batch_size, num_words] arc_loss (if parent_idx is not None) tag_loss (if parent_idxs and parent_tags are not None) start_loss (if parent_starts is not None) end_loss (if parent_ends is not None) child_loss (if child_idxs is not None) child_start_loss (if child_starts is not None) child_end_loss (if child_ends is not None) """ cls_embedding, embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": # bert = self.bert if self.arch == "bert" else self.roberta extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, seq_len, encoding_dim = encoded_text.size() # shape (batch_size, sequence_length, tag_classes) parent_tag_scores = self.parent_tag_feedforward(encoded_text) # shape (batch_size, sequence_length) parent_scores = self.parent_feedforward(encoded_text).squeeze(-1) parent_start_scores = self.parent_start_feedforward( encoded_text).squeeze(-1) parent_end_scores = self.parent_end_feedforward(encoded_text).squeeze( -1) # mask out impossible positions minus_inf = -1e8 parent_mask = torch.logical_and(parent_mask, word_mask) parent_scores = parent_scores + (~parent_mask).float() * minus_inf parent_start_mask = torch.logical_and(parent_start_mask, word_mask) parent_start_scores = parent_start_scores + ( ~parent_start_mask).float() * minus_inf parent_end_mask = torch.logical_and(parent_end_mask, word_mask) parent_end_scores = parent_end_scores + ( ~parent_end_mask).float() * minus_inf parent_probs = F.softmax(parent_scores, dim=-1) parent_start_probs = F.softmax(parent_start_scores, dim=-1) parent_end_probs = F.softmax(parent_end_scores, dim=-1) parent_tag_probs = F.softmax(parent_tag_scores, dim=-1) output = (parent_probs, parent_tag_probs, parent_start_probs, parent_end_probs) if self.config.predict_child: child_scores = self.child_feedforward(encoded_text).squeeze(-1) child_start_scores = self.child_start_feedforward( encoded_text).squeeze(-1) child_end_scores = self.child_end_feedforward( encoded_text).squeeze(-1) # todo add child mask - child should be inside the origin span if child_mask is None: child_mask = torch.ones_like(word_mask) else: child_mask = torch.logical_and(child_mask, word_mask) child_scores = child_scores + (~child_mask).float() * minus_inf child_start_scores = child_start_scores + ( ~child_mask).float() * minus_inf child_end_scores = child_end_scores + ( ~child_mask).float() * minus_inf child_probs = torch.sigmoid(child_scores) child_start_probs = torch.sigmoid(child_start_scores) child_end_probs = torch.sigmoid(child_end_scores) output = output + (child_probs, child_start_probs, child_end_probs) # add losses batch_range_vector = get_range_vector( batch_size, get_device_of(encoded_text)) # [bsz] if parent_idxs is not None: # [bsz, seq_len] parent_logits = F.log_softmax(parent_scores, dim=-1) parent_arc_nll = -parent_logits[batch_range_vector, parent_idxs] parent_arc_nll = parent_arc_nll.mean() output = output + (parent_arc_nll, ) if parent_tags is not None: parent_tag_nll = F.cross_entropy( parent_tag_scores[batch_range_vector, parent_idxs], parent_tags) output = output + (parent_tag_nll, ) if parent_starts is not None: # [bsz, seq_len] parent_start_logits = F.log_softmax(parent_start_scores, dim=-1) parent_start_nll = -parent_start_logits[batch_range_vector, parent_starts].mean() output = output + (parent_start_nll, ) if parent_ends is not None: # [bsz, seq_len] parent_end_logits = F.log_softmax(parent_end_scores, dim=-1) parent_end_nll = -parent_end_logits[batch_range_vector, parent_ends].mean() output = output + (parent_end_nll, ) if self.config.predict_child: if child_idxs is not None: child_loss = F.binary_cross_entropy_with_logits( child_scores, child_idxs.float(), reduction="none") child_loss = (child_loss * child_mask).sum() / (child_mask.sum() + 1e-8) output = output + (child_loss, ) if child_starts is not None: child_start_loss = F.binary_cross_entropy_with_logits( child_start_scores, child_starts.float(), reduction="none") child_start_loss = (child_start_loss * child_mask).sum() / ( child_mask.sum() + 1e-8) output = output + (child_start_loss, ) if child_ends is not None: child_end_loss = F.binary_cross_entropy_with_logits( child_end_scores, child_ends.float(), reduction="none") child_end_loss = (child_end_loss * child_mask).sum() / (child_mask.sum() + 1e-8) output = output + (child_end_loss, ) return output
def forward(self, tokens: torch.Tensor, mask: torch.BoolTensor): if mask is not None: tokens = tokens * mask.unsqueeze(-1) else: # If mask doesn't exist create one of shape (batch_size, num_tokens) mask = torch.ones(tokens.shape[0], tokens.shape[1], device=tokens.device).bool() # Our input is expected to have shape `(batch_size, num_tokens, embedding_dim)`. The # convolution layers expect input of shape `(batch_size, in_channels, sequence_length)`, # where the conv layer `in_channels` is our `embedding_dim`. We thus need to transpose the # tensor first. tokens = torch.transpose(tokens, 1, 2) # Each convolution layer returns output of size `(batch_size, num_filters, pool_length)`, # where `pool_length = num_tokens - ngram_size + 1`. We then do an activation function, # masking, then do max pooling over each filter for the whole input sequence. # Because our max pooling is simple, we just use `torch.max`. The resultant tensor has shape # `(batch_size, num_conv_layers * num_filters)`, which then gets projected using the # projection layer, if requested. # To ensure the cnn_encoder respects masking we add a large negative value to # the activations of all filters that convolved over a masked token. We do this by # first enumerating all filters for a given convolution size (torch.arange()) # then by comparing it to an index of the last filter that does not involve a masked # token (.ge()) and finally adjusting dimensions to allow for addition and multiplying # by a large negative value (.unsqueeze()) filter_outputs = [] batch_size = tokens.shape[0] # shape: (batch_size, 1) last_unmasked_tokens = mask.sum(dim=1).unsqueeze(dim=-1) for i in range(len(self._convolution_layers)): convolution_layer = getattr(self, "conv_layer_{}".format(i)) pool_length = tokens.shape[2] - convolution_layer.kernel_size[0] + 1 # Forward pass of the convolutions. # shape: (batch_size, num_filters, pool_length) activations = self._activation(convolution_layer(tokens)) # Create activation mask. # shape: (batch_size, pool_length) indices = (torch.arange( pool_length, device=activations.device).unsqueeze(0).expand( batch_size, pool_length)) # shape: (batch_size, pool_length) activations_mask = indices.ge(last_unmasked_tokens - convolution_layer.kernel_size[0] + 1) # shape: (batch_size, num_filters, pool_length) activations_mask = activations_mask.unsqueeze(1).expand_as( activations) # Replace masked out values with smallest possible value of the dtype so # that max pooling will ignore these activations. # shape: (batch_size, pool_length) activations = activations + (activations_mask * min_value_of_dtype(activations.dtype)) # Pick out the max filters filter_outputs.append(activations.max(dim=2)[0]) # Now we have a list of `num_conv_layers` tensors of shape `(batch_size, num_filters)`. # Concatenating them gives us a tensor of shape `(batch_size, num_filters * num_conv_layers)`. maxpool_output = (torch.cat(filter_outputs, dim=1) if len(filter_outputs) > 1 else filter_outputs[0]) # Replace the maxpool activations that picked up the masks with 0s maxpool_output[maxpool_output == min_value_of_dtype( maxpool_output.dtype)] = 0.0 if self.projection_layer: result = self.projection_layer(maxpool_output) else: result = maxpool_output return result
def get_lengths_from_binary_sequence_mask(mask: torch.BoolTensor) -> torch.LongTensor: return mask.sum(-1)