コード例 #1
0
def test_slow_tensor_dset(max_seq_len: int) -> None:
  """Load dataset and convert to tensor on the fly."""
  tknzr = WsTknzr(is_uncased=True, max_vocab=-1, min_count=10)
  tknzr.build_vocab(batch_txt=['a', 'b', 'c'])

  wiki_dset = WikiText2Dset(ver='valid')

  dset = lmp.util.dset.SlowTensorDset(dset=wiki_dset, max_seq_len=max_seq_len, tknzr=tknzr)

  assert isinstance(dset, lmp.util.dset.SlowTensorDset)
  assert len(dset) == len(wiki_dset)
  for idx, tkids in enumerate(dset):
    assert isinstance(tkids, torch.Tensor), 'Each sample in the tensor dataset must be tensor.'
    assert tkids.size() == torch.Size([max_seq_len]), 'Each sample in the tensor dataset must have same length.'
    assert torch.all(dset[idx] == tkids), 'Support ``__getitem__`` and ``__iter__``.'
コード例 #2
0
def test_build_vocab(
    parameters,
    test_input: Sequence[str],
    expected: Dict[str, int],
):
    r"""Correctly build vocabulary under the constraint of given parameters."""

    tknzr = WsTknzr(
        is_uncased=parameters['is_uncased'],
        max_vocab=parameters['max_vocab'],
        min_count=parameters['min_count'],
        tk2id=parameters['tk2id'],
    )

    tknzr.build_vocab(test_input)

    assert tknzr.tk2id == expected