예제 #1
0
 def testStrToVocabTokenAppendEOSFalse(self):
     vocab = test_helper.test_src_dir_path(
         'core/ops/testdata/test_vocab.txt')
     with self.session(use_gpu=False):
         token_ids, target_ids, paddings = self.evaluate(
             ops.str_to_vocab_tokens([
                 'a b c d e',
                 '<epsilon> <S> </S> <UNK>',
                 'øut über ♣ 愤青 ←',
             ],
                                     append_eos=False,
                                     maxlen=10,
                                     vocab_filepath=vocab))
         self.assertEqual(token_ids.tolist(),
                          [[1, 5, 6, 7, 8, 9, 2, 2, 2, 2],
                           [1, 0, 1, 2, 3, 2, 2, 2, 2, 2],
                           [1, 10, 11, 12, 13, 3, 2, 2, 2, 2]])
         self.assertEqual(target_ids.tolist(),
                          [[5, 6, 7, 8, 9, 2, 2, 2, 2, 2],
                           [0, 1, 2, 3, 2, 2, 2, 2, 2, 2],
                           [10, 11, 12, 13, 3, 2, 2, 2, 2, 2]])
         self.assertEqual(paddings.tolist(),
                          [[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
                           [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
                           [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]])
예제 #2
0
 def tokenize_words(words_t):
     padded_tokenized_t, _, paddings_t = str_to_vocab_tokens(
         labels=words_t,
         maxlen=longest_word_length,
         append_eos=True,
         pad_to_maxlen=True,
         vocab_filepath=FLAGS.in_units_txt,
         load_token_ids_from_vocab=False,
         delimiter="",
     )
     # Either lengths or paddings are incorrect.
     lengths_t = py_utils.LengthsFromPaddings(paddings_t)
     ragged_tokenized_t = tf.RaggedTensor.from_tensor(padded_tokenized_t,
                                                      lengths=lengths_t)
     # Drop start-of-sentence-token
     ragged_tokenized_t = ragged_tokenized_t[:, 1:]
     lengths_t -= 1
     letters_t = vocab_id_to_token(
         id=ragged_tokenized_t.flat_values,
         vocab=vocab_tokens,
         load_token_ids_from_vocab=False,
     )
     ragged_letters_t = tf.RaggedTensor.from_row_lengths(
         letters_t, lengths_t)
     # Is capatilizationt he problem?
     return ragged_tokenized_t, ragged_letters_t
예제 #3
0
 def testStrToVocabTokenTruncates(self):
   vocab = test_helper.test_src_dir_path('core/ops/testdata/test_vocab.txt')
   with self.session(use_gpu=False) as sess:
     token_ids, target_ids, paddings = sess.run(
         ops.str_to_vocab_tokens(['a b c d e ' * 1000],
                                 append_eos=True,
                                 maxlen=5,
                                 vocab_filepath=vocab))
     self.assertEqual(token_ids.tolist(), [[1, 5, 6, 7, 8]])
     self.assertEqual(target_ids.tolist(), [[5, 6, 7, 8, 9]])
     self.assertEqual(paddings.tolist(), [[0., 0., 0., 0., 0.]])
예제 #4
0
 def testStrToVocabTokenSplitToCharacters(self):
   custom_delimiter = ''
   vocab = test_helper.test_src_dir_path('core/ops/testdata/test_vocab.txt')
   with self.session(use_gpu=False) as sess:
     token_ids, target_ids, paddings = sess.run(
         ops.str_to_vocab_tokens(['abcde'],
                                 append_eos=True,
                                 maxlen=8,
                                 vocab_filepath=vocab,
                                 delimiter=custom_delimiter))
     self.assertEqual(token_ids.tolist(), [[1, 5, 6, 7, 8, 9, 2, 2]])
     self.assertEqual(target_ids.tolist(), [[5, 6, 7, 8, 9, 2, 2, 2]])
     self.assertEqual(paddings.tolist(), [[0., 0., 0., 0., 0., 0., 1., 1.]])
예제 #5
0
  def _StringsToIdsImpl(self, strs, max_length, append_eos, languages):
    self._CheckParams()
    p = self.params

    if p.token_vocab_filepath:
      return ops.str_to_vocab_tokens(
          strs,
          maxlen=max_length,
          pad_to_maxlen=p.pad_to_max_length,
          append_eos=append_eos,
          vocab_filepath=p.token_vocab_filepath,
          load_token_ids_from_vocab=p.load_token_ids_from_vocab,
          delimiter=p.tokens_delimiter)
    elif p.ngram_vocab_filepath:
      raise NotImplementedError('ngram vocab StringsToIds is not supported.')