def create_and_check_t5_model(
            self,
            config,
            encoder_input_ids,
            decoder_input_ids,
            encoder_attention_mask,
            decoder_attention_mask,
            decoder_lm_labels,
        ):
            model = T5Model(config=config)
            model.eval()
            decoder_output, encoder_output = model(
                encoder_input_ids=encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                encoder_attention_mask=encoder_attention_mask,
                decoder_attention_mask=decoder_attention_mask,
            )
            decoder_output, encoder_output = model(
                encoder_input_ids=encoder_input_ids,
                decoder_input_ids=decoder_input_ids)

            result = {
                "encoder_output": encoder_output,
                "decoder_output": decoder_output,
            }
            self.parent.assertListEqual(
                list(result["encoder_output"].size()),
                [self.batch_size, self.encoder_seq_length, self.hidden_size])
            self.parent.assertListEqual(
                list(result["decoder_output"].size()),
                [self.batch_size, self.decoder_seq_length, self.hidden_size])
    def check_prepare_lm_labels_via_shift_left(
        self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
    ):
        model = T5Model(config=config)
        model.to(torch_device)
        model.eval()

        # make sure that lm_labels are correctly padded from the right
        lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id)

        # add casaul pad token mask
        triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not()
        lm_labels.masked_fill_(triangular_mask, self.pad_token_id)
        decoder_input_ids = model._shift_right(lm_labels)

        for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)):
            # first item
            self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id)
            if i < decoder_input_ids_slice.shape[-1]:
                if i < decoder_input_ids.shape[-1] - 1:
                    # items before diagonal
                    self.parent.assertListEqual(
                        decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist()
                    )
                # pad items after diagonal
                if i < decoder_input_ids.shape[-1] - 2:
                    self.parent.assertListEqual(
                        decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist()
                    )
            else:
                # all items after square
                self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
    def create_and_check_model(
        self,
        config,
        input_ids,
        decoder_input_ids,
        attention_mask,
        decoder_attention_mask,
        lm_labels,
    ):
        model = T5Model(config=config)
        model.to(torch_device)
        model.eval()
        result = model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )
        result = model(input_ids=input_ids,
                       decoder_input_ids=decoder_input_ids)
        decoder_output = result.last_hidden_state
        decoder_past = result.past_key_values
        encoder_output = result.encoder_last_hidden_state

        self.parent.assertEqual(
            encoder_output.size(),
            (self.batch_size, self.encoder_seq_length, self.hidden_size))
        self.parent.assertEqual(
            decoder_output.size(),
            (self.batch_size, self.decoder_seq_length, self.hidden_size))
        # There should be `num_layers` key value embeddings stored in decoder_past
        self.parent.assertEqual(len(decoder_past), config.num_layers)
        # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
        self.parent.assertEqual(len(decoder_past[0]), 4)
Beispiel #4
0
    def create_and_check_decoder_model_past_large_inputs(
        self,
        config,
        input_ids,
        decoder_input_ids,
        attention_mask,
        decoder_attention_mask,
        lm_labels,
    ):
        model = T5Model(config=config).get_decoder().to(torch_device).eval()
        # first forward pass
        outputs = model(input_ids, use_cache=True)

        output, past_key_values = outputs.to_tuple()

        # create hypothetical multiple next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)

        # append to next input_ids and
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)

        output_from_no_past = model(next_input_ids)["last_hidden_state"]
        output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]

        # select random slice
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()

        self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])

        # test that outputs are equal for slice
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
    def create_and_check_t5_decoder_model_past(
        self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
    ):
        model = T5Model(config=config).get_decoder()
        model.to(torch_device)
        model.eval()

        # first forward pass
        outputs = model(input_ids, use_cache=True)
        outputs_use_cache_conf = model(input_ids)
        outputs_no_past = model(input_ids, use_cache=False)

        self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
        self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)

        output, past_key_value_states = outputs

        # create hypothetical next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)

        # append to next input_ids and
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)

        output_from_no_past = model(next_input_ids)[0]
        output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]

        # select random slice
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
        output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()

        # test that outputs are equal for slice
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
    def create_and_check_t5_model(
        self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
    ):
        model = T5Model(config=config)
        model.to(torch_device)
        model.eval()
        decoder_output, decoder_past, encoder_output = model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )
        decoder_output, decoder_past, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)

        result = {
            "encoder_output": encoder_output,
            "decoder_output": decoder_output,
            "decoder_past": decoder_past,
        }
        self.parent.assertListEqual(
            list(result["encoder_output"].size()), [self.batch_size, self.encoder_seq_length, self.hidden_size]
        )
        self.parent.assertListEqual(
            list(result["decoder_output"].size()), [self.batch_size, self.decoder_seq_length, self.hidden_size]
        )
        self.parent.assertEqual(len(decoder_past), 2)
        # decoder_past[0] should correspond to encoder output
        self.parent.assertTrue(torch.all(decoder_past[0][0] == encoder_output))
        # There should be `num_layers` key value embeddings stored in decoder_past[1]
        self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
        # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
        self.parent.assertEqual(len(decoder_past[1][0]), 4)
 def test_export_to_onnx(self):
     config_and_inputs = self.model_tester.prepare_config_and_inputs()
     model = T5Model(config_and_inputs[0]).to(torch_device)
     with tempfile.TemporaryDirectory() as tmpdirname:
         torch.onnx.export(
             model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
         )
 def create_and_check_t5_model_fp16_forward(
     self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
 ):
     model = T5Model(config=config)
     model.to(torch_device)
     model.half()
     model.eval()
     output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[0]
     self.parent.assertFalse(torch.isnan(output).any().item())
Beispiel #9
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = T5Config.from_json_file(config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = T5Model(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_t5(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
Beispiel #10
0
    def create_and_check_decoder_model_attention_mask_past(
        self,
        config,
        input_ids,
        decoder_input_ids,
        attention_mask,
        decoder_attention_mask,
        lm_labels,
    ):
        model = T5Model(config=config).get_decoder()
        model.to(torch_device)
        model.eval()

        # create attention mask
        attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)

        half_seq_length = input_ids.shape[-1] // 2
        attn_mask[:, half_seq_length:] = 0

        # first forward pass
        output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()

        # create hypothetical next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)

        # change a random masked slice from input_ids
        random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
        random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
        input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens

        # append to next input_ids and attn_mask
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
        attn_mask = torch.cat(
            [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
            dim=1,
        )

        # get two different outputs
        output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
        output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
            "last_hidden_state"
        ]

        # select random slice
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
        output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()

        # test that outputs are equal for slice
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))