def test_save_parameters_to_checkpoint_file(self): with tempfile.TemporaryDirectory() as tempdir: checkpoint_saver = CheckpointSaver( checkpoint_dir=os.path.join(tempdir, "ckpt/"), checkpoint_steps=5, keep_checkpoint_max=3, include_evaluation=False, ) pserver_servicer = PserverServicer( parameters=Parameters(), grads_to_wait=0, optimizer="optimizer", checkpoint_saver=checkpoint_saver, ps_id=0, num_ps_pods=1, ) model_params = { "v0": tf.Variable([[1, 1, 1], [1, 1, 1]]), "v1": tf.Variable([[2, 2, 2], [2, 2, 2]]), } server_params = pserver_servicer._parameters for var_name, var_value in model_params.items(): server_params.non_embedding_params[var_name] = var_value embedding_table = EmbeddingTable( name="embedding_0", dim=3, initializer="random_uniform" ) server_params.embedding_params["embedding_0"] = embedding_table server_params.set_embedding_param( name="embedding_0", indices=np.array([0, 1]), values=np.array([[1, 1, 1], [2, 2, 2]]), ) for i in range(100): pserver_servicer._parameters.version += 1 pserver_servicer._save_params_to_checkpoint_if_needed() self.assertEqual(len(os.listdir(checkpoint_saver._directory)), 3) self.assertEqual( sorted(os.listdir(checkpoint_saver._directory)), ["version-100", "version-90", "version-95"], ) self.assertEqual( os.listdir(checkpoint_saver._directory + "/version-100"), ["variables-0-of-1.ckpt"], )
def prepare(self): max_workers = min(self.num_workers, 64) self.logger.info("The max threads in PS servers is %d" % max_workers) server = grpc.server( futures.ThreadPoolExecutor(max_workers=max_workers), options=[ ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) pserver_servicer = PserverServicer( self.parameters, self.grads_to_wait, self.optimizer, self.lr_scheduler, lr_staleness_modulation=self.lr_staleness_modulation, sync_version_tolerance=self.sync_version_tolerance, use_async=self.use_async, evaluation_steps=self.evaluation_steps, master_channel=self.master_channel, checkpoint_saver=self.checkpoint_saver, ps_id=self.ps_id, num_ps_pods=self.num_ps_pods, ) elasticdl_pb2_grpc.add_PserverServicer_to_server( pserver_servicer, server) server.add_insecure_port("[::]:{}".format(self.port)) server.start() self.server = server self.logger.info("RPC Server started at port: %d", self.port)
def prepare(self): server = grpc.server( futures.ThreadPoolExecutor(max_workers=64), options=[ ("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) pserver_servicer = PserverServicer( self.parameters, self.grads_to_wait, self.optimizer, lr_staleness_modulation=self.lr_staleness_modulation, use_async=self.use_async, ) elasticdl_pb2_grpc.add_PserverServicer_to_server( pserver_servicer, server ) server.add_insecure_port("[::]:{}".format(self.port)) server.start() self.server = server self.logger.info("RPC Server started at port: %d", self.port)