Exemplo n.º 1
0
def save_variables_to_checkpoint(root_dir, params):
    ckpt_dir = os.path.join(root_dir, "testSaveLoadCheckpoint")
    os.makedirs(ckpt_dir)
    checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False)
    model_pb = params.to_model_pb()
    checkpoint_saver.save(params.version, model_pb, False)
    return ckpt_dir
Exemplo n.º 2
0
    def testSaveLoadCheckpoint(self):
        init_var = m["custom_model"]().trainable_variables
        with tempfile.TemporaryDirectory() as tempdir:
            ckpt_dir = os.path.join(tempdir, "testSaveLoadCheckpoint")
            os.makedirs(ckpt_dir)
            checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False)
            self.assertTrue(checkpoint_saver.is_enabled())
            params = Parameters()

            for var in init_var:
                params.non_embedding_params[var.name] = var
            model_pb = params.to_model_pb()

            checkpoint_saver.save(0, model_pb, False)

            ckpt_version_dir = os.path.join(ckpt_dir, "version-0")
            restore_params = CheckpointSaver.restore_params_from_checkpoint(
                ckpt_version_dir, 0, 1)
            self.assertEqual(restore_params.version, params.version)
            for var_name in params.non_embedding_params:
                self.assertTrue(
                    np.array_equal(
                        params.non_embedding_params[var_name].numpy(),
                        restore_params.non_embedding_params[var_name].numpy(),
                    ))
Exemplo n.º 3
0
def save_checkpoint_without_embedding(model, checkpoint_dir, version=100):
    checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False)
    params = Parameters()
    for var in model.trainable_variables:
        params.non_embedding_params[var.name] = var
    params.version = version
    model_pb = params.to_model_pb()
    checkpoint_saver.save(version, model_pb, False)
Exemplo n.º 4
0
    def test_restore_parameters_from_checkpoint(self):
        checkpoint_dir = "elasticdl/python/tests/testdata/ps_ckpt"
        checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False)
        params = Parameters()
        table = EmbeddingTable("embedding", 2, "random_uniform")
        table.set([0, 1, 2, 3], np.ones((4, 2), dtype=np.float32))
        params.embedding_params["embedding"] = table
        params.non_embedding_params["dense/kernel:0"] = tf.Variable(
            [[1.0], [1.0]]
        )
        params.non_embedding_params["dense/bias:0"] = tf.Variable([1.0])
        params.version = 100
        model_pb = params.to_model_pb()
        checkpoint_saver.save(100, model_pb, False)

        checkpoint_dir_for_init = checkpoint_dir + "/version-100"
        args = PserverArgs(
            ps_id=0,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_0 = ParameterServer(args)

        embedding_table = pserver_0.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [0, 2]
        )
        self.assertEqual(
            list(pserver_0.parameters.non_embedding_params.keys()),
            ["dense/kernel:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_0.parameters.non_embedding_params[
                    "dense/kernel:0"
                ].numpy(),
                np.array([[1], [1]], dtype=int),
            )
        )
        self.assertEqual(pserver_0.parameters.version, 100)

        args = PserverArgs(
            ps_id=1,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_1 = ParameterServer(args)

        embedding_table = pserver_1.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [1, 3]
        )
        self.assertEqual(
            list(pserver_1.parameters.non_embedding_params.keys()),
            ["dense/bias:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_1.parameters.non_embedding_params[
                    "dense/bias:0"
                ].numpy(),
                np.array([1], dtype=int),
            )
        )
        self.assertEqual(pserver_1.parameters.version, 100)
Exemplo n.º 5
0
 def _mock_model_weights_and_save_checkpoint(self, model):
     ckpt_dir = self.model_handler._checkpoint_dir
     checkpoint_saver = CheckpointSaver(ckpt_dir, 0, 0, False)
     params = self._mock_model_parameters(model)
     model_pb = params.to_model_pb()
     checkpoint_saver.save(100, model_pb, False)