def forward(self, words: torch.Tensor, word_lengths: torch.Tensor, word_indices: torch.Tensor, word_idxs: torch.Tensor, sent_context: torch.Tensor, labs: torch.Tensor, lens: torch.Tensor = None) -> torch.Tensor: # pylint: disable=arguments-differ,not-callable """ Compute the negative of the log-likelihood. Parameters ---------- words : torch.Tensor Word tensors ``[sent_length x max_word_lenth x n_chars]``. word_lengths : torch.Tensor Contains the length of each word (in characters), with shape ``[sent_length]``. word_indices : torch.Tensor Contains sorted indices of words by length, with shape ``[sent_length]``. word_idxs : torch.Tensor Word indices with shape ``[sent_length]``. sent_context : torch.Tensor Sentence context. labs : torch.Tensor Corresponding target label sequence with shape ``[sent_length]``. lens : torch.Tensor, optional (default: None) Gives the length of each sentence in the batch ``[batch_size]``. Returns ------- torch.Tensor The negative log-likelihood evaluated at the inputs. """ if lens is None: lens = torch.tensor([words.size(0)], device=words.device) mask = sequence_mask(lens) # Fake batch dimension for ``labs``. labs = labs.unsqueeze(0) # labs: ``[1 x sent_length]`` # Gather word feats. feats = self._feats(words, word_lengths, word_indices, word_idxs, sent_context) # feats: ``[1 x sent_length x n_labels]`` loglik = self.crf(feats, labs, mask=mask) return -1. * loglik
def predict(self, words: torch.Tensor, word_lengths: torch.Tensor, word_indices: torch.Tensor, word_idxs: torch.Tensor, sent_context: torch.Tensor, lens: torch.Tensor = None) -> List[Tuple[List[int], float]]: # pylint: disable=not-callable """ Compute the best tag sequence. Parameters ---------- words : torch.Tensor Word tensors ``[sent_length x max_word_lenth x n_chars]``. word_lengths : torch.Tensor Contains the length of each word (in characters), with shape ``[sent_length]``. word_indices : torch.Tensor Contains sorted indices of words by length, with shape ``[sent_length]``. word_idxs : torch.Tensor Word indices with shape ``[sent_length]``. sent_context : torch.Tensor Sentence context. lens : torch.Tensor, optional (default: None) Gives the length of each sentence in the batch ``[batch_size]``. Returns ------- List[List[int]] The best path for each sentence in the batch. """ if lens is None: lens = torch.tensor([words.size(0)], device=words.device) mask = sequence_mask(lens) # Gather word feats. feats = self._feats(words, word_lengths, word_indices, word_idxs, sent_context) # feats: ``[1 x sent_length x n_labels]`` # Run features through Viterbi decode algorithm. preds = self.crf.viterbi_tags(feats, mask) return preds
def test_sequence_mask(inp, chk, max_len): """Test `sequence_mask()` method.""" res = utils.sequence_mask(inp, max_len=max_len) utils.assert_equal(res, chk)