Esempio n. 1
0
 def test_test_preprocess_fn_return_dataset_element_spec(self):
     token = collections.OrderedDict(tokens=([
         'one must imagine',
     ]))
     ds = tf.data.Dataset.from_tensor_slices(token)
     test_preprocess_fn = dataset.create_test_dataset_preprocess_fn(
         max_seq_len=10, vocab=['one', 'must'])
     test_preprocessed_ds = test_preprocess_fn(ds)
     self.assertEqual(test_preprocessed_ds.element_spec,
                      (tf.TensorSpec(shape=[None, 10], dtype=tf.int64),
                       tf.TensorSpec(shape=[None, 10], dtype=tf.int64)))
Esempio n. 2
0
 def test_test_preprocess_fn_returns_correct_sequence(self):
     token = collections.OrderedDict(tokens=([
         'one must imagine',
     ]))
     ds = tf.data.Dataset.from_tensor_slices(token)
     test_preprocess_fn = dataset.create_test_dataset_preprocess_fn(
         max_seq_len=6, vocab=['one', 'must'])
     test_preprocessed_ds = test_preprocess_fn(ds)
     element = next(iter(test_preprocessed_ds))
     # BOS is len(vocab)+2, EOS is len(vocab)+3, pad is 0, OOV is len(vocab)+1
     self.assertAllEqual(self.evaluate(element[0]),
                         np.array([[4, 1, 2, 3, 5, 0]]))