コード例 #1
0
ファイル: modeling.py プロジェクト: zphang/nyu-jiant
def get_copy_of_bert_encoder_with_adapters(
        bert_encoder: modeling_bert.BertModel,
        adapter_config: AdapterConfig) -> Tuple[nn.Module, Dict]:
    """Returns a copy of BertModel with adapters, and a dictionary of adapter modules added

    We're going to make a deepcopy, and then reassign the old parameters
    """
    assert isinstance(bert_encoder, modeling_bert.BertModel)
    new_bert_encoder = copy.deepcopy(bert_encoder)
    adapter_modules = add_adapters_to_bert_encoder(
        bert_encoder=new_bert_encoder,
        adapter_config=adapter_config,
    )
    for name, param in bert_encoder.named_parameters():
        *prefixes, leaf_param_name = name.split(".")
        curr = new_bert_encoder
        for prefix in prefixes:
            curr = getattr(curr, prefix)
        setattr(curr, leaf_param_name, param)
    return new_bert_encoder, adapter_modules
class DocumentBertLSTM(BertPreTrainedModel):
    """
    BERT output over document in LSTM
    """
    def __init__(self, bert_model_config: BertConfig):
        super(DocumentBertLSTM, self).__init__(bert_model_config)
        self.bert = BertModel(bert_model_config)
        self.bert_batch_size = self.bert.config.bert_batch_size
        self.dropout = nn.Dropout(p=bert_model_config.hidden_dropout_prob)
        self.lstm = LSTM(
            bert_model_config.hidden_size,
            bert_model_config.hidden_size,
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=bert_model_config.hidden_dropout_prob),
            nn.Linear(bert_model_config.hidden_size,
                      bert_model_config.num_labels), nn.Tanh())

    #input_ids, token_type_ids, attention_masks
    def forward(self,
                document_batch: torch.Tensor,
                document_sequence_lengths: list,
                device='cuda'):

        #contains all BERT sequences
        #bert should output a (batch_size, num_sequences, bert_hidden_size)
        bert_output = torch.zeros(size=(document_batch.shape[0],
                                        min(document_batch.shape[1],
                                            self.bert_batch_size),
                                        self.bert.config.hidden_size),
                                  dtype=torch.float,
                                  device=device)

        #only pass through bert_batch_size numbers of inputs into bert.
        #this means that we are possibly cutting off the last part of documents.
        #use_grad = not freeze_bert
        #with torch.set_grad_enabled(False):

        for doc_id in range(document_batch.shape[0]):
            bert_output[doc_id][:self.bert_batch_size] = self.dropout(
                self.bert(document_batch[doc_id][:self.bert_batch_size, 0],
                          token_type_ids=document_batch[doc_id]
                          [:self.bert_batch_size, 1],
                          attention_mask=document_batch[doc_id]
                          [:self.bert_batch_size, 2])[1])

        #lstm expects a ( num_sequences, batch_size (i.e. number of documents) , bert_hidden_size )
        #self.lstm.flatten_parameters()
        output, (_, _) = self.lstm(bert_output.permute(1, 0, 2))

        #print(bert_output.requires_grad)
        #print(output.requires_grad)

        last_layer = output[-1]
        #print("Last LSTM layer shape:",last_layer.shape)

        prediction = self.classifier(last_layer)
        #print("Prediction Shape", prediction.shape)
        assert prediction.shape[0] == document_batch.shape[0]
        return prediction

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True

    def unfreeze_bert_encoder_last_layers(self):
        for name, param in self.bert.named_parameters():
            if "encoder.layer.11" in name or "pooler" in name:
                param.requires_grad = True

    def unfreeze_bert_encoder_pooler_layer(self):
        for name, param in self.bert.named_parameters():
            if "pooler" in name:
                param.requires_grad = True
class DocumentBertTransformer(BertPreTrainedModel):
    """
    BERT -> TransformerEncoder -> Max over attention output.
    """
    def __init__(self, bert_model_config: BertConfig):
        super(DocumentBertTransformer, self).__init__(bert_model_config)
        self.bert = BertModel(bert_model_config)
        self.bert_batch_size = self.bert.config.bert_batch_size
        self.dropout = nn.Dropout(p=bert_model_config.hidden_dropout_prob)

        encoder_layer = TransformerEncoderLayer(
            d_model=bert_model_config.hidden_size,
            nhead=6,
            dropout=bert_model_config.hidden_dropout_prob)
        self.transformer_encoder = TransformerEncoder(encoder_layer,
                                                      num_layers=6)
        self.classifier = nn.Sequential(
            nn.Dropout(p=bert_model_config.hidden_dropout_prob),
            nn.Linear(bert_model_config.hidden_size,
                      bert_model_config.num_labels), nn.Tanh())

    #input_ids, token_type_ids, attention_masks
    def forward(self,
                document_batch: torch.Tensor,
                document_sequence_lengths: list,
                device='cuda'):

        #contains all BERT sequences
        #bert should output a (batch_size, num_sequences, bert_hidden_size)
        bert_output = torch.zeros(size=(document_batch.shape[0],
                                        min(document_batch.shape[1],
                                            self.bert_batch_size),
                                        self.bert.config.hidden_size),
                                  dtype=torch.float,
                                  device=device)

        #only pass through bert_batch_size numbers of inputs into bert.
        #this means that we are possibly cutting off the last part of documents.
        for doc_id in range(document_batch.shape[0]):
            bert_output[doc_id][:self.bert_batch_size] = self.dropout(
                self.bert(document_batch[doc_id][:self.bert_batch_size, 0],
                          token_type_ids=document_batch[doc_id]
                          [:self.bert_batch_size, 1],
                          attention_mask=document_batch[doc_id]
                          [:self.bert_batch_size, 2])[1])

        transformer_output = self.transformer_encoder(
            bert_output.permute(1, 0, 2))

        #print(transformer_output.shape)

        prediction = self.classifier(
            transformer_output.permute(1, 0, 2).max(dim=1)[0])
        assert prediction.shape[0] == document_batch.shape[0]
        return prediction

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True

    def unfreeze_bert_encoder_last_layers(self):
        for name, param in self.bert.named_parameters():
            if "encoder.layer.11" in name or "pooler" in name:
                param.requires_grad = True

    def unfreeze_bert_encoder_pooler_layer(self):
        for name, param in self.bert.named_parameters():
            if "pooler" in name:
                param.requires_grad = True
コード例 #4
0
class MtlEncoderRanker(BertPreTrainedModel):  # type: ignore
    def __init__(self, config: BertConfig, **kwargs: Any):
        """The classification init is a super set of LM init"""
        super().__init__(config, **kwargs)
        self.config = config
        self.bert = BertModel(config=self.config)

        self.lm_head = BertOnlyMLMHead(self.config)
        self.lm_head.apply(self._init_weights)

        self.qa_head = BertOnlyMLMHead(self.config)
        self.qa_head.apply(self._init_weights)

        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.config.hidden_size,
                                    self.config.num_labels)
        self.classifier.apply(self._init_weights)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,
        mode: str = "summarizer",
        input_weights: Optional[torch.Tensor] = None,
        **kwargs: Any,
    ) -> Any:
        """Versatile forward interface. By default it should behaves as an LM head so it's
           compatible with the `generate()` interface.

        labels: Labels for ranking.

        """
        model_outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=output_hidden_states,
            output_attentions=output_attentions,
        )

        if mode == "summarizer":
            lm_logits = self.lm_head(model_outputs[0])
            if labels is None:
                labels = kwargs.get("lm_labels", None)
            if labels is not None:
                if input_weights is None:
                    lm_loss = F.cross_entropy(
                        lm_logits.view(-1, self.config.vocab_size),
                        labels.reshape(-1))
                else:
                    lm_loss = F.cross_entropy(
                        lm_logits.view(-1, self.config.vocab_size),
                        labels.reshape(-1),
                        reduction="none",
                    )
                    # Weigh different examples
                    lm_loss = lm_loss.reshape(input_ids.size(0), -1)
                    lm_loss = lm_loss * input_weights.reshape(
                        input_ids.size(0), 1)
                    lm_loss = lm_loss[labels != -100].mean()
                outputs = (lm_loss, lm_logits) + model_outputs[2:]
            else:
                outputs = (lm_logits, ) + model_outputs[2:]
            return outputs
        elif mode == "qa":
            qa_logits = self.qa_head(model_outputs[0])
            if labels is not None:
                qa_loss = F.cross_entropy(
                    qa_logits.view(-1, self.config.vocab_size),
                    labels.view(-1))
                outputs = (qa_loss, qa_logits) + model_outputs[2:]
            else:
                outputs = (qa_logits, ) + model_outputs[2:]
            return outputs
        elif mode == "ranker":
            rank_logits = self.classifier(self.dropout(model_outputs[1]))
            if labels is not None:
                loss = F.cross_entropy(
                    rank_logits.view(-1, self.config.num_labels),
                    labels.view(-1))
                outputs = (loss, rank_logits) + model_outputs[2:]
            else:
                outputs = (rank_logits, ) + model_outputs[2:]
            return outputs
        else:
            assert False, f"Unknown mode {mode}"

    def get_output_embeddings(self) -> nn.Module:  # type: ignore
        return self.qa_head.predictions.decoder  # type: ignore

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs: Any,
    ) -> Dict[str, Union[bool, torch.Tensor, None]]:
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": kwargs.get("token_type_ids"),
        }

    def shared_grads(self) -> Optional[torch.Tensor]:
        grads_list = []
        for name, params in self.bert.named_parameters():
            if name.startswith("pooler."):
                continue
            if params.requires_grad:
                if params.grad is not None:
                    grads_list.append(params.grad.flatten().cpu())
        if not grads_list:
            return None
        grads = torch.cat(grads_list)
        return grads

    def _init_weights(self, module: nn.Module) -> None:  # type: ignore
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def get_lm_head_cls(arch: str) -> nn.Module:  # type: ignore
        if arch.startswith("albert"):
            return AlbertMLMHead  # type: ignore
        else:
            return BertOnlyMLMHead  # type: ignore