def test_transformer_text_field_batching():
    batch = Batch(
        [
            Instance({"text": TransformerTextField(torch.IntTensor([1, 2, 3]))}),
            Instance({"text": TransformerTextField(torch.IntTensor([2, 3, 4, 5]))}),
            Instance({"text": TransformerTextField(torch.IntTensor())}),
        ]
    )
    tensors = batch.as_tensor_dict(batch.get_padding_lengths())
    assert tensors["text"]["input_ids"].shape == (3, 4)
    assert tensors["text"]["input_ids"][0, -1] == 0
    assert tensors["text"]["attention_mask"][0, -1] == torch.Tensor([False])
    assert torch.all(tensors["text"]["input_ids"][-1] == 0)
    assert torch.all(tensors["text"]["attention_mask"][-1] == torch.tensor([False]))
def test_transformer_text_field_from_huggingface(return_tensors):
    tokenizer = get_tokenizer("bert-base-cased")

    batch = Batch(
        [
            Instance(
                {"text": TransformerTextField(**tokenizer(text, return_tensors=return_tensors))}
            )
            for text in [
                "Hello, World!",
                "The fox jumped over the fence",
                "Humpty dumpty sat on a wall",
            ]
        ]
    )
    tensors = batch.as_tensor_dict(batch.get_padding_lengths())
    assert tensors["text"]["input_ids"].shape == (3, 11)
Exemple #3
0
def allennlp_collate(instances: List[Instance]) -> TensorDict:
    batch = Batch(instances)
    return batch.as_tensor_dict(batch.get_padding_lengths())