def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): model = M2M100Model(config=config).get_decoder().to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] head_mask = inputs_dict["head_mask"] # first forward pass outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() # create hypothetical multiple next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) next_attn_mask = ids_tensor((self.batch_size, 3), 2) # append to next input_ids and next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ "last_hidden_state" ] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
def check_encoder_decoder_model_standalone(self, config, inputs_dict): model = M2M100Model(config=config).to(torch_device).eval() outputs = model(**inputs_dict) encoder_last_hidden_state = outputs.encoder_last_hidden_state last_hidden_state = outputs.last_hidden_state with tempfile.TemporaryDirectory() as tmpdirname: encoder = model.get_encoder() encoder.save_pretrained(tmpdirname) encoder = M2M100Encoder.from_pretrained(tmpdirname).to(torch_device) encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[ 0 ] self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) with tempfile.TemporaryDirectory() as tmpdirname: decoder = model.get_decoder() decoder.save_pretrained(tmpdirname) decoder = M2M100Decoder.from_pretrained(tmpdirname).to(torch_device) last_hidden_state_2 = decoder( input_ids=inputs_dict["decoder_input_ids"], attention_mask=inputs_dict["decoder_attention_mask"], encoder_hidden_states=encoder_last_hidden_state, encoder_attention_mask=inputs_dict["attention_mask"], )[0] self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
def test_inference_no_head(self): model = M2M100Model.from_pretrained("facebook/m2m100_418M").to(torch_device) input_ids = _long_tensor([[128028, 98, 12, 30527, 2732, 159, 7755, 61904, 39144, 38, 2]]) decoder_input_ids = _long_tensor([[2, 128028, 98, 12, 30527, 2732, 159, 7755, 61904, 39144, 38]]) inputs_dict = prepare_m2m_100_inputs_dict(model.config, input_ids, decoder_input_ids) with torch.no_grad(): output = model(**inputs_dict)[0] expected_shape = torch.Size((1, 11, 1024)) self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( [[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]], device=torch_device ) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
def load(args): # The below line is not useful. Maybe deleted later print('loading M2M-100 model') device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") ''' tokenizer = BertTokenizer.from_pretrained(args.m2m100_model, do_lower_case=True, cache_dir=args.cache_dir) model = BertModel.from_pretrained(args.m2m100_model, cache_dir=args.cache_dir) ''' model = M2M100Model.from_pretrained('facebook/m2m100_418M') tokenizer = M2M100Tokenizer.from_pretrained('facebook/m2m100_418M') model.to(device) if args.num_gpus > 1: model = torch.nn.DataParallel(model) model.eval() return model, tokenizer, device