def test_transformer_text_field_batching(): batch = Batch( [ Instance({"text": TransformerTextField(torch.IntTensor([1, 2, 3]))}), Instance({"text": TransformerTextField(torch.IntTensor([2, 3, 4, 5]))}), Instance({"text": TransformerTextField(torch.IntTensor())}), ] ) tensors = batch.as_tensor_dict(batch.get_padding_lengths()) assert tensors["text"]["input_ids"].shape == (3, 4) assert tensors["text"]["input_ids"][0, -1] == 0 assert tensors["text"]["attention_mask"][0, -1] == torch.Tensor([False]) assert torch.all(tensors["text"]["input_ids"][-1] == 0) assert torch.all(tensors["text"]["attention_mask"][-1] == torch.tensor([False]))
def test_transformer_text_field_from_huggingface(return_tensors): tokenizer = get_tokenizer("bert-base-cased") batch = Batch( [ Instance( {"text": TransformerTextField(**tokenizer(text, return_tensors=return_tensors))} ) for text in [ "Hello, World!", "The fox jumped over the fence", "Humpty dumpty sat on a wall", ] ] ) tensors = batch.as_tensor_dict(batch.get_padding_lengths()) assert tensors["text"]["input_ids"].shape == (3, 11)
def allennlp_collate(instances: List[Instance]) -> TensorDict: batch = Batch(instances) return batch.as_tensor_dict(batch.get_padding_lengths())