def test_advanced_inputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs() config.use_cache = False inputs_dict["input_ids"][:, -2:] = config.pad_token_id decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs( config, inputs_dict["input_ids"] ) model = FSMTModel(config).to(torch_device).eval() decoder_features_with_created_mask = model(**inputs_dict)[0] decoder_features_with_passed_mask = model( decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict )[0] _assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask) useless_mask = torch.zeros_like(decoder_attn_mask) decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0] self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions self.assertEqual( decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.tgt_vocab_size), ) if decoder_attn_mask.min().item() < -1e3: # some tokens were masked self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item()) # Test different encoder attention masks decoder_features_with_long_encoder_mask = model( inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long() )[0] _assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
def test_prepare_fsmt_decoder_inputs(self): config, *_ = self._get_config_and_data() input_ids = _long_tensor(([4, 4, 2])) decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]]) ignore = float("-inf") decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs( config, input_ids, decoder_input_ids) expected_causal_mask = torch.tensor( [[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad ).to(input_ids.device) self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size()) self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())