예제 #1
0
    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)

        attention_mask = None
        if self.use_attention_mask:
            attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)

        config = ProphetNetConfig(
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_encoder_layers=self.num_encoder_layers,
            num_decoder_layers=self.num_decoder_layers,
            decoder_ffn_dim=self.decoder_ffn_dim,
            encoder_ffn_dim=self.encoder_ffn_dim,
            num_encoder_attention_heads=self.num_encoder_attention_heads,
            num_decoder_attention_heads=self.num_decoder_attention_heads,
            eos_token_id=self.eos_token_id,
            bos_token_id=self.bos_token_id,
            use_cache=self.use_cache,
            pad_token_id=self.pad_token_id,
            decoder_start_token_id=self.decoder_start_token_id,
            num_buckets=self.num_buckets,
            relative_max_distance=self.relative_max_distance,
            disable_ngram_loss=self.disable_ngram_loss,
            max_position_embeddings=self.max_position_embeddings,
            add_cross_attention=self.add_cross_attention,
            is_encoder_decoder=self.is_encoder_decoder,
            return_dict=self.return_dict,
        )

        return (
            config,
            input_ids,
            attention_mask,
        )
예제 #2
0
    def test_config_save(self):
        config = self.model_tester.prepare_config_and_inputs()[0]
        config.add_cross_attention = False
        with tempfile.TemporaryDirectory() as tmp_dirname:
            config.save_pretrained(tmp_dirname)
            config = ProphetNetConfig.from_pretrained(tmp_dirname)

        self.assertFalse(config.add_cross_attention)
예제 #3
0
    def __init__(self, config, dataset):
        super(ProphetNet, self).__init__(config, dataset)
        self.pretrained_model_path = config['pretrained_model_path']
        self.config = ProphetNetConfig.from_pretrained(
            self.pretrained_model_path)
        self.tokenizer = ProphetNetTokenizer.from_pretrained(
            self.pretrained_model_path)
        self.model = ProphetNetForConditionalGeneration.from_pretrained(
            self.pretrained_model_path, config=self.config)

        self.padding_token_idx = self.tokenizer.pad_token_id
        self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx,
                                        reduction='none')