コード例 #1
0
def save_variables_to_checkpoint(root_dir, params):
    ckpt_dir = os.path.join(root_dir, "testSaveLoadCheckpoint")
    os.makedirs(ckpt_dir)
    checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False)
    model_pb = params.to_model_pb()
    checkpoint_saver.save(params.version, model_pb, False)
    return ckpt_dir
コード例 #2
0
ファイル: save_utils_test.py プロジェクト: zuodh/elasticdl
 def testGetCheckpointPath(self):
     ckpt_dir = "test/checkpoint_dir"
     checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False)
     checkpint_path = checkpoint_saver._get_checkpoint_file(100)
     self.assertEqual(
         checkpint_path,
         "test/checkpoint_dir/version-100/variables-0-of-1.ckpt",
     )
コード例 #3
0
ファイル: test_utils.py プロジェクト: xinan-jiang/elasticdl
def save_checkpoint_without_embedding(model, checkpoint_dir, version=100):
    checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False)
    params = Parameters()
    for var in model.trainable_variables:
        params.non_embedding_params[var.name] = var
    params.version = version
    model_pb = params.to_model_pb()
    checkpoint_saver.save(version, model_pb, False)
コード例 #4
0
ファイル: task_manager.py プロジェクト: Terry1504/elasticdl
    def _set_completed_steps_by_checkpoint(self, checkpoint_dir_for_init):
        if not checkpoint_dir_for_init:
            return

        if not CheckpointSaver.check_checkpoint_valid(checkpoint_dir_for_init):
            raise ValueError("Invalid checkpoint directory {}".format(
                checkpoint_dir_for_init))

        self._completed_steps = CheckpointSaver.get_version_from_checkpoint(
            checkpoint_dir_for_init)
コード例 #5
0
    def _set_completed_steps_by_checkpoint(self, checkpoint_dir_for_init):
        if not checkpoint_dir_for_init:
            return

        if not CheckpointSaver.check_checkpoint_valid(checkpoint_dir_for_init):
            raise ValueError("Invalid checkpoint directory {}".format(
                checkpoint_dir_for_init))

        model_verion = CheckpointSaver.get_version_from_checkpoint(
            checkpoint_dir_for_init)
        for callback in self.callbacks_list.callbacks:
            if isinstance(callback, MaxStepsStopping):
                callback.set_completed_steps(model_verion)
コード例 #6
0
    def _restore_params_from_checkpoint(self, checkpoint_dir_for_init):
        """Restore parameters from a checkpint directory for the PS instance
        """
        if not checkpoint_dir_for_init:
            self.logger.info("checkpoint directory for init is None")
            return

        if not CheckpointSaver.check_checkpoint_valid(checkpoint_dir_for_init):
            raise ValueError("Invalid checkpoint directory")

        self.parameters = CheckpointSaver.restore_params_from_checkpoint(
            checkpoint_dir_for_init, self.ps_id, self.num_ps_pods)
        self.parameters.init_status = True
        self.logger.info("The version of restored parameters is %d" %
                         self.parameters.version)
コード例 #7
0
ファイル: save_utils_test.py プロジェクト: zuodh/elasticdl
    def testNeedToCheckpoint(self):
        checkpointer = CheckpointSaver("", 0, 5, False)
        self.assertFalse(checkpointer.is_enabled())
        checkpointer._steps = 3
        self.assertTrue(checkpointer.is_enabled())

        self.assertFalse(checkpointer.need_to_checkpoint(1))
        self.assertFalse(checkpointer.need_to_checkpoint(2))
        self.assertTrue(checkpointer.need_to_checkpoint(3))
        self.assertFalse(checkpointer.need_to_checkpoint(4))
        self.assertFalse(checkpointer.need_to_checkpoint(5))
        self.assertTrue(checkpointer.need_to_checkpoint(6))
コード例 #8
0
    def get_model_to_export(self, model, dataset):
        """Get the model which can be exported to a SavedModel by
        `tf.saved_model.save`.
        """
        model = self._restore_keras_model_def(model)
        if not model.inputs:
            # build model to add inputs and outputs that
            # can be consumed by tf-serving
            model._build_model_with_inputs(inputs=dataset, targets=None)

        checkpoint_dir = CheckpointSaver.get_valid_lastest_version_dir(
            self._checkpoint_dir)
        if checkpoint_dir is None:
            logger.warning("No available checkpoint to export model")
            return model

        trained_params = _get_trained_params_from_checkpoint(checkpoint_dir)
        for var in model.trainable_variables:
            if isinstance(trained_params[var.name], EmbeddingTable):
                embedding_params = _convert_embedding_table_to_numpy_array(
                    trained_params[var.name], var.shape)
                var.assign(embedding_params)
            else:
                var.assign(trained_params[var.name].numpy())
        return model
コード例 #9
0
 def testGetVersionFromCheckpoint(self):
     with tempfile.TemporaryDirectory() as tempdir:
         self.params.version = 100
         ckpt_dir = save_variables_to_checkpoint(tempdir, self.params)
         ckpt_version_dir = os.path.join(ckpt_dir, "version-100")
         model_version = CheckpointSaver.get_version_from_checkpoint(
             ckpt_version_dir)
         self.assertTrue(model_version, 100)
コード例 #10
0
def _get_trained_params_from_checkpoint(checkpoint_dir):
    """Get parameters from a checkpoint directory saved by ElasticDL"""
    parameters = CheckpointSaver.restore_params_from_checkpoint(
        checkpoint_dir, 0, 1)

    trained_params = parameters.non_embedding_params
    for name, table in parameters.embedding_params.items():
        trained_params[name] = table
    return trained_params
コード例 #11
0
def _get_trained_params_from_checkpoint(checkpoint_dir):
    """Get parameters from a checkpoint directory saved by ElasticDL"""
    parameters = CheckpointSaver.restore_params_from_checkpoint(
        checkpoint_dir, 0, 1)

    trained_params = parameters.non_embedding_params
    for name, table in parameters.embedding_params.items():
        # The name of variable in a tf.keras.layers.Embedding layer is
        # "{layer_name}/embeddings:0"
        var_name = name + "/embeddings:0"
        trained_params[var_name] = table
    return trained_params
コード例 #12
0
 def _init_checkpoint_saver(self, args):
     if all([args.checkpoint_dir, args.checkpoint_steps]):
         self.checkpoint_saver = CheckpointSaver(
             args.checkpoint_dir,
             args.checkpoint_steps,
             args.keep_checkpoint_max,
             include_evaluation=False,
         )
     else:
         self.checkpoint_saver = None
         self.logger.warning(
             "Invalid checkpoint config and no model will be saved")
コード例 #13
0
ファイル: save_utils_test.py プロジェクト: zuodh/elasticdl
    def testSaveLoadCheckpoint(self):
        init_var = m["custom_model"]().trainable_variables
        with tempfile.TemporaryDirectory() as tempdir:
            ckpt_dir = os.path.join(tempdir, "testSaveLoadCheckpoint")
            os.makedirs(ckpt_dir)
            checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False)
            self.assertTrue(checkpoint_saver.is_enabled())
            params = Parameters()

            for var in init_var:
                params.non_embedding_params[var.name] = var
            model_pb = params.to_model_pb()

            checkpoint_saver.save(0, model_pb, False)

            ckpt_version_dir = os.path.join(ckpt_dir, "version-0")
            restore_params = CheckpointSaver.restore_params_from_checkpoint(
                ckpt_version_dir, 0, 1)
            self.assertEqual(restore_params.version, params.version)
            for var_name in params.non_embedding_params:
                self.assertTrue(
                    np.array_equal(
                        params.non_embedding_params[var_name].numpy(),
                        restore_params.non_embedding_params[var_name].numpy(),
                    ))
コード例 #14
0
 def testSaveLoadCheckpoint(self):
     with tempfile.TemporaryDirectory() as tempdir:
         self.params.version = 0
         ckpt_dir = save_variables_to_checkpoint(tempdir, self.params)
         ckpt_version_dir = os.path.join(ckpt_dir, "version-0")
         restore_params = CheckpointSaver.restore_params_from_checkpoint(
             ckpt_version_dir, 0, 1)
         self.assertEqual(restore_params.version, self.params.version)
         for var_name in self.params.non_embedding_params:
             self.assertTrue(
                 np.array_equal(
                     self.params.non_embedding_params[var_name].numpy(),
                     restore_params.non_embedding_params[var_name].numpy(),
                 ))
コード例 #15
0
    def test_save_parameters_to_checkpoint_file(self):
        with tempfile.TemporaryDirectory() as tempdir:
            checkpoint_saver = CheckpointSaver(
                checkpoint_dir=os.path.join(tempdir, "ckpt/"),
                checkpoint_steps=5,
                keep_checkpoint_max=3,
                include_evaluation=False,
            )
            pserver_servicer = PserverServicer(
                parameters=Parameters(),
                grads_to_wait=0,
                optimizer="optimizer",
                checkpoint_saver=checkpoint_saver,
                ps_id=0,
                num_ps_pods=1,
            )
            model_params = {
                "v0": tf.Variable([[1, 1, 1], [1, 1, 1]]),
                "v1": tf.Variable([[2, 2, 2], [2, 2, 2]]),
            }

            server_params = pserver_servicer._parameters
            for var_name, var_value in model_params.items():
                server_params.non_embedding_params[var_name] = var_value

            embedding_table = EmbeddingTable(
                name="embedding_0", dim=3, initializer="random_uniform"
            )
            server_params.embedding_params["embedding_0"] = embedding_table
            server_params.set_embedding_param(
                name="embedding_0",
                indices=np.array([0, 1]),
                values=np.array([[1, 1, 1], [2, 2, 2]]),
            )

            for i in range(100):
                pserver_servicer._parameters.version += 1
                pserver_servicer._save_params_to_checkpoint_if_needed()

            self.assertEqual(len(os.listdir(checkpoint_saver._directory)), 3)
            self.assertEqual(
                sorted(os.listdir(checkpoint_saver._directory)),
                ["version-100", "version-90", "version-95"],
            )
            self.assertEqual(
                os.listdir(checkpoint_saver._directory + "/version-100"),
                ["variables-0-of-1.ckpt"],
            )
コード例 #16
0
    def test_restore_parameters_from_checkpoint(self):
        checkpoint_dir = "elasticdl/python/tests/testdata/ps_ckpt"
        checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False)
        params = Parameters()
        table = EmbeddingTable("embedding", 2, "random_uniform")
        table.set([0, 1, 2, 3], np.ones((4, 2), dtype=np.float32))
        params.embedding_params["embedding"] = table
        params.non_embedding_params["dense/kernel:0"] = tf.Variable(
            [[1.0], [1.0]]
        )
        params.non_embedding_params["dense/bias:0"] = tf.Variable([1.0])
        params.version = 100
        model_pb = params.to_model_pb()
        checkpoint_saver.save(100, model_pb, False)

        checkpoint_dir_for_init = checkpoint_dir + "/version-100"
        args = PserverArgs(
            ps_id=0,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_0 = ParameterServer(args)

        embedding_table = pserver_0.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [0, 2]
        )
        self.assertEqual(
            list(pserver_0.parameters.non_embedding_params.keys()),
            ["dense/kernel:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_0.parameters.non_embedding_params[
                    "dense/kernel:0"
                ].numpy(),
                np.array([[1], [1]], dtype=int),
            )
        )
        self.assertEqual(pserver_0.parameters.version, 100)

        args = PserverArgs(
            ps_id=1,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_1 = ParameterServer(args)

        embedding_table = pserver_1.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [1, 3]
        )
        self.assertEqual(
            list(pserver_1.parameters.non_embedding_params.keys()),
            ["dense/bias:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_1.parameters.non_embedding_params[
                    "dense/bias:0"
                ].numpy(),
                np.array([1], dtype=int),
            )
        )
        self.assertEqual(pserver_1.parameters.version, 100)
コード例 #17
0
 def _mock_model_weights_and_save_checkpoint(self, model):
     ckpt_dir = self.model_handler._checkpoint_dir
     checkpoint_saver = CheckpointSaver(ckpt_dir, 0, 0, False)
     params = self._mock_model_parameters(model)
     model_pb = params.to_model_pb()
     checkpoint_saver.save(100, model_pb, False)