def reduce_output( self, tensor: torch.Tensor, mask: torch.BoolTensor ) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]: """ Reduce transformer output at end of forward pass. :param tensor: encoded input :param mask: mask for encoded input :return (tensor, mask): returns the reduced tensor, and mask if appropriate """ tensor *= self.output_scaling if self.reduction_type == 'first': return tensor[:, 0, :], None elif self.reduction_type == 'max': return tensor.max(dim=1)[0], None elif self.reduction_type == 'mean': divisor = mask.float().sum(dim=1).unsqueeze(-1).clamp(min=1).type_as(tensor) output = tensor.sum(dim=1) / divisor return output, None elif self.reduction_type is None or 'none' in self.reduction_type: return tensor, mask else: raise ValueError( "Can't handle --reduction-type {}".format(self.reduction_type) )
def forward(self, scores: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: """Map a score vector to the uniform probability distribution Args: scores (torch.Tensor): (Batch x Sequence Length) Attention scores (also referred to as weights) mask (torch.BoolTensor): (Batch x Sequence Length) Specifies which indices are just padding Returns: torch.Tensor: the Uniform distribution """ lengths = mask.float().sum(dim=-1, keepdim=True) scores = 1.0 / lengths uniform = mask.float() * scores return uniform
def sequence_cross_entropy_with_logits( logits: torch.FloatTensor, targets: torch.LongTensor, mask: torch.BoolTensor, label_smoothing: bool, reduce: str = "mean") -> torch.FloatTensor: """ label_smoothing : ``float``, optional (default = 0.0) It should be smaller than 1. """ # shape : (batch * sequence_length, num_classes) logits_flat = logits.view(-1, logits.size(-1)) # shape : (batch * sequence_length, num_classes) log_probs_flat = F.log_softmax(logits_flat, dim=-1) # shape : (batch * max_len, 1) targets_flat = targets.view(-1, 1).long() if label_smoothing > 0.0: num_classes = logits.size(-1) smoothing_value = label_smoothing / float(num_classes) # Fill all the correct indices with 1 - smoothing value. one_hot_targets = torch.zeros_like(log_probs_flat).scatter_( -1, targets_flat, 1.0 - label_smoothing) smoothed_targets = one_hot_targets + smoothing_value negative_log_likelihood_flat = -log_probs_flat * smoothed_targets negative_log_likelihood_flat = negative_log_likelihood_flat.sum( -1, keepdim=True) else: # shape : (batch * sequence_length, 1) negative_log_likelihood_flat = -torch.gather( log_probs_flat, dim=1, index=targets_flat) # shape : (batch, sequence_length) negative_log_likelihood = negative_log_likelihood_flat.view( -1, logits.shape[1]) mask = mask.float() # shape : (batch, sequence_length) loss = negative_log_likelihood * mask if reduce == "mean": loss = loss.sum() / (mask.sum() + 1e-13) elif reduce == "batch": # shape : (batch,) loss = loss.sum(1) / (mask.sum(1) + 1e-13) elif reduce == "batch-sequence": # we favor longer sequences, so we don't divide with the total sequence length here # shape : (batch,) loss = loss.sum(1) return loss
def _compute_score( self, emissions: torch.Tensor, tags: torch.LongTensor, mask: torch.BoolTensor ) -> torch.Tensor: # emissions: (seq_length, batch_size, num_tags) # tags: (seq_length, batch_size) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and tags.dim() == 2 assert emissions.shape[:2] == tags.shape assert emissions.size(2) == self.num_tags assert mask.shape == tags.shape assert mask[0].all() seq_length, batch_size = tags.shape mask = mask.float() # Start transition score and first emission # shape: (batch_size,) score = self.start_transitions[tags[0]] score += emissions[0, torch.arange(batch_size), tags[0]] for i in range(1, seq_length): # Transition score to next tag, only added if next timestep is valid (mask == 1) # shape: (batch_size,) score += self.transitions[tags[i - 1], tags[i]] * mask[i] # Emission score for next tag, only added if next timestep is valid (mask == 1) # shape: (batch_size,) score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] # End transition score # shape: (batch_size,) seq_ends = mask.long().sum(dim=0) - 1 # shape: (batch_size,) last_tags = tags[seq_ends, torch.arange(batch_size)] # shape: (batch_size,) score += self.end_transitions[last_tags] return score
def softmax_mask(w: torch.Tensor, dim=-1, mask: torch.BoolTensor = None) -> torch.Tensor: """ Allows having -np.inf in w to mask out, or give explicit bool mask :param w: :param dim: :param mask: :return: """ if mask is None: mask = w != -np.inf minval = torch.min(w[~mask]) # to avoid affecting torch.max w1 = w.clone() w1[~mask] = minval # to prevent over/underflow w1 = w1 - torch.max(w1, dim=dim, keepdim=True)[0] w1 = torch.exp(w1) p = w1 / torch.sum(w1 * mask.float(), dim=dim, keepdim=True) p[~mask] = 0. return p
def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, parent_mask: torch.BoolTensor, parent_start_mask: torch.BoolTensor, parent_end_mask: torch.BoolTensor, child_mask: torch.BoolTensor = None, parent_idxs: torch.LongTensor = None, parent_tags: torch.LongTensor = None, parent_starts: torch.BoolTensor = None, parent_ends: torch.BoolTensor = None, child_idxs: torch.BoolTensor = None, child_starts: torch.BoolTensor = None, child_ends: torch.BoolTensor = None, ): """ todo implement docstring Args: token_ids: [batch_size, num_word_pieces] type_ids: [batch_size, num_word_pieces] offsets: [batch_size, num_words, 2] wordpiece_mask: [batch_size, num_word_pieces] pos_tags: [batch_size, num_words] word_mask: [batch_size, num_words] parent_mask: [batch_size, num_words] parent_start_mask: [batch_size, num_words] parent_end_mask: [batch_size, num_words] child_mask: [batch_size, num_words] parent_idxs: [batch_size] parent_tags: [batch_size] parent_starts: [batch_size] parent_ends: [batch_size] child_idxs: [batch_size, num_words] child_starts: [batch_size, num_words] child_ends: [batch_size, num_words] Returns: parent_probs: [batch_size, num_words] parent_tag_probs: [batch_size, num_words, num_tags] parent_start_probs: [batch_size, num_words] parent_end_probs: [batch_size, num_words] child_probs: [batch_size, num_words] child_start_probs: [batch_size, num_words] child_end_probs: [batch_size, num_words] arc_loss (if parent_idx is not None) tag_loss (if parent_idxs and parent_tags are not None) start_loss (if parent_starts is not None) end_loss (if parent_ends is not None) child_loss (if child_idxs is not None) child_start_loss (if child_starts is not None) child_end_loss (if child_ends is not None) """ cls_embedding, embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": # bert = self.bert if self.arch == "bert" else self.roberta extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, seq_len, encoding_dim = encoded_text.size() # shape (batch_size, sequence_length, tag_classes) parent_tag_scores = self.parent_tag_feedforward(encoded_text) # shape (batch_size, sequence_length) parent_scores = self.parent_feedforward(encoded_text).squeeze(-1) parent_start_scores = self.parent_start_feedforward( encoded_text).squeeze(-1) parent_end_scores = self.parent_end_feedforward(encoded_text).squeeze( -1) # mask out impossible positions minus_inf = -1e8 parent_mask = torch.logical_and(parent_mask, word_mask) parent_scores = parent_scores + (~parent_mask).float() * minus_inf parent_start_mask = torch.logical_and(parent_start_mask, word_mask) parent_start_scores = parent_start_scores + ( ~parent_start_mask).float() * minus_inf parent_end_mask = torch.logical_and(parent_end_mask, word_mask) parent_end_scores = parent_end_scores + ( ~parent_end_mask).float() * minus_inf parent_probs = F.softmax(parent_scores, dim=-1) parent_start_probs = F.softmax(parent_start_scores, dim=-1) parent_end_probs = F.softmax(parent_end_scores, dim=-1) parent_tag_probs = F.softmax(parent_tag_scores, dim=-1) output = (parent_probs, parent_tag_probs, parent_start_probs, parent_end_probs) if self.config.predict_child: child_scores = self.child_feedforward(encoded_text).squeeze(-1) child_start_scores = self.child_start_feedforward( encoded_text).squeeze(-1) child_end_scores = self.child_end_feedforward( encoded_text).squeeze(-1) # todo add child mask - child should be inside the origin span if child_mask is None: child_mask = torch.ones_like(word_mask) else: child_mask = torch.logical_and(child_mask, word_mask) child_scores = child_scores + (~child_mask).float() * minus_inf child_start_scores = child_start_scores + ( ~child_mask).float() * minus_inf child_end_scores = child_end_scores + ( ~child_mask).float() * minus_inf child_probs = torch.sigmoid(child_scores) child_start_probs = torch.sigmoid(child_start_scores) child_end_probs = torch.sigmoid(child_end_scores) output = output + (child_probs, child_start_probs, child_end_probs) # add losses batch_range_vector = get_range_vector( batch_size, get_device_of(encoded_text)) # [bsz] if parent_idxs is not None: # [bsz, seq_len] parent_logits = F.log_softmax(parent_scores, dim=-1) parent_arc_nll = -parent_logits[batch_range_vector, parent_idxs] parent_arc_nll = parent_arc_nll.mean() output = output + (parent_arc_nll, ) if parent_tags is not None: parent_tag_nll = F.cross_entropy( parent_tag_scores[batch_range_vector, parent_idxs], parent_tags) output = output + (parent_tag_nll, ) if parent_starts is not None: # [bsz, seq_len] parent_start_logits = F.log_softmax(parent_start_scores, dim=-1) parent_start_nll = -parent_start_logits[batch_range_vector, parent_starts].mean() output = output + (parent_start_nll, ) if parent_ends is not None: # [bsz, seq_len] parent_end_logits = F.log_softmax(parent_end_scores, dim=-1) parent_end_nll = -parent_end_logits[batch_range_vector, parent_ends].mean() output = output + (parent_end_nll, ) if self.config.predict_child: if child_idxs is not None: child_loss = F.binary_cross_entropy_with_logits( child_scores, child_idxs.float(), reduction="none") child_loss = (child_loss * child_mask).sum() / (child_mask.sum() + 1e-8) output = output + (child_loss, ) if child_starts is not None: child_start_loss = F.binary_cross_entropy_with_logits( child_start_scores, child_starts.float(), reduction="none") child_start_loss = (child_start_loss * child_mask).sum() / ( child_mask.sum() + 1e-8) output = output + (child_start_loss, ) if child_ends is not None: child_end_loss = F.binary_cross_entropy_with_logits( child_end_scores, child_ends.float(), reduction="none") child_end_loss = (child_end_loss * child_mask).sum() / (child_mask.sum() + 1e-8) output = output + (child_end_loss, ) return output
def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, span_idx: torch.LongTensor, span_tag: torch.LongTensor, child_arcs: torch.LongTensor, child_tags: torch.LongTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, mrc_mask: torch.BoolTensor, ): """ todo implement docstring Args: token_ids: [batch_size, num_word_pieces] type_ids: [batch_size, num_word_pieces] offsets: [batch_size, num_words, 2] wordpiece_mask: [batch_size, num_word_pieces] span_idx: [batch_size, 2] span_tag: [batch_size, 1] child_arcs: [batch_size, num_words] child_tags: [batch_size, num_words] pos_tags: [batch_size, num_words] word_mask: [batch_size, num_words] mrc_mask: [batch_size, num_words] Returns: parent_probs: [batch_size, num_word] parent_tag_probs: [batch_size, num_words] arc_nll: [1] tag_nll: [1] """ embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, seq_len, encoding_dim = encoded_text.size() # shape (batch_size, sequence_length, tag_classes) parent_tag_scores = self.parent_tag_feedforward(encoded_text) # shape (batch_size, sequence_length) parent_scores = self.parent_feedforward(encoded_text).squeeze(-1) # [bsz, seq_len, tag_classes] child_tag_scores = self.child_tag_feedforward(encoded_text) # [bsz, seq_len] child_scores = self.child_feedforward(encoded_text).squeeze(-1) # todo support cases that span_idx and span_tag are None # [bsz] batch_range_vector = get_range_vector(batch_size, get_device_of(encoded_text)) # [bsz] gold_positions = span_idx[:, 0] # compute parent arc loss minus_inf = -1e8 mrc_mask = torch.logical_and(mrc_mask, word_mask) parent_scores = parent_scores + (~mrc_mask).float() * minus_inf child_scores = child_scores + (~mrc_mask).float() * minus_inf # [bsz, seq_len] parent_logits = F.log_softmax(parent_scores, dim=-1) parent_arc_nll = -parent_logits[batch_range_vector, gold_positions].mean() # compute parent tag loss parent_tag_nll = F.cross_entropy( parent_tag_scores[batch_range_vector, gold_positions], span_tag) parent_probs = F.softmax(parent_scores, dim=-1) parent_tag_probs = F.softmax(parent_tag_scores, dim=-1) child_probs = F.sigmoid(child_scores) child_tag_probs = F.softmax(child_tag_scores, dim=-1) child_arc_loss = F.binary_cross_entropy_with_logits(child_scores, child_arcs.float(), reduction="none") child_arc_loss = (child_arc_loss * mrc_mask.float()).sum() / mrc_mask.float().sum() child_tag_loss = F.cross_entropy(child_tag_scores.view( batch_size * seq_len, -1), child_tags.view(-1), reduction="none") child_tag_loss = (child_tag_loss * child_arcs.float().view(-1) ).sum() / (child_arcs.float().sum() + 1e-8) return parent_probs, parent_tag_probs, child_probs, child_tag_probs, parent_arc_nll, parent_tag_nll, child_arc_loss, child_tag_loss
def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log() return F.cross_entropy(pred, true, reduction="none") * mask