def test_invalid_dispatcher_ids(self): key = secret.make_secret_key() service = ComputeService(2, 4, key, nics=None) try: client = ComputeClient(service.addresses(), key, verbose=2) with self.assertRaises(IndexError): client.register_dispatcher(-1, 'grpc://localhost:10000') with self.assertRaises(IndexError): client.register_dispatcher(2, 'grpc://localhost:10000') with self.assertRaises(IndexError): client.wait_for_dispatcher_registration(-1, 0.1) with self.assertRaises(IndexError): client.wait_for_dispatcher_registration(2, 0.1) with self.assertRaises(IndexError): client.register_worker_for_dispatcher(-1, 0) with self.assertRaises(IndexError): client.register_worker_for_dispatcher(2, 0) with self.assertRaises(IndexError): client.wait_for_dispatcher_worker_registration(-1, 0.1) with self.assertRaises(IndexError): client.wait_for_dispatcher_worker_registration(2, 0.1) finally: service.shutdown()
def test_register_dispatcher_duplicate(self): key = secret.make_secret_key() service = ComputeService(2, 1, key, nics=None) try: client = ComputeClient(service.addresses(), key, verbose=2) client.register_dispatcher(0, 'grpc://localhost:10000') with self.assertRaisesRegex( ValueError, 'Dispatcher with id 0 has already been registered under ' 'different address grpc://localhost:10000: grpc://localhost:10001' ): client.register_dispatcher(0, 'grpc://localhost:10001') finally: service.shutdown()
def test_register_dispatcher_worker_timeout(self): key = secret.make_secret_key() service = ComputeService(1, 1, key, nics=None) try: client = ComputeClient(service.addresses(), key, verbose=2) with self.assertRaisesRegex( TimeoutException, expected_regex= 'Timed out waiting for workers for dispatcher 0 to register. ' 'Try to find out what takes the workers so long ' 'to register or increase timeout. Timeout after 0.1 seconds.' ): client.wait_for_dispatcher_worker_registration(0, timeout=0.1) finally: service.shutdown()
def test_register_dispatcher_worker_replay(self): key = secret.make_secret_key() service = ComputeService(1, 2, key, nics=None) try: client = ComputeClient(service.addresses(), key, verbose=2) client.register_dispatcher(0, 'grpc://localhost:10000') client.register_worker_for_dispatcher(0, 0) # register the same worker again should not complete the registration client.register_worker_for_dispatcher(0, 0) with self.assertRaises(TimeoutException): client.wait_for_dispatcher_worker_registration(0, timeout=2) # registering the second dispatcher completes registration client.register_worker_for_dispatcher(0, 1) client.wait_for_dispatcher_worker_registration(0, timeout=2) finally: service.shutdown()
def main(dispatchers: int, dispatcher_side: str, configfile: str, timeout: int): hvd.init() rank, size = hvd.rank(), hvd.size() if size % dispatchers: raise ValueError( f'Number of processes ({size}) must be a multiple of number of dispatchers ({dispatchers}).' ) workers_per_dispatcher = size // dispatchers # start the compute service on rank 0 compute = None try: compute_config = None if rank == 0: key = secret.make_secret_key() compute = ComputeService(dispatchers, workers_per_dispatcher, key=key) compute_config = TfDataServiceConfig( dispatchers=dispatchers, workers_per_dispatcher=workers_per_dispatcher, dispatcher_side=dispatcher_side, addresses=compute.addresses(), key=key, timeout=timeout) compute_config.write(configfile) # broadcast this config to all ranks via CPU ops with tf.device(f'/cpu:0'): compute_config = hvd.broadcast_object(compute_config, name='TfDataServiceConfig') # start all compute workers compute_worker_fn(compute_config) finally: if compute is not None: compute.shutdown()
spark_context = spark.sparkContext workers = spark_context.defaultParallelism if workers % parsed_args.dispatchers: raise ValueError(f'Number of processes ({workers}) must be ' f'a multiple of number of dispatchers ({parsed_args.dispatchers}).') workers_per_dispatcher = workers // parsed_args.dispatchers key = secret.make_secret_key() compute = ComputeService(parsed_args.dispatchers, workers_per_dispatcher, key=key) compute_config = TfDataServiceConfig( dispatchers=parsed_args.dispatchers, workers_per_dispatcher=workers_per_dispatcher, dispatcher_side=parsed_args.dispatcher_side, addresses=compute.addresses(), key=key ) compute_config.write(parsed_args.configfile) def _exit_gracefully(): logging.info('Spark driver receiving SIGTERM. Exiting gracefully') spark_context.stop() signal.signal(signal.SIGTERM, _exit_gracefully) ret = run(compute_worker_fn, args=(compute_config,), stdout=sys.stdout, stderr=sys.stderr, num_proc=workers,
def test_good_path(self): for dispatchers_num, workers_per_dispatcher in [(1, 1), (1, 2), (1, 4), (2, 1), (2, 2), (2, 4), (32, 16), (1, 512)]: with self.subTest(dispatchers=dispatchers_num, workers_per_dispatcher=workers_per_dispatcher): key = secret.make_secret_key() service = ComputeService(dispatchers_num, workers_per_dispatcher, key, nics=None) try: client = ComputeClient(service.addresses(), key, verbose=2) # create thread waiting for shutdown shutdown = Queue() shutdown_thread = self.wait_for_shutdown(client, shutdown) # dispatcher registration # start threads that wait for dispatchers threads = [] dispatchers = Queue() for id in range(dispatchers_num): threads.append( self.wait_for_dispatcher(client, id, dispatchers)) # register dispatchers for id in range(dispatchers_num): client.register_dispatcher( id, f'grpc://localhost:{10000+id}') # check threads terminate for thread in threads: thread.join(10) self.assertFalse( thread.is_alive(), msg= "threads waiting for dispatchers did not terminate" ) # check reported dispatcher addresses self.assertEqual([(id, f'grpc://localhost:{10000+id}') for id in range(dispatchers_num)], sorted(self.get_all(dispatchers))) # worker registration # start threads to wait for dispatcher worker registration threads = [] dispatchers = Queue() for id in range(dispatchers_num): threads.append( self.wait_for_dispatcher_workers( client, id, dispatchers)) # register dispatcher workers for id in range(dispatchers_num * workers_per_dispatcher): client.register_worker_for_dispatcher( dispatcher_id=id // workers_per_dispatcher, worker_id=id) # check threads terminate for thread in threads: thread.join(10) self.assertFalse( thread.is_alive(), msg= "threads waiting for dispatchers' workers did not terminate" ) # check reported dispatcher success self.assertEqual(sorted(range(dispatchers_num)), sorted(self.get_all(dispatchers))) # shutdown and wait for shutdown self.assertTrue( shutdown_thread.is_alive(), msg="thread waiting for shutdown, terminated early") client.shutdown() shutdown_thread.join(10) self.assertFalse( shutdown_thread.is_alive(), msg="thread waiting for shutdown did not terminate") self.assertEqual([True], list(self.get_all(shutdown))) finally: service.shutdown()