def create_and_check_bert_encoder_decoder_model_labels( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, labels, **kwargs): encoder_model = BertModel(config) decoder_model = BertLMHeadModel(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, labels=labels, ) mlm_loss = outputs_encoder_decoder[0] # check that backprop works mlm_loss.backward() self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size, )))
def create_and_check_save_and_load(self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs): encoder_model = BertModel(config) decoder_model = BertLMHeadModel(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) enc_dec_model.eval() with torch.no_grad(): outputs = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) out_2 = outputs[0].cpu().numpy() out_2[np.isnan(out_2)] = 0 with tempfile.TemporaryDirectory() as tmpdirname: enc_dec_model.save_pretrained(tmpdirname) EncoderDecoderModel.from_pretrained(tmpdirname) after_outputs = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) out_1 = after_outputs[0].cpu().numpy() out_1[np.isnan(out_1)] = 0 max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5)
def create_and_check_bert_encoder_decoder_model_from_pretrained( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs): encoder_model = BertModel(config) decoder_model = BertLMHeadModel(decoder_config) kwargs = { "encoder_model": encoder_model, "decoder_model": decoder_model } enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained( **kwargs) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size, )))
def create_and_check_bert_for_causal_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask, ): model = BertLMHeadModel(config=config) model.to(torch_device) model.eval() loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = { "loss": loss, "prediction_scores": prediction_scores, } self.parent.assertListEqual( list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.check_loss_output(result)
def create_and_check_bert_encoder_decoder_model_generate( self, input_ids, config, decoder_config, **kwargs): encoder_model = BertModel(config) decoder_model = BertLMHeadModel(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) # Bert does not have a bos token id, so use pad_token_id instead generated_output = enc_dec_model.generate( input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id) self.assertEqual(generated_output.shape, (input_ids.shape[0], ) + (decoder_config.max_length, ))
def create_and_check_bert_encoder_decoder_model( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs): encoder_model = BertModel(config) decoder_model = BertLMHeadModel(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) self.assertTrue(enc_dec_model.config.decoder.is_decoder) self.assertTrue(enc_dec_model.config.decoder.add_cross_attention) self.assertTrue(enc_dec_model.config.is_encoder_decoder) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size, ))) encoder_outputs = (encoder_hidden_states, ) outputs_encoder_decoder = enc_dec_model( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size, )))