def save_checkpoint_without_embedding(model, checkpoint_dir, version=100): checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False) params = Parameters() for var in model.trainable_variables: params.non_embedding_params[var.name] = var params.version = version model_pb = params.to_model_pb() checkpoint_saver.save(version, model_pb, False)
def restore_params_from_checkpoint(checkpoint_dir, shard_index, shard_num): """Restore a shard parameters from the checkpoint directory. If shard_num=1, a entire model parameters will be restored. Args: checkpoint_dir: a directory with checkpoint files. shard_index: Model shard index, e.g. the PS instance index using ParameterServerStrategy with multiple PS instances. shard_num: The total number of model shards, e.g. the total PS instancecount using ParameterServerStrategy with multiple PS instances. Return: parameters: A Parameter object which contains model version, non-embedding parameters and embedding tables for the PS instance with ps_id. """ variable_shard_files = os.listdir(checkpoint_dir) non_embedding_vars = {} embedding_tables = {} version = None for shard_file in variable_shard_files: shard_file_path = os.path.join(checkpoint_dir, shard_file) model_pb = elasticdl_pb2.Model() model_pb = load_pb_from_file(model_pb, shard_file_path) if version is None: version = model_pb.version elif version != model_pb.version: raise ValueError( "The versions in model shards are not consistent" ) for embedding_info_pb in model_pb.embedding_table_infos: embedding_table = create_embedding_table(embedding_info_pb) embedding_tables.setdefault( embedding_table.name, embedding_table ) ( shard_non_embedding_vars, shard_embedding_table_values, ) = _get_params_shard_from_pb(model_pb, shard_index, shard_num) non_embedding_vars.update(shard_non_embedding_vars) for name, pair in shard_embedding_table_values.items(): embedding_tables[name].set(pair[0], pair[1]) parameters = Parameters() parameters.non_embedding_params.update(non_embedding_vars) parameters.embedding_params.update(embedding_tables) parameters.version = version return parameters
def _mock_model_parameters(self, model): params = Parameters() for weight in model.trainable_variables: if "embedding" in weight.name: embedding_table = EmbeddingTable( name=weight.name, dim=weight.shape[1], initializer="RandomUniform", ) embedding_table.set(np.arange(weight.shape[0]), np.ones(weight.shape)) params.embedding_params[weight.name] = embedding_table else: params.non_embedding_params[weight.name] = tf.ones( weight.shape) params.version = 100 return params
def test_restore_parameters_from_checkpoint(self): checkpoint_dir = "elasticdl/python/tests/testdata/ps_ckpt" checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False) params = Parameters() table = EmbeddingTable("embedding", 2, "random_uniform") table.set([0, 1, 2, 3], np.ones((4, 2), dtype=np.float32)) params.embedding_params["embedding"] = table params.non_embedding_params["dense/kernel:0"] = tf.Variable( [[1.0], [1.0]] ) params.non_embedding_params["dense/bias:0"] = tf.Variable([1.0]) params.version = 100 model_pb = params.to_model_pb() checkpoint_saver.save(100, model_pb, False) checkpoint_dir_for_init = checkpoint_dir + "/version-100" args = PserverArgs( ps_id=0, num_ps_pods=2, model_zoo=_test_model_zoo_path, model_def="test_module.custom_model", checkpoint_dir_for_init=checkpoint_dir_for_init, ) pserver_0 = ParameterServer(args) embedding_table = pserver_0.parameters.embedding_params["embedding"] self.assertEqual( list(embedding_table.embedding_vectors.keys()), [0, 2] ) self.assertEqual( list(pserver_0.parameters.non_embedding_params.keys()), ["dense/kernel:0"], ) self.assertTrue( np.array_equal( pserver_0.parameters.non_embedding_params[ "dense/kernel:0" ].numpy(), np.array([[1], [1]], dtype=int), ) ) self.assertEqual(pserver_0.parameters.version, 100) args = PserverArgs( ps_id=1, num_ps_pods=2, model_zoo=_test_model_zoo_path, model_def="test_module.custom_model", checkpoint_dir_for_init=checkpoint_dir_for_init, ) pserver_1 = ParameterServer(args) embedding_table = pserver_1.parameters.embedding_params["embedding"] self.assertEqual( list(embedding_table.embedding_vectors.keys()), [1, 3] ) self.assertEqual( list(pserver_1.parameters.non_embedding_params.keys()), ["dense/bias:0"], ) self.assertTrue( np.array_equal( pserver_1.parameters.non_embedding_params[ "dense/bias:0" ].numpy(), np.array([1], dtype=int), ) ) self.assertEqual(pserver_1.parameters.version, 100)