def testVocabVariableUpdate(self):
   mapping = [0, -1, -1, 2, -1, 4]
   old = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0])
   vocab_size = 7
   new = checkpoint._update_vocabulary_variable(old, vocab_size, mapping)
   for index, value in zip(mapping, new):
     if index >= 0:
       self.assertEqual(value, old[index])
Example #2
0
 def testVocabVariableUpdate(self):
     mapping = [0, -1, -1, 2, -1, 4]
     old = np.array([1, 2, 3, 4, 5, 6, 7])
     vocab_size = 7
     new = checkpoint._update_vocabulary_variable(old, vocab_size, mapping)
     self.assertAllEqual([1, 0, 0, 3, 0, 5], new)