Example #1
0
def save_checkpoint_without_embedding(model, checkpoint_dir, version=100):
    checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False)
    params = Parameters()
    for var in model.trainable_variables:
        params.non_embedding_params[var.name] = var
    params.version = version
    model_pb = params.to_model_pb()
    checkpoint_saver.save(version, model_pb, False)
Example #2
0
    def restore_params_from_checkpoint(checkpoint_dir, shard_index, shard_num):
        """Restore a shard parameters from the checkpoint directory.
        If shard_num=1, a entire model parameters will be restored.

        Args:
            checkpoint_dir: a directory with checkpoint files.
            shard_index: Model shard index, e.g. the PS instance index
                using ParameterServerStrategy with multiple PS instances.
            shard_num: The total number of model shards, e.g. the total PS
                instancecount using ParameterServerStrategy with multiple
                PS instances.

        Return:
            parameters: A Parameter object which contains model version,
                non-embedding parameters and embedding tables for the
                PS instance with ps_id.
        """

        variable_shard_files = os.listdir(checkpoint_dir)
        non_embedding_vars = {}
        embedding_tables = {}
        version = None
        for shard_file in variable_shard_files:
            shard_file_path = os.path.join(checkpoint_dir, shard_file)
            model_pb = elasticdl_pb2.Model()
            model_pb = load_pb_from_file(model_pb, shard_file_path)
            if version is None:
                version = model_pb.version
            elif version != model_pb.version:
                raise ValueError(
                    "The versions in model shards are not consistent"
                )

            for embedding_info_pb in model_pb.embedding_table_infos:
                embedding_table = create_embedding_table(embedding_info_pb)
                embedding_tables.setdefault(
                    embedding_table.name, embedding_table
                )

            (
                shard_non_embedding_vars,
                shard_embedding_table_values,
            ) = _get_params_shard_from_pb(model_pb, shard_index, shard_num)

            non_embedding_vars.update(shard_non_embedding_vars)
            for name, pair in shard_embedding_table_values.items():
                embedding_tables[name].set(pair[0], pair[1])

        parameters = Parameters()
        parameters.non_embedding_params.update(non_embedding_vars)
        parameters.embedding_params.update(embedding_tables)
        parameters.version = version
        return parameters
 def _mock_model_parameters(self, model):
     params = Parameters()
     for weight in model.trainable_variables:
         if "embedding" in weight.name:
             embedding_table = EmbeddingTable(
                 name=weight.name,
                 dim=weight.shape[1],
                 initializer="RandomUniform",
             )
             embedding_table.set(np.arange(weight.shape[0]),
                                 np.ones(weight.shape))
             params.embedding_params[weight.name] = embedding_table
         else:
             params.non_embedding_params[weight.name] = tf.ones(
                 weight.shape)
     params.version = 100
     return params
Example #4
0
    def test_restore_parameters_from_checkpoint(self):
        checkpoint_dir = "elasticdl/python/tests/testdata/ps_ckpt"
        checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False)
        params = Parameters()
        table = EmbeddingTable("embedding", 2, "random_uniform")
        table.set([0, 1, 2, 3], np.ones((4, 2), dtype=np.float32))
        params.embedding_params["embedding"] = table
        params.non_embedding_params["dense/kernel:0"] = tf.Variable(
            [[1.0], [1.0]]
        )
        params.non_embedding_params["dense/bias:0"] = tf.Variable([1.0])
        params.version = 100
        model_pb = params.to_model_pb()
        checkpoint_saver.save(100, model_pb, False)

        checkpoint_dir_for_init = checkpoint_dir + "/version-100"
        args = PserverArgs(
            ps_id=0,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_0 = ParameterServer(args)

        embedding_table = pserver_0.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [0, 2]
        )
        self.assertEqual(
            list(pserver_0.parameters.non_embedding_params.keys()),
            ["dense/kernel:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_0.parameters.non_embedding_params[
                    "dense/kernel:0"
                ].numpy(),
                np.array([[1], [1]], dtype=int),
            )
        )
        self.assertEqual(pserver_0.parameters.version, 100)

        args = PserverArgs(
            ps_id=1,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_1 = ParameterServer(args)

        embedding_table = pserver_1.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [1, 3]
        )
        self.assertEqual(
            list(pserver_1.parameters.non_embedding_params.keys()),
            ["dense/bias:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_1.parameters.non_embedding_params[
                    "dense/bias:0"
                ].numpy(),
                np.array([1], dtype=int),
            )
        )
        self.assertEqual(pserver_1.parameters.version, 100)