Ejemplo n.º 1
0
 def test_vocab_transform(self):
     asset_name = 'vocab_test2.txt'
     asset_path = get_asset_path(asset_name)
     f = open(asset_path, 'r')
     vocab_transform = VocabTransform(vocab_from_file(f))
     self.assertEqual(vocab_transform(['of', 'that', 'new']), [7, 18, 24])
     jit_vocab_transform = torch.jit.script(vocab_transform.to_ivalue())
     self.assertEqual(jit_vocab_transform(['of', 'that', 'new']), [7, 18, 24])
Ejemplo n.º 2
0
 def test_vocab_transform(self):
     asset_name = 'vocab_test2.txt'
     asset_path = get_asset_path(asset_name)
     with open(asset_path, 'r') as f:
         vocab_transform = VocabTransform(vocab_from_file(f))
         self.assertEqual(
             vocab_transform([['of', 'that', 'new'],
                              ['of', 'that', 'new', 'that']]),
             [[21, 26, 20], [21, 26, 20, 26]])
         jit_vocab_transform = torch.jit.script(vocab_transform.to_ivalue())
         self.assertEqual(
             jit_vocab_transform([['of', 'that', 'new'],
                                  ['of', 'that', 'new', 'that']]),
             [[21, 26, 20], [21, 26, 20, 26]])
Ejemplo n.º 3
0
 def test_vocab_transform(self):
     asset_name = 'vocab_test2.txt'
     asset_path = get_asset_path(asset_name)
     vocab_transform = VocabTransform(load_vocab_from_file(asset_path))
     self.assertEqual(vocab_transform(['of', 'that', 'new']), [7, 18, 24])
     jit_vocab_transform = torch.jit.script(vocab_transform)
     self.assertEqual(jit_vocab_transform(['of', 'that', 'new', 'that']),
                      [7, 18, 24, 18])
Ejemplo n.º 4
0
def build_text_vocab_pipeline(hf_vocab_file):
    tokenizer = basic_english_normalize()
    f = open(hf_vocab_file, 'r')
    vocab = vocab_from_file_object(f)

    # Insert token in vocab to match a pretrained vocab
    pipeline = TextSequentialTransforms(tokenizer, VocabTransform(vocab), ToLongTensor())
    jit_pipeline = torch.jit.script(pipeline.to_ivalue())
    print('jit text vocab pipeline success!')
    return pipeline, pipeline.to_ivalue(), jit_pipeline