Пример #1
0
    def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:

        # convert list to dict and tensorize input
        batch = BatchEncoding(
            {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
        )

        input_ids = batch["input_ids"]
        batch_size, expandend_input_length = input_ids.shape

        mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
        labels_mask = ~mask_indices

        input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
        labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))

        batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
        batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)

        if batch["input_ids"].shape[-1] != self.input_length:
            raise ValueError(
                f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}."
            )

        if batch["labels"].shape[-1] != self.target_length:
            raise ValueError(
                f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}."
            )

        # to check that tokens are correctly proprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
        batch["decoder_input_ids"] = shift_tokens_right(
            batch["labels"], self.pad_token_id, self.decoder_start_token_id
        )

        return batch
    def test_small_byt5_integration_test(self):
        """
        For comparision run:
        >>> import t5  # pip install t5==0.9.1

        >>> path_to_byt5_small_checkpoint = '<fill_in>'
        >>> t5_model = t5.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None)
        >>> vocab = t5.data.ByteVocabulary()
        >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
        """

        model = FlaxT5ForConditionalGeneration.from_pretrained(
            "google/byt5-small")
        tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")

        input_ids = tokenizer("Hello there", return_tensors="np").input_ids
        labels = tokenizer("Hi I am", return_tensors="np").input_ids

        decoder_input_ids = shift_tokens_right(
            labels, model.config.pad_token_id,
            model.config.decoder_start_token_id)

        logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits
        loss = optax.softmax_cross_entropy(logits,
                                           onehot(labels,
                                                  logits.shape[-1])).mean()

        mtf_score = -(labels.shape[-1] * loss.item())

        EXPECTED_SCORE = -60.7397
        self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
Пример #3
0
    def test_shift_right(self):
        decoder_start_token_id = 0
        pad_token_id = 1
        labels = np.arange(2, 102).reshape(5, 20)
        labels[:2, 15:] = -100

        decoder_input_ids = shift_tokens_right(labels, pad_token_id, decoder_start_token_id)
        np_decoder_input_ids = np.array(decoder_input_ids)

        padded_slice = np_decoder_input_ids[:2, (15 + 1) :]
        self.assertTrue((padded_slice == 1).all())

        not_padded_slice = np_decoder_input_ids[2:, 1:]
        rolled_labels = np.roll(labels[2:], 1)[:, 1:]
        self.assertTrue((not_padded_slice == rolled_labels).all())
        self.assertTrue((np_decoder_input_ids[:, 0] == 0).all())