def _make_mask_from_seq_lens(self, seq_lens): seq_lens = seq_lens.view(-1, 1) max_len = torch.max(seq_lens) range_tensor = GetTensor(torch.arange(max_len)).unsqueeze(0) range_tensor = range_tensor.expand(seq_lens.size(0), range_tensor.size(1)) mask = (range_tensor < seq_lens).float() return mask
def _shift_target(in_sequences, seq_lens, eos_idx, pad_idx): shifted_sequence = GetTensor( torch.LongTensor(in_sequences.size()).fill_(pad_idx)) for i, in_seq in enumerate(in_sequences): shifted_sequence[i, 0] = eos_idx # Copy everything except ones starting from the EOS at the end. shifted_sequence[ i, 1:seq_lens[i].item()] = in_seq[0:seq_lens[i].item() - 1] return shifted_sequence
def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.FloatTensor) -> torch.Tensor: seq_len = emissions.shape[1] mask = mask.to(torch.uint8) log_prob = emissions[:, 0].clone() log_prob += self.transitions[ self.start_tag, :self.start_tag].unsqueeze(0) # At each step, we need to keep track of the total score, as if this step # was the last valid step. end_scores = log_prob + self.transitions[:self.start_tag, self.end_tag].unsqueeze(0) best_scores_list = [] # If the element has only token, empty tensor in best_paths helps # torch.cat() from crashing best_paths_list = [GetTensor(torch.Tensor().long())] best_scores_list.append(end_scores.unsqueeze(1)) for idx in range(1, seq_len): broadcast_emissions = emissions[:, idx].unsqueeze(1) broadcast_transmissions = self.transitions[:self.start_tag, :self. start_tag].unsqueeze(0) broadcast_log_prob = log_prob.unsqueeze(2) score = broadcast_emissions + broadcast_transmissions + broadcast_log_prob max_scores, max_score_indices = torch.max(score, 1) best_paths_list.append(max_score_indices.unsqueeze(1)) # Storing the scores incase this was the last step. end_scores = max_scores + self.transitions[:self.start_tag, self. end_tag].unsqueeze(0) best_scores_list.append(end_scores.unsqueeze(1)) log_prob = max_scores best_scores = torch.cat(best_scores_list, 1).float() best_paths = torch.cat(best_paths_list, 1) _, max_indices_from_scores = torch.max(best_scores, 2) valid_index_tensor = GetTensor(torch.tensor(0)).long() padding_tensor = GetTensor(torch.tensor(self.ignore_index)).long() # Label for the last position is always based on the index with max score # For illegal timesteps, we set as ignore_index labels = max_indices_from_scores[:, seq_len - 1] labels = self._mask_tensor(labels, 1.0 - mask[:, seq_len - 1], padding_tensor) all_labels = labels.unsqueeze(1).long() # For Viterbi decoding, we start at the last position and go towards first for idx in range(seq_len - 2, -1, -1): # There are two ways to obtain labels for tokens at a particular position. # Option 1: Use the labels obtained from the previous position to index # the path in present position. This is used for all positions except # last position in the sequence. # Option 2: Find the indices with maximum scores obtained during # viterbi decoding. This is used for the token at the last position # For option 1 need to convert invalid indices to 0 so that lookups # dont fail. indices_for_lookup = all_labels[:, -1].clone() indices_for_lookup = self._mask_tensor( indices_for_lookup, indices_for_lookup == self.ignore_index, valid_index_tensor, ) # Option 1 is used here when previous timestep (idx+1) was valid. indices_from_prev_pos = (best_paths[:, idx, :].gather( 1, indices_for_lookup.view(-1, 1).long()).squeeze(1)) indices_from_prev_pos = self._mask_tensor(indices_from_prev_pos, (1.0 - mask[:, idx + 1]), padding_tensor) # Option 2 is used when last timestep was not valid which means idx+1 # is the last position in the sequence. indices_from_max_scores = max_indices_from_scores[:, idx] indices_from_max_scores = self._mask_tensor( indices_from_max_scores, mask[:, idx + 1], padding_tensor) # We need to combine results from 1 and 2 as rows in a batch can have # sequences of varying lengths labels = torch.where( indices_from_max_scores == self.ignore_index, indices_from_prev_pos, indices_from_max_scores, ) # Set to ignore_index if present state is not valid. labels = self._mask_tensor(labels, (1 - mask[:, idx]), padding_tensor) all_labels = torch.cat((all_labels, labels.view(-1, 1).long()), 1) return torch.flip(all_labels, [1])