Exemple #1
0
 def create_and_check_model_fp16_forward(
     self,
     config,
     input_ids,
     decoder_input_ids,
     attention_mask,
     decoder_attention_mask,
     lm_labels,
 ):
     model = ProphetNetModel(config=config).to(torch_device).half().eval()
     output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]
     self.parent.assertFalse(torch.isnan(output).any().item())
Exemple #2
0
    def create_and_check_model(
        self,
        config,
        input_ids,
        decoder_input_ids,
        attention_mask,
        decoder_attention_mask,
        lm_labels,
    ):
        model = ProphetNetModel(config=config)
        model.to(torch_device)
        model.eval()
        result = model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )
        result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
        decoder_output = result.last_hidden_state
        decoder_past = result.past_key_values
        encoder_output = result.encoder_last_hidden_state

        self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
        self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
        # There should be `num_layers` key value embeddings stored in decoder_past
        self.parent.assertEqual(len(decoder_past), config.num_decoder_layers)
        # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
        self.parent.assertEqual(len(decoder_past[0]), 2)  # cross-attention + uni-directional self-attention
    def check_prepare_lm_labels_via_shift_left(
        self,
        config,
        input_ids,
        decoder_input_ids,
        attention_mask,
        decoder_attention_mask,
        lm_labels,
    ):
        model = ProphetNetModel(config=config)
        model.to(torch_device)
        model.eval()

        # make sure that lm_labels are correctly padded from the right
        lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id),
                               self.eos_token_id)

        # add casaul pad token mask
        triangular_mask = torch.tril(lm_labels.new_ones(
            lm_labels.shape)).logical_not()
        lm_labels.masked_fill_(triangular_mask, self.pad_token_id)
        decoder_input_ids = model._shift_right(lm_labels)

        for i, (decoder_input_ids_slice,
                lm_labels_slice) in enumerate(zip(decoder_input_ids,
                                                  lm_labels)):
            # first item
            self.parent.assertEqual(decoder_input_ids_slice[0].item(),
                                    self.decoder_start_token_id)
            if i < decoder_input_ids_slice.shape[-1]:
                if i < decoder_input_ids.shape[-1] - 1:
                    # items before diagonal
                    self.parent.assertListEqual(
                        decoder_input_ids_slice[1:i + 1].tolist(),
                        lm_labels_slice[:i].tolist())
                # pad items after diagonal
                if i < decoder_input_ids.shape[-1] - 2:
                    self.parent.assertListEqual(
                        decoder_input_ids_slice[i + 2:].tolist(),
                        lm_labels_slice[i + 1:-1].tolist())
            else:
                # all items after square
                self.parent.assertListEqual(
                    decoder_input_ids_slice[1:].tolist(),
                    lm_labels_slice[:-1].tolist())
Exemple #4
0
    def check_model_with_attn_mask(self, config, input_ids, decoder_input_ids, *args):
        model = ProphetNetModel(config=config)
        model.to(torch_device)
        model.eval()

        outputs_no_mask = model(
            input_ids=input_ids[:, :5], decoder_input_ids=decoder_input_ids[:, :5], return_dict=True
        )
        attention_mask = torch.ones_like(input_ids)
        decoder_attention_mask = torch.ones_like(decoder_input_ids)

        attention_mask[:, 5:] = 0

        outputs_with_mask = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            return_dict=True,
        )

        # check encoder
        self.parent.assertTrue(
            torch.allclose(
                outputs_no_mask.encoder_last_hidden_state[0, :, 0],
                outputs_with_mask.encoder_last_hidden_state[0, :5, 0],
                atol=1e-3,
            )
        )

        # check decoder
        # main stream
        self.parent.assertTrue(
            torch.allclose(
                outputs_no_mask.last_hidden_state[0, :, 0], outputs_with_mask.last_hidden_state[0, :5, 0], atol=1e-3
            )
        )
        # predict stream
        self.parent.assertTrue(
            torch.allclose(
                outputs_no_mask.last_hidden_state_ngram[0, :5, 0],
                outputs_with_mask.last_hidden_state_ngram[0, :5, 0],
                atol=1e-3,
            )
        )