def restore_params_from_checkpoint(checkpoint_dir, shard_index, shard_num): """Restore a shard parameters from the checkpoint directory. If shard_num=1, a entire model parameters will be restored. Args: checkpoint_dir: a directory with checkpoint files. shard_index: Model shard index, e.g. the PS instance index using ParameterServerStrategy with multiple PS instances. shard_num: The total number of model shards, e.g. the total PS instancecount using ParameterServerStrategy with multiple PS instances. Return: parameters: A Parameter object which contains model version, non-embedding parameters and embedding tables for the PS instance with ps_id. """ variable_shard_files = os.listdir(checkpoint_dir) non_embedding_vars = {} embedding_tables = {} version = None for shard_file in variable_shard_files: shard_file_path = os.path.join(checkpoint_dir, shard_file) model_pb = elasticdl_pb2.Model() model_pb = load_pb_from_file(model_pb, shard_file_path) if version is None: version = model_pb.version elif version != model_pb.version: raise ValueError( "The versions in model shards are not consistent" ) for embedding_info_pb in model_pb.embedding_table_infos: embedding_table = create_embedding_table(embedding_info_pb) embedding_tables.setdefault( embedding_table.name, embedding_table ) ( shard_non_embedding_vars, shard_embedding_table_values, ) = _get_params_shard_from_pb(model_pb, shard_index, shard_num) non_embedding_vars.update(shard_non_embedding_vars) for name, pair in shard_embedding_table_values.items(): embedding_tables[name].set(pair[0], pair[1]) parameters = Parameters() parameters.non_embedding_params.update(non_embedding_vars) parameters.embedding_params.update(embedding_tables) parameters.version = version return parameters
def test_create_embedding_table(self): embedding_pb = EmbeddingTableInfo() embedding_pb.name = self.name embedding_pb.dim = self.dim embedding_pb.initializer = self.initializer table = create_embedding_table(embedding_pb) self.assertIsNotNone(table) self.assertEqual(table.name, self.name) self.assertEqual( tf.keras.initializers.get(self.initializer).__class__, table.initializer.__class__, ) self.assertEqual(table.dim, self.dim)
def init_embedding_params(self, embeddings_pb): for pb in embeddings_pb: if pb.name not in self.embedding_params: self.embedding_params[pb.name] = create_embedding_table(pb)