def save_variables_to_checkpoint(root_dir, params): ckpt_dir = os.path.join(root_dir, "testSaveLoadCheckpoint") os.makedirs(ckpt_dir) checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False) model_pb = params.to_model_pb() checkpoint_saver.save(params.version, model_pb, False) return ckpt_dir
def testSaveLoadCheckpoint(self): init_var = m["custom_model"]().trainable_variables with tempfile.TemporaryDirectory() as tempdir: ckpt_dir = os.path.join(tempdir, "testSaveLoadCheckpoint") os.makedirs(ckpt_dir) checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False) self.assertTrue(checkpoint_saver.is_enabled()) params = Parameters() for var in init_var: params.non_embedding_params[var.name] = var model_pb = params.to_model_pb() checkpoint_saver.save(0, model_pb, False) ckpt_version_dir = os.path.join(ckpt_dir, "version-0") restore_params = CheckpointSaver.restore_params_from_checkpoint( ckpt_version_dir, 0, 1) self.assertEqual(restore_params.version, params.version) for var_name in params.non_embedding_params: self.assertTrue( np.array_equal( params.non_embedding_params[var_name].numpy(), restore_params.non_embedding_params[var_name].numpy(), ))
def save_checkpoint_without_embedding(model, checkpoint_dir, version=100): checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False) params = Parameters() for var in model.trainable_variables: params.non_embedding_params[var.name] = var params.version = version model_pb = params.to_model_pb() checkpoint_saver.save(version, model_pb, False)
def test_restore_parameters_from_checkpoint(self): checkpoint_dir = "elasticdl/python/tests/testdata/ps_ckpt" checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False) params = Parameters() table = EmbeddingTable("embedding", 2, "random_uniform") table.set([0, 1, 2, 3], np.ones((4, 2), dtype=np.float32)) params.embedding_params["embedding"] = table params.non_embedding_params["dense/kernel:0"] = tf.Variable( [[1.0], [1.0]] ) params.non_embedding_params["dense/bias:0"] = tf.Variable([1.0]) params.version = 100 model_pb = params.to_model_pb() checkpoint_saver.save(100, model_pb, False) checkpoint_dir_for_init = checkpoint_dir + "/version-100" args = PserverArgs( ps_id=0, num_ps_pods=2, model_zoo=_test_model_zoo_path, model_def="test_module.custom_model", checkpoint_dir_for_init=checkpoint_dir_for_init, ) pserver_0 = ParameterServer(args) embedding_table = pserver_0.parameters.embedding_params["embedding"] self.assertEqual( list(embedding_table.embedding_vectors.keys()), [0, 2] ) self.assertEqual( list(pserver_0.parameters.non_embedding_params.keys()), ["dense/kernel:0"], ) self.assertTrue( np.array_equal( pserver_0.parameters.non_embedding_params[ "dense/kernel:0" ].numpy(), np.array([[1], [1]], dtype=int), ) ) self.assertEqual(pserver_0.parameters.version, 100) args = PserverArgs( ps_id=1, num_ps_pods=2, model_zoo=_test_model_zoo_path, model_def="test_module.custom_model", checkpoint_dir_for_init=checkpoint_dir_for_init, ) pserver_1 = ParameterServer(args) embedding_table = pserver_1.parameters.embedding_params["embedding"] self.assertEqual( list(embedding_table.embedding_vectors.keys()), [1, 3] ) self.assertEqual( list(pserver_1.parameters.non_embedding_params.keys()), ["dense/bias:0"], ) self.assertTrue( np.array_equal( pserver_1.parameters.non_embedding_params[ "dense/bias:0" ].numpy(), np.array([1], dtype=int), ) ) self.assertEqual(pserver_1.parameters.version, 100)
def _mock_model_weights_and_save_checkpoint(self, model): ckpt_dir = self.model_handler._checkpoint_dir checkpoint_saver = CheckpointSaver(ckpt_dir, 0, 0, False) params = self._mock_model_parameters(model) model_pb = params.to_model_pb() checkpoint_saver.save(100, model_pb, False)