Beispiel #1
0
class lstm_ner(nn.Module):
    def __init__(self,
                 input_size,
                 layer,
                 hidden,
                 n_label,
                 emb_dropout=0.5,
                 word_dropout=0.5):

        super(lstm_ner, self).__init__()

        self.emb_dropout = nn.Dropout(emb_dropout)

        self.word_lstm = nn.LSTM(input_size,
                                 hidden,
                                 layer,
                                 batch_first=True,
                                 bidirectional=True)
        self.word_dropout = nn.Dropout(word_dropout)
        self.crf = ChainCRF(hidden * 2, n_label)

    def forward(self, feat, word_length, target, gather_input, pad):
        word_input = feat

        if self.emb_dropout.p > 0.0:
            word_input = self.emb_dropout(word_input)

        word_length, word_idx = word_length.sort(0, descending=True)
        word_input = word_input[word_idx]

        word_packed_input = pack_padded_sequence(
            word_input, word_length.cpu().data.numpy(), batch_first=True)
        word_packed_output, _ = self.word_lstm(word_packed_input)
        word_output, _ = pad_packed_sequence(word_packed_output,
                                             batch_first=True,
                                             total_length=feat.size(1))
        word_output = word_output[torch.from_numpy(
            np.argsort(word_idx.cpu().data.numpy())).cuda()]

        if self.word_dropout.p > 0.0:
            word_output = self.word_dropout(word_output)

        # we use the 1st bpe corresponding to a word for prediction.
        word_output = torch.gather(
            torch.cat([word_output, pad], 1), 1,
            gather_input.unsqueeze(2).expand(gather_input.size(0),
                                             gather_input.size(1),
                                             word_output.size(2)))

        crf_loss = self.crf.loss(word_output, target, target.ne(0).float())
        predict = self.crf.decode(word_output, target.ne(0).float(), 1)
        return crf_loss, predict
Beispiel #2
0
class BertForSequentialClassification(BertPreTrainedModel):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    Params:
        `config`: a BertConfig class instance with the configuration to build a new model.
        `num_labels`: the number of classes for the classifier. Default = 2.
    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
            `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
            with indices selected in [0, ..., num_labels].
    Outputs:
        if `labels` is not `None`:
            Outputs the CrossEntropy classification loss of the output with the labels.
        if `labels` is `None`:
            Outputs the classification logits of shape [batch_size, num_labels].
    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
    num_labels = 2
    model = BertForSequenceClassification(config, num_labels)
    logits = model(input_ids, token_type_ids, input_mask)
    ```
    """
    def __init__(self,
                 config,
                 num_labels,
                 tag_space=0,
                 rnn_mode='LSTM',
                 use_crf=False,
                 rnn_hidden_size=None,
                 dropout=None):
        super(BertForSequentialClassification, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        if dropout is None:
            dropout = config.hidden_dropout_prob
        self.dropout_other = nn.Dropout(dropout)
        self.use_crf = use_crf

        if rnn_mode == 'RNN':
            RNN = nn.RNN
        elif rnn_mode == 'LSTM':
            RNN = nn.LSTM
        elif rnn_mode == 'GRU':
            RNN = nn.GRU

        if rnn_hidden_size is not None:
            self.rnn = RNN(config.hidden_size,
                           rnn_hidden_size,
                           num_layers=1,
                           batch_first=True,
                           bidirectional=True)
            out_dim = rnn_hidden_size * 2
        else:
            self.rnn = None
            out_dim = config.hidden_size

        if tag_space:
            self.dense = nn.Linear(out_dim, tag_space)
            out_dim = tag_space
        else:
            self.dense = None

        if use_crf:
            self.crf = ChainCRF(out_dim, num_labels, bigram=True)
        else:
            self.dense_softmax = nn.Linear(out_dim, num_labels)

        self.apply(self.init_bert_weights)

    def forward(self,
                input_ids=None,
                token_type_ids=None,
                attention_mask=None,
                document_mask=None,
                labels=None,
                input_embeddings=None):
        _, output, embeddings = self.bert(input_ids,
                                          token_type_ids,
                                          attention_mask,
                                          output_all_encoded_layers=False,
                                          input_embeddings=input_embeddings)
        output = self.dropout(output)

        # sentence level transform to document level
        length = document_mask.sum(dim=1).long()
        max_len = length.max()
        output = output.view(-1, max_len, self.config.hidden_size)

        # document level RNN processing
        if self.rnn is not None:
            output, hx, rev_order, mask = utils.prepare_rnn_seq(
                output, length, hx=None, masks=document_mask, batch_first=True)
            output, hn = self.rnn(output, hx=hx)
            output, hn = utils.recover_rnn_seq(output,
                                               rev_order,
                                               hx=hn,
                                               batch_first=True)

        # apply dropout for the output of rnn
        output = self.dropout_other(output)
        if self.dense is not None:
            # [batch, length, tag_space]
            output = self.dropout_other(F.elu(self.dense(output)))

        # final output layer
        if not self.use_crf:
            # not use crf
            output = self.dense_softmax(output)  # [batch, length, num_labels]
            if labels is None:
                _, preds = torch.max(output, dim=2)
                return preds, None, embeddings
            else:
                return (F.cross_entropy(output.view(-1, output.size(-1)),
                                        labels.view(-1),
                                        reduction='none') *
                        document_mask.view(-1)
                        ).sum() / document_mask.sum(), None, embeddings
        else:
            # CRF processing
            if labels is not None:
                loss, logits = self.crf.loss(output,
                                             labels,
                                             mask=document_mask)
                return loss.mean(), logits, embeddings
            else:
                seq_pred, logits = self.crf.decode(output,
                                                   mask=document_mask,
                                                   leading_symbolic=0)
                return seq_pred, logits, embeddings
Beispiel #3
0
class MainSequenceClassification(nn.Module):
    def __init__(self, config, num_labels, vocab, args):
        super(MainSequenceClassification, self).__init__()
        self.args = args

        if args.task_name == 'roberta' or args.task_name == 'robertaf1c':
            self.bert = RobertaModel.from_pretrained("roberta-base")
            # self.bert.resize_token_embeddings(len(tokenizer))
            self.bert.config.type_vocab_size = 2
            single_emb = self.bert.embeddings.token_type_embeddings
            self.bert.embeddings.token_type_embeddings = torch.nn.Embedding(
                2, single_emb.embedding_dim)
            self.bert.embeddings.token_type_embeddings.weight = torch.nn.Parameter(
                single_emb.weight.repeat([2, 1]))
        else:
            self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.latent_type = self.args.latent_type
        self.first_layer = args.first_layer
        self.second_layer = args.second_layer
        self.latent_dropout = args.latent_dropout
        #self.ner_ed = args.ner_ed # indicate NER or ED task
        self.speaker_reg = args.speaker_reg
        self.lasso_reg = args.lasso_reg
        self.l0_reg = args.l0_reg
        self.hidden = args.rnn_hidden_size
        self.dataset = args.dataset
        self.bilstm_crf = args.bilstm_crf
        self.bert_crf = args.bert_crf
        self.rnn_hidden = args.rnn_hidden_size
        self.sigmoid = nn.Sigmoid()

        #         if self.dataset=='dialogre':
        #             self.NUM=36
        #         else:
        #             self.NUM=4
        self.NUM = 4
        if self.latent_type == 'gdp':
            self.latent_graph = PoolGCN(config, args)

        elif self.latent_type == 'diffmask':
            self.num_layer = args.num_layer
            self.diff_position = args.diff_position
            self.diff_mlp_hidden = args.diff_mlp_hidden
            if args.encoder_type == 'LSTM':
                self.linear_latent = nn.Linear(2 * self.hidden,
                                               self.hidden // 2)
            elif args.encoder_type == 'Bert':
                self.linear_latent = nn.Linear(config.hidden_size,
                                               self.hidden // 2)
            #self.dropout = nn.Dropout(self.latent_dropout)
            self.latent_graph = GNNDiffMaskGate(self.hidden // 2,
                                                self.diff_mlp_hidden,
                                                1,
                                                self.diff_position,
                                                self.latent_dropout,
                                                self.num_layer,
                                                self.first_layer,
                                                self.second_layer,
                                                self.l0_reg,
                                                self.lasso_reg,
                                                self.speaker_reg,
                                                alpha=args.alpha,
                                                beta=args.beta,
                                                gamma=args.gamma)

        # classifier hidden dimension
        if self.latent_type == 'diffmask':
            self.classifier = nn.Linear(self.hidden // 2 + config.hidden_size,
                                        num_labels * self.NUM + 1)

        elif self.args.lstm_only == True:  # only lstm without gdpnet
            self.classifier = nn.Linear(self.rnn_hidden * 2,
                                        num_labels * self.NUM + 1)

        elif args.task_name == 'lstm' or self.args.task_name == 'lstmf1c':
            self.classifier = nn.Linear(300, num_labels * self.NUM + 1)

        elif args.bert_only == True:
            self.classifier = nn.Linear(config.hidden_size,
                                        num_labels * self.NUM + 1)

        else:  # bert or roberta
            self.classifier = nn.Linear(
                config.hidden_size + args.graph_hidden_size,
                num_labels * self.NUM + 1)

        def init_weights(module):
            if isinstance(module, (nn.Linear, nn.Embedding)):
                # Slightly different from the TF version which uses truncated_normal for initialization
                # cf https://github.com/pytorch/pytorch/pull/5617
                module.weight.data.normal_(mean=0.0,
                                           std=config.initializer_range)
            elif isinstance(module, BERTLayerNorm):
                module.beta.data.normal_(mean=0.0,
                                         std=config.initializer_range)
                module.gamma.data.normal_(mean=0.0,
                                          std=config.initializer_range)
            if isinstance(module, nn.Linear):
                module.bias.data.zero_()

        if self.args.task_name != 'lstm':
            self.apply(init_weights)
        self.DTW_criterion = SoftDTW(use_cuda=True, gamma=0.1)

        if self.args.task_name == 'lstm' or self.args.task_name == 'lstmf1c':
            # lsc LSTM needed layers
            self.vocab = vocab
            self.embedding = self.embedding = nn.Embedding(
                self.vocab.n_words,
                self.args.embed_dim,
                padding_idx=constant.PAD_ID)
            self.init_pretrained_embeddings_from_numpy(
                np.load(open(args.embed_f, 'rb'), allow_pickle=True))
            self.lstm = nn.LSTM(args.embed_dim + 30,
                                args.rnn_hidden_size,
                                args.lstm_layers,
                                batch_first=True,
                                dropout=args.lstm_dropout,
                                bidirectional=True)
            self.pooling = MultiHeadedPooling(2 * args.rnn_hidden_size)
            #self.linear = nn.Linear(768*3+40, 768)

            self.pos_emb = nn.Embedding(len(constant.POS_TO_ID), 30)
            self.ner_emb = nn.Embedding(6, 20, padding_idx=5)

        self.crf = ChainCRF(num_labels * self.NUM + 1,
                            num_labels * self.NUM + 1,
                            bigram=True)

    def init_pretrained_embeddings_from_numpy(self, pretrained_word_vectors):
        self.embedding.weight = nn.Parameter(
            torch.from_numpy(pretrained_word_vectors).float())
        if self.args.tune_topk <= 0:
            print("Do not fine tune word embedding layer")
            self.embedding.weight.requires_grad = False
        elif self.args.tune_topk < self.vocab.n_words:
            print(f"Finetune top {self.args.tune_topk} word embeddings")
            self.embedding.weight.register_hook(
                lambda x: torch_utils.keep_partial_grad(
                    x, self.args.tune_topk))
        else:
            print("Finetune all word embeddings")

    def forward(self,
                input_ids,
                token_type_ids,
                attention_mask,
                subj_pos=None,
                obj_pos=None,
                pos_ids=None,
                x_type=None,
                y_type=None,
                labels=None,
                n_class=1,
                train=True):
        seq_length = input_ids.size(2)
        attention_mask_ = attention_mask.view(-1, seq_length)
        batch_size = input_ids.size(0)
        mask_local = ~(attention_mask_.data.eq(0).view(batch_size, seq_length))

        l = (attention_mask_.data.cpu().numpy() != 0).astype(np.int64).sum(1)

        real_length = max(l)

        if self.args.encoder_type == 'LSTM':
            input_ids = input_ids.view(-1, seq_length)
            emb = self.embedding(input_ids)
            pos_ids = pos_ids.view(-1, seq_length)
            pos_emb = self.pos_emb(pos_ids)
            emb = torch.cat([emb, pos_emb], dim=2)

            h0, c0 = rnn_zero_state(batch_size, self.rnn_hidden, 1)
            rnn_input = pack_padded_sequence(emb,
                                             l,
                                             batch_first=True,
                                             enforce_sorted=False)
            rnn_output, (ht, ct) = self.lstm(rnn_input, (h0, c0))
            word_embedding, _ = pad_packed_sequence(rnn_output,
                                                    batch_first=True)

            pooled_output = pool(word_embedding,
                                 ~mask_local[:, :real_length].unsqueeze(-1),
                                 type='max')

            mask = attention_mask.squeeze()
            s_labels = labels.squeeze()

            if self.args.lstm_only:
                output = self.dropout(word_embedding)

                logits = self.classifier(output)
                length = logits.size()[1]
                if self.bilstm_crf:
                    if train:
                        loss = self.crf.loss(logits,
                                             s_labels[:, :length],
                                             mask=mask[:, :length])

                        return loss.sum(), logits
                    else:
                        preds = self.crf.decode(logits,
                                                mask=mask[:, :length],
                                                leading_symbolic=0)

                        return logits, preds
                else:  #bilstm only

                    criterion = CrossEntropyLoss()
                    loss = criterion(logits.transpose(1, 2),
                                     labels.squeeze()[:, 0:length])
                    return loss, logits
        else:  # bert
            if self.args.task_name == 'roberta' or self.args.task_name == 'robertaf1c':
                word_embedding, pooled_output = self.bert(
                    input_ids.view(-1, seq_length),
                    attention_mask.view(-1, seq_length))

                # due to OS env or lib version variety, Roberta from transformers may output string instead of logits
                if type(word_embedding) == str:
                    word_embedding, pooled_output = self.bert(
                        input_ids.view(-1, seq_length),
                        attention_mask.view(-1, seq_length)).values()

                word_embedding = word_embedding
                word_embedding = word_embedding[:, :real_length, :]
                mask = attention_mask.squeeze()
                s_labels = labels.squeeze().long()

                if self.args.bert_only:
                    output = self.dropout(word_embedding)

                    logits = self.classifier(output)
                    length = logits.size()[1]

                    # there is some problem here!
                    if self.bert_crf:
                        if train:
                            loss = self.crf.loss(logits,
                                                 s_labels[:, :length],
                                                 mask=mask[:, :length])

                            return loss.sum(), logits
                        else:
                            preds = self.crf.decode(logits,
                                                    mask=mask[:, :length],
                                                    leading_symbolic=0)

                            return logits, preds
                    else:  #bert only
                        criterion = CrossEntropyLoss()
                        loss = criterion(logits.transpose(1, 2),
                                         labels.squeeze()[:, 0:length].long())
                        return loss, logits
            else:  # bert
                word_embedding, pooled_output = self.bert(
                    input_ids.view(-1, seq_length),
                    token_type_ids.view(-1, seq_length),
                    attention_mask.view(-1, seq_length))
                word_embedding = word_embedding[-1]
                word_embedding = word_embedding[:, :real_length, :]
                mask = attention_mask.squeeze()
                s_labels = labels.squeeze().long()

                if self.args.bert_only:
                    output = self.dropout(word_embedding)

                    logits = self.classifier(output)
                    length = logits.size()[1]

                    # there is some problem here!
                    if self.bert_crf:
                        if train:
                            loss = self.crf.loss(logits,
                                                 s_labels[:, :length],
                                                 mask=mask[:, :length])

                            return loss.sum(), logits
                        else:
                            preds = self.crf.decode(logits,
                                                    mask=mask[:, :length],
                                                    leading_symbolic=0)

                            return logits, preds
                    else:  #bert only
                        criterion = CrossEntropyLoss()
                        loss = criterion(logits.transpose(1, 2),
                                         labels.squeeze()[:, 0:length].long())
                        return loss, logits

        if self.latent_type == 'diffmask':
            ctx = self.dropout(self.linear_latent(word_embedding))

            #speaker mask always none cause there is no subject-object pair in PEDC
            speaker_mask = None

            h, aux_loss = self.latent_graph(
                ctx,
                mask_local.squeeze(-1)[:, :real_length], speaker_mask)
            h = h[:, :real_length]

            #h_out = pool(h, ~mask_local, type="max")
            h_out = pool(h,
                         ~mask_local[:, :real_length].unsqueeze(-1),
                         type='max')

            output = self.dropout(torch.cat([pooled_output, h_out], dim=1))
            logits = self.classifier(output)

            length = logits.size()[1]
            criterion = BCEWithLogitsLoss()
            labels = labels.squeeze()[:, 0:length]
            labels = labels.type(torch.cuda.FloatTensor)
            labels = self.sigmoid(labels)
            loss = criterion(logits, labels)

            loss += aux_loss
            return loss, logits
Beispiel #4
0
class Tagger(Model):
    def __init__(self, hparams):
        super(Tagger, self).__init__(hparams)

        self._comparsion = {
            Task.conllner: "max",
            Task.wikiner: "max",
            Task.udpos: "max",
        }[self.hparams.task]
        self._selection_criterion = {
            Task.conllner: "val_f1",
            Task.wikiner: "val_f1",
            Task.udpos: "val_acc",
        }[self.hparams.task]
        self._nb_labels: Optional[int] = None
        self._nb_labels = {
            Task.conllner: ConllNER.nb_labels(),
            Task.wikiner: WikiAnnNER.nb_labels(),
            Task.udpos: UdPOS.nb_labels(),
        }[self.hparams.task]
        self._metric = {
            Task.conllner: NERMetric(ConllNER.get_labels()),
            Task.wikiner: NERMetric(WikiAnnNER.get_labels()),
            Task.udpos: POSMetric(),
        }[self.hparams.task]

        if self.hparams.tagger_use_crf:
            self.crf = ChainCRF(self.hidden_size, self.nb_labels, bigram=True)
        else:
            self.classifier = nn.Linear(self.hidden_size, self.nb_labels)
        self.padding = {
            "sent": self.tokenizer.pad_token_id,
            "lang": 0,
            "labels": LABEL_PAD_ID,
        }

        self.setup_metrics()

    @property
    def nb_labels(self):
        assert self._nb_labels is not None
        return self._nb_labels

    def preprocess_batch(self, batch):
        return batch

    def forward(self, batch):
        batch = self.preprocess_batch(batch)
        hs = self.encode_sent(batch["sent"], batch["lang"])
        if self.hparams.tagger_use_crf:
            mask = (batch["labels"] != LABEL_PAD_ID).float()
            energy = self.crf(hs, mask=mask)
            target = batch["labels"].masked_fill(
                batch["labels"] == LABEL_PAD_ID, self.nb_labels)
            loss = self.crf.loss(energy, target, mask=mask)
            log_probs = energy
        else:
            logits = self.classifier(hs)
            log_probs = F.log_softmax(logits, dim=-1)

            loss = F.nll_loss(
                log_probs.view(-1, self.nb_labels),
                batch["labels"].view(-1),
                ignore_index=LABEL_PAD_ID,
            )
        return loss, log_probs

    def training_step(self, batch, batch_idx):
        loss, _ = self.forward(batch)
        self.log("loss", loss)
        return loss

    def evaluation_step_helper(self, batch, prefix):
        loss, log_probs = self.forward(batch)

        assert (len(set(batch["lang"])) == 1
                ), "eval batch should contain only one language"
        lang = batch["lang"][0]
        if self.hparams.tagger_use_crf:
            energy = log_probs
            prediction = self.crf.decode(
                energy, mask=(batch["labels"] != LABEL_PAD_ID).float())
            self.metrics[lang].add(batch["labels"], prediction)
        else:
            self.metrics[lang].add(batch["labels"], log_probs)

        result = dict()
        result[f"{prefix}_{lang}_loss"] = loss
        return result

    def prepare_datasets(self, split: str) -> List[Dataset]:
        hparams = self.hparams
        data_class: Type[Dataset]
        if hparams.task == Task.conllner:
            data_class = ConllNER
        elif hparams.task == Task.wikiner:
            data_class = WikiAnnNER
        elif hparams.task == Task.udpos:
            data_class = UdPOS
        else:
            raise ValueError(f"Unsupported task: {hparams.task}")

        if split == Split.train:
            return self.prepare_datasets_helper(data_class, hparams.trn_langs,
                                                Split.train,
                                                hparams.max_trn_len)
        elif split == Split.dev:
            return self.prepare_datasets_helper(data_class, hparams.val_langs,
                                                Split.dev, hparams.max_tst_len)
        elif split == Split.test:
            return self.prepare_datasets_helper(data_class, hparams.tst_langs,
                                                Split.test,
                                                hparams.max_tst_len)
        else:
            raise ValueError(f"Unsupported split: {hparams.split}")

    @classmethod
    def add_model_specific_args(cls, parser):
        parser.add_argument("--tagger_use_crf",
                            default=False,
                            type=util.str2bool)
        return parser