コード例 #1
0
ファイル: crf.py プロジェクト: whitespur/pytext
 def _make_mask_from_seq_lens(self, seq_lens):
     seq_lens = seq_lens.view(-1, 1)
     max_len = torch.max(seq_lens)
     range_tensor = GetTensor(torch.arange(max_len)).unsqueeze(0)
     range_tensor = range_tensor.expand(seq_lens.size(0), range_tensor.size(1))
     mask = (range_tensor < seq_lens).float()
     return mask
コード例 #2
0
 def _shift_target(in_sequences, seq_lens, eos_idx, pad_idx):
     shifted_sequence = GetTensor(
         torch.LongTensor(in_sequences.size()).fill_(pad_idx))
     for i, in_seq in enumerate(in_sequences):
         shifted_sequence[i, 0] = eos_idx
         # Copy everything except ones starting from the EOS at the end.
         shifted_sequence[
             i, 1:seq_lens[i].item()] = in_seq[0:seq_lens[i].item() - 1]
     return shifted_sequence
コード例 #3
0
ファイル: crf.py プロジェクト: jingfeidu/pytext-1
    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.FloatTensor) -> torch.Tensor:
        seq_len = emissions.shape[1]
        mask = mask.to(torch.uint8)

        log_prob = emissions[:, 0].clone()
        log_prob += self.transitions[
            self.start_tag, :self.start_tag].unsqueeze(0)

        # At each step, we need to keep track of the total score, as if this step
        # was the last valid step.
        end_scores = log_prob + self.transitions[:self.start_tag,
                                                 self.end_tag].unsqueeze(0)

        best_scores_list = []
        # If the element has only token, empty tensor in best_paths helps
        # torch.cat() from crashing
        best_paths_list = [GetTensor(torch.Tensor().long())]
        best_scores_list.append(end_scores.unsqueeze(1))

        for idx in range(1, seq_len):
            broadcast_emissions = emissions[:, idx].unsqueeze(1)
            broadcast_transmissions = self.transitions[:self.start_tag, :self.
                                                       start_tag].unsqueeze(0)
            broadcast_log_prob = log_prob.unsqueeze(2)

            score = broadcast_emissions + broadcast_transmissions + broadcast_log_prob

            max_scores, max_score_indices = torch.max(score, 1)

            best_paths_list.append(max_score_indices.unsqueeze(1))

            # Storing the scores incase this was the last step.
            end_scores = max_scores + self.transitions[:self.start_tag, self.
                                                       end_tag].unsqueeze(0)

            best_scores_list.append(end_scores.unsqueeze(1))
            log_prob = max_scores

        best_scores = torch.cat(best_scores_list, 1).float()
        best_paths = torch.cat(best_paths_list, 1)

        _, max_indices_from_scores = torch.max(best_scores, 2)

        valid_index_tensor = GetTensor(torch.tensor(0)).long()
        padding_tensor = GetTensor(torch.tensor(self.ignore_index)).long()

        # Label for the last position is always based on the index with max score
        # For illegal timesteps, we set as ignore_index
        labels = max_indices_from_scores[:, seq_len - 1]
        labels = self._mask_tensor(labels, 1.0 - mask[:, seq_len - 1],
                                   padding_tensor)

        all_labels = labels.unsqueeze(1).long()

        # For Viterbi decoding, we start at the last position and go towards first
        for idx in range(seq_len - 2, -1, -1):
            # There are two ways to obtain labels for tokens at a particular position.

            # Option 1: Use the labels obtained from the previous position to index
            # the path in present position. This is used for all positions except
            # last position in the sequence.
            # Option 2: Find the indices with maximum scores obtained during
            # viterbi decoding. This is used for the token at the last position

            # For option 1 need to convert invalid indices to 0 so that lookups
            # dont fail.
            indices_for_lookup = all_labels[:, -1].clone()
            indices_for_lookup = self._mask_tensor(
                indices_for_lookup,
                indices_for_lookup == self.ignore_index,
                valid_index_tensor,
            )

            # Option 1 is used here when previous timestep (idx+1) was valid.
            indices_from_prev_pos = (best_paths[:, idx, :].gather(
                1,
                indices_for_lookup.view(-1, 1).long()).squeeze(1))
            indices_from_prev_pos = self._mask_tensor(indices_from_prev_pos,
                                                      (1.0 - mask[:, idx + 1]),
                                                      padding_tensor)

            # Option 2 is used when last timestep was not valid which means idx+1
            # is the last position in the sequence.
            indices_from_max_scores = max_indices_from_scores[:, idx]
            indices_from_max_scores = self._mask_tensor(
                indices_from_max_scores, mask[:, idx + 1], padding_tensor)

            # We need to combine results from 1 and 2 as rows in a batch can have
            # sequences of varying lengths
            labels = torch.where(
                indices_from_max_scores == self.ignore_index,
                indices_from_prev_pos,
                indices_from_max_scores,
            )

            # Set to ignore_index if present state is not valid.
            labels = self._mask_tensor(labels, (1 - mask[:, idx]),
                                       padding_tensor)
            all_labels = torch.cat((all_labels, labels.view(-1, 1).long()), 1)

        return torch.flip(all_labels, [1])