def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: # convert list to dict and tensorize input batch = BatchEncoding( {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()} ) input_ids = batch["input_ids"] batch_size, expandend_input_length = input_ids.shape mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)]) labels_mask = ~mask_indices input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8)) labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8)) batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel) batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel) if batch["input_ids"].shape[-1] != self.input_length: raise ValueError( f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}." ) if batch["labels"].shape[-1] != self.target_length: raise ValueError( f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}." ) # to check that tokens are correctly proprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here... batch["decoder_input_ids"] = shift_tokens_right( batch["labels"], self.pad_token_id, self.decoder_start_token_id ) return batch
def test_small_byt5_integration_test(self): """ For comparision run: >>> import t5 # pip install t5==0.9.1 >>> path_to_byt5_small_checkpoint = '<fill_in>' >>> t5_model = t5.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) >>> vocab = t5.data.ByteVocabulary() >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ model = FlaxT5ForConditionalGeneration.from_pretrained( "google/byt5-small") tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small") input_ids = tokenizer("Hello there", return_tensors="np").input_ids labels = tokenizer("Hi I am", return_tensors="np").input_ids decoder_input_ids = shift_tokens_right( labels, model.config.pad_token_id, model.config.decoder_start_token_id) logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean() mtf_score = -(labels.shape[-1] * loss.item()) EXPECTED_SCORE = -60.7397 self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
def test_shift_right(self): decoder_start_token_id = 0 pad_token_id = 1 labels = np.arange(2, 102).reshape(5, 20) labels[:2, 15:] = -100 decoder_input_ids = shift_tokens_right(labels, pad_token_id, decoder_start_token_id) np_decoder_input_ids = np.array(decoder_input_ids) padded_slice = np_decoder_input_ids[:2, (15 + 1) :] self.assertTrue((padded_slice == 1).all()) not_padded_slice = np_decoder_input_ids[2:, 1:] rolled_labels = np.roll(labels[2:], 1)[:, 1:] self.assertTrue((not_padded_slice == rolled_labels).all()) self.assertTrue((np_decoder_input_ids[:, 0] == 0).all())