コード例 #1
0
def create_pserver(model_zoo_path, model_def, grads_to_wait, use_async,
                   num_ps_pods):
    ports = [i + 12345 for i in range(num_ps_pods)]
    channels = []
    for port in ports:
        addr = "localhost:%d" % port
        channel = build_channel(addr)
        channels.append(channel)

    pservers = []
    for port in ports:
        args = PserverArgs(
            grads_to_wait=grads_to_wait,
            use_async=True,
            port=port,
            model_zoo=model_zoo_path,
            model_def=model_def,
        )
        pserver = ParameterServer(args)
        pserver.prepare()
        pservers.append(pserver)

    for channel in channels:
        grpc.channel_ready_future(channel).result()

    return ports, channels, pservers
コード例 #2
0
    def _create_pserver_and_channel(self, ports):
        pservers = []
        channels = []
        for port in ports:
            args = PserverArgs(
                grads_to_wait=1,
                use_async=False,
                port=port,
                model_zoo=self._model_zoo_path,
                model_def=self._model_def,
            )
            pserver = ParameterServer(args)
            pserver.prepare()
            pservers.append(pserver)

            addr = "localhost:%d" % port
            channel = grpc.insecure_channel(
                addr,
                options=[
                    (
                        "grpc.max_send_message_length",
                        GRPC.MAX_SEND_MESSAGE_LENGTH,
                    ),
                    (
                        "grpc.max_receive_message_length",
                        GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
                    ),
                ],
            )
            channels.append(channel)
        return pservers, channels
コード例 #3
0
    def test_restore_parameters_from_checkpoint(self):
        checkpoint_dir_for_init = (
            "elasticdl/python/tests/testdata/functional_ckpt/version-100")

        args = PserverArgs(
            ps_id=0,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
            lr_scheduler="learning_rate_scheduler",
        )
        pserver_0 = ParameterServer(args)

        embedding_table = pserver_0.parameters.embedding_params["embedding"]
        self.assertEqual(list(embedding_table.embedding_vectors.keys()),
                         [0, 2])
        self.assertEqual(
            list(pserver_0.parameters.non_embedding_params.keys()),
            ["dense/kernel:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_0.parameters.non_embedding_params["dense/kernel:0"].
                numpy(),
                np.array([[1], [1]], dtype=int),
            ))
        self.assertEqual(pserver_0.parameters.version, 100)

        args = PserverArgs(
            ps_id=1,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
            lr_scheduler="learning_rate_scheduler",
        )
        pserver_1 = ParameterServer(args)

        embedding_table = pserver_1.parameters.embedding_params["embedding"]
        self.assertEqual(list(embedding_table.embedding_vectors.keys()),
                         [1, 3])
        self.assertEqual(
            list(pserver_1.parameters.non_embedding_params.keys()),
            ["dense/bias:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_1.parameters.non_embedding_params["dense/bias:0"].
                numpy(),
                np.array([1], dtype=int),
            ))
        self.assertEqual(pserver_1.parameters.version, 100)
コード例 #4
0
    def create_server_and_stub(self, grads_to_wait, lr_staleness_modulation,
                               use_async, **kwargs):
        args = PserverArgs(grads_to_wait=grads_to_wait,
                           lr_staleness_modulation=lr_staleness_modulation,
                           use_async=use_async,
                           port=self._port,
                           model_zoo=_test_model_zoo_path,
                           model_def="test_module.custom_model",
                           **kwargs)
        pserver = ParameterServer(args)
        pserver.prepare()
        self._parameters = pserver.parameters
        self._server = pserver.server
        self._stub = elasticdl_pb2_grpc.PserverStub(self._channel)

        self._lr = 0.1
コード例 #5
0
    def _create_pserver(self, model_def, num):
        self._ports = [i + 12345 for i in range(num)]
        for port in self._ports:
            addr = "localhost:%d" % port
            channel = build_channel(addr)
            self._channels.append(channel)

        self._model_def = model_def
        for port in self._ports:
            args = PserverArgs(
                grads_to_wait=1,
                use_async=True,
                port=port,
                model_zoo=self._model_zoo_path,
                model_def=self._model_def,
            )
            pserver = ParameterServer(args)
            pserver.prepare()
            self._pservers.append(pserver)
コード例 #6
0
    def test_restore_parameters_from_checkpoint(self):
        checkpoint_dir = "elasticdl/python/tests/testdata/ps_ckpt"
        checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False)
        params = Parameters()
        table = EmbeddingTable("embedding", 2, "random_uniform")
        table.set([0, 1, 2, 3], np.ones((4, 2), dtype=np.float32))
        params.embedding_params["embedding"] = table
        params.non_embedding_params["dense/kernel:0"] = tf.Variable(
            [[1.0], [1.0]]
        )
        params.non_embedding_params["dense/bias:0"] = tf.Variable([1.0])
        params.version = 100
        model_pb = params.to_model_pb()
        checkpoint_saver.save(100, model_pb, False)

        checkpoint_dir_for_init = checkpoint_dir + "/version-100"
        args = PserverArgs(
            ps_id=0,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_0 = ParameterServer(args)

        embedding_table = pserver_0.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [0, 2]
        )
        self.assertEqual(
            list(pserver_0.parameters.non_embedding_params.keys()),
            ["dense/kernel:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_0.parameters.non_embedding_params[
                    "dense/kernel:0"
                ].numpy(),
                np.array([[1], [1]], dtype=int),
            )
        )
        self.assertEqual(pserver_0.parameters.version, 100)

        args = PserverArgs(
            ps_id=1,
            num_ps_pods=2,
            model_zoo=_test_model_zoo_path,
            model_def="test_module.custom_model",
            checkpoint_dir_for_init=checkpoint_dir_for_init,
        )
        pserver_1 = ParameterServer(args)

        embedding_table = pserver_1.parameters.embedding_params["embedding"]
        self.assertEqual(
            list(embedding_table.embedding_vectors.keys()), [1, 3]
        )
        self.assertEqual(
            list(pserver_1.parameters.non_embedding_params.keys()),
            ["dense/bias:0"],
        )
        self.assertTrue(
            np.array_equal(
                pserver_1.parameters.non_embedding_params[
                    "dense/bias:0"
                ].numpy(),
                np.array([1], dtype=int),
            )
        )
        self.assertEqual(pserver_1.parameters.version, 100)
コード例 #7
0
def main():
    args = parse_ps_args()
    pserver = ParameterServer(args)
    pserver.prepare()
    pserver.run()