Ejemplo n.º 1
0
    def create_and_check_bert_encoder_decoder_model_labels(
            self, config, input_ids, attention_mask, encoder_hidden_states,
            decoder_config, decoder_input_ids, decoder_attention_mask, labels,
            **kwargs):
        encoder_model = BertModel(config)
        decoder_model = BertLMHeadModel(decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
        )

        mlm_loss = outputs_encoder_decoder[0]
        # check that backprop works
        mlm_loss.backward()

        self.assertEqual(outputs_encoder_decoder[1].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(outputs_encoder_decoder[2].shape,
                         (input_ids.shape + (config.hidden_size, )))
Ejemplo n.º 2
0
    def create_and_check_save_and_load(self, config, input_ids, attention_mask,
                                       encoder_hidden_states, decoder_config,
                                       decoder_input_ids,
                                       decoder_attention_mask, **kwargs):
        encoder_model = BertModel(config)
        decoder_model = BertLMHeadModel(decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        enc_dec_model.to(torch_device)
        enc_dec_model.eval()
        with torch.no_grad():
            outputs = enc_dec_model(
                input_ids=input_ids,
                decoder_input_ids=decoder_input_ids,
                attention_mask=attention_mask,
                decoder_attention_mask=decoder_attention_mask,
            )
            out_2 = outputs[0].cpu().numpy()
            out_2[np.isnan(out_2)] = 0

            with tempfile.TemporaryDirectory() as tmpdirname:
                enc_dec_model.save_pretrained(tmpdirname)
                EncoderDecoderModel.from_pretrained(tmpdirname)

                after_outputs = enc_dec_model(
                    input_ids=input_ids,
                    decoder_input_ids=decoder_input_ids,
                    attention_mask=attention_mask,
                    decoder_attention_mask=decoder_attention_mask,
                )
                out_1 = after_outputs[0].cpu().numpy()
                out_1[np.isnan(out_1)] = 0
                max_diff = np.amax(np.abs(out_1 - out_2))
                self.assertLessEqual(max_diff, 1e-5)
Ejemplo n.º 3
0
    def create_and_check_bert_encoder_decoder_model_from_pretrained(
            self, config, input_ids, attention_mask, encoder_hidden_states,
            decoder_config, decoder_input_ids, decoder_attention_mask,
            **kwargs):
        encoder_model = BertModel(config)
        decoder_model = BertLMHeadModel(decoder_config)
        kwargs = {
            "encoder_model": encoder_model,
            "decoder_model": decoder_model
        }
        enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
            **kwargs)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(outputs_encoder_decoder[0].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(outputs_encoder_decoder[1].shape,
                         (input_ids.shape + (config.hidden_size, )))
 def create_and_check_bert_for_causal_lm(
     self,
     config,
     input_ids,
     token_type_ids,
     input_mask,
     sequence_labels,
     token_labels,
     choice_labels,
     encoder_hidden_states,
     encoder_attention_mask,
 ):
     model = BertLMHeadModel(config=config)
     model.to(torch_device)
     model.eval()
     loss, prediction_scores = model(input_ids,
                                     attention_mask=input_mask,
                                     token_type_ids=token_type_ids,
                                     labels=token_labels)
     result = {
         "loss": loss,
         "prediction_scores": prediction_scores,
     }
     self.parent.assertListEqual(
         list(result["prediction_scores"].size()),
         [self.batch_size, self.seq_length, self.vocab_size])
     self.check_loss_output(result)
Ejemplo n.º 5
0
    def create_and_check_bert_encoder_decoder_model_generate(
            self, input_ids, config, decoder_config, **kwargs):
        encoder_model = BertModel(config)
        decoder_model = BertLMHeadModel(decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        enc_dec_model.to(torch_device)

        # Bert does not have a bos token id, so use pad_token_id instead
        generated_output = enc_dec_model.generate(
            input_ids,
            decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id)
        self.assertEqual(generated_output.shape, (input_ids.shape[0], ) +
                         (decoder_config.max_length, ))
Ejemplo n.º 6
0
    def create_and_check_bert_encoder_decoder_model(
            self, config, input_ids, attention_mask, encoder_hidden_states,
            decoder_config, decoder_input_ids, decoder_attention_mask,
            **kwargs):
        encoder_model = BertModel(config)
        decoder_model = BertLMHeadModel(decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        self.assertTrue(enc_dec_model.config.decoder.is_decoder)
        self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
        self.assertTrue(enc_dec_model.config.is_encoder_decoder)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(outputs_encoder_decoder[0].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(outputs_encoder_decoder[1].shape,
                         (input_ids.shape + (config.hidden_size, )))
        encoder_outputs = (encoder_hidden_states, )
        outputs_encoder_decoder = enc_dec_model(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(outputs_encoder_decoder[0].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(outputs_encoder_decoder[1].shape,
                         (input_ids.shape + (config.hidden_size, )))