def testSaveLoadCheckpoint(self): init_var = m["custom_model"]().trainable_variables with tempfile.TemporaryDirectory() as tempdir: ckpt_dir = os.path.join(tempdir, "testSaveLoadCheckpoint") os.makedirs(ckpt_dir) checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False) self.assertTrue(checkpoint_saver.is_enabled()) params = Parameters() for var in init_var: params.non_embedding_params[var.name] = var model_pb = params.to_model_pb() checkpoint_saver.save(0, model_pb, False) ckpt_version_dir = os.path.join(ckpt_dir, "version-0") restore_params = CheckpointSaver.restore_params_from_checkpoint( ckpt_version_dir, 0, 1) self.assertEqual(restore_params.version, params.version) for var_name in params.non_embedding_params: self.assertTrue( np.array_equal( params.non_embedding_params[var_name].numpy(), restore_params.non_embedding_params[var_name].numpy(), ))
def _get_trained_params_from_checkpoint(checkpoint_dir): """Get parameters from a checkpoint directory saved by ElasticDL""" parameters = CheckpointSaver.restore_params_from_checkpoint( checkpoint_dir, 0, 1) trained_params = parameters.non_embedding_params for name, table in parameters.embedding_params.items(): trained_params[name] = table return trained_params
def _get_trained_params_from_checkpoint(checkpoint_dir): """Get parameters from a checkpoint directory saved by ElasticDL""" parameters = CheckpointSaver.restore_params_from_checkpoint( checkpoint_dir, 0, 1) trained_params = parameters.non_embedding_params for name, table in parameters.embedding_params.items(): # The name of variable in a tf.keras.layers.Embedding layer is # "{layer_name}/embeddings:0" var_name = name + "/embeddings:0" trained_params[var_name] = table return trained_params
def testSaveLoadCheckpoint(self): with tempfile.TemporaryDirectory() as tempdir: self.params.version = 0 ckpt_dir = save_variables_to_checkpoint(tempdir, self.params) ckpt_version_dir = os.path.join(ckpt_dir, "version-0") restore_params = CheckpointSaver.restore_params_from_checkpoint( ckpt_version_dir, 0, 1) self.assertEqual(restore_params.version, self.params.version) for var_name in self.params.non_embedding_params: self.assertTrue( np.array_equal( self.params.non_embedding_params[var_name].numpy(), restore_params.non_embedding_params[var_name].numpy(), ))
def _restore_params_from_checkpoint(self, checkpoint_dir_for_init): """Restore parameters from a checkpint directory for the PS instance """ if not checkpoint_dir_for_init: self.logger.info("checkpoint directory for init is None") return if not CheckpointSaver.check_checkpoint_valid(checkpoint_dir_for_init): raise ValueError("Invalid checkpoint directory") self.parameters = CheckpointSaver.restore_params_from_checkpoint( checkpoint_dir_for_init, self.ps_id, self.num_ps_pods) self.parameters.init_status = True self.logger.info("The version of restored parameters is %d" % self.parameters.version)