def test_config_save(self): config = self.model_tester.prepare_config_and_inputs()[0] config.add_cross_attention = False with tempfile.TemporaryDirectory() as tmp_dirname: config.save_pretrained(tmp_dirname) config = ProphetNetConfig.from_pretrained(tmp_dirname) self.assertFalse(config.add_cross_attention)
def __init__(self, config, dataset): super(ProphetNet, self).__init__(config, dataset) self.pretrained_model_path = config['pretrained_model_path'] self.config = ProphetNetConfig.from_pretrained( self.pretrained_model_path) self.tokenizer = ProphetNetTokenizer.from_pretrained( self.pretrained_model_path) self.model = ProphetNetForConditionalGeneration.from_pretrained( self.pretrained_model_path, config=self.config) self.padding_token_idx = self.tokenizer.pad_token_id self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token_idx, reduction='none')