class DistilBERT(nn.Module): """DistilBERT model to classify news Based on the paper: DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter by Victor Sanh, Lysandre Debut, Julien Chaumond, Thomas Wolf (https://arxiv.org/abs/1910.01108) """ def __init__(self, hidden_size, num_labels, drop_prob, freeze, use_img, img_size): super(DistilBERT, self).__init__() self.img_size = img_size self.use_img = use_img config = DistilBertConfig(vocab_size=119547) self.distilbert = DistilBertModel(config) for param in self.distilbert.parameters(): param.requires_grad = not freeze self.classifier = layers.DistilBERTClassifier(hidden_size, num_labels, drop_prob=drop_prob, use_img=use_img, img_size=img_size) def forward(self, input_idxs, atten_masks): con_x = self.distilbert(input_ids=input_idxs, attention_mask=atten_masks)[0][:, 0] # img_x = self.resnet18(images).view(-1, self.img_size) if self.use_img else None logit = self.classifier(con_x) log = torch.sigmoid(logit) return log
class TextEncoder(nn.Module): def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable): super().__init__() if pretrained: self.model = DistilBertModel.from_pretrained(model_name) else: self.model = DistilBertModel(config=DistilBertConfig()) for p in self.model.parameters(): p.requires_grad = trainable # we are using the CLS token hidden representation as the sentence's embedding self.target_token_idx = 0 def forward(self, input_ids, attention_mask): output = self.model(input_ids=input_ids, attention_mask=attention_mask) last_hidden_state = output.last_hidden_state return last_hidden_state[:, self.target_token_idx, :]
class DistilEmbeddingBertLightning(LightningModule): def __init__(self, config: Dict): super().__init__() self.config = config self.model_config = DistilBertConfig(**self.config["model"]) self.model = DistilBertModel(self.model_config) self.criterion = nn.CosineEmbeddingLoss(margin=0.0, reduction='mean') def forward(self, input_ids, attention_mask, embedding): output = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True, output_hidden_states=True) target_embedding = embedding predicted_embedding = output.hidden_states[-1][:, 0, :] batch_size = target_embedding.size(0) loss = self.criterion(target_embedding, predicted_embedding, torch.ones(batch_size)) return loss def training_step(self, batch, batch_nb): train_loss = self(**batch) self.log("train_loss", train_loss, prog_bar=True, logger=True) return train_loss def validation_step(self, batch, batch_nb): val_loss = self(**batch) self.log("val_loss", val_loss, prog_bar=True, logger=True) def configure_optimizers(self): optimizer = AdamW(self.model.parameters(), lr=self.config["learning_rate"]) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=self.config["num_warmup_steps"], num_training_steps=self.config["num_training_steps"]) return [optimizer], [scheduler]
class DistilBertForMultiLabelSequenceClassification(DistilBertPreTrainedModel): """ DistilBert model adapted for multi-label sequence classification. Note that for imbalance problems will also provide an extra parameter to add inside the loss function to integrate the classes distribution. """ def __init__(self, config): super(DistilBertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = config.num_labels self.distilbert = DistilBertModel(config) self.pre_classifier = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(), nn.Dropout(config.hidden_dropout_prob)) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.pos_weight = torch.Tensor( config.pos_weight).to(device) if config.use_pos_weight else None self.init_weights() def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None): """ :param input_ids: sentence or sentences represented as tokens :param attention_mask: tells the model which tokens in the input_ids are words and which are padding. 1 indicates a token and 0 indicates padding. :param head_mask: mask to nullify selected heads of the self-attention modules :param inputs_embeds: Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices into associated vectors than the model's internal embedding lookup matrix. :param labels: target for each input :return: """ distilbert_output = self.distilbert(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds) hidden_state = distilbert_output[0] # (bs, seq_len, dim) pooled_output = hidden_state[:, 0] # (bs, dim) pooled_output = self.pre_classifier(pooled_output) # (bs, dim) logits = self.classifier(pooled_output) # (bs, dim) outputs = (logits, ) + distilbert_output[1:] if labels is not None: loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) labels = labels.float() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) outputs = (loss, ) + outputs return outputs # (loss), logits, (hidden_states), (attentions) def freeze_bert_encoder(self): """Freeze DistilBERT layers""" for param in self.distilbert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): """Unfreeze DistilBERT layers""" for param in self.distilbert.parameters(): param.requires_grad = True