class NNCRF(nn.Module): def __init__(self, config, print_info: bool = True): super(NNCRF, self).__init__() self.embedder = WordEmbedder(config, print_info=print_info) self.encoder = BiLSTMEncoder(config, self.embedder.get_output_dim(), print_info=print_info) self.inferencer = LinearCRF(config, print_info=print_info) @overrides def forward(self, words: torch.Tensor, word_seq_lens: torch.Tensor, context_emb: torch.Tensor, chars: torch.Tensor, char_seq_lens: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ Calculate the negative loglikelihood. :param words: (batch_size x max_seq_len) :param word_seq_lens: (batch_size) :param context_emb: (batch_size x max_seq_len x context_emb_size) :param chars: (batch_size x max_seq_len x max_char_len) :param char_seq_lens: (batch_size x max_seq_len) :param labels: (batch_size x max_seq_len) :return: the total negative log-likelihood loss """ word_rep = self.embedder(words, word_seq_lens, context_emb, chars, char_seq_lens) lstm_scores = self.encoder(word_rep, word_seq_lens) batch_size = words.size(0) sent_len = words.size(1) dev_num = word_rep.get_device() curr_dev = torch.device( f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu") maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long, device=curr_dev).view(1, sent_len).expand( batch_size, sent_len) mask = torch.le( maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)) unlabed_score, labeled_score = self.inferencer(lstm_scores, word_seq_lens, labels, mask) return unlabed_score - labeled_score def decode(self, words: torch.Tensor, word_seq_lens: torch.Tensor, context_emb: torch.Tensor, chars: torch.Tensor, char_seq_lens: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """ Decode the batch input :param batchInput: :return: """ word_rep = self.embedder(words, word_seq_lens, context_emb, chars, char_seq_lens) features = self.encoder(word_rep, word_seq_lens) bestScores, decodeIdx = self.inferencer.decode(features, word_seq_lens) return bestScores, decodeIdx
class TransformersCRF(nn.Module): def __init__(self, config, print_info: bool = True): super(TransformersCRF, self).__init__() self.device = config.device self.embedder = TransformersEmbedder(config, print_info=print_info) if config.hidden_dim > 0: self.encoder = BiLSTMEncoder(config, self.embedder.get_output_dim(), print_info=print_info) else: self.encoder = LinearEncoder(config, self.embedder.get_output_dim(), print_info=print_info) self.inferencer = LinearCRF(config, print_info=print_info) @overrides def forward(self, words: torch.Tensor, word_seq_lens: torch.Tensor, orig_to_tok_index: torch.Tensor, input_mask: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ Calculate the negative loglikelihood. :param words: (batch_size x max_seq_len) :param word_seq_lens: (batch_size) :param context_emb: (batch_size x max_seq_len x context_emb_size) :param chars: (batch_size x max_seq_len x max_char_len) :param char_seq_lens: (batch_size x max_seq_len) :param labels: (batch_size x max_seq_len) :return: the total negative log-likelihood loss """ word_rep = self.embedder(words, orig_to_tok_index, input_mask) lstm_scores = self.encoder(word_rep, word_seq_lens) batch_size = word_rep.size(0) sent_len = word_rep.size(1) maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(1, sent_len).expand(batch_size, sent_len).to(self.device) mask = torch.le(maskTemp, word_seq_lens.view(batch_size, 1).expand(batch_size, sent_len)).to(self.device) unlabed_score, labeled_score = self.inferencer(lstm_scores, word_seq_lens, labels, mask) return unlabed_score - labeled_score def decode(self, words: torch.Tensor, word_seq_lens: torch.Tensor, orig_to_tok_index: torch.Tensor, input_mask, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """ Decode the batch input :param batchInput: :return: """ word_rep = self.embedder(words, orig_to_tok_index, input_mask) features = self.encoder(word_rep, word_seq_lens) bestScores, decodeIdx = self.inferencer.decode(features, word_seq_lens) return bestScores, decodeIdx