コード例 #1
0
    def test_save_parameters_to_checkpoint_file(self):
        with tempfile.TemporaryDirectory() as tempdir:
            checkpoint_saver = CheckpointSaver(
                checkpoint_dir=os.path.join(tempdir, "ckpt/"),
                checkpoint_steps=5,
                keep_checkpoint_max=3,
                include_evaluation=False,
            )
            pserver_servicer = PserverServicer(
                parameters=Parameters(),
                grads_to_wait=0,
                optimizer="optimizer",
                checkpoint_saver=checkpoint_saver,
                ps_id=0,
                num_ps_pods=1,
            )
            model_params = {
                "v0": tf.Variable([[1, 1, 1], [1, 1, 1]]),
                "v1": tf.Variable([[2, 2, 2], [2, 2, 2]]),
            }

            server_params = pserver_servicer._parameters
            for var_name, var_value in model_params.items():
                server_params.non_embedding_params[var_name] = var_value

            embedding_table = EmbeddingTable(
                name="embedding_0", dim=3, initializer="random_uniform"
            )
            server_params.embedding_params["embedding_0"] = embedding_table
            server_params.set_embedding_param(
                name="embedding_0",
                indices=np.array([0, 1]),
                values=np.array([[1, 1, 1], [2, 2, 2]]),
            )

            for i in range(100):
                pserver_servicer._parameters.version += 1
                pserver_servicer._save_params_to_checkpoint_if_needed()

            self.assertEqual(len(os.listdir(checkpoint_saver._directory)), 3)
            self.assertEqual(
                sorted(os.listdir(checkpoint_saver._directory)),
                ["version-100", "version-90", "version-95"],
            )
            self.assertEqual(
                os.listdir(checkpoint_saver._directory + "/version-100"),
                ["variables-0-of-1.ckpt"],
            )
コード例 #2
0
 def prepare(self):
     max_workers = min(self.num_workers, 64)
     self.logger.info("The max threads in PS servers is %d" % max_workers)
     server = grpc.server(
         futures.ThreadPoolExecutor(max_workers=max_workers),
         options=[
             ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
             (
                 "grpc.max_receive_message_length",
                 GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
             ),
         ],
     )
     pserver_servicer = PserverServicer(
         self.parameters,
         self.grads_to_wait,
         self.optimizer,
         self.lr_scheduler,
         lr_staleness_modulation=self.lr_staleness_modulation,
         sync_version_tolerance=self.sync_version_tolerance,
         use_async=self.use_async,
         evaluation_steps=self.evaluation_steps,
         master_channel=self.master_channel,
         checkpoint_saver=self.checkpoint_saver,
         ps_id=self.ps_id,
         num_ps_pods=self.num_ps_pods,
     )
     elasticdl_pb2_grpc.add_PserverServicer_to_server(
         pserver_servicer, server)
     server.add_insecure_port("[::]:{}".format(self.port))
     server.start()
     self.server = server
     self.logger.info("RPC Server started at port: %d", self.port)
コード例 #3
0
 def prepare(self):
     server = grpc.server(
         futures.ThreadPoolExecutor(max_workers=64),
         options=[
             ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
             (
                 "grpc.max_receive_message_length",
                 GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
             ),
         ],
     )
     pserver_servicer = PserverServicer(
         self.parameters,
         self.grads_to_wait,
         self.optimizer,
         lr_staleness_modulation=self.lr_staleness_modulation,
         use_async=self.use_async,
     )
     elasticdl_pb2_grpc.add_PserverServicer_to_server(
         pserver_servicer, server
     )
     server.add_insecure_port("[::]:{}".format(self.port))
     server.start()
     self.server = server
     self.logger.info("RPC Server started at port: %d", self.port)