예제 #1
0
    def testSaveLoadCheckpoint(self):
        init_var = m["custom_model"]().trainable_variables
        with tempfile.TemporaryDirectory() as tempdir:
            ckpt_dir = os.path.join(tempdir, "testSaveLoadCheckpoint")
            os.makedirs(ckpt_dir)
            checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False)
            self.assertTrue(checkpoint_saver.is_enabled())
            params = Parameters()

            for var in init_var:
                params.non_embedding_params[var.name] = var
            model_pb = params.to_model_pb()

            checkpoint_saver.save(0, model_pb, False)

            ckpt_version_dir = os.path.join(ckpt_dir, "version-0")
            restore_params = CheckpointSaver.restore_params_from_checkpoint(
                ckpt_version_dir, 0, 1)
            self.assertEqual(restore_params.version, params.version)
            for var_name in params.non_embedding_params:
                self.assertTrue(
                    np.array_equal(
                        params.non_embedding_params[var_name].numpy(),
                        restore_params.non_embedding_params[var_name].numpy(),
                    ))
예제 #2
0
def save_checkpoint_without_embedding(model, checkpoint_dir, version=100):
    checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False)
    params = Parameters()
    for var in model.trainable_variables:
        params.non_embedding_params[var.name] = var
    params.version = version
    model_pb = params.to_model_pb()
    checkpoint_saver.save(version, model_pb, False)
예제 #3
0
    def test_restore_parameters_from_checkpoint(self):
        checkpoint_dir = "elasticdl/python/tests/testdata/ps_ckpt"
        checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False)
        params = Parameters()
        table = EmbeddingTable("embedding", 2, "random_uniform")
        table.set([0, 1, 2, 3], np.ones((4, 2), dtype=np.float32))
        params.embedding_params["embedding"] = table
        params.non_embedding_params["dense/kernel:0"] = tf.Variable(
            [[1.0], [1.0]]
        )
        params.non_embedding_params["dense/bias:0"] = tf.Variable([1.0])
        params.version = 100
        model_pb = params.to_model_pb()
        checkpoint_saver.save(100, model_pb, False)

        checkpoint_dir_for_init = checkpoint_dir + "/version-100"
        args = PserverArgs(
            ps_id=0,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_0 = ParameterServer(args)

        embedding_table = pserver_0.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [0, 2]
        )
        self.assertEqual(
            list(pserver_0.parameters.non_embedding_params.keys()),
            ["dense/kernel:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_0.parameters.non_embedding_params[
                    "dense/kernel:0"
                ].numpy(),
                np.array([[1], [1]], dtype=int),
            )
        )
        self.assertEqual(pserver_0.parameters.version, 100)

        args = PserverArgs(
            ps_id=1,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_1 = ParameterServer(args)

        embedding_table = pserver_1.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [1, 3]
        )
        self.assertEqual(
            list(pserver_1.parameters.non_embedding_params.keys()),
            ["dense/bias:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_1.parameters.non_embedding_params[
                    "dense/bias:0"
                ].numpy(),
                np.array([1], dtype=int),
            )
        )
        self.assertEqual(pserver_1.parameters.version, 100)
예제 #4
0
class ParametersTest(unittest.TestCase):
    def setUp(self):
        self.params = Parameters()

        self.model_pb = Model()
        self.infos_pb = self.model_pb.embedding_table_infos
        self.tensors_pb = self.model_pb.dense_parameters
        self.embedding_tables_pb = self.model_pb.embedding_tables

        self.embedding_table_name = "embedding_1"
        self.embedding_dim = 10
        embedding_pb = self.infos_pb.add()
        embedding_pb.name = self.embedding_table_name
        embedding_pb.dim = self.embedding_dim
        embedding_pb.initializer = "uniform"

        arr1 = np.random.uniform(size=(3, 4))
        serialize_ndarray(arr1, self.tensors_pb["x"])
        arr2 = np.random.uniform(size=(4, 5))
        serialize_ndarray(arr2, self.tensors_pb["y"])

        embedding_vectors = np.random.uniform(size=(2, 10))
        embedding_indices = np.array([0, 8])
        serialize_indexed_slices(
            Tensor(None, embedding_vectors, embedding_indices),
            self.embedding_tables_pb[self.embedding_table_name],
        )

    def _test_get_embedding_param(self, slot_names=[], slot_init_value={}):
        indices = [0, 3, 7]

        res = self.params.get_embedding_param(
            self.embedding_table_name, indices
        )
        self.assertTupleEqual(res.shape, (3, 10))
        for slot in slot_names:
            res = self.params.get_embedding_param(
                get_slot_table_name(self.embedding_table_name, slot), indices
            )
            self.assertTrue(((res - slot_init_value[slot]) < 0.0001).all())

        res = self.params.get_embedding_param(self.embedding_table_name, [])
        self.assertIsNone(res)

        with self.assertRaises(ValueError):
            self.params.get_embedding_param("tom", indices)

    def test_init_from_model_pb(self):
        self.params.reset()
        self.params.init_from_model_pb(self.model_pb)

        res = self.params.non_embedding_params
        self.assertTrue("x" in res)
        self.assertTrue("y" in res)
        self.assertTrue(res["x"].trainable)
        self.assertTupleEqual(tuple(res["y"].shape.as_list()), (4, 5))

        self._test_get_embedding_param()

    def test_non_embedding_params(self):
        self.params.reset()

        res = self.params.non_embedding_params
        self.assertFalse(any(res))

        variables = {
            "x": tf.Variable(1, name="x"),
            "y": tf.Variable(2, name="y"),
        }

        self.params.non_embedding_params = variables
        self.assertTrue("x" in self.params.non_embedding_params)
        self.assertTrue("y" in self.params.non_embedding_params)

    def test_get_embedding_param(self):
        self.params.reset()
        self.params.init_embedding_params(self.infos_pb)
        self._test_get_embedding_param()

    def test_set_embedding_param(self):
        self.params.reset()
        self.params.init_embedding_params(self.infos_pb)
        indices = [100, 34, 8]
        x = len(indices)
        values = np.random.uniform(size=x * self.embedding_dim).reshape(
            (x, self.embedding_dim)
        )

        self.params.set_embedding_param(
            self.embedding_table_name, indices, values
        )

        row0 = self.params.get_embedding_param(
            self.embedding_table_name, [100]
        )
        row1 = self.params.get_embedding_param(self.embedding_table_name, [34])
        row2 = self.params.get_embedding_param(self.embedding_table_name, [8])

        rows = [row0, row1, row2]
        rows = np.concatenate(rows)
        np.testing.assert_array_equal(rows, values)

        with self.assertRaises(ValueError):
            self.params.set_embedding_param("tom", [0, 1, 2], values)

    def test_check_grad(self):
        self.params.reset()
        self.params.init_from_model_pb(self.model_pb)

        grad0 = Tensor("z", None, None)
        with self.assertRaisesRegex(ValueError, "Name error"):
            self.params.check_grad(grad0)

        grad1 = Tensor("x", np.random.uniform(size=(3, 5)), None)
        with self.assertRaisesRegex(ValueError, "Non embedding param error"):
            self.params.check_grad(grad1)

        grad2 = Tensor(
            name="embedding_1",
            values=np.random.uniform(size=(3, 11)),
            indices=np.array([1, 2, 3]),
        )
        with self.assertRaisesRegex(
            ValueError, "ElasticDL embedding param error"
        ):
            self.params.check_grad(grad2)

        grad3 = Tensor(
            name="x",
            values=np.random.uniform(size=(4, 4)),
            indices=np.array([1, 2, 3, 4]),
        )
        with self.assertRaisesRegex(ValueError, "Keras embedding param error"):
            self.params.check_grad(grad3)

    def test_create_slot_params(self):
        # At first, no embedding table are in the parameters
        self.assertFalse(self.params.has_embedding_params())

        # create embedding tables in the parameters
        self.params.init_embedding_params(self.infos_pb)
        self.assertTrue(self.params.has_embedding_params())

        slot_names = ["accumulator", "linear"]
        slot_init_value = {slot_names[0]: 3.5, slot_names[1]: 0.0}
        self.params.create_slot_params(slot_names, slot_init_value)
        self._test_get_embedding_param(slot_names, slot_init_value)

    def test_export_to_model_pb(self):
        self.params.init_from_model_pb(self.model_pb)
        self.params.version = 15
        model_pb = self.params.to_model_pb()

        params = Parameters()
        params.init_from_model_pb(model_pb)
        self.assertEqual(params.version, self.params.version)
        self.assertEqual(
            params.non_embedding_params.keys(),
            self.params.non_embedding_params.keys(),
        )
        self.assertEqual(
            params.embedding_params["embedding_1"].get([0]).tolist(),
            self.params.embedding_params["embedding_1"].get([0]).tolist(),
        )