Пример #1
0
    def setUp(self):
        self.params = Parameters()

        self.model_pb = Model()
        self.infos_pb = self.model_pb.embedding_table_infos
        self.tensors_pb = self.model_pb.dense_parameters
        self.embedding_tables_pb = self.model_pb.embedding_tables

        self.embedding_table_name = "embedding_1"
        self.embedding_dim = 10
        embedding_pb = self.infos_pb.add()
        embedding_pb.name = self.embedding_table_name
        embedding_pb.dim = self.embedding_dim
        embedding_pb.initializer = "uniform"

        arr1 = np.random.uniform(size=(3, 4))
        serialize_ndarray(arr1, self.tensors_pb["x"])
        arr2 = np.random.uniform(size=(4, 5))
        serialize_ndarray(arr2, self.tensors_pb["y"])

        embedding_vectors = np.random.uniform(size=(2, 10))
        embedding_indices = np.array([0, 8])
        serialize_indexed_slices(
            Tensor(None, embedding_vectors, embedding_indices),
            self.embedding_tables_pb[self.embedding_table_name],
        )
Пример #2
0
    def __init__(self, args):
        self.logger = get_logger("PS", level=args.log_level.upper())
        self.grads_to_wait = args.grads_to_wait
        self.lr_staleness_modulation = args.lr_staleness_modulation
        self.sync_version_tolerance = args.sync_version_tolerance
        self.use_async = args.use_async
        self.port = args.port
        model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__
        self.optimizer = model_module[args.optimizer]()
        self._set_lr_scheduler(model_module, args.learning_rate_scheduler)
        self.ps_id = args.ps_id
        self.num_ps_pods = args.num_ps_pods
        self.num_workers = args.num_workers
        # Create Parameters instance
        self.parameters = Parameters()
        if args.master_addr is None:
            raise ValueError("master_addr is missing for parameter servers")
        self.master_channel = build_channel(args.master_addr)
        self.evaluation_steps = args.evaluation_steps

        self.master_name = get_master_pod_name(args.job_name)
        self.namespace = args.namespace
        self._init_checkpoint_saver(args)
        self._restore_params_from_checkpoint(args.checkpoint_dir_for_init)
        self._debug_info_needed = args.log_level.upper() == "DEBUG"
Пример #3
0
    def testSaveLoadCheckpoint(self):
        init_var = m["custom_model"]().trainable_variables
        with tempfile.TemporaryDirectory() as tempdir:
            ckpt_dir = os.path.join(tempdir, "testSaveLoadCheckpoint")
            os.makedirs(ckpt_dir)
            checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False)
            self.assertTrue(checkpoint_saver.is_enabled())
            params = Parameters()

            for var in init_var:
                params.non_embedding_params[var.name] = var
            model_pb = params.to_model_pb()

            checkpoint_saver.save(0, model_pb, False)

            ckpt_version_dir = os.path.join(ckpt_dir, "version-0")
            restore_params = CheckpointSaver.restore_params_from_checkpoint(
                ckpt_version_dir, 0, 1)
            self.assertEqual(restore_params.version, params.version)
            for var_name in params.non_embedding_params:
                self.assertTrue(
                    np.array_equal(
                        params.non_embedding_params[var_name].numpy(),
                        restore_params.non_embedding_params[var_name].numpy(),
                    ))
Пример #4
0
    def test_set_slot_to_optimizer(self):
        embed_name = "test_emb"
        indices = np.ndarray([2], dtype=np.int32)
        embed_values = np.ndarray([2, 2], dtype=np.float32)
        slot_values = {
            "m": np.ndarray([2, 2], dtype=np.float32),
            "v": np.ndarray([2, 2], dtype=np.float32),
        }
        params = Parameters()
        params.embedding_params[embed_name] = EmbeddingTable(embed_name, 8)
        for slot in ["m", "v"]:
            slot_table_name = get_slot_table_name(embed_name, slot)
            params.embedding_params[slot_table_name] = EmbeddingTable(
                slot_table_name, 2, "0.0", True)

        opt = Adam()
        opt_wrapper = OptimizerWrapper(opt, None, params.get_embedding_param)
        opt_wrapper._init_thread_local()

        opt_wrapper._tls._unique_ids_all_layers[embed_name] = indices
        opt_wrapper._create_embedding_variable(embed_name, embed_values)
        opt_wrapper._get_slot_and_set_to_optimizer(embed_name)

        self.assertEqual(len(opt._slots), 1)
        opt_slots = list(opt._slots.values())[0]
        self.assertEqual(sorted(opt_slots.keys()), ["m", "v"])
        for name in ["m", "v"]:
            self.assertTrue(
                np.allclose(opt_slots[name].numpy(), slot_values[name]))
Пример #5
0
    def test_delete_variables(self):
        params = Parameters()
        embed_layers = ["test_1", "test_2"]
        slot_names = ["m", "v"]
        dim = 8
        for layer in embed_layers:
            params.embedding_params[layer] = EmbeddingTable(layer, dim)
            for slot in slot_names:
                slot_key = get_slot_table_name(layer, slot)
                params.embedding_params[slot_key] = EmbeddingTable(
                    slot_key, dim, "0.0", True)

        opt = Adam()
        opt_wrapper = OptimizerWrapper(opt, None, params.get_embedding_param,
                                       params.set_embedding_param)

        opt_wrapper._init_thread_local()
        for name in embed_layers:
            opt_wrapper._tls._unique_ids_all_layers[name] = np.ndarray(
                [2], np.int32)
            opt_wrapper._create_embedding_variable(
                name, np.ndarray([2, dim], np.float32))
            opt_wrapper._get_slot_and_set_to_optimizer(name)

        self.assertTrue(len(opt._weights) == 4)
        self.assertTrue(len(opt._slots) == 2)
        for slot_dict in opt._slots.values():
            self.assertTrue(len(slot_dict) == 2)

        opt_wrapper._delete_slots_and_weights_in_optimizer()
        self.assertTrue(len(opt._weights) == 0)
        self.assertTrue(len(opt._slots) == 0)
Пример #6
0
    def setUp(self):
        self.params = Parameters()

        self.model_pb = Model()
        self.tensors_pb = self.model_pb.param
        self.embeddings_pb = self.model_pb.embedding_table_info

        arr1 = np.random.uniform(size=(3, 4))
        tensor1_pb = Tensor(arr1, name="x").to_tensor_pb()
        arr2 = np.random.uniform(size=(4, 5))
        tensor2_pb = Tensor(arr2, name="y").to_tensor_pb()
        self.tensors_pb.extend([tensor1_pb, tensor2_pb])

        self.embedding_table_name = "embedding_1"
        self.embedding_dim = 10
        embedding_pb = EmbeddingTableInfo()
        embedding_pb.name = self.embedding_table_name
        embedding_pb.dim = self.embedding_dim
        embedding_pb.initializer = "uniform"

        embedding_vectors = np.random.uniform(size=(2, 10))
        embedding_indices = np.array([0, 8])
        embedding_tensor = Tensor(
            embedding_vectors,
            indices=embedding_indices,
            name=self.embedding_table_name,
        )
        embedding_tensor_pb = embedding_tensor.to_tensor_pb()
        self.tensors_pb.append(embedding_tensor_pb)

        self.embeddings_pb.append(embedding_pb)
Пример #7
0
def save_checkpoint_without_embedding(model, checkpoint_dir, version=100):
    checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False)
    params = Parameters()
    for var in model.trainable_variables:
        params.non_embedding_params[var.name] = var
    params.version = version
    model_pb = params.to_model_pb()
    checkpoint_saver.save(version, model_pb, False)
    def _test_correctness(self, optimizer_class, X, Y, seed, **opt_kwargs):
        """Test the correctness of specific TensorFlow optimizer."""
        _model_file = get_module_file_path(
            os.path.dirname(os.path.realpath(__file__)),
            "embedding_test_module.KerasEmbeddingModel",
        )
        model_module = load_module(_model_file).__dict__

        # train model with TensorFlow optimizer
        dim = 4
        weights = self._random_init_model_weight([(4, dim), (4, dim), (72, 1),
                                                  (1, )], seed)
        loss_fn = model_module["loss"]
        model1 = model_module["KerasEmbeddingModel"](4, dim, weights)
        opt1 = optimizer_class(**opt_kwargs)
        _train(model1, opt1, X, Y, loss_fn, random_seed=seed)

        model2 = model_module["EdlEmbeddingModel"](dim, weights[2:])
        opt2 = optimizer_class(**opt_kwargs)

        embedding_weight_names = [
            layer.embedding_weight_name
            for layer in find_layer(model2, Embedding)
        ]

        # create Parameters object and initialize embedding vectors
        params = Parameters()
        for weight_name, embed_value in zip(embedding_weight_names,
                                            weights[:2]):
            embed_table = EmbeddingTable(weight_name, dim)
            embed_table.set(range(len(embed_value)), embed_value)
            params.embedding_params[weight_name] = embed_table

        _train_edl_embedding_with_optimizer_wrapper(model2,
                                                    opt2,
                                                    X,
                                                    Y,
                                                    loss_fn,
                                                    params,
                                                    random_seed=seed)

        # compare trained parameters
        wrong_msg = (
            "The updated parameters of Optimizer Wrapper and TensorFlow "
            "optimizer %s differ." % opt1.get_config()["name"])

        for layer1, layer2 in zip(model1.layers, model2.layers):
            if "embedding" in layer2.name:
                w1 = layer1.weights[0].numpy()
                w2 = params.get_embedding_param(layer2.embedding_weight_name,
                                                range(4))
                self.assertTrue(np.isclose(w1, w2).all(), msg=wrong_msg)
            else:
                for w1, w2 in zip(layer1.weights, layer2.weights):
                    self.assertTrue(np.isclose(w1.numpy(), w2.numpy()).all(),
                                    msg=wrong_msg)
Пример #9
0
    def restore_params_from_checkpoint(checkpoint_dir, shard_index, shard_num):
        """Restore a shard parameters from the checkpoint directory.
        If shard_num=1, a entire model parameters will be restored.

        Args:
            checkpoint_dir: a directory with checkpoint files.
            shard_index: Model shard index, e.g. the PS instance index
                using ParameterServerStrategy with multiple PS instances.
            shard_num: The total number of model shards, e.g. the total PS
                instancecount using ParameterServerStrategy with multiple
                PS instances.

        Return:
            parameters: A Parameter object which contains model version,
                non-embedding parameters and embedding tables for the
                PS instance with ps_id.
        """

        variable_shard_files = os.listdir(checkpoint_dir)
        non_embedding_vars = {}
        embedding_tables = {}
        version = None
        for shard_file in variable_shard_files:
            shard_file_path = os.path.join(checkpoint_dir, shard_file)
            model_pb = elasticdl_pb2.Model()
            model_pb = load_pb_from_file(model_pb, shard_file_path)
            if version is None:
                version = model_pb.version
            elif version != model_pb.version:
                raise ValueError(
                    "The versions in model shards are not consistent"
                )

            for embedding_info_pb in model_pb.embedding_table_infos:
                embedding_table = create_embedding_table(embedding_info_pb)
                embedding_tables.setdefault(
                    embedding_table.name, embedding_table
                )

            (
                shard_non_embedding_vars,
                shard_embedding_table_values,
            ) = _get_params_shard_from_pb(model_pb, shard_index, shard_num)

            non_embedding_vars.update(shard_non_embedding_vars)
            for name, pair in shard_embedding_table_values.items():
                embedding_tables[name].set(pair[0], pair[1])

        parameters = Parameters()
        parameters.non_embedding_params.update(non_embedding_vars)
        parameters.embedding_params.update(embedding_tables)
        parameters.version = version
        return parameters
Пример #10
0
    def test_export_to_model_pb(self):
        self.params.init_from_model_pb(self.model_pb)
        self.params.version = 15
        model_pb = self.params.to_model_pb()

        params = Parameters()
        params.init_from_model_pb(model_pb)
        self.assertEqual(params.version, self.params.version)
        self.assertEqual(
            params.non_embedding_params.keys(),
            self.params.non_embedding_params.keys(),
        )
        self.assertEqual(
            params.embedding_params["embedding_1"].get([0]).tolist(),
            self.params.embedding_params["embedding_1"].get([0]).tolist(),
        )
Пример #11
0
 def _mock_model_parameters(self, model):
     params = Parameters()
     for weight in model.trainable_variables:
         if "embedding" in weight.name:
             embedding_table = EmbeddingTable(
                 name=weight.name,
                 dim=weight.shape[1],
                 initializer="RandomUniform",
             )
             embedding_table.set(np.arange(weight.shape[0]),
                                 np.ones(weight.shape))
             params.embedding_params[weight.name] = embedding_table
         else:
             params.non_embedding_params[weight.name] = tf.ones(
                 weight.shape)
     params.version = 100
     return params
Пример #12
0
    def __init__(self, args):
        self.logger = get_logger("PS", level=args.log_level.upper())

        self.grads_to_wait = args.grads_to_wait
        self.lr_staleness_modulation = args.lr_staleness_modulation
        self.use_async = args.use_async
        self.port = args.port
        model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)
        ).__dict__
        self.optimizer = model_module[args.optimizer]()
        # Create Parameters instance
        self.parameters = Parameters()
Пример #13
0
    def test_save_parameters_to_checkpoint_file(self):
        with tempfile.TemporaryDirectory() as tempdir:
            checkpoint_saver = CheckpointSaver(
                checkpoint_dir=os.path.join(tempdir, "ckpt/"),
                checkpoint_steps=5,
                keep_checkpoint_max=3,
                include_evaluation=False,
            )
            pserver_servicer = PserverServicer(
                parameters=Parameters(),
                grads_to_wait=0,
                optimizer="optimizer",
                checkpoint_saver=checkpoint_saver,
                ps_id=0,
                num_ps_pods=1,
            )
            model_params = {
                "v0": tf.Variable([[1, 1, 1], [1, 1, 1]]),
                "v1": tf.Variable([[2, 2, 2], [2, 2, 2]]),
            }

            server_params = pserver_servicer._parameters
            for var_name, var_value in model_params.items():
                server_params.non_embedding_params[var_name] = var_value

            embedding_table = EmbeddingTable(
                name="embedding_0", dim=3, initializer="random_uniform"
            )
            server_params.embedding_params["embedding_0"] = embedding_table
            server_params.set_embedding_param(
                name="embedding_0",
                indices=np.array([0, 1]),
                values=np.array([[1, 1, 1], [2, 2, 2]]),
            )

            for i in range(100):
                pserver_servicer._parameters.version += 1
                pserver_servicer._save_params_to_checkpoint_if_needed()

            self.assertEqual(len(os.listdir(checkpoint_saver._directory)), 3)
            self.assertEqual(
                sorted(os.listdir(checkpoint_saver._directory)),
                ["version-100", "version-90", "version-95"],
            )
            self.assertEqual(
                os.listdir(checkpoint_saver._directory + "/version-100"),
                ["variables-0-of-1.ckpt"],
            )
Пример #14
0
    def test_update_embedding_param(self):
        params = Parameters()
        for name in ["test_1", "test_2"]:
            params.embedding_params[name] = EmbeddingTable(name, 8)
            slot_key = get_slot_table_name(name, "momentum")
            params.embedding_params[slot_key] = EmbeddingTable(
                slot_key, 8, "0.0", True)

        indices = {
            "test_1": np.array([1, 5]),
            "test_2": np.array([10]),
        }
        embed_vars = {
            "test_1": tf.Variable(np.random.rand(2, 8).astype(np.float32)),
            "test_2": tf.Variable(np.random.rand(1, 8).astype(np.float32)),
        }
        slot_vars = {
            "test_1": {
                "momentum":
                tf.Variable(np.random.rand(2, 8).astype(np.float32))
            },
            "test_2": {
                "momentum":
                tf.Variable(np.random.rand(1, 8).astype(np.float32))
            },
        }

        opt = SGD(momentum=0.1)
        opt_wrapper = OptimizerWrapper(opt, None, None,
                                       params.set_embedding_param)
        opt_wrapper._tls._unique_ids_all_layers = indices
        opt_wrapper._tls._embed_variables = embed_vars
        opt_wrapper._tls._slot_variables = slot_vars
        opt_wrapper._update_embedding_param()

        for name in ["test_1", "test_2"]:
            self.assertTrue(
                np.allclose(
                    embed_vars[name].numpy(),
                    params.get_embedding_param(name, indices[name]),
                ))

            slot = "momentum"
            slot_table_name = get_slot_table_name(name, slot)
            self.assertTrue(
                np.allclose(
                    slot_vars[name][slot].numpy(),
                    params.get_embedding_param(slot_table_name, indices[name]),
                ))
Пример #15
0
    def __init__(self, args):
        self.logger = get_logger("PS", level=args.log_level.upper())

        self.grads_to_wait = args.grads_to_wait
        self.lr_staleness_modulation = args.lr_staleness_modulation
        self.use_async = args.use_async
        self.port = args.port
        model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__
        self.optimizer = model_module[args.optimizer]()
        self.ps_id = args.ps_id
        self.num_ps_pods = args.num_ps_pods
        # Create Parameters instance
        self.parameters = Parameters()
        if args.master_addr is None:
            raise ValueError("master_addr is missing for parameter servers")
        self.master_channel = build_channel(args.master_addr)
        self.evaluation_steps = args.evaluation_steps

        self.master_name = get_master_pod_name(args.job_name)
        self.namespace = args.namespace
        self._init_checkpoint_service(args)
Пример #16
0
    def test_restore_parameters_from_checkpoint(self):
        checkpoint_dir = "elasticdl/python/tests/testdata/ps_ckpt"
        checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False)
        params = Parameters()
        table = EmbeddingTable("embedding", 2, "random_uniform")
        table.set([0, 1, 2, 3], np.ones((4, 2), dtype=np.float32))
        params.embedding_params["embedding"] = table
        params.non_embedding_params["dense/kernel:0"] = tf.Variable(
            [[1.0], [1.0]]
        )
        params.non_embedding_params["dense/bias:0"] = tf.Variable([1.0])
        params.version = 100
        model_pb = params.to_model_pb()
        checkpoint_saver.save(100, model_pb, False)

        checkpoint_dir_for_init = checkpoint_dir + "/version-100"
        args = PserverArgs(
            ps_id=0,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_0 = ParameterServer(args)

        embedding_table = pserver_0.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [0, 2]
        )
        self.assertEqual(
            list(pserver_0.parameters.non_embedding_params.keys()),
            ["dense/kernel:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_0.parameters.non_embedding_params[
                    "dense/kernel:0"
                ].numpy(),
                np.array([[1], [1]], dtype=int),
            )
        )
        self.assertEqual(pserver_0.parameters.version, 100)

        args = PserverArgs(
            ps_id=1,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_1 = ParameterServer(args)

        embedding_table = pserver_1.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [1, 3]
        )
        self.assertEqual(
            list(pserver_1.parameters.non_embedding_params.keys()),
            ["dense/bias:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_1.parameters.non_embedding_params[
                    "dense/bias:0"
                ].numpy(),
                np.array([1], dtype=int),
            )
        )
        self.assertEqual(pserver_1.parameters.version, 100)
Пример #17
0
    def _test_async_correctness(
        self,
        grads_and_vars_batches,
        embed_values,
        expected_non_embed_values,
        expected_embed_values=None,
    ):
        """Checks the correctness of async OptimizerWrapper. This function
        creates many threads and these threads call
        `OptimizerWrapper.apply_gradients` simultaneously.

        Args:
            grads_and_vars_batches: A python list of `grads_and_vars`. Every
                thread takes a `grads_and_vars` and calls `apply_gradients`.
            embed_values: A python dictionary of
                `(layer_name, embedding table)`.
            expected_non_embed_values: A python list of expected non-embdding
                values after applying gradients.
            expected_embed_values: A python dictionary of expected embedding
                values after applying gradients. None means no need to check
                embedding values.
        """
        thread_num = len(grads_and_vars_batches)
        input_dims = {}
        embed_var_n = len(embed_values)
        params = Parameters()
        for layer, values in embed_values.items():
            embed_dim = values.shape[1]
            input_dims[layer] = values.shape[0]
            embed_table = EmbeddingTable(layer, embed_dim)
            embed_table.set(range(input_dims[layer]), values)
            params.embedding_params[layer] = embed_table

        opt = SGD(0.1)
        opt_wrapper = OptimizerWrapper(
            opt,
            True,
            lookup_embedding_func=params.get_embedding_param,
            update_embedding_func=params.set_embedding_param,
        )

        # call optimizer_wrapper.apply_gradients asynchronously
        def _apply_gradients(opt_wrapper, grads_and_vars):
            # sleep 1s to wait that all threads are in this method call
            time.sleep(1)
            opt_wrapper.apply_gradients(grads_and_vars)

        executor = ThreadPoolExecutor(max_workers=thread_num)
        tasks = [
            executor.submit(_apply_gradients, opt_wrapper, grads_and_vars)
            for grads_and_vars in grads_and_vars_batches
        ]
        _ = [task.result() for task in tasks]

        # check updated results of non-embedding variables
        non_embed_vars = [
            var for grad, var in grads_and_vars_batches[0][:-embed_var_n]
        ]
        for var, expected_value in zip(non_embed_vars,
                                       expected_non_embed_values):
            self.assertTrue(np.isclose(var.numpy(), expected_value).all())

        # `expected_embed_values=None` means that no need to check
        # embedding table
        if not expected_embed_values:
            return
        # check updated results of embedding table
        for layer, expected_values in expected_embed_values.items():
            value = params.get_embedding_param(layer, range(input_dims[layer]))

            self.assertTrue(
                any([
                    np.isclose(value, expected).all()
                    for expected in expected_values
                ]))
Пример #18
0
class ParametersTest(unittest.TestCase):
    def setUp(self):
        self.params = Parameters()

        self.model_pb = Model()
        self.infos_pb = self.model_pb.embedding_table_infos
        self.tensors_pb = self.model_pb.dense_parameters
        self.embedding_tables_pb = self.model_pb.embedding_tables

        self.embedding_table_name = "embedding_1"
        self.embedding_dim = 10
        embedding_pb = self.infos_pb.add()
        embedding_pb.name = self.embedding_table_name
        embedding_pb.dim = self.embedding_dim
        embedding_pb.initializer = "uniform"

        arr1 = np.random.uniform(size=(3, 4))
        serialize_ndarray(arr1, self.tensors_pb["x"])
        arr2 = np.random.uniform(size=(4, 5))
        serialize_ndarray(arr2, self.tensors_pb["y"])

        embedding_vectors = np.random.uniform(size=(2, 10))
        embedding_indices = np.array([0, 8])
        serialize_indexed_slices(
            Tensor(None, embedding_vectors, embedding_indices),
            self.embedding_tables_pb[self.embedding_table_name],
        )

    def _test_get_embedding_param(self, slot_names=[], slot_init_value={}):
        indices = [0, 3, 7]

        res = self.params.get_embedding_param(
            self.embedding_table_name, indices
        )
        self.assertTupleEqual(res.shape, (3, 10))
        for slot in slot_names:
            res = self.params.get_embedding_param(
                get_slot_table_name(self.embedding_table_name, slot), indices
            )
            self.assertTrue(((res - slot_init_value[slot]) < 0.0001).all())

        res = self.params.get_embedding_param(self.embedding_table_name, [])
        self.assertIsNone(res)

        with self.assertRaises(ValueError):
            self.params.get_embedding_param("tom", indices)

    def test_init_from_model_pb(self):
        self.params.reset()
        self.params.init_from_model_pb(self.model_pb)

        res = self.params.non_embedding_params
        self.assertTrue("x" in res)
        self.assertTrue("y" in res)
        self.assertTrue(res["x"].trainable)
        self.assertTupleEqual(tuple(res["y"].shape.as_list()), (4, 5))

        self._test_get_embedding_param()

    def test_non_embedding_params(self):
        self.params.reset()

        res = self.params.non_embedding_params
        self.assertFalse(any(res))

        variables = {
            "x": tf.Variable(1, name="x"),
            "y": tf.Variable(2, name="y"),
        }

        self.params.non_embedding_params = variables
        self.assertTrue("x" in self.params.non_embedding_params)
        self.assertTrue("y" in self.params.non_embedding_params)

    def test_get_embedding_param(self):
        self.params.reset()
        self.params.init_embedding_params(self.infos_pb)
        self._test_get_embedding_param()

    def test_set_embedding_param(self):
        self.params.reset()
        self.params.init_embedding_params(self.infos_pb)
        indices = [100, 34, 8]
        x = len(indices)
        values = np.random.uniform(size=x * self.embedding_dim).reshape(
            (x, self.embedding_dim)
        )

        self.params.set_embedding_param(
            self.embedding_table_name, indices, values
        )

        row0 = self.params.get_embedding_param(
            self.embedding_table_name, [100]
        )
        row1 = self.params.get_embedding_param(self.embedding_table_name, [34])
        row2 = self.params.get_embedding_param(self.embedding_table_name, [8])

        rows = [row0, row1, row2]
        rows = np.concatenate(rows)
        np.testing.assert_array_equal(rows, values)

        with self.assertRaises(ValueError):
            self.params.set_embedding_param("tom", [0, 1, 2], values)

    def test_check_grad(self):
        self.params.reset()
        self.params.init_from_model_pb(self.model_pb)

        grad0 = Tensor("z", None, None)
        with self.assertRaisesRegex(ValueError, "Name error"):
            self.params.check_grad(grad0)

        grad1 = Tensor("x", np.random.uniform(size=(3, 5)), None)
        with self.assertRaisesRegex(ValueError, "Non embedding param error"):
            self.params.check_grad(grad1)

        grad2 = Tensor(
            name="embedding_1",
            values=np.random.uniform(size=(3, 11)),
            indices=np.array([1, 2, 3]),
        )
        with self.assertRaisesRegex(
            ValueError, "ElasticDL embedding param error"
        ):
            self.params.check_grad(grad2)

        grad3 = Tensor(
            name="x",
            values=np.random.uniform(size=(4, 4)),
            indices=np.array([1, 2, 3, 4]),
        )
        with self.assertRaisesRegex(ValueError, "Keras embedding param error"):
            self.params.check_grad(grad3)

    def test_create_slot_params(self):
        # At first, no embedding table are in the parameters
        self.assertFalse(self.params.has_embedding_params())

        # create embedding tables in the parameters
        self.params.init_embedding_params(self.infos_pb)
        self.assertTrue(self.params.has_embedding_params())

        slot_names = ["accumulator", "linear"]
        slot_init_value = {slot_names[0]: 3.5, slot_names[1]: 0.0}
        self.params.create_slot_params(slot_names, slot_init_value)
        self._test_get_embedding_param(slot_names, slot_init_value)

    def test_export_to_model_pb(self):
        self.params.init_from_model_pb(self.model_pb)
        self.params.version = 15
        model_pb = self.params.to_model_pb()

        params = Parameters()
        params.init_from_model_pb(model_pb)
        self.assertEqual(params.version, self.params.version)
        self.assertEqual(
            params.non_embedding_params.keys(),
            self.params.non_embedding_params.keys(),
        )
        self.assertEqual(
            params.embedding_params["embedding_1"].get([0]).tolist(),
            self.params.embedding_params["embedding_1"].get([0]).tolist(),
        )
Пример #19
0
 def setUp(self):
     init_var = m["custom_model"]().trainable_variables
     self.params = Parameters()
     for var in init_var:
         self.params.non_embedding_params[var.name] = var
Пример #20
0
class ParameterServer(object):
    def __init__(self, args):
        self.logger = get_logger("PS", level=args.log_level.upper())
        self.grads_to_wait = args.grads_to_wait
        self.lr_staleness_modulation = args.lr_staleness_modulation
        self.sync_version_tolerance = args.sync_version_tolerance
        self.use_async = args.use_async
        self.port = args.port
        model_module = load_module(
            get_module_file_path(args.model_zoo, args.model_def)).__dict__
        self.optimizer = model_module[args.optimizer]()
        self._set_lr_scheduler(model_module, args.learning_rate_scheduler)
        self.ps_id = args.ps_id
        self.num_ps_pods = args.num_ps_pods
        self.num_workers = args.num_workers
        # Create Parameters instance
        self.parameters = Parameters()
        if args.master_addr is None:
            raise ValueError("master_addr is missing for parameter servers")
        self.master_channel = build_channel(args.master_addr)
        self.evaluation_steps = args.evaluation_steps

        self.master_name = get_master_pod_name(args.job_name)
        self.namespace = args.namespace
        self._init_checkpoint_saver(args)
        self._restore_params_from_checkpoint(args.checkpoint_dir_for_init)
        self._debug_info_needed = args.log_level.upper() == "DEBUG"

    def _set_lr_scheduler(self, model_module, learning_rate_scheduler_arg):
        if learning_rate_scheduler_arg in model_module:
            self.lr_scheduler = add_lr_scheduler_to_optimizer(
                self.optimizer, model_module[learning_rate_scheduler_arg])
        else:
            self.lr_scheduler = None

    def _restore_params_from_checkpoint(self, checkpoint_dir_for_init):
        """Restore parameters from a checkpint directory for the PS instance
        """
        if not checkpoint_dir_for_init:
            self.logger.info("checkpoint directory for init is None")
            return

        if not CheckpointSaver.check_checkpoint_valid(checkpoint_dir_for_init):
            raise ValueError("Invalid checkpoint directory")

        self.parameters = CheckpointSaver.restore_params_from_checkpoint(
            checkpoint_dir_for_init, self.ps_id, self.num_ps_pods)
        self.parameters.init_status = True
        self.logger.info("The version of restored parameters is %d" %
                         self.parameters.version)

    def _init_checkpoint_saver(self, args):
        if all([args.checkpoint_dir, args.checkpoint_steps]):
            self.checkpoint_saver = CheckpointSaver(
                args.checkpoint_dir,
                args.checkpoint_steps,
                args.keep_checkpoint_max,
                include_evaluation=False,
            )
        else:
            self.checkpoint_saver = None
            self.logger.warning(
                "Invalid checkpoint config and no model will be saved")

    def prepare(self):
        max_workers = min(self.num_workers, 64)
        self.logger.info("The max threads in PS servers is %d" % max_workers)
        server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=max_workers),
            options=[
                ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
                (
                    "grpc.max_receive_message_length",
                    GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
                ),
            ],
        )
        pserver_servicer = PserverServicer(
            self.parameters,
            self.grads_to_wait,
            self.optimizer,
            self.lr_scheduler,
            lr_staleness_modulation=self.lr_staleness_modulation,
            sync_version_tolerance=self.sync_version_tolerance,
            use_async=self.use_async,
            evaluation_steps=self.evaluation_steps,
            master_channel=self.master_channel,
            checkpoint_saver=self.checkpoint_saver,
            ps_id=self.ps_id,
            num_ps_pods=self.num_ps_pods,
        )
        elasticdl_pb2_grpc.add_PserverServicer_to_server(
            pserver_servicer, server)
        server.add_insecure_port("[::]:{}".format(self.port))
        server.start()
        self.server = server
        self.logger.info("RPC Server started at port: %d", self.port)

    def run(self):
        config.load_incluster_config()
        api = client.CoreV1Api()
        try:
            while True:
                time.sleep(30)
                master_pod = api.read_namespaced_pod(namespace=self.namespace,
                                                     name=self.master_name)
                if master_pod.status.phase == PodStatus.SUCCEEDED:
                    self.logger.info("Master pod is Succeeded")
                    break
                elif master_pod.status.phase == PodStatus.FAILED:
                    self.logger.info("Master pod is Failed")
                    break
                elif (master_pod.status.phase == PodStatus.RUNNING
                      and master_pod.metadata.labels["status"]
                      == PodStatus.FINISHED):
                    self.logger.info(
                        "Task is finished, "
                        "master pod is still running tensorboard service")
                    break

                if self._debug_info_needed:
                    self.logger.debug("Parameters info:\n%s" %
                                      self.parameters.debug_info())
        except KeyboardInterrupt:
            self.logger.warning("Server stopping")

        self.server.stop(0)
        self.logger.info("RPC server stopped")