def testCheckpointAveraging(self): model_dir = os.path.join(self.get_temp_dir(), "ckpt") os.makedirs(model_dir) checkpoints = [] checkpoints.append( self._generateCheckpoint( model_dir, 10, { "x": np.zeros((2, 3), dtype=np.float32), "words_per_sec/features_init": np.int64(42) }, last_checkpoints=checkpoints)) checkpoints.append( self._generateCheckpoint( model_dir, 20, { "x": np.ones((2, 3), dtype=np.float32), "words_per_sec/features_init": np.int64(89) }, last_checkpoints=checkpoints)) avg_dir = os.path.join(model_dir, "avg") checkpoint.average_checkpoints(model_dir, avg_dir) avg_var = checkpoint.get_checkpoint_variables(avg_dir) self.assertEqual(avg_var["global_step"].dtype, np.int64) self.assertEqual(avg_var["global_step"], 20) self.assertEqual(avg_var["words_per_sec/features_init"].dtype, np.int64) self.assertEqual(avg_var["words_per_sec/features_init"], 89) self.assertAllEqual(avg_var["x"], np.full((2, 3), 0.5, dtype=np.float32))
def testCheckpointAveraging(self): model = _DummyModel() optimizer = tf.keras.optimizers.Adam() @tf.function def _build_model(): x = tf.random.uniform([4, 10]) y = model(x) loss = tf.reduce_mean(y) gradients = optimizer.get_gradients(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) def _assign_var(var, scalar): var.assign(tf.ones_like(var) * scalar) def _all_equal(var, scalar): return tf.size(tf.where(tf.not_equal(var, scalar))).numpy() == 0 def _get_var_list(checkpoint_path): return [ name for name, _ in tf.train.list_variables(checkpoint_path) ] _build_model() # Write some checkpoint with all variables set to the step value. steps = [10, 20, 30, 40] num_checkpoints = len(steps) avg_value = sum(steps) / num_checkpoints directory = os.path.join(self.get_temp_dir(), "src") checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory, max_to_keep=num_checkpoints) for step in steps: _assign_var(model.layers[0].kernel, step) _assign_var(model.layers[0].bias, step) checkpoint_manager.save(checkpoint_number=step) output_dir = os.path.join(self.get_temp_dir(), "dst") checkpoint_util.average_checkpoints( directory, output_dir, dict(model=model, optimizer=optimizer)) avg_checkpoint = tf.train.latest_checkpoint(output_dir) self.assertIsNotNone(avg_checkpoint) self.assertEqual( checkpoint_util.get_step_from_checkpoint_prefix(avg_checkpoint), steps[-1]) checkpoint.restore(avg_checkpoint) self.assertTrue(_all_equal(model.layers[0].kernel, avg_value)) self.assertTrue(_all_equal(model.layers[0].bias, avg_value)) self.assertListEqual( _get_var_list(avg_checkpoint), _get_var_list(checkpoint_manager.latest_checkpoint))
def average_checkpoints(self, output_dir, max_count=8): """Averages checkpoints. Args: output_dir: The directory that will contain the averaged checkpoint. max_count: The maximum number of checkpoints to average. Returns: The path to the directory containing the averaged checkpoint. """ config = self._finalize_config() model = self._init_model(config) optimizer = model.get_optimizer() checkpoint = checkpoint_util.Checkpoint.from_config( config, model, optimizer=optimizer) checkpoint.restore() model.create_variables(optimizer=optimizer) trackables = dict(model=model, optimizer=optimizer) output_dir = checkpoint_util.average_checkpoints(checkpoint.model_dir, output_dir, trackables, max_count=max_count) _forward_model_description(self.model_dir, output_dir) self._config["model_dir"] = output_dir return output_dir
def main(): tf.logging.set_verbosity(tf.logging.INFO) parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--model_dir", required=True, help="The model directory containing the checkpoints.") parser.add_argument( "--output_dir", required=True, help="The output directory where the averaged checkpoint will be saved." ) parser.add_argument("--max_count", type=int, default=8, help="The maximal number of checkpoints to average.") args = parser.parse_args() average_checkpoints(args.model_dir, args.output_dir, max_count=args.max_count)
def average_checkpoints(self, output_dir, max_count=8): """Averages checkpoints. Args: output_dir: The directory that will contain the averaged checkpoint. max_count: The maximum number of checkpoints to average. Returns: The path to the directory containing the averaged checkpoint. """ return checkpoint.average_checkpoints( self._config["model_dir"], output_dir, max_count=max_count, session_config=self._session_config)
def average_checkpoints(self, output_dir, max_count=8): """Averages checkpoints. Args: output_dir: The directory that will contain the averaged checkpoint. max_count: The maximum number of checkpoints to average. Returns: The path to the directory containing the averaged checkpoint. """ checkpoint, _ = self._init_run() checkpoint.restore() model = checkpoint.model optimizer = checkpoint.optimizer model.create_variables(optimizer=optimizer) trackables = dict(model=model, optimizer=optimizer) return checkpoint_util.average_checkpoints(checkpoint.model_dir, output_dir, trackables, max_count=max_count)