예제 #1
0
class DistilBERT(nn.Module):
    """DistilBERT model to classify news

    Based on the paper:
    DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter
    by Victor Sanh, Lysandre Debut, Julien Chaumond, Thomas Wolf
    (https://arxiv.org/abs/1910.01108)
    """
    def __init__(self, hidden_size, num_labels, drop_prob, freeze, use_img,
                 img_size):
        super(DistilBERT, self).__init__()
        self.img_size = img_size
        self.use_img = use_img
        config = DistilBertConfig(vocab_size=119547)
        self.distilbert = DistilBertModel(config)
        for param in self.distilbert.parameters():
            param.requires_grad = not freeze
        self.classifier = layers.DistilBERTClassifier(hidden_size,
                                                      num_labels,
                                                      drop_prob=drop_prob,
                                                      use_img=use_img,
                                                      img_size=img_size)

    def forward(self, input_idxs, atten_masks):
        con_x = self.distilbert(input_ids=input_idxs,
                                attention_mask=atten_masks)[0][:, 0]
        # img_x = self.resnet18(images).view(-1, self.img_size) if self.use_img else None
        logit = self.classifier(con_x)
        log = torch.sigmoid(logit)

        return log
예제 #2
0
class TextEncoder(nn.Module):
    def __init__(self,
                 model_name=CFG.text_encoder_model,
                 pretrained=CFG.pretrained,
                 trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())

        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]
예제 #3
0
class DistilEmbeddingBertLightning(LightningModule):
    def __init__(self, config: Dict):
        super().__init__()

        self.config = config
        self.model_config = DistilBertConfig(**self.config["model"])
        self.model = DistilBertModel(self.model_config)
        self.criterion = nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')

    def forward(self, input_ids, attention_mask, embedding):
        output = self.model(input_ids=input_ids,
                            attention_mask=attention_mask,
                            return_dict=True,
                            output_hidden_states=True)
        target_embedding = embedding
        predicted_embedding = output.hidden_states[-1][:, 0, :]
        batch_size = target_embedding.size(0)
        loss = self.criterion(target_embedding, predicted_embedding,
                              torch.ones(batch_size))
        return loss

    def training_step(self, batch, batch_nb):
        train_loss = self(**batch)
        self.log("train_loss", train_loss, prog_bar=True, logger=True)
        return train_loss

    def validation_step(self, batch, batch_nb):
        val_loss = self(**batch)
        self.log("val_loss", val_loss, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(),
                          lr=self.config["learning_rate"])
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.config["num_warmup_steps"],
            num_training_steps=self.config["num_training_steps"])
        return [optimizer], [scheduler]
예제 #4
0
class DistilBertForMultiLabelSequenceClassification(DistilBertPreTrainedModel):
    """
    DistilBert model adapted for multi-label sequence classification.
    Note that for imbalance problems will also provide an extra parameter to add inside
    the loss function to integrate the classes distribution.
    """
    def __init__(self, config):
        super(DistilBertForMultiLabelSequenceClassification,
              self).__init__(config)
        self.num_labels = config.num_labels
        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(),
            nn.Dropout(config.hidden_dropout_prob))
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.pos_weight = torch.Tensor(
            config.pos_weight).to(device) if config.use_pos_weight else None

        self.init_weights()

    def forward(self,
                input_ids=None,
                attention_mask=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None):
        """
        :param input_ids: sentence or sentences represented as tokens
        :param attention_mask: tells the model which tokens in the input_ids are words and which are padding.
                               1 indicates a token and 0 indicates padding.
        :param head_mask: mask to nullify selected heads of the self-attention modules
        :param inputs_embeds: Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an
                              embedded representation. This is useful if you want more control over how to convert
                              :obj:`input_ids` indices into associated vectors than the model's internal embedding
                              lookup matrix.
        :param labels: target for each input
        :return:
        """
        distilbert_output = self.distilbert(input_ids=input_ids,
                                            attention_mask=attention_mask,
                                            head_mask=head_mask,
                                            inputs_embeds=inputs_embeds)
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        logits = self.classifier(pooled_output)  # (bs, dim)

        outputs = (logits, ) + distilbert_output[1:]
        if labels is not None:
            loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
            labels = labels.float()
            loss = loss_fct(logits.view(-1, self.num_labels),
                            labels.view(-1, self.num_labels))
            outputs = (loss, ) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)

    def freeze_bert_encoder(self):
        """Freeze DistilBERT layers"""
        for param in self.distilbert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        """Unfreeze DistilBERT layers"""
        for param in self.distilbert.parameters():
            param.requires_grad = True