示例#1
0
class SequenceTagger(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.lstm = Lstm(config)
        self.crf = CRF(config.ntags)
        self.linear = LinearClassifier(config, layers=[config.hidden_size_lstm*2, config.ntags], drops=[0.5])
        self.dropout = nn.Dropout(p=config.dropout)

    def forward(self, word_input, mask=None, labels=None):
        # Word_dim = (batch_size x seq_length)
        if mask is None:
            print("mAAAAAsk!!!")
            return -1

        # pdb.set_trace()
        lstm_output, (_, _) = self.lstm(word_input) #shape = S*B*hidden_size_lstm
        lstm_output = self.dropout(lstm_output)
        lstm_output = self.linear(lstm_output)
        preds = self.crf.decode(lstm_output.transpose(0,1), mask=mask.transpose(0,1))
        if labels is not None:
            loss = -1 * self.crf(lstm_output.transpose(0,1), labels.transpose(0,1), mask=mask.transpose(0,1))
            return loss, preds

        return preds #shape = S*B*ntags
示例#2
0
class UtteranceClassificationHRNN(nn.Module):
    """docstring for UtteranceClassificationHRNN"""
    def __init__(self, utterance_encoder: PreTrainedModel,
                 conversation_encoder: nn.RNNBase, n_classes: int) -> None:
        super(UtteranceClassificationHRNN, self).__init__()
        self.hrnn = HierarchicalRNN(utterance_encoder, conversation_encoder)
        self.output_layer = nn.Linear(
            in_features=conversation_encoder.hidden_size,
            out_features=n_classes)
        self.crf = CRF(n_classes)

    def forward(self, conversation_tokens: List[torch.LongTensor],
                conversation_attention_mask: List[torch.FloatTensor],
                labels: torch.LongTensor):
        output, state_n = self.hrnn(conversation_tokens,
                                    conversation_attention_mask)
        emissions = self.output_layer(output.reshape(
            -1, output.shape[2])).reshape(output.shape[0], output.shape[1], -1)
        return self.crf(emissions, labels.T)

    def decode(self, conversation_tokens: List[torch.LongTensor],
               conversation_attention_mask: List[torch.FloatTensor]):
        output, state_n = self.hrnn(conversation_tokens,
                                    conversation_attention_mask)
        emissions = self.output_layer(output.reshape(
            -1, output.shape[2])).reshape(output.shape[0], output.shape[1], -1)
        return self.crf.decode(emissions)
示例#3
0
class LanguageCrfForNer(BertPreTrainedModel):
    def __init__(self, model_dir, args):
        self.args = args
        self.label_num = args.num_labels
        self.num_labels = args.num_labels

        self.config_class, _, config_model = MODEL_CLASSES[args.model_type]
        bert_config = self.config_class.from_pretrained(args.model_name_or_path)
        super(LanguageCrfForNer, self).__init__(bert_config)

        self.bert = config_model.from_pretrained(args.model_name_or_path, config=bert_config)  # Load pretrained bert

        # self.dropout = nn.Dropout(self.args.dropout_rate)
        # self.classifier = nn.Linear(bert_config.hidden_size, self.label_num)
        self.classifier = FCLayer1(bert_config.hidden_size, self.label_num)
        self.crf = CRF(num_tags=self.label_num, batch_first=True)
        # self.init_weights()

    def forward(self, input_ids, attention_mask, token_type_ids, label=None, is_test=False):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        # sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        outputs = (logits,)
        if label is not None:
            loss = self.crf(emissions=logits, tags=label, mask=attention_mask)

            if is_test:
                tags = self.crf.decode(logits, attention_mask)
                outputs = (tags,)
                outputs = (-1 * loss,) + outputs
            else:
                outputs = (-1 * loss,) + outputs
        if label is None and not is_test:
            tags = self.crf.decode(logits, attention_mask)
            return tags
        return outputs  # (loss), scores
示例#4
0
class MyBertPlusCRFForTokenClassification(BertPreTrainedModel):
    def __init__(self, config, num_labels, device):
        super(MyBertPlusCRFForTokenClassification, self).__init__(config)
        self.num_labels = num_labels
        self.device = device
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels+3) # +3 for bos, eos, pad
        self.crf = CRF(num_labels, device=device).to(device)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids, attention_mask, labels=None, label_mask=None):
        sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        sequence_output = self.classifier(self.dropout(sequence_output))
        
        if labels is not None:
            orig_seq_len = sequence_output.size(1)
            seq_logits_list, seq_labels_list, seq_mask_list = [], [], []
            for seq_logits, seq_labels, seq_mask in zip(sequence_output, labels, label_mask):
                seq_logits = seq_logits[seq_mask != 0].unsqueeze(0) # `!= 0` seems to be important here
                seq_logits = torch.cat([seq_logits, seq_logits.new_zeros((1, orig_seq_len - seq_logits.size(1), seq_logits.size(2)))], 1)
                
                seq_labels = seq_labels[seq_mask != 0].unsqueeze(0)
                seq_mask = torch.cat([seq_labels.new_ones((1, seq_labels.size(1))), seq_labels.new_zeros((1, orig_seq_len - seq_labels.size(1)))], 1)
                seq_labels = torch.cat([seq_labels, seq_labels.new_zeros((1, orig_seq_len - seq_labels.size(1)))], 1)
                
                seq_logits_list.append(seq_logits)
                seq_labels_list.append(seq_labels)
                seq_mask_list.append(seq_mask)
            cat_seq_logits = torch.cat(seq_logits_list, 0)
            cat_seq_labels = torch.cat(seq_labels_list, 0)
            cat_seq_mask = torch.cat(seq_mask_list, 0)
            
            return self.crf(cat_seq_logits, cat_seq_labels, mask=cat_seq_mask)
        else:
            return_logits_list = []
            for seq_logits, seq_mask in zip(sequence_output, label_mask): # inefficient, need to optimize later
                seq_logits = seq_logits[seq_mask != 0].unsqueeze(0)
                _, path = self.crf.decode(seq_logits) # use Viterbi
                return_logits = torch.ones((sequence_output.size(1), self.num_labels)).to(self.device)
                i_path = 0
                for j_seq_mask in range(len(seq_mask)):
                    if seq_mask[j_seq_mask] != 0:
                        return_logits[j_seq_mask][path[0][i_path]] = 10000
                        i_path += 1
                return_logits_list.append(return_logits.unsqueeze(0))
            return torch.cat(return_logits_list, 0)
示例#5
0
class Transformer_CRF(nn.Module):
    def __init__(self, config, src_vocab, nb_labels, ):
        super(Transformer_CRF, self).__init__()
        self.config = config
        self.src_vocab = src_vocab
        
        self.transformer = Transformer(self.config, self.src_vocab)
        self.crf = CRF(
            nb_labels, 
            Const.BOS_TAG_ID, 
            Const.EOS_TAG_ID, 
            pad_tag_id=Const.PAD_TAG_ID, 
            batch_first=True,
        )

    def forward(self, x, mask=None):
        emissions = self.transformer(x)
        score, path = self.crf.decode(emissions, mask=mask)
        return score, path
示例#6
0
class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, nb_labels, emb_dim=5, hidden_dim=4):
        super().__init__()
        self.lstm = SimpleLSTM(
            vocab_size, nb_labels, emb_dim=emb_dim, hidden_dim=hidden_dim
        )
        self.crf = CRF(
            nb_labels,
            Const.BOS_TAG_ID,
            Const.EOS_TAG_ID,
            pad_tag_id=Const.PAD_TAG_ID,  # try setting pad_tag_id to None
            batch_first=True,
        )

    def forward(self, x, mask=None):
        emissions = self.lstm(x)
        score, path = self.crf.decode(emissions, mask=mask)
        return score, path

    def loss(self, x, y, mask=None):
        emissions = self.lstm(x)
        nll = self.crf(emissions, y, mask=mask)
        return nll
示例#7
0
class CharWordSeg(nn.Module):
    def __init__(self,
                 vocab_tag,
                 char_embed_size,
                 num_hidden_layer,
                 channel_size,
                 kernel_size,
                 dropout_rate=0.2):
        super(CharWordSeg, self).__init__()
        self.vocab_tag = vocab_tag
        self.char_embed_size = char_embed_size
        self.num_hidden_layer = num_hidden_layer
        self.channel_size = channel_size
        self.kernel_size = kernel_size
        self.dropout_rate = dropout_rate

        num_tags = len(self.vocab_tag['tag_to_index'])
        vocab_size = len(self.vocab_tag['token_to_index'])
        self.char_embedding = nn.Embedding(vocab_size, char_embed_size)
        self.dropout_embed = nn.Dropout(dropout_rate)
        self.glu_layers = nn.ModuleList([
            ConvGLUBlock(in_channels=char_embed_size,
                         out_channels=channel_size,
                         kernel_size=kernel_size,
                         drop_out=0.2,
                         padding=1)
        ] + [
            ConvGLUBlock(in_channels=channel_size,
                         out_channels=channel_size,
                         kernel_size=kernel_size,
                         drop_out=0.2,
                         padding=1) for _ in range(num_hidden_layer - 1)
        ])
        self.hidden_to_tag = nn.Linear(char_embed_size, num_tags)
        self.crf_layer = CRF(num_tags, batch_first=True)

    def forward(self, input_sentences, tags, mask):
        """ 
        Args:
            input_sentences: torch.tensor -> shape: (batch_size, max_sent_length)
            tags: torch.tensor -> (batch_size, max_sent_length)
            masks: torch.tensor -> (batch_size, max_sent_length)

        Example:
        >>> vocab_size = 10
        >>> char_embed_size = 5
        >>> tag_size = 3
        >>> batch_size = 3
        >>> max_sent_size = 5
        >>> input_sents = torch.tensor([[3, 9, 8, 3, 4],[9, 2, 7, 0, 1],[9, 7, 9, 1, 6]])
        >>> model = CharWordSeg(vocab_size, char_embed_size, 2, 5, 3)
        >>> model(input_sents).shape
        torch.Size([3, 5, 3])
        """
        x = self.dropout_embed(
            self.char_embedding(input_sentences)
        )  # X shape: (batch_size, max_sent_length, char_embed_size)

        # glu layers
        for glu_layers in self.glu_layers:
            x = glu_layers(x)
        # output dim: (batch_size, max_sent_length, channel_size)

        emission = self.hidden_to_tag(
            x)  # output dim: (batch_size, max_sen_length, num_tags)
        x = self.crf_layer(emission, tags, mask)
        return x

    def decode(self, input_sentences, mask):
        x = self.dropout_embed(
            self.char_embedding(input_sentences)
        )  # X shape: (batch_size, max_sent_length, char_embed_size)

        # glu layers
        for glu_layers in self.glu_layers:
            x = glu_layers(x)
        # output dim: (batch_size, max_sent_length, channel_size)

        emission = self.hidden_to_tag(
            x)  # output dim: (batch_size, max_sen_length, num_tags)
        x = self.crf_layer.decode(
            emission, mask)  # output dim: (batch_size, max_sen_length)
        return x

    @staticmethod
    def load(model_path: str):
        """ Load the model from a file.
        @param model_path (str): path to model
        """
        params = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
        args = params['args']
        model = CharWordSeg(vocab_tag=params['vocab_tag'], **args)
        model.load_state_dict(params['state_dict'])

        return model

    def save(self, path: str):
        """ Save the odel to a file.
        @param path (str): path to the model
        """
        logging.info(f'save model parameters to [{path}]')

        params = {
            'args':
            dict(char_embed_size=self.char_embed_size,
                 num_hidden_layer=self.num_hidden_layer,
                 channel_size=self.channel_size,
                 kernel_size=self.kernel_size,
                 dropout_rate=self.dropout_rate),
            'vocab_tag':
            self.vocab_tag,
            'state_dict':
            self.state_dict()
        }

        torch.save(params, path)
class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim,
                 batch_size, max_len):
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)
        self.max_len = max_len
        self.crf = CRF(len(tag_to_ix), batch_first=True)

        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim,
                            hidden_dim // 2,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True)

        # Maps the output of the LSTM into tag space.
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)

        # Matrix of transition parameters.  Entry i,j is the score of
        # transitioning *to* i *from* j.
        self.transitions = nn.Parameter(
            torch.randn(self.tagset_size, self.tagset_size))

        # These two statements enforce the constraint that we never transfer
        # to the start tag and we never transfer from the stop tag
        self.transitions.data[tag_to_ix["<START>"], :] = -10000
        self.transitions.data[:, tag_to_ix["<STOP>"]] = -10000

        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.randn(2, self.batch_size, self.hidden_dim // 2),
                torch.randn(2, self.batch_size, self.hidden_dim // 2))

    def _get_lstm_features(self, sentence):
        self.hidden = self.init_hidden()
        embeds = self.word_embeds(sentence)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        #lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def forward(self, sentence,
                mask):  # dont confuse this with _forward_alg above.
        # Get the emission scores from the BiLSTM
        lstm_feats = self._get_lstm_features(sentence)

        # Find the best path, given the features.

        tag_seq = self.crf.decode(lstm_feats, mask=mask)
        return tag_seq

    def loss_fn(self, sentence, tags,
                mask):  # dont confuse this with _forward_alg above.
        # Get the emission scores from the BiLSTM
        lstm_feats = self._get_lstm_features(sentence)
        loss = self.crf(lstm_feats, tags, mask=mask)
        return -loss
示例#9
0
class LstmCrf(nn.Module):
    def __init__(self, weighted_matrix, vocab_size, nb_labels, embedding_dim,
                 hidden_dim, char_vocab_size, char_embedding_dim,
                 char_in_channels, device):

        super().__init__()

        self.device = device
        self.hidden_dim = hidden_dim

        self.embeddings = nn.Embedding.from_pretrained(
            torch.Tensor(weighted_matrix))
        self.lstm = nn.LSTM(embedding_dim + char_in_channels,
                            hidden_dim,
                            bidirectional=True,
                            batch_first=True)
        self.cnn_embeddings = CNNEmbeddings(char_vocab_size,
                                            char_embedding_dim,
                                            char_in_channels)
        self.fc = nn.Linear(hidden_dim * 2, nb_labels)
        #self.crf = CRFDevice(nb_labels, Const.BOS_TAG_ID, Const.EOS_TAG_ID, device=device)
        self.crf = CRF(nb_labels, True)

    def _init_hidden(self, batch_size):
        return (torch.randn(2, batch_size, self.hidden_dim,
                            device=self.device),
                torch.randn(2, batch_size, self.hidden_dim,
                            device=self.device))

    def _get_mask(self, tokens):
        return (tokens != Const.PAD_ID).float()

    def _lstm(self, x, x_char):
        emb = self.embeddings(x)
        char_emb = self.cnn_embeddings(x_char)

        char_emb = char_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
        emb_cat = torch.cat([emb, char_emb], dim=2)
        hidden = self._init_hidden(x.shape[0])

        x, _ = self.lstm(emb_cat, hidden)

        emissions = self.fc(x)
        emissions = F.softmax(emissions, dim=1)

        return emissions

    def forward(self, features):
        x, x_char = features
        mask = self._get_mask(x)
        emissions = self._lstm(x, x_char)
        #score, path = self.crf.decode(emissions, mask=mask)
        path = self.crf.decode(emissions, mask=mask)

        #return score, path
        return path

    def loss(self, features, y):
        x, x_char = features
        mask = self._get_mask(x)
        emissions = self._lstm(x, x_char)
        nll = -self.crf(emissions, y, mask=mask)
        return nll
示例#10
0
class BilstmCrf(nn.Module):
    def __init__(self, vocab_size, tag_to_idx, hidden_size, embed_size,
                 num_layers, dropout):
        super(BilstmCrf, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.tag_to_idx = tag_to_idx
        self.tag_size = len(tag_to_idx)

        # architecture
        self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=PAD_IDX)
        self.lstm = nn.LSTM(input_size=embed_size,
                            hidden_size=self.hidden_size // 2,
                            num_layers=self.num_layers,
                            bias=True,
                            dropout=dropout,
                            bidirectional=True)
        # LSTM output to tag
        self.out = nn.Linear(self.hidden_size, self.tag_size)
        self.crflayer = CRF(self.tag_size)

        # use xavier_normal to init weight
        self.init_weight()

    def init_weight(self):
        nn.init.xavier_normal(self.embed.weight)
        for name, param in self.lstm.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal(param)
        nn.init.xavier_normal(self.out.weight)

    def init_hidden(self, batch_size):  # initialize hidden states
        h = Variable(
            torch.zeros(self.num_layers * 2, batch_size,
                        self.hidden_size // 2))  # hidden states
        c = Variable(
            torch.zeros(self.num_layers * 2, batch_size,
                        self.hidden_size // 2))  # cell states
        return h, c

    def loss(self, x, y):
        mask = torch.autograd.Variable(x.data.gt(0))
        emissions = self.lstm_forward(x)
        return self.crflayer(emissions, y, mask=mask)

    def forward(self, x):
        """ prediction """
        mask = torch.autograd.Variable(x.data.gt(0))
        emissions = self.lstm_forward(x)
        return self.crflayer.decode(emissions, mask=mask)

    def lstm_forward(self, x):  # LSTM forward pass
        """
        lstm forward to calculate emissions(equals to probability of tags)
        """
        batch_size = x.size(1)
        # transpose:  (seq_len, batch_size) ==> (batch_size, seq_len)
        lengths = [self.len_unpadded(seq) for seq in x.transpose(0, 1)]
        embed = self.embed(x)  # batch, seq,tag

        # pack and unpack are excellent
        embed_pack = torch.nn.utils.rnn.pack_padded_sequence(embed, lengths)
        lstm_output_pack, _ = self.lstm(
            embed_pack, self.init_hidden(batch_size=batch_size))
        lstm_output, _ = torch.nn.utils.rnn.pad_packed_sequence(
            lstm_output_pack)

        emissions = self.out(lstm_output)
        return emissions

    def len_unpadded(self, x):
        """ get unpadded sequence length"""
        def scalar(x):
            return x.view(-1).data.tolist()[0]

        return next((i for i, j in enumerate(x) if scalar(j) == 0), len(x))