def test_advanced_inputs(self): # (config, input_ids, token_type_ids, input_mask, *unused) = \ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs( config, inputs_dict["input_ids"]) model = BartModel(config) model.to(torch_device) model.eval() # test init self.assertTrue( (model.encoder.embed_tokens.weight == model.shared.weight ).all().item()) def _check_var(module): """Check that we initialized various parameters from N(0, config.init_std).""" self.assertAlmostEqual( torch.std(module.weight).item(), config.init_std, 2) _check_var(model.encoder.embed_tokens) _check_var(model.encoder.layers[0].self_attn.k_proj) _check_var(model.encoder.layers[0].fc1) _check_var(model.encoder.embed_positions) decoder_features_with_created_mask = model.forward(**inputs_dict)[0] decoder_features_with_passed_mask = model.forward( decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict)[0] _assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask) useless_mask = torch.zeros_like(decoder_attn_mask) decoder_features = model.forward(decoder_attention_mask=useless_mask, **inputs_dict)[0] self.assertTrue(isinstance( decoder_features, torch.Tensor)) # no hidden states or attentions self.assertEqual(decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model)) if decoder_attn_mask.min().item() < -1e3: # some tokens were masked self.assertFalse( (decoder_features_with_created_mask == decoder_features ).all().item()) # Test different encoder attention masks decoder_features_with_long_encoder_mask = model.forward( inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long())[0] _assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our BERT structure. """ bart = torch.hub.load("pytorch/fairseq", checkpoint_path) bart.eval() # disable dropout bart.model.upgrade_state_dict(bart.model.state_dict()) hf_model_name = checkpoint_path.replace(".", "-") config = BartConfig.from_pretrained(hf_model_name) tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0) tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode( SAMPLE_TEXT, return_tensors="pt").unsqueeze(0) assert torch.eq(tokens, tokens2).all() if checkpoint_path in ["bart.large", "bart.large.cnn"]: state_dict = bart.model.state_dict() for k in IGNORE_KEYS: state_dict.pop(k, None) state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] model = BartModel(config) their_output = bart.extract_features(tokens) else: # MNLI Case state_dict = bart.state_dict() for k in IGNORE_KEYS: state_dict.pop(k, None) state_dict["model.shared.weight"] = state_dict[ "model.decoder.embed_tokens.weight"] for src, dest in rename_keys: rename_key(state_dict, src, dest) model = BartForSequenceClassification(config) their_output = bart.eval("mnli", tokens, return_logits=True) # Load state dict model.load_state_dict(state_dict) model.eval() # Check results if checkpoint_path == "bart.large.cnn": # generate doesnt work yet model = BartForMaskedLM(config, base_model=model) assert "lm_head.weight" in model.state_dict() assert model.lm_head.out_features == config.max_position_embeddings model.eval() our_outputs = model.model.forward(tokens)[0] else: our_outputs = model.forward(tokens)[0] assert their_output.shape == our_outputs.shape assert (their_output == our_outputs).all().item() Path(pytorch_dump_folder_path).mkdir(exist_ok=True) model.save_pretrained(pytorch_dump_folder_path)
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our BERT structure. """ b2 = torch.hub.load("pytorch/fairseq", checkpoint_path) b2.eval() # disable dropout b2.model.upgrade_state_dict(b2.model.state_dict()) config = BartConfig() tokens = b2.encode(SAMPLE_TEXT).unsqueeze(0) tokens2 = BartTokenizer.from_pretrained("bart-large").encode( SAMPLE_TEXT).unsqueeze(0) assert torch.eq(tokens, tokens2).all() # assert their_output.size() == (1, 11, 1024) if checkpoint_path == "bart.large": state_dict = b2.model.state_dict() state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] model = BartModel(config) their_output = b2.extract_features(tokens) else: # MNLI Case state_dict = b2.state_dict() state_dict["model.shared.weight"] = state_dict[ "model.decoder.embed_tokens.weight"] for src, dest in rename_keys: rename_key(state_dict, src, dest) state_dict.pop("_float_tensor", None) model = BartForSequenceClassification(config) their_output = b2.predict("mnli", tokens, return_logits=True) for k in IGNORE_KEYS: state_dict.pop(k, None) model.load_state_dict(state_dict) model.eval() our_outputs = model.forward(tokens)[0] assert their_output.shape == our_outputs.shape assert (their_output == our_outputs).all().item() Path(pytorch_dump_folder_path).mkdir(exist_ok=True) model.save_pretrained(pytorch_dump_folder_path)