def create_and_check_decoder_model_past( self, config, input_ids, attention_mask, lm_labels, ): config.use_cache = True model = PLBartDecoder(config=config).to(torch_device).eval() # first forward pass outputs = model(input_ids, use_cache=True) outputs_use_cache_conf = model(input_ids) outputs_no_past = model(input_ids, use_cache=False) self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) past_key_values = outputs["past_key_values"] # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) # append to next input_ids and next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) output_from_no_past = model(next_input_ids)["last_hidden_state"] output_from_past = model( next_tokens, past_key_values=past_key_values)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1, ), output_from_past.shape[-1]).item() output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[ -1] - 1, random_slice_idx].detach() output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() # test that outputs are equal for slice self.parent.assertTrue( torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def check_encoder_decoder_model_standalone(self, config, inputs_dict): model = PLBartModel(config=config).to(torch_device).eval() outputs = model(**inputs_dict) encoder_last_hidden_state = outputs.encoder_last_hidden_state last_hidden_state = outputs.last_hidden_state with tempfile.TemporaryDirectory() as tmpdirname: encoder = model.get_encoder() encoder.save_pretrained(tmpdirname) encoder = PLBartEncoder.from_pretrained(tmpdirname).to( torch_device) encoder_last_hidden_state_2 = encoder( inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[0] self.parent.assertTrue( (encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) with tempfile.TemporaryDirectory() as tmpdirname: decoder = model.get_decoder() decoder.save_pretrained(tmpdirname) decoder = PLBartDecoder.from_pretrained(tmpdirname).to( torch_device) last_hidden_state_2 = decoder( input_ids=inputs_dict["decoder_input_ids"], attention_mask=inputs_dict["decoder_attention_mask"], encoder_hidden_states=encoder_last_hidden_state, encoder_attention_mask=inputs_dict["attention_mask"], )[0] self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
def create_and_check_decoder_model_attention_mask_past( self, config, input_ids, attention_mask, lm_labels, ): model = PLBartDecoder(config=config).to(torch_device).eval() # create attention mask attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) half_seq_length = input_ids.shape[-1] // 2 attn_mask[:, half_seq_length:] = 0 # first forward pass past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) # change a random masked slice from input_ids random_seq_idx_to_change = ids_tensor( (1, ), half_seq_length).item() + 1 random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens # append to next input_ids and attn_mask next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) attn_mask = torch.cat( [ attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device) ], dim=1, ) # get two different outputs output_from_no_past = model( next_input_ids, attention_mask=attn_mask)["last_hidden_state"] output_from_past = model( next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1, ), output_from_past.shape[-1]).item() output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[ -1] - 1, random_slice_idx].detach() output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() # test that outputs are equal for slice self.parent.assertTrue( torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))