Пример #1
0
    def forward(self,
                words: torch.Tensor,
                word_lengths: torch.Tensor,
                word_indices: torch.Tensor,
                word_idxs: torch.Tensor,
                sent_context: torch.Tensor,
                labs: torch.Tensor,
                lens: torch.Tensor = None) -> torch.Tensor:
        # pylint: disable=arguments-differ,not-callable
        """
        Compute the negative of the log-likelihood.

        Parameters
        ----------
        words : torch.Tensor
            Word tensors ``[sent_length x max_word_lenth x n_chars]``.

        word_lengths : torch.Tensor
            Contains the length of each word (in characters), with shape
            ``[sent_length]``.

        word_indices : torch.Tensor
            Contains sorted indices of words by length, with shape
            ``[sent_length]``.

        word_idxs : torch.Tensor
            Word indices with shape ``[sent_length]``.

        sent_context : torch.Tensor
            Sentence context.

        labs : torch.Tensor
            Corresponding target label sequence with shape ``[sent_length]``.

        lens : torch.Tensor, optional (default: None)
            Gives the length of each sentence in the batch ``[batch_size]``.

        Returns
        -------
        torch.Tensor
            The negative log-likelihood evaluated at the inputs.

        """
        if lens is None:
            lens = torch.tensor([words.size(0)], device=words.device)
        mask = sequence_mask(lens)

        # Fake batch dimension for ``labs``.
        labs = labs.unsqueeze(0)
        # labs: ``[1 x sent_length]``

        # Gather word feats.
        feats = self._feats(words, word_lengths, word_indices, word_idxs,
                            sent_context)
        # feats: ``[1 x sent_length x n_labels]``

        loglik = self.crf(feats, labs, mask=mask)

        return -1. * loglik
Пример #2
0
    def predict(self,
                words: torch.Tensor,
                word_lengths: torch.Tensor,
                word_indices: torch.Tensor,
                word_idxs: torch.Tensor,
                sent_context: torch.Tensor,
                lens: torch.Tensor = None) -> List[Tuple[List[int], float]]:
        # pylint: disable=not-callable
        """
        Compute the best tag sequence.

        Parameters
        ----------
        words : torch.Tensor
            Word tensors ``[sent_length x max_word_lenth x n_chars]``.

        word_lengths : torch.Tensor
            Contains the length of each word (in characters), with shape
            ``[sent_length]``.

        word_indices : torch.Tensor
            Contains sorted indices of words by length, with shape
            ``[sent_length]``.

        word_idxs : torch.Tensor
            Word indices with shape ``[sent_length]``.

        sent_context : torch.Tensor
            Sentence context.

        lens : torch.Tensor, optional (default: None)
            Gives the length of each sentence in the batch ``[batch_size]``.

        Returns
        -------
        List[List[int]]
            The best path for each sentence in the batch.

        """
        if lens is None:
            lens = torch.tensor([words.size(0)], device=words.device)
        mask = sequence_mask(lens)

        # Gather word feats.
        feats = self._feats(words, word_lengths, word_indices, word_idxs,
                            sent_context)
        # feats: ``[1 x sent_length x n_labels]``

        # Run features through Viterbi decode algorithm.
        preds = self.crf.viterbi_tags(feats, mask)

        return preds
Пример #3
0
def test_sequence_mask(inp, chk, max_len):
    """Test `sequence_mask()` method."""
    res = utils.sequence_mask(inp, max_len=max_len)
    utils.assert_equal(res, chk)