def extract_features( self, tokens: torch.LongTensor, return_all_hiddens: bool = False ) -> torch.Tensor: if tokens.dim() == 1: tokens = tokens.unsqueeze(0) if tokens.size(-1) > min(self.model.max_positions()): raise ValueError( "tokens exceeds maximum length: {} > {}".format( tokens.size(-1), self.model.max_positions() ) ) tokens.to(device=self.device), prev_output_tokens = tokens.clone() prev_output_tokens[:, 0] = tokens.gather( 1, (tokens.ne(self.task.source_dictionary.pad()).sum(dim=1) - 1).unsqueeze(-1), ).squeeze() prev_output_tokens[:, 1:] = tokens[:, :-1] features, extra = self.model( src_tokens=tokens, src_lengths=None, prev_output_tokens=prev_output_tokens, features_only=True, return_all_hiddens=return_all_hiddens, ) if return_all_hiddens: # convert from T x B x C -> B x T x C inner_states = extra["inner_states"] return [inner_state.transpose(0, 1) for inner_state in inner_states] else: return features # just the last layer's features
def forward(self, emissions: FloatTensor, labels: LongTensor, mask: Optional[BoolTensor] = None) -> FloatTensor: """ Computes the negative log-likelihood given emission scores for a sequence of tags using the forward algorithm. Parameters: ----------- emissions: FloatTensor shape: (seq_length, batch_size, num_tags) if batch_first is False emission score for each tag type and timestep labels: LongTensor shape: (seq_length, batch_size) if batch_first is False ground truth tag sequences mask: Optional[BoolTensor] shape: (seq_length, batch_size) if batch_first is False optional boolean mask for each sequence Returns: -------- result: torch.FloatTensor shape: () Negative log-likelihood normalized by the mask sum """ if mask is None: mask = torch.ones_like(labels, dtype=torch.bool) if self.batch_first: emissions = emissions.transpose(0, 1) labels = labels.transpose(0, 1) mask = mask.transpose(0, 1) numerator = self.starts[labels[0]] numerator += (emissions.gather(2, labels.unsqueeze(-1)).squeeze(-1) * mask).sum(dim=0) numerator += (self.transitions[labels[:-1], labels[1:]] * mask[1:]).sum(dim=0) seq_ends = mask.long().sum(dim=0) - 1 last_tags = labels.gather(0, seq_ends.unsqueeze(0)).squeeze(0) numerator += self.ends[last_tags] denominator = self.starts + emissions[0] broadcast_emissions = emissions.unsqueeze(2) for i in range(1, labels.shape[0]): broadcast_denominator = denominator.unsqueeze(2) next_denominator = broadcast_denominator + self.transitions + broadcast_emissions[ i] next_denominator = next_denominator.logsumexp(dim=1) denominator = next_denominator.where(mask[i].unsqueeze(1), denominator) denominator += self.ends denominator = denominator.logsumexp(dim=1) llh = numerator - denominator return -llh.sum() / mask.sum()
def _compute_numerator_log_likelihood( self, h: torch.FloatTensor, y: torch.LongTensor, mask: torch.FloatTensor) -> torch.FloatTensor: """ compute the numerator term for the log-likelihood :param h: hidden matrix (batch_size, seq_len, num_labels) :param y: answer labels of each sequence in mini batch (batch_size, seq_len) :param mask: mask tensor of each sequence in mini batch (batch_size, seq_len) :return: The score of numerator term for the log-likelihood """ batch_size, seq_len, _ = h.size() # 系列のスタート位置のベクトルを抽出 # extract first vector of sequences in mini batch score = self.start_trans[y[:, 0]] h = h.unsqueeze(-1) trans = self.trans_matrix.unsqueeze(-1) for t in range(seq_len - 1): mask_t = mask[:, t] mask_t1 = mask[:, t + 1] # t+1番目のラベルのスコアを抽出 # extract the score of t+1 label # (batch_size) h_t = torch.cat([h[b, t, y[b, t]] for b in range(batch_size)]) # t番目のラベルからt+1番目のラベルへの遷移スコアを抽出 # extract the transition score from t-th label to t+1 label # (batch_size) trans_t = torch.cat([trans[s[t], s[t + 1]] for s in y]) # 足し合わせる # add the score of t+1 and the transition score # (batch_size) score += h_t * mask_t + trans_t * mask_t1 # バッチ内の各系列の最後尾のラベル番号を抽出する # extract end label number of each sequence in mini batch # (batch_size) last_mask_index = mask.long().sum(1) - 1 last_labels = y.gather(1, last_mask_index.unsqueeze(-1)) # hの形を元に戻す # restore the shape of h h = h.unsqueeze(-1).view(batch_size, seq_len, self.num_labels) # バッチ内の最大長の系列のスコアを足し合わせる # Add the score of the sequences of the maximum length in mini batch score += h[:, -1].gather(1, last_labels).squeeze(1) * mask[:, -1] # 各系列の最後尾のタグからEOSまでのスコアを足し合わせる # Add the scores from the last tag of each sequence to EOS score += self.end_trans[last_labels].view(batch_size) return score
def _numerator_score(self, emissions: torch.Tensor, tags: torch.LongTensor, mask: torch.ByteTensor) -> torch.Tensor: """ Parameters: emissions: (batch_size, sequence_length, num_tags) tags: (batch_size, sequence_length) mask: Show padding tags. 0 don't calculate score. (batch_size, sequence_length) Returns: scores: (batch_size) """ batch_size, sequence_length, _ = emissions.data.shape emissions = emissions.transpose(0, 1).contiguous() tags = tags.transpose(0, 1).contiguous() mask = mask.float().transpose(0, 1).contiguous() # Start transition score and first emission score = self.start_transitions.index_select(0, tags[0]) for i in range(sequence_length - 1): current_tag, next_tag = tags[i], tags[i + 1] # Emissions score for next tag emissions_score = emissions[i].gather( 1, current_tag.view(batch_size, 1)).squeeze(1) # Transition score from current_tag to next_tag transition_score = self.transitions[current_tag.view(-1), next_tag.view(-1)] # Add all score score += transition_score * mask[i + 1] + emissions_score * mask[i] # Add end transition score last_tag_index = mask.sum(0).long() - 1 last_tags = tags.gather(0, last_tag_index.view(1, batch_size)).squeeze(0) # Compute score of transitioning to STOP_TAG from each LAST_TAG last_transition_score = self.end_transitions.index_select(0, last_tags) last_inputs = emissions[-1] # (batch_size, num_tags) last_input_score = last_inputs.gather(1, last_tags.view( -1, 1)) # (batch_size, 1) last_input_score = last_input_score.squeeze() # (batch_size,) score = score + last_transition_score + last_input_score * mask[-1] return score
def _compute_numerator_log_likelihood(self, h: FloatTensor, y: LongTensor, mask: BoolTensor) -> FloatTensor: """ compute the numerator term for the log-likelihood :param h: hidden matrix (batch_size, seq_len, num_labels) :param y: answer labels of each sequence in mini batch (batch_size, seq_len) :param mask: mask tensor of each sequence in mini batch (batch_size, seq_len) :return: The score of numerator term for the log-likelihood """ batch_size, seq_len, _ = h.size() # extract first vector of sequences in mini batch score = self.start_trans[y[:, 0]] h = h.unsqueeze(-1) trans = self.trans_matrix.unsqueeze(-1) for t in range(seq_len - 1): mask_t = mask[:, t].cuda() if CRF.CUDA else mask[:, t] mask_t1 = mask[:, t + 1] if CRF.CUDA else mask[:, t + 1] # extract the score of t+1 label # (batch_size) h_t = torch.cat([h[b, t, y[b, t]] for b in range(batch_size)]) # extract the transition score from t-th label to t+1 label # (batch_size) trans_t = torch.cat([trans[s[t], s[t + 1]] for s in y]) # add the score of t+1 and the transition score # (batch_size) score += h_t * mask_t + trans_t * mask_t1 # extract end label number of each sequence in mini batch # (batch_size) last_mask_index = mask.long().sum(1) - 1 last_labels = y.gather(1, last_mask_index.unsqueeze(-1)) # restore the shape of h h = h.unsqueeze(-1).view(batch_size, seq_len, self.num_labels) # Add the score of the sequences of the maximum length in mini batch score += h[:, -1].gather(1, last_labels).squeeze(1) * mask[:, -1] # Add the scores from the last tag of each sequence to EOS score += self.end_trans[last_labels].view(batch_size) return score
def _compute_joint_llh(self, emissions: torch.Tensor, tags: torch.LongTensor, mask: torch.ByteTensor) -> torch.Tensor: # emissions: (seq_length, batch_size, num_tags) # tags: (seq_length, batch_size) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and tags.dim() == 2 assert emissions.size()[:2] == tags.size() assert emissions.size(2) == self.num_tags assert mask.size() == tags.size() assert all(mask[0]) seq_length = emissions.size(0) mask = mask.float() # Start transition score llh = self.start_transitions[tags[0]] # (batch_size,) for i in range(seq_length - 1): cur_tag, next_tag = tags[i], tags[i + 1] # Emission score for current tag llh += emissions[i].gather(1, cur_tag.view(-1, 1)).squeeze(1) * mask[i] # Transition score to next tag transition_score = self.transitions[cur_tag, next_tag] # Only add transition score if the next tag is not masked (mask == 1) llh += transition_score * mask[i + 1] # Find last tag index last_tag_indices = mask.long().sum(0) - 1 # (batch_size,) last_tags = tags.gather(0, last_tag_indices.view(1, -1)).squeeze(0) # End transition score llh += self.end_transitions[last_tags] # Emission score for the last tag, if mask is valid (mask == 1) llh += emissions[-1].gather(1, last_tags.view(-1, 1)).squeeze(1) * mask[-1] return llh
def _action_to_token(self, action_tokens: torch.LongTensor, draft_tokens: torch.LongTensor) -> torch.LongTensor: predicted_pointer = action_tokens.new_zeros((draft_tokens.size(0), 1)) draft_pointer = draft_tokens.new_ones((draft_tokens.size(0), 1)) predicted_tokens = action_tokens.new_full((action_tokens.size()), self.END) for act_step in action_tokens.t(): # KEEP, DELETE, COPY, ADD (other) keep_mask = act_step == self.KEEP drop_mask = act_step == self.DROP add_mask = ~(keep_mask | drop_mask) predicted_tokens.scatter_(1, predicted_pointer, draft_tokens.gather(1, draft_pointer)) predicted_tokens[add_mask] = predicted_tokens[add_mask].scatter( 1, predicted_pointer[add_mask], act_step[add_mask].unsqueeze(1)) draft_pointer[keep_mask | drop_mask] += 1 predicted_pointer[~drop_mask] += 1 return predicted_tokens
def _compute_joint_llh( self, emissions: torch.FloatTensor, tags: torch.LongTensor, mask: torch.FloatTensor, ) -> torch.Tensor: seq_len = emissions.shape[1] # Log-likelihood for a given input is calculated by using the known # correct tag for each timestep and its respective emission value. # Since actual tags for each time step is also known, sum of transition # probabilities is also calculated. # Sum of emission and transition probabilities gives the final score for # the input. llh = self.transitions[self.start_tag, tags[:, 0]].unsqueeze(1) llh += emissions[:, 0, :].gather(1, tags[:, 0].view( -1, 1)) * mask[:, 0].unsqueeze(1) for idx in range(1, seq_len): old_state, new_state = ( tags[:, idx - 1].view(-1, 1), tags[:, idx].view(-1, 1), ) emission_scores = emissions[:, idx, :].gather(1, new_state) transition_scores = self.transitions[old_state, new_state] llh += (emission_scores + transition_scores) * mask[:, idx].unsqueeze(1) # Index of the last tag is calculated by taking the sum of mask matrix # for each input row and subtracting 1 from the sum. last_tag_indices = mask.sum(1, dtype=torch.long) - 1 last_tags = tags.gather(1, last_tag_indices.view(-1, 1)) llh += self.transitions[last_tags.squeeze(1), self.end_tag].unsqueeze(1) return llh.squeeze(1)