Beispiel #1
0
    def create_and_check_decoder_model_past_large_inputs(
            self, config, inputs_dict):
        model = PLBartModel(
            config=config).get_decoder().to(torch_device).eval()
        input_ids = inputs_dict["input_ids"]
        attention_mask = inputs_dict["attention_mask"]
        head_mask = inputs_dict["head_mask"]

        # first forward pass
        outputs = model(input_ids,
                        attention_mask=attention_mask,
                        head_mask=head_mask,
                        use_cache=True)

        output, past_key_values = outputs.to_tuple()

        # create hypothetical multiple next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
        next_attn_mask = ids_tensor((self.batch_size, 3), 2)

        # append to next input_ids and
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
        next_attention_mask = torch.cat([attention_mask, next_attn_mask],
                                        dim=-1)

        output_from_no_past = model(
            next_input_ids,
            attention_mask=next_attention_mask)["last_hidden_state"]
        output_with_past_key_values = model(next_tokens,
                                            attention_mask=next_attention_mask,
                                            past_key_values=past_key_values)
        output_from_past = output_with_past_key_values["last_hidden_state"]

        # select random slice
        random_slice_idx = ids_tensor((1, ), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -3:,
                                                        random_slice_idx].detach(
                                                        )
        output_from_past_slice = output_from_past[:, :,
                                                  random_slice_idx].detach()

        self.parent.assertTrue(
            output_from_past_slice.shape[1] == next_tokens.shape[1])

        # test that outputs are equal for slice
        self.parent.assertTrue(
            torch.allclose(output_from_past_slice,
                           output_from_no_past_slice,
                           atol=1e-3))
Beispiel #2
0
    def check_encoder_decoder_model_standalone(self, config, inputs_dict):
        model = PLBartModel(config=config).to(torch_device).eval()
        outputs = model(**inputs_dict)

        encoder_last_hidden_state = outputs.encoder_last_hidden_state
        last_hidden_state = outputs.last_hidden_state

        with tempfile.TemporaryDirectory() as tmpdirname:
            encoder = model.get_encoder()
            encoder.save_pretrained(tmpdirname)
            encoder = PLBartEncoder.from_pretrained(tmpdirname).to(
                torch_device)

        encoder_last_hidden_state_2 = encoder(
            inputs_dict["input_ids"],
            attention_mask=inputs_dict["attention_mask"])[0]

        self.parent.assertTrue(
            (encoder_last_hidden_state_2 -
             encoder_last_hidden_state).abs().max().item() < 1e-3)

        with tempfile.TemporaryDirectory() as tmpdirname:
            decoder = model.get_decoder()
            decoder.save_pretrained(tmpdirname)
            decoder = PLBartDecoder.from_pretrained(tmpdirname).to(
                torch_device)

        last_hidden_state_2 = decoder(
            input_ids=inputs_dict["decoder_input_ids"],
            attention_mask=inputs_dict["decoder_attention_mask"],
            encoder_hidden_states=encoder_last_hidden_state,
            encoder_attention_mask=inputs_dict["attention_mask"],
        )[0]

        self.parent.assertTrue((last_hidden_state_2 -
                                last_hidden_state).abs().max().item() < 1e-3)