def create_and_check_model_fp16_forward( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): model = ProphetNetModel(config=config).to(torch_device).half().eval() output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_model( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): model = ProphetNetModel(config=config) model.to(torch_device) model.eval() result = model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) decoder_output = result.last_hidden_state decoder_past = result.past_key_values encoder_output = result.encoder_last_hidden_state self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size)) # There should be `num_layers` key value embeddings stored in decoder_past self.parent.assertEqual(len(decoder_past), config.num_decoder_layers) # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple self.parent.assertEqual(len(decoder_past[0]), 2) # cross-attention + uni-directional self-attention
def check_prepare_lm_labels_via_shift_left( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): model = ProphetNetModel(config=config) model.to(torch_device) model.eval() # make sure that lm_labels are correctly padded from the right lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id) # add casaul pad token mask triangular_mask = torch.tril(lm_labels.new_ones( lm_labels.shape)).logical_not() lm_labels.masked_fill_(triangular_mask, self.pad_token_id) decoder_input_ids = model._shift_right(lm_labels) for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)): # first item self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id) if i < decoder_input_ids_slice.shape[-1]: if i < decoder_input_ids.shape[-1] - 1: # items before diagonal self.parent.assertListEqual( decoder_input_ids_slice[1:i + 1].tolist(), lm_labels_slice[:i].tolist()) # pad items after diagonal if i < decoder_input_ids.shape[-1] - 2: self.parent.assertListEqual( decoder_input_ids_slice[i + 2:].tolist(), lm_labels_slice[i + 1:-1].tolist()) else: # all items after square self.parent.assertListEqual( decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
def check_model_with_attn_mask(self, config, input_ids, decoder_input_ids, *args): model = ProphetNetModel(config=config) model.to(torch_device) model.eval() outputs_no_mask = model( input_ids=input_ids[:, :5], decoder_input_ids=decoder_input_ids[:, :5], return_dict=True ) attention_mask = torch.ones_like(input_ids) decoder_attention_mask = torch.ones_like(decoder_input_ids) attention_mask[:, 5:] = 0 outputs_with_mask = model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, return_dict=True, ) # check encoder self.parent.assertTrue( torch.allclose( outputs_no_mask.encoder_last_hidden_state[0, :, 0], outputs_with_mask.encoder_last_hidden_state[0, :5, 0], atol=1e-3, ) ) # check decoder # main stream self.parent.assertTrue( torch.allclose( outputs_no_mask.last_hidden_state[0, :, 0], outputs_with_mask.last_hidden_state[0, :5, 0], atol=1e-3 ) ) # predict stream self.parent.assertTrue( torch.allclose( outputs_no_mask.last_hidden_state_ngram[0, :5, 0], outputs_with_mask.last_hidden_state_ngram[0, :5, 0], atol=1e-3, ) )