def test_restore_parameters_from_checkpoint(self): checkpoint_dir_for_init = ( "elasticdl/python/tests/testdata/functional_ckpt/version-100") args = PserverArgs( ps_id=0, num_ps_pods=2, model_zoo=_test_model_zoo_path, model_def="test_module.custom_model", checkpoint_dir_for_init=checkpoint_dir_for_init, lr_scheduler="learning_rate_scheduler", ) pserver_0 = ParameterServer(args) embedding_table = pserver_0.parameters.embedding_params["embedding"] self.assertEqual(list(embedding_table.embedding_vectors.keys()), [0, 2]) self.assertEqual( list(pserver_0.parameters.non_embedding_params.keys()), ["dense/kernel:0"], ) self.assertTrue( np.array_equal( pserver_0.parameters.non_embedding_params["dense/kernel:0"]. numpy(), np.array([[1], [1]], dtype=int), )) self.assertEqual(pserver_0.parameters.version, 100) args = PserverArgs( ps_id=1, num_ps_pods=2, model_zoo=_test_model_zoo_path, model_def="test_module.custom_model", checkpoint_dir_for_init=checkpoint_dir_for_init, lr_scheduler="learning_rate_scheduler", ) pserver_1 = ParameterServer(args) embedding_table = pserver_1.parameters.embedding_params["embedding"] self.assertEqual(list(embedding_table.embedding_vectors.keys()), [1, 3]) self.assertEqual( list(pserver_1.parameters.non_embedding_params.keys()), ["dense/bias:0"], ) self.assertTrue( np.array_equal( pserver_1.parameters.non_embedding_params["dense/bias:0"]. numpy(), np.array([1], dtype=int), )) self.assertEqual(pserver_1.parameters.version, 100)
def _create_pserver_and_channel(self, ports): pservers = [] channels = [] for port in ports: args = PserverArgs( grads_to_wait=1, use_async=False, port=port, model_zoo=self._model_zoo_path, model_def=self._model_def, ) pserver = ParameterServer(args) pserver.prepare() pservers.append(pserver) addr = "localhost:%d" % port channel = grpc.insecure_channel( addr, options=[ ( "grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH, ), ( "grpc.max_receive_message_length", GRPC.MAX_RECEIVE_MESSAGE_LENGTH, ), ], ) channels.append(channel) return pservers, channels
def create_server_and_stub(self, grads_to_wait, lr_staleness_modulation, use_async, **kwargs): args = PserverArgs(grads_to_wait=grads_to_wait, lr_staleness_modulation=lr_staleness_modulation, use_async=use_async, port=self._port, model_zoo=_test_model_zoo_path, model_def="test_module.custom_model", **kwargs) pserver = ParameterServer(args) pserver.prepare() self._parameters = pserver.parameters self._server = pserver.server self._stub = elasticdl_pb2_grpc.PserverStub(self._channel) self._lr = 0.1
def _create_pserver(self, model_def, num): self._ports = [i + 12345 for i in range(num)] for port in self._ports: addr = "localhost:%d" % port channel = build_channel(addr) self._channels.append(channel) self._model_def = model_def for port in self._ports: args = PserverArgs( grads_to_wait=1, use_async=True, port=port, model_zoo=self._model_zoo_path, model_def=self._model_def, ) pserver = ParameterServer(args) pserver.prepare() self._pservers.append(pserver)
def test_restore_parameters_from_checkpoint(self): checkpoint_dir = "elasticdl/python/tests/testdata/ps_ckpt" checkpoint_saver = CheckpointSaver(checkpoint_dir, 0, 0, False) params = Parameters() table = EmbeddingTable("embedding", 2, "random_uniform") table.set([0, 1, 2, 3], np.ones((4, 2), dtype=np.float32)) params.embedding_params["embedding"] = table params.non_embedding_params["dense/kernel:0"] = tf.Variable( [[1.0], [1.0]] ) params.non_embedding_params["dense/bias:0"] = tf.Variable([1.0]) params.version = 100 model_pb = params.to_model_pb() checkpoint_saver.save(100, model_pb, False) checkpoint_dir_for_init = checkpoint_dir + "/version-100" args = PserverArgs( ps_id=0, num_ps_pods=2, model_zoo=_test_model_zoo_path, model_def="test_module.custom_model", checkpoint_dir_for_init=checkpoint_dir_for_init, ) pserver_0 = ParameterServer(args) embedding_table = pserver_0.parameters.embedding_params["embedding"] self.assertEqual( list(embedding_table.embedding_vectors.keys()), [0, 2] ) self.assertEqual( list(pserver_0.parameters.non_embedding_params.keys()), ["dense/kernel:0"], ) self.assertTrue( np.array_equal( pserver_0.parameters.non_embedding_params[ "dense/kernel:0" ].numpy(), np.array([[1], [1]], dtype=int), ) ) self.assertEqual(pserver_0.parameters.version, 100) args = PserverArgs( ps_id=1, num_ps_pods=2, model_zoo=_test_model_zoo_path, model_def="test_module.custom_model", checkpoint_dir_for_init=checkpoint_dir_for_init, ) pserver_1 = ParameterServer(args) embedding_table = pserver_1.parameters.embedding_params["embedding"] self.assertEqual( list(embedding_table.embedding_vectors.keys()), [1, 3] ) self.assertEqual( list(pserver_1.parameters.non_embedding_params.keys()), ["dense/bias:0"], ) self.assertTrue( np.array_equal( pserver_1.parameters.non_embedding_params[ "dense/bias:0" ].numpy(), np.array([1], dtype=int), ) ) self.assertEqual(pserver_1.parameters.version, 100)