def testCheckpointDTypeConversion(self): model_dir = os.path.join(self.get_temp_dir(), "ckpt-fp32") os.makedirs(model_dir) variables = { "x": np.ones((2, 3), dtype=np.float32), "optim/x": np.ones((2, 3), dtype=np.float32), "counter": np.int64(42) } checkpoint_path, _ = self._generateCheckpoint(model_dir, 10, variables) half_dir = os.path.join(model_dir, "fp16") checkpoint.convert_checkpoint(checkpoint_path, half_dir, tf.float32, tf.float16) half_var = checkpoint.get_checkpoint_variables(half_dir) self.assertEqual(half_var["global_step"], 10) self.assertEqual(half_var["x"].dtype, np.float16) self.assertEqual(half_var["optim/x"].dtype, np.float32) self.assertEqual(half_var["counter"].dtype, np.int64)
def main(): tf.logging.set_verbosity(tf.logging.INFO) parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--model_dir", default=None, help="The path to the model directory.") parser.add_argument("--checkpoint_path", default=None, help="The path to the checkpoint to convert.") parser.add_argument( "--output_dir", required=True, help="The output directory where the updated checkpoint will be saved." ) parser.add_argument("--target_dtype", required=True, help="Target data type (e.g. float16 or float32).") parser.add_argument( "--source_dtype", default=None, help="Source data type (e.g. float16 or float32, inferred if not set)." ) args = parser.parse_args() if args.model_dir is None and args.checkpoint_path is None: raise ValueError( "One of --checkpoint_path and --model_dir should be set") checkpoint_path = args.checkpoint_path if checkpoint_path is None: checkpoint_path = tf.train.latest_checkpoint(args.model_dir) target_dtype = tf.as_dtype(args.target_dtype) if args.source_dtype is None: source_dtype = tf.float32 if target_dtype == tf.float16 else tf.float16 else: source_dtype = tf.as_dtype(args.source_dtype) checkpoint.convert_checkpoint( checkpoint_path, args.output_dir, source_dtype, target_dtype, session_config=tf.ConfigProto(device_count={"GPU": 0}))