Ejemplo n.º 1
0
 def testCheckpointAveraging(self):
     model_dir = os.path.join(self.get_temp_dir(), "ckpt")
     os.makedirs(model_dir)
     checkpoints = []
     checkpoints.append(
         self._generateCheckpoint(
             model_dir,
             10, {
                 "x": np.zeros((2, 3), dtype=np.float32),
                 "words_per_sec/features_init": np.int64(42)
             },
             last_checkpoints=checkpoints))
     checkpoints.append(
         self._generateCheckpoint(
             model_dir,
             20, {
                 "x": np.ones((2, 3), dtype=np.float32),
                 "words_per_sec/features_init": np.int64(89)
             },
             last_checkpoints=checkpoints))
     avg_dir = os.path.join(model_dir, "avg")
     checkpoint.average_checkpoints(model_dir, avg_dir)
     avg_var = checkpoint.get_checkpoint_variables(avg_dir)
     self.assertEqual(avg_var["global_step"].dtype, np.int64)
     self.assertEqual(avg_var["global_step"], 20)
     self.assertEqual(avg_var["words_per_sec/features_init"].dtype,
                      np.int64)
     self.assertEqual(avg_var["words_per_sec/features_init"], 89)
     self.assertAllEqual(avg_var["x"], np.full((2, 3),
                                               0.5,
                                               dtype=np.float32))
 def testCheckpointDTypeConversion(self):
   model_dir = os.path.join(self.get_temp_dir(), "ckpt-fp32")
   os.makedirs(model_dir)
   variables = {
     "x": np.ones((2, 3), dtype=np.float32),
     "optim/x": np.ones((2, 3), dtype=np.float32),
     "counter": np.int64(42)
   }
   checkpoint_path, _ = self._generateCheckpoint(model_dir, 10, variables)
   half_dir = os.path.join(model_dir, "fp16")
   checkpoint.convert_checkpoint(checkpoint_path, half_dir, tf.float32, tf.float16)
   half_var = checkpoint.get_checkpoint_variables(half_dir)
   self.assertEqual(half_var["global_step"], 10)
   self.assertEqual(half_var["x"].dtype, np.float16)
   self.assertEqual(half_var["optim/x"].dtype, np.float32)
   self.assertEqual(half_var["counter"].dtype, np.int64)