def test_inference_masked_lm_long_input(self): model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096") input_ids = torch.arange(4096).unsqueeze(0) with torch.no_grad(): output = model(input_ids)[0] vocab_size = 50265 expected_shape = torch.Size((1, 4096, vocab_size)) self.assertEqual(output.shape, expected_shape) expected_slice = torch.tensor([[[-2.3914, -4.3742, -5.0956], [-4.0988, -4.2384, -7.0406], [-3.1427, -3.7192, -6.6800]]]) self.assertTrue( torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
def test_inference_masked_lm(self): model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096") input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) with torch.no_grad(): output = model(input_ids)[0] vocab_size = 50265 expected_shape = torch.Size((1, 6, vocab_size)) self.assertEqual(output.shape, expected_shape) expected_slice = torch.tensor([[[-2.1313, -3.7285, -2.2407], [-2.7047, -3.3314, -2.6408], [0.0629, -2.5166, -0.3356]]]) self.assertTrue( torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
def create_and_check_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = YosoForMaskedLM(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertEqual( result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def convert_yoso_checkpoint(checkpoint_path, yoso_config_file, pytorch_dump_path): orig_state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] config = YosoConfig.from_json_file(yoso_config_file) model = YosoForMaskedLM(config) new_state_dict = convert_checkpoint_helper(config.max_position_embeddings, orig_state_dict) print(model.load_state_dict(new_state_dict)) model.eval() model.save_pretrained(pytorch_dump_path) print( f"Checkpoint successfuly converted. Model saved at {pytorch_dump_path}" )