def get_copy_of_bert_encoder_with_adapters( bert_encoder: modeling_bert.BertModel, adapter_config: AdapterConfig) -> Tuple[nn.Module, Dict]: """Returns a copy of BertModel with adapters, and a dictionary of adapter modules added We're going to make a deepcopy, and then reassign the old parameters """ assert isinstance(bert_encoder, modeling_bert.BertModel) new_bert_encoder = copy.deepcopy(bert_encoder) adapter_modules = add_adapters_to_bert_encoder( bert_encoder=new_bert_encoder, adapter_config=adapter_config, ) for name, param in bert_encoder.named_parameters(): *prefixes, leaf_param_name = name.split(".") curr = new_bert_encoder for prefix in prefixes: curr = getattr(curr, prefix) setattr(curr, leaf_param_name, param) return new_bert_encoder, adapter_modules
class DocumentBertLSTM(BertPreTrainedModel): """ BERT output over document in LSTM """ def __init__(self, bert_model_config: BertConfig): super(DocumentBertLSTM, self).__init__(bert_model_config) self.bert = BertModel(bert_model_config) self.bert_batch_size = self.bert.config.bert_batch_size self.dropout = nn.Dropout(p=bert_model_config.hidden_dropout_prob) self.lstm = LSTM( bert_model_config.hidden_size, bert_model_config.hidden_size, ) self.classifier = nn.Sequential( nn.Dropout(p=bert_model_config.hidden_dropout_prob), nn.Linear(bert_model_config.hidden_size, bert_model_config.num_labels), nn.Tanh()) #input_ids, token_type_ids, attention_masks def forward(self, document_batch: torch.Tensor, document_sequence_lengths: list, device='cuda'): #contains all BERT sequences #bert should output a (batch_size, num_sequences, bert_hidden_size) bert_output = torch.zeros(size=(document_batch.shape[0], min(document_batch.shape[1], self.bert_batch_size), self.bert.config.hidden_size), dtype=torch.float, device=device) #only pass through bert_batch_size numbers of inputs into bert. #this means that we are possibly cutting off the last part of documents. #use_grad = not freeze_bert #with torch.set_grad_enabled(False): for doc_id in range(document_batch.shape[0]): bert_output[doc_id][:self.bert_batch_size] = self.dropout( self.bert(document_batch[doc_id][:self.bert_batch_size, 0], token_type_ids=document_batch[doc_id] [:self.bert_batch_size, 1], attention_mask=document_batch[doc_id] [:self.bert_batch_size, 2])[1]) #lstm expects a ( num_sequences, batch_size (i.e. number of documents) , bert_hidden_size ) #self.lstm.flatten_parameters() output, (_, _) = self.lstm(bert_output.permute(1, 0, 2)) #print(bert_output.requires_grad) #print(output.requires_grad) last_layer = output[-1] #print("Last LSTM layer shape:",last_layer.shape) prediction = self.classifier(last_layer) #print("Prediction Shape", prediction.shape) assert prediction.shape[0] == document_batch.shape[0] return prediction def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True def unfreeze_bert_encoder_last_layers(self): for name, param in self.bert.named_parameters(): if "encoder.layer.11" in name or "pooler" in name: param.requires_grad = True def unfreeze_bert_encoder_pooler_layer(self): for name, param in self.bert.named_parameters(): if "pooler" in name: param.requires_grad = True
class DocumentBertTransformer(BertPreTrainedModel): """ BERT -> TransformerEncoder -> Max over attention output. """ def __init__(self, bert_model_config: BertConfig): super(DocumentBertTransformer, self).__init__(bert_model_config) self.bert = BertModel(bert_model_config) self.bert_batch_size = self.bert.config.bert_batch_size self.dropout = nn.Dropout(p=bert_model_config.hidden_dropout_prob) encoder_layer = TransformerEncoderLayer( d_model=bert_model_config.hidden_size, nhead=6, dropout=bert_model_config.hidden_dropout_prob) self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) self.classifier = nn.Sequential( nn.Dropout(p=bert_model_config.hidden_dropout_prob), nn.Linear(bert_model_config.hidden_size, bert_model_config.num_labels), nn.Tanh()) #input_ids, token_type_ids, attention_masks def forward(self, document_batch: torch.Tensor, document_sequence_lengths: list, device='cuda'): #contains all BERT sequences #bert should output a (batch_size, num_sequences, bert_hidden_size) bert_output = torch.zeros(size=(document_batch.shape[0], min(document_batch.shape[1], self.bert_batch_size), self.bert.config.hidden_size), dtype=torch.float, device=device) #only pass through bert_batch_size numbers of inputs into bert. #this means that we are possibly cutting off the last part of documents. for doc_id in range(document_batch.shape[0]): bert_output[doc_id][:self.bert_batch_size] = self.dropout( self.bert(document_batch[doc_id][:self.bert_batch_size, 0], token_type_ids=document_batch[doc_id] [:self.bert_batch_size, 1], attention_mask=document_batch[doc_id] [:self.bert_batch_size, 2])[1]) transformer_output = self.transformer_encoder( bert_output.permute(1, 0, 2)) #print(transformer_output.shape) prediction = self.classifier( transformer_output.permute(1, 0, 2).max(dim=1)[0]) assert prediction.shape[0] == document_batch.shape[0] return prediction def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True def unfreeze_bert_encoder_last_layers(self): for name, param in self.bert.named_parameters(): if "encoder.layer.11" in name or "pooler" in name: param.requires_grad = True def unfreeze_bert_encoder_pooler_layer(self): for name, param in self.bert.named_parameters(): if "pooler" in name: param.requires_grad = True
class MtlEncoderRanker(BertPreTrainedModel): # type: ignore def __init__(self, config: BertConfig, **kwargs: Any): """The classification init is a super set of LM init""" super().__init__(config, **kwargs) self.config = config self.bert = BertModel(config=self.config) self.lm_head = BertOnlyMLMHead(self.config) self.lm_head.apply(self._init_weights) self.qa_head = BertOnlyMLMHead(self.config) self.qa_head.apply(self._init_weights) self.dropout = nn.Dropout(self.config.hidden_dropout_prob) self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels) self.classifier.apply(self._init_weights) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, labels: Optional[torch.Tensor] = None, mode: str = "summarizer", input_weights: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Any: """Versatile forward interface. By default it should behaves as an LM head so it's compatible with the `generate()` interface. labels: Labels for ranking. """ model_outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) if mode == "summarizer": lm_logits = self.lm_head(model_outputs[0]) if labels is None: labels = kwargs.get("lm_labels", None) if labels is not None: if input_weights is None: lm_loss = F.cross_entropy( lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) else: lm_loss = F.cross_entropy( lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1), reduction="none", ) # Weigh different examples lm_loss = lm_loss.reshape(input_ids.size(0), -1) lm_loss = lm_loss * input_weights.reshape( input_ids.size(0), 1) lm_loss = lm_loss[labels != -100].mean() outputs = (lm_loss, lm_logits) + model_outputs[2:] else: outputs = (lm_logits, ) + model_outputs[2:] return outputs elif mode == "qa": qa_logits = self.qa_head(model_outputs[0]) if labels is not None: qa_loss = F.cross_entropy( qa_logits.view(-1, self.config.vocab_size), labels.view(-1)) outputs = (qa_loss, qa_logits) + model_outputs[2:] else: outputs = (qa_logits, ) + model_outputs[2:] return outputs elif mode == "ranker": rank_logits = self.classifier(self.dropout(model_outputs[1])) if labels is not None: loss = F.cross_entropy( rank_logits.view(-1, self.config.num_labels), labels.view(-1)) outputs = (loss, rank_logits) + model_outputs[2:] else: outputs = (rank_logits, ) + model_outputs[2:] return outputs else: assert False, f"Unknown mode {mode}" def get_output_embeddings(self) -> nn.Module: # type: ignore return self.qa_head.predictions.decoder # type: ignore def prepare_inputs_for_generation( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Dict[str, Union[bool, torch.Tensor, None]]: return { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": kwargs.get("token_type_ids"), } def shared_grads(self) -> Optional[torch.Tensor]: grads_list = [] for name, params in self.bert.named_parameters(): if name.startswith("pooler."): continue if params.requires_grad: if params.grad is not None: grads_list.append(params.grad.flatten().cpu()) if not grads_list: return None grads = torch.cat(grads_list) return grads def _init_weights(self, module: nn.Module) -> None: # type: ignore """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def get_lm_head_cls(arch: str) -> nn.Module: # type: ignore if arch.startswith("albert"): return AlbertMLMHead # type: ignore else: return BertOnlyMLMHead # type: ignore