def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) attention_mask = None if self.use_attention_mask: attention_mask = ids_tensor( [self.batch_size, self.decoder_seq_length], vocab_size=2) lm_labels = None if self.use_labels: lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) config = PLBartConfig( vocab_size=self.vocab_size, d_model=self.d_model, decoder_layers=self.decoder_layers, decoder_ffn_dim=self.decoder_ffn_dim, encoder_attention_heads=self.encoder_attention_heads, decoder_attention_heads=self.decoder_attention_heads, eos_token_id=self.eos_token_id, bos_token_id=self.bos_token_id, use_cache=self.use_cache, pad_token_id=self.pad_token_id, decoder_start_token_id=self.decoder_start_token_id, max_position_embeddings=self.max_position_embeddings, is_encoder_decoder=self.is_encoder_decoder, ) return (config, input_ids, attention_mask, lm_labels)
def test_plbart_fast_forward(self): config = PLBartConfig( vocab_size=99, d_model=24, encoder_layers=2, decoder_layers=2, encoder_attention_heads=2, decoder_attention_heads=2, encoder_ffn_dim=32, decoder_ffn_dim=32, max_position_embeddings=48, add_final_layer_norm=True, ) lm_model = PLBartForConditionalGeneration(config).to(torch_device) context = torch.tensor( [[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], device=torch_device, dtype=torch.long) summary = torch.tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], device=torch_device, dtype=torch.long) result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary) expected_shape = (*summary.shape, config.vocab_size) self.assertEqual(result.logits.shape, expected_shape)
def convert_fairseq_plbart_checkpoint_from_disk( checkpoint_path, hf_config_path="uclanlp/plbart-base", finetuned=False, classification=False): state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] remove_ignore_keys_(state_dict) s_vocab = state_dict["encoder.embed_tokens.weight"].shape[0] plbart_config = PLBartConfig.from_pretrained(hf_config_path, s_vocab=s_vocab) state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] if not classification: model = PLBartForConditionalGeneration(plbart_config) model.model.load_state_dict(state_dict) if finetuned: model.lm_head = make_linear_from_emb(model.model.shared) else: classification_head = {} for key, value in state_dict.copy().items(): if key.startswith( "classification_heads.sentence_classification_head"): classification_head[key.replace( "classification_heads.sentence_classification_head.", "")] = value state_dict.pop(key) model = PLBartForSequenceClassification(plbart_config) model.model.load_state_dict(state_dict) model.classification_head.load_state_dict(classification_head) return model
def test_plbart_java_cs_config(self): plbart_models = ["uclanlp/plbart-java-cs"] expected = {"scale_embedding": True} for name in plbart_models: config = PLBartConfig.from_pretrained(name) for k, v in expected.items(): try: self.assertEqual(v, getattr(config, k)) except AssertionError as e: e.args += (name, k) raise
def get_config(self): return PLBartConfig( vocab_size=self.vocab_size, d_model=self.hidden_size, encoder_layers=self.num_hidden_layers, decoder_layers=self.num_hidden_layers, encoder_attention_heads=self.num_attention_heads, decoder_attention_heads=self.num_attention_heads, encoder_ffn_dim=self.intermediate_size, decoder_ffn_dim=self.intermediate_size, dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, eos_token_id=self.eos_token_id, bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, )