Exemplo n.º 1
0
    def restore_params_from_checkpoint(checkpoint_dir, shard_index, shard_num):
        """Restore a shard parameters from the checkpoint directory.
        If shard_num=1, a entire model parameters will be restored.

        Args:
            checkpoint_dir: a directory with checkpoint files.
            shard_index: Model shard index, e.g. the PS instance index
                using ParameterServerStrategy with multiple PS instances.
            shard_num: The total number of model shards, e.g. the total PS
                instancecount using ParameterServerStrategy with multiple
                PS instances.

        Return:
            parameters: A Parameter object which contains model version,
                non-embedding parameters and embedding tables for the
                PS instance with ps_id.
        """

        variable_shard_files = os.listdir(checkpoint_dir)
        non_embedding_vars = {}
        embedding_tables = {}
        version = None
        for shard_file in variable_shard_files:
            shard_file_path = os.path.join(checkpoint_dir, shard_file)
            model_pb = elasticdl_pb2.Model()
            model_pb = load_pb_from_file(model_pb, shard_file_path)
            if version is None:
                version = model_pb.version
            elif version != model_pb.version:
                raise ValueError(
                    "The versions in model shards are not consistent"
                )

            for embedding_info_pb in model_pb.embedding_table_infos:
                embedding_table = create_embedding_table(embedding_info_pb)
                embedding_tables.setdefault(
                    embedding_table.name, embedding_table
                )

            (
                shard_non_embedding_vars,
                shard_embedding_table_values,
            ) = _get_params_shard_from_pb(model_pb, shard_index, shard_num)

            non_embedding_vars.update(shard_non_embedding_vars)
            for name, pair in shard_embedding_table_values.items():
                embedding_tables[name].set(pair[0], pair[1])

        parameters = Parameters()
        parameters.non_embedding_params.update(non_embedding_vars)
        parameters.embedding_params.update(embedding_tables)
        parameters.version = version
        return parameters
Exemplo n.º 2
0
 def test_create_embedding_table(self):
     embedding_pb = EmbeddingTableInfo()
     embedding_pb.name = self.name
     embedding_pb.dim = self.dim
     embedding_pb.initializer = self.initializer
     table = create_embedding_table(embedding_pb)
     self.assertIsNotNone(table)
     self.assertEqual(table.name, self.name)
     self.assertEqual(
         tf.keras.initializers.get(self.initializer).__class__,
         table.initializer.__class__,
     )
     self.assertEqual(table.dim, self.dim)
Exemplo n.º 3
0
 def init_embedding_params(self, embeddings_pb):
     for pb in embeddings_pb:
         if pb.name not in self.embedding_params:
             self.embedding_params[pb.name] = create_embedding_table(pb)