def test_advanced_inputs(self): # (config, input_ids, token_type_ids, input_mask, *unused) = \ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs( config, inputs_dict["input_ids"]) model = BartModel(config) model.to(torch_device) model.eval() # test init self.assertTrue( (model.encoder.embed_tokens.weight == model.shared.weight ).all().item()) def _check_var(module): """Check that we initialized various parameters from N(0, config.init_std).""" self.assertAlmostEqual( torch.std(module.weight).item(), config.init_std, 2) _check_var(model.encoder.embed_tokens) _check_var(model.encoder.layers[0].self_attn.k_proj) _check_var(model.encoder.layers[0].fc1) _check_var(model.encoder.embed_positions) decoder_features_with_created_mask = model.forward(**inputs_dict)[0] decoder_features_with_passed_mask = model.forward( decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict)[0] _assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask) useless_mask = torch.zeros_like(decoder_attn_mask) decoder_features = model.forward(decoder_attention_mask=useless_mask, **inputs_dict)[0] self.assertTrue(isinstance( decoder_features, torch.Tensor)) # no hidden states or attentions self.assertEqual(decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model)) if decoder_attn_mask.min().item() < -1e3: # some tokens were masked self.assertFalse( (decoder_features_with_created_mask == decoder_features ).all().item()) # Test different encoder attention masks decoder_features_with_long_encoder_mask = model.forward( inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long())[0] _assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
def test_initialization_more(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = BartModel(config) model.to(torch_device) model.eval() # test init self.assertTrue((model.encoder.embed_tokens.weight == model.shared.weight).all().item()) def _check_var(module): """Check that we initialized various parameters from N(0, config.init_std).""" self.assertAlmostEqual(torch.std(module.weight).item(), config.init_std, 2) _check_var(model.encoder.embed_tokens) _check_var(model.encoder.layers[0].self_attn.k_proj) _check_var(model.encoder.layers[0].fc1) _check_var(model.encoder.embed_positions)