def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our BERT structure. """ bart = torch.hub.load("pytorch/fairseq", checkpoint_path) bart.eval() # disable dropout bart.model.upgrade_state_dict(bart.model.state_dict()) hf_model_name = checkpoint_path.replace(".", "-") config = BartConfig.from_pretrained(hf_model_name) tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0) tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode( SAMPLE_TEXT, return_tensors="pt").unsqueeze(0) assert torch.eq(tokens, tokens2).all() if checkpoint_path in ["bart.large", "bart.large.cnn"]: state_dict = bart.model.state_dict() for k in IGNORE_KEYS: state_dict.pop(k, None) state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] model = BartModel(config) their_output = bart.extract_features(tokens) else: # MNLI Case state_dict = bart.state_dict() for k in IGNORE_KEYS: state_dict.pop(k, None) state_dict["model.shared.weight"] = state_dict[ "model.decoder.embed_tokens.weight"] for src, dest in rename_keys: rename_key(state_dict, src, dest) model = BartForSequenceClassification(config) their_output = bart.eval("mnli", tokens, return_logits=True) # Load state dict model.load_state_dict(state_dict) model.eval() # Check results if checkpoint_path == "bart.large.cnn": # generate doesnt work yet model = BartForMaskedLM(config, base_model=model) assert "lm_head.weight" in model.state_dict() assert model.lm_head.out_features == config.max_position_embeddings model.eval() our_outputs = model.model.forward(tokens)[0] else: our_outputs = model.forward(tokens)[0] assert their_output.shape == our_outputs.shape assert (their_output == our_outputs).all().item() Path(pytorch_dump_folder_path).mkdir(exist_ok=True) model.save_pretrained(pytorch_dump_folder_path)
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our BERT structure. """ b2 = torch.hub.load("pytorch/fairseq", checkpoint_path) b2.eval() # disable dropout b2.model.upgrade_state_dict(b2.model.state_dict()) config = BartConfig() tokens = b2.encode(SAMPLE_TEXT).unsqueeze(0) tokens2 = BartTokenizer.from_pretrained("bart-large").encode( SAMPLE_TEXT).unsqueeze(0) assert torch.eq(tokens, tokens2).all() # assert their_output.size() == (1, 11, 1024) if checkpoint_path == "bart.large": state_dict = b2.model.state_dict() state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] model = BartModel(config) their_output = b2.extract_features(tokens) else: # MNLI Case state_dict = b2.state_dict() state_dict["model.shared.weight"] = state_dict[ "model.decoder.embed_tokens.weight"] for src, dest in rename_keys: rename_key(state_dict, src, dest) state_dict.pop("_float_tensor", None) model = BartForSequenceClassification(config) their_output = b2.predict("mnli", tokens, return_logits=True) for k in IGNORE_KEYS: state_dict.pop(k, None) model.load_state_dict(state_dict) model.eval() our_outputs = model.forward(tokens)[0] assert their_output.shape == our_outputs.shape assert (their_output == our_outputs).all().item() Path(pytorch_dump_folder_path).mkdir(exist_ok=True) model.save_pretrained(pytorch_dump_folder_path)