예제 #1
0
 def test_vocab_transform_truncate_eos(self):
     transform = VocabTransform(
         os.path.join(self.base_dir, "vocab_dummy"), max_seq_len=2, add_eos=True
     )
     tokens = [["<unk>", ",", "."], ["▁que", "▁и", "i", "e"], ["i"]]
     expected = [[3, 103], [41, 103], [14, 103]]
     self.assertEqual(transform(tokens), expected)
예제 #2
0
 def test_vocab_transform_bos_and_eos(self):
     transform = VocabTransform(
         os.path.join(self.base_dir, "vocab_dummy"), add_bos=True, add_eos=True
     )
     tokens = [["<unk>", ",", "."], ["▁que", "▁и", "i", "e"]]
     expected = [[101, 3, 4, 5, 103], [101, 41, 35, 14, 13, 103]]
     self.assertEqual(transform(tokens), expected)
예제 #3
0
 def test_multi_workers_reading(self):
     transform_dict = {
         "text":
         [WhitespaceTokenizerTransform(),
          VocabTransform(self.vocab)]
     }
     ds0 = BaseDataset(
         iterable=self.input_iterator,
         batch_size=1,
         is_shuffle=False,
         transforms_dict=transform_dict,
         rank=0,
         num_workers=2,
     )
     ds1 = BaseDataset(
         iterable=self.input_iterator,
         batch_size=1,
         is_shuffle=False,
         transforms_dict=transform_dict,
         rank=1,
         num_workers=2,
     )
     batches0 = list(ds0)
     batches1 = list(ds1)
     # expect ds0 and ds1 to read different partitions of the data
     # the last (len(input_iterator) % num_workers) rows of the data
     # will be discarded because distributed training needs to be in sync
     assert len(batches0) == len(batches1) == 2
     assert torch.all(batches0[0]["token_ids"].eq(torch.tensor([[0, 1]])))
     assert torch.all(batches1[0]["token_ids"].eq(torch.tensor([[2, 3,
                                                                 4]])))
     assert torch.all(batches0[1]["token_ids"].eq(torch.tensor([[0]])))
     assert torch.all(batches1[1]["token_ids"].eq(torch.tensor([[3, 1]])))
예제 #4
0
 def test_vocab_transform_truncate_bos(self):
     transform = VocabTransform(
         os.path.join(self.base_dir, "vocab_dummy"), max_seq_len=2, add_bos=True
     )
     # <unk> added by fairseq
     tokens = [["<unk>", ",", "."], ["▁que", "▁и", "i", "e"], ["i"]]
     expected = [[101, 3], [101, 41], [101, 14]]
     self.assertEqual(transform(tokens), expected)
예제 #5
0
 def test_vocab_transform_replace(self):
     transform = VocabTransform(
         os.path.join(self.base_dir, "vocab_dummy"),
         special_token_replacements={"<unk>": SpecialTokens.UNK},
     )
     # Replace <unk> added by fairseq with our token
     tokens = [["__UNKNOWN__", ",", "."], ["▁que", "▁и", "i", "e"]]
     expected = [[3, 4, 5], [41, 35, 14, 13]]
     self.assertEqual(transform(tokens), expected)
예제 #6
0
 def test_load_doc_model(self):
     transform = VocabTransform(os.path.join(self.base_dir, "vocab_dummy"))
     vocab = transform.vocab
     model = models.DocModel(
         pretrained_embeddings_path=os.path.join(self.base_dir,
                                                 "word_embedding_dummy"),
         embedding_dim=300,
         mlp_layer_dims=(2, ),
         skip_header=True,
         kernel_num=1,
         kernel_sizes=(3, 4, 5),
         decoder_hidden_dims=(2, ),
         vocab=vocab,
     )
     self.assertTrue(isinstance(model, nn.Module))
예제 #7
0
 def test_ds_with_pooling_batcher(self):
     transform_dict = {
         "text":
         [WhitespaceTokenizerTransform(),
          VocabTransform(self.vocab)]
     }
     ds = BaseDataset(
         iterable=self.input_iterator,
         batch_size=2,
         is_shuffle=False,
         transforms_dict=transform_dict,
     )
     ds.batch(batcher=PoolingBatcher(2))
     batches = list(ds)
     assert len(batches) == 3
     # in [0, 1, 0], the trailing 0 is padding index
     assert torch.all(batches[0]["token_ids"].eq(
         torch.tensor([[0, 1, 0], [2, 3, 4]])))
예제 #8
0
 def test_base_dataset(self):
     transform_dict = {
         "text":
         [WhitespaceTokenizerTransform(),
          VocabTransform(self.vocab)]
     }
     ds = BaseDataset(
         iterable=self.input_iterator,
         batch_size=2,
         is_shuffle=False,
         transforms_dict=transform_dict,
     )
     batches = list(ds)
     assert len(batches) == 3
     assert torch.all(batches[0]["token_ids"].eq(
         torch.tensor([[0, 1, 0], [2, 3, 4]])))
     assert torch.all(batches[1]["token_ids"].eq(
         torch.tensor([[0, 0], [3, 1]])))
     assert torch.all(batches[2]["token_ids"].eq(torch.tensor([[4, 1]])))
예제 #9
0
 def _build_transforms(self,
                       label_vocab: List[str],
                       vocab_path: str,
                       max_seq_len: int = 256):
     # Custom batching and sampling will be setup here
     vocab = build_fairseq_vocab(vocab_path)
     label_transform = LabelTransform(label_vocab)
     train_transforms = {
         "text": [
             SpmTokenizerTransform(),
             VocabTransform(vocab),
             TruncateTransform(
                 vocab.get_bos_index(),
                 vocab.get_eos_index(),
                 max_seq_len=max_seq_len,
             ),
         ],
         "label": [label_transform],
     }
     infer_transforms = {"text": train_transforms["text"]}
     return train_transforms, infer_transforms
예제 #10
0
 def test_vocab_transform_truncate_bos_and_eos_replace(self):
     """
     Can be easily called as RoBERTa vocab look up test.
     We need BOS = 0 and EOS = 2 for pretrained models compat.
     """
     transform = VocabTransform(
         os.path.join(self.base_dir, "vocab_dummy"),
         max_seq_len=3,
         add_bos=True,
         add_eos=True,
         special_token_replacements={
             "<pad>": SpecialTokens.PAD,
             "<s>": SpecialTokens.BOS,
             "</s>": SpecialTokens.EOS,
             "<unk>": SpecialTokens.UNK,
             "<mask>": SpecialTokens.MASK,
         },
     )
     tokens = [["<unk>", ",", "."], ["▁que", "▁и", "i", "e"], ["i"]]
     expected = [[0, 3, 2], [0, 41, 2], [0, 14, 2]]
     self.assertEqual(transform(tokens), expected)
예제 #11
0
 def test_vocab_transform(self):
     transform = VocabTransform(os.path.join(self.base_dir, "vocab_dummy"))
     # <unk> added by fairseq
     tokens = [["<unk>", ",", "."], ["▁que", "▁и", "i", "e"]]
     expected = [[3, 4, 5], [41, 35, 14, 13]]
     self.assertEqual(transform(tokens), expected)