Exemplo n.º 1
0
 def __init__(self, config, print_info: bool = True):
     super(NNCRF, self).__init__()
     self.embedder = WordEmbedder(config, print_info=print_info)
     self.encoder = BiLSTMEncoder(config,
                                  self.embedder.get_output_dim(),
                                  print_info=print_info)
     self.inferencer = LinearCRF(config, print_info=print_info)
Exemplo n.º 2
0
 def __init__(self, config, print_info: bool = True):
     super(TransformersCRF, self).__init__()
     self.device = config.device
     self.embedder = TransformersEmbedder(config, print_info=print_info)
     if config.hidden_dim > 0:
         self.encoder = BiLSTMEncoder(config, self.embedder.get_output_dim(), print_info=print_info)
     else:
         self.encoder = LinearEncoder(config, self.embedder.get_output_dim(), print_info=print_info)
     self.inferencer = LinearCRF(config, print_info=print_info)
Exemplo n.º 3
0
class NNCRF(nn.Module):
    def __init__(self, config, print_info: bool = True):
        super(NNCRF, self).__init__()
        self.embedder = WordEmbedder(config, print_info=print_info)
        self.encoder = BiLSTMEncoder(config,
                                     self.embedder.get_output_dim(),
                                     print_info=print_info)
        self.inferencer = LinearCRF(config, print_info=print_info)

    @overrides
    def forward(self, words: torch.Tensor, word_seq_lens: torch.Tensor,
                context_emb: torch.Tensor, chars: torch.Tensor,
                char_seq_lens: torch.Tensor,
                labels: torch.Tensor) -> torch.Tensor:
        """
        Calculate the negative loglikelihood.
        :param words: (batch_size x max_seq_len)
        :param word_seq_lens: (batch_size)
        :param context_emb: (batch_size x max_seq_len x context_emb_size)
        :param chars: (batch_size x max_seq_len x max_char_len)
        :param char_seq_lens: (batch_size x max_seq_len)
        :param labels: (batch_size x max_seq_len)
        :return: the total negative log-likelihood loss
        """
        word_rep = self.embedder(words, word_seq_lens, context_emb, chars,
                                 char_seq_lens)
        lstm_scores = self.encoder(word_rep, word_seq_lens)
        batch_size = words.size(0)
        sent_len = words.size(1)

        dev_num = word_rep.get_device()
        curr_dev = torch.device(
            f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
        maskTemp = torch.arange(1,
                                sent_len + 1,
                                dtype=torch.long,
                                device=curr_dev).view(1, sent_len).expand(
                                    batch_size, sent_len)
        mask = torch.le(
            maskTemp,
            word_seq_lens.view(batch_size, 1).expand(batch_size, sent_len))
        unlabed_score, labeled_score = self.inferencer(lstm_scores,
                                                       word_seq_lens, labels,
                                                       mask)
        return unlabed_score - labeled_score

    def decode(self, words: torch.Tensor, word_seq_lens: torch.Tensor,
               context_emb: torch.Tensor, chars: torch.Tensor,
               char_seq_lens: torch.Tensor,
               **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode the batch input
        :param batchInput:
        :return:
        """
        word_rep = self.embedder(words, word_seq_lens, context_emb, chars,
                                 char_seq_lens)
        features = self.encoder(word_rep, word_seq_lens)
        bestScores, decodeIdx = self.inferencer.decode(features, word_seq_lens)
        return bestScores, decodeIdx
Exemplo n.º 4
0
class TransformersCRF(nn.Module):

    def __init__(self, config, print_info: bool = True):
        super(TransformersCRF, self).__init__()
        self.device = config.device
        self.embedder = TransformersEmbedder(config, print_info=print_info)
        if config.hidden_dim > 0:
            self.encoder = BiLSTMEncoder(config, self.embedder.get_output_dim(), print_info=print_info)
        else:
            self.encoder = LinearEncoder(config, self.embedder.get_output_dim(), print_info=print_info)
        self.inferencer = LinearCRF(config, print_info=print_info)

    @overrides
    def forward(self, words: torch.Tensor,
                    word_seq_lens: torch.Tensor,
                    orig_to_tok_index: torch.Tensor,
                    input_mask: torch.Tensor,
                    labels: torch.Tensor) -> torch.Tensor:
        """
        Calculate the negative loglikelihood.
        :param words: (batch_size x max_seq_len)
        :param word_seq_lens: (batch_size)
        :param context_emb: (batch_size x max_seq_len x context_emb_size)
        :param chars: (batch_size x max_seq_len x max_char_len)
        :param char_seq_lens: (batch_size x max_seq_len)
        :param labels: (batch_size x max_seq_len)
        :return: the total negative log-likelihood loss
        """
        word_rep = self.embedder(words, orig_to_tok_index, input_mask)
        lstm_scores = self.encoder(word_rep, word_seq_lens)
        batch_size = word_rep.size(0)
        sent_len = word_rep.size(1)
        maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(1, sent_len).expand(batch_size, sent_len).to(self.device)
        mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)).to(self.device)
        unlabed_score, labeled_score =  self.inferencer(lstm_scores, word_seq_lens, labels, mask)
        return unlabed_score - labeled_score

    def decode(self, words: torch.Tensor,
                    word_seq_lens: torch.Tensor,
                    orig_to_tok_index: torch.Tensor,
                    input_mask,
                    **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode the batch input
        :param batchInput:
        :return:
        """
        word_rep = self.embedder(words, orig_to_tok_index, input_mask)
        features = self.encoder(word_rep, word_seq_lens)
        bestScores, decodeIdx = self.inferencer.decode(features, word_seq_lens)
        return bestScores, decodeIdx