コード例 #1
0
    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.decoder_seq_length],
                               self.vocab_size)

        attention_mask = None
        if self.use_attention_mask:
            attention_mask = ids_tensor(
                [self.batch_size, self.decoder_seq_length], vocab_size=2)

        lm_labels = None
        if self.use_labels:
            lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length],
                                   self.vocab_size)

        config = PLBartConfig(
            vocab_size=self.vocab_size,
            d_model=self.d_model,
            decoder_layers=self.decoder_layers,
            decoder_ffn_dim=self.decoder_ffn_dim,
            encoder_attention_heads=self.encoder_attention_heads,
            decoder_attention_heads=self.decoder_attention_heads,
            eos_token_id=self.eos_token_id,
            bos_token_id=self.bos_token_id,
            use_cache=self.use_cache,
            pad_token_id=self.pad_token_id,
            decoder_start_token_id=self.decoder_start_token_id,
            max_position_embeddings=self.max_position_embeddings,
            is_encoder_decoder=self.is_encoder_decoder,
        )

        return (config, input_ids, attention_mask, lm_labels)
コード例 #2
0
 def test_plbart_fast_forward(self):
     config = PLBartConfig(
         vocab_size=99,
         d_model=24,
         encoder_layers=2,
         decoder_layers=2,
         encoder_attention_heads=2,
         decoder_attention_heads=2,
         encoder_ffn_dim=32,
         decoder_ffn_dim=32,
         max_position_embeddings=48,
         add_final_layer_norm=True,
     )
     lm_model = PLBartForConditionalGeneration(config).to(torch_device)
     context = torch.tensor(
         [[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]],
         device=torch_device,
         dtype=torch.long)
     summary = torch.tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]],
                            device=torch_device,
                            dtype=torch.long)
     result = lm_model(input_ids=context,
                       decoder_input_ids=summary,
                       labels=summary)
     expected_shape = (*summary.shape, config.vocab_size)
     self.assertEqual(result.logits.shape, expected_shape)
コード例 #3
0
ファイル: plbart.py プロジェクト: quantapix/qnarre
def convert_fairseq_plbart_checkpoint_from_disk(
        checkpoint_path,
        hf_config_path="uclanlp/plbart-base",
        finetuned=False,
        classification=False):
    state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
    remove_ignore_keys_(state_dict)
    s_vocab = state_dict["encoder.embed_tokens.weight"].shape[0]

    plbart_config = PLBartConfig.from_pretrained(hf_config_path,
                                                 s_vocab=s_vocab)

    state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
    if not classification:
        model = PLBartForConditionalGeneration(plbart_config)
        model.model.load_state_dict(state_dict)
        if finetuned:
            model.lm_head = make_linear_from_emb(model.model.shared)

    else:
        classification_head = {}
        for key, value in state_dict.copy().items():
            if key.startswith(
                    "classification_heads.sentence_classification_head"):
                classification_head[key.replace(
                    "classification_heads.sentence_classification_head.",
                    "")] = value
                state_dict.pop(key)
        model = PLBartForSequenceClassification(plbart_config)
        model.model.load_state_dict(state_dict)
        model.classification_head.load_state_dict(classification_head)

    return model
コード例 #4
0
 def test_plbart_java_cs_config(self):
     plbart_models = ["uclanlp/plbart-java-cs"]
     expected = {"scale_embedding": True}
     for name in plbart_models:
         config = PLBartConfig.from_pretrained(name)
         for k, v in expected.items():
             try:
                 self.assertEqual(v, getattr(config, k))
             except AssertionError as e:
                 e.args += (name, k)
                 raise
コード例 #5
0
 def get_config(self):
     return PLBartConfig(
         vocab_size=self.vocab_size,
         d_model=self.hidden_size,
         encoder_layers=self.num_hidden_layers,
         decoder_layers=self.num_hidden_layers,
         encoder_attention_heads=self.num_attention_heads,
         decoder_attention_heads=self.num_attention_heads,
         encoder_ffn_dim=self.intermediate_size,
         decoder_ffn_dim=self.intermediate_size,
         dropout=self.hidden_dropout_prob,
         attention_dropout=self.attention_probs_dropout_prob,
         max_position_embeddings=self.max_position_embeddings,
         eos_token_id=self.eos_token_id,
         bos_token_id=self.bos_token_id,
         pad_token_id=self.pad_token_id,
     )