예제 #1
0
    def test_invalid_dispatcher_ids(self):
        key = secret.make_secret_key()
        service = ComputeService(2, 4, key, nics=None)
        try:
            client = ComputeClient(service.addresses(), key, verbose=2)

            with self.assertRaises(IndexError):
                client.register_dispatcher(-1, 'grpc://localhost:10000')
            with self.assertRaises(IndexError):
                client.register_dispatcher(2, 'grpc://localhost:10000')

            with self.assertRaises(IndexError):
                client.wait_for_dispatcher_registration(-1, 0.1)
            with self.assertRaises(IndexError):
                client.wait_for_dispatcher_registration(2, 0.1)

            with self.assertRaises(IndexError):
                client.register_worker_for_dispatcher(-1, 0)
            with self.assertRaises(IndexError):
                client.register_worker_for_dispatcher(2, 0)

            with self.assertRaises(IndexError):
                client.wait_for_dispatcher_worker_registration(-1, 0.1)
            with self.assertRaises(IndexError):
                client.wait_for_dispatcher_worker_registration(2, 0.1)
        finally:
            service.shutdown()
예제 #2
0
 def test_register_dispatcher_duplicate(self):
     key = secret.make_secret_key()
     service = ComputeService(2, 1, key, nics=None)
     try:
         client = ComputeClient(service.addresses(), key, verbose=2)
         client.register_dispatcher(0, 'grpc://localhost:10000')
         with self.assertRaisesRegex(
                 ValueError,
                 'Dispatcher with id 0 has already been registered under '
                 'different address grpc://localhost:10000: grpc://localhost:10001'
         ):
             client.register_dispatcher(0, 'grpc://localhost:10001')
     finally:
         service.shutdown()
예제 #3
0
 def test_register_dispatcher_worker_timeout(self):
     key = secret.make_secret_key()
     service = ComputeService(1, 1, key, nics=None)
     try:
         client = ComputeClient(service.addresses(), key, verbose=2)
         with self.assertRaisesRegex(
                 TimeoutException,
                 expected_regex=
                 'Timed out waiting for workers for dispatcher 0 to register. '
                 'Try to find out what takes the workers so long '
                 'to register or increase timeout. Timeout after 0.1 seconds.'
         ):
             client.wait_for_dispatcher_worker_registration(0, timeout=0.1)
     finally:
         service.shutdown()
예제 #4
0
    def test_register_dispatcher_worker_replay(self):
        key = secret.make_secret_key()
        service = ComputeService(1, 2, key, nics=None)
        try:
            client = ComputeClient(service.addresses(), key, verbose=2)
            client.register_dispatcher(0, 'grpc://localhost:10000')
            client.register_worker_for_dispatcher(0, 0)

            # register the same worker again should not complete the registration
            client.register_worker_for_dispatcher(0, 0)
            with self.assertRaises(TimeoutException):
                client.wait_for_dispatcher_worker_registration(0, timeout=2)

            # registering the second dispatcher completes registration
            client.register_worker_for_dispatcher(0, 1)
            client.wait_for_dispatcher_worker_registration(0, timeout=2)
        finally:
            service.shutdown()
예제 #5
0
def main(dispatchers: int, dispatcher_side: str, configfile: str,
         timeout: int):
    hvd.init()
    rank, size = hvd.rank(), hvd.size()

    if size % dispatchers:
        raise ValueError(
            f'Number of processes ({size}) must be a multiple of number of dispatchers ({dispatchers}).'
        )
    workers_per_dispatcher = size // dispatchers

    # start the compute service on rank 0
    compute = None
    try:
        compute_config = None

        if rank == 0:
            key = secret.make_secret_key()
            compute = ComputeService(dispatchers,
                                     workers_per_dispatcher,
                                     key=key)

            compute_config = TfDataServiceConfig(
                dispatchers=dispatchers,
                workers_per_dispatcher=workers_per_dispatcher,
                dispatcher_side=dispatcher_side,
                addresses=compute.addresses(),
                key=key,
                timeout=timeout)
            compute_config.write(configfile)

        # broadcast this config to all ranks via CPU ops
        with tf.device(f'/cpu:0'):
            compute_config = hvd.broadcast_object(compute_config,
                                                  name='TfDataServiceConfig')

        # start all compute workers
        compute_worker_fn(compute_config)
    finally:
        if compute is not None:
            compute.shutdown()
예제 #6
0
                        help=f"Where do the dispatcher run? On 'compute' side or 'training' side.",
                        dest="dispatcher_side")

    parsed_args = parser.parse_args()

    spark = SparkSession.builder.getOrCreate()
    spark_context = spark.sparkContext
    workers = spark_context.defaultParallelism

    if workers % parsed_args.dispatchers:
        raise ValueError(f'Number of processes ({workers}) must be '
                         f'a multiple of number of dispatchers ({parsed_args.dispatchers}).')
    workers_per_dispatcher = workers // parsed_args.dispatchers

    key = secret.make_secret_key()
    compute = ComputeService(parsed_args.dispatchers, workers_per_dispatcher, key=key)

    compute_config = TfDataServiceConfig(
        dispatchers=parsed_args.dispatchers,
        workers_per_dispatcher=workers_per_dispatcher,
        dispatcher_side=parsed_args.dispatcher_side,
        addresses=compute.addresses(),
        key=key
    )
    compute_config.write(parsed_args.configfile)

    def _exit_gracefully():
        logging.info('Spark driver receiving SIGTERM. Exiting gracefully')
        spark_context.stop()

    signal.signal(signal.SIGTERM, _exit_gracefully)
예제 #7
0
    def test_good_path(self):
        for dispatchers_num, workers_per_dispatcher in [(1, 1), (1, 2), (1, 4),
                                                        (2, 1), (2, 2), (2, 4),
                                                        (32, 16), (1, 512)]:
            with self.subTest(dispatchers=dispatchers_num,
                              workers_per_dispatcher=workers_per_dispatcher):
                key = secret.make_secret_key()
                service = ComputeService(dispatchers_num,
                                         workers_per_dispatcher,
                                         key,
                                         nics=None)
                try:
                    client = ComputeClient(service.addresses(), key, verbose=2)

                    # create thread waiting for shutdown
                    shutdown = Queue()
                    shutdown_thread = self.wait_for_shutdown(client, shutdown)

                    # dispatcher registration
                    # start threads that wait for dispatchers
                    threads = []
                    dispatchers = Queue()
                    for id in range(dispatchers_num):
                        threads.append(
                            self.wait_for_dispatcher(client, id, dispatchers))

                    # register dispatchers
                    for id in range(dispatchers_num):
                        client.register_dispatcher(
                            id, f'grpc://localhost:{10000+id}')

                    # check threads terminate
                    for thread in threads:
                        thread.join(10)
                        self.assertFalse(
                            thread.is_alive(),
                            msg=
                            "threads waiting for dispatchers did not terminate"
                        )

                    # check reported dispatcher addresses
                    self.assertEqual([(id, f'grpc://localhost:{10000+id}')
                                      for id in range(dispatchers_num)],
                                     sorted(self.get_all(dispatchers)))

                    # worker registration
                    # start threads to wait for dispatcher worker registration
                    threads = []
                    dispatchers = Queue()
                    for id in range(dispatchers_num):
                        threads.append(
                            self.wait_for_dispatcher_workers(
                                client, id, dispatchers))

                    # register dispatcher workers
                    for id in range(dispatchers_num * workers_per_dispatcher):
                        client.register_worker_for_dispatcher(
                            dispatcher_id=id // workers_per_dispatcher,
                            worker_id=id)

                    # check threads terminate
                    for thread in threads:
                        thread.join(10)
                        self.assertFalse(
                            thread.is_alive(),
                            msg=
                            "threads waiting for dispatchers' workers did not terminate"
                        )

                    # check reported dispatcher success
                    self.assertEqual(sorted(range(dispatchers_num)),
                                     sorted(self.get_all(dispatchers)))

                    # shutdown and wait for shutdown
                    self.assertTrue(
                        shutdown_thread.is_alive(),
                        msg="thread waiting for shutdown, terminated early")
                    client.shutdown()
                    shutdown_thread.join(10)
                    self.assertFalse(
                        shutdown_thread.is_alive(),
                        msg="thread waiting for shutdown did not terminate")
                    self.assertEqual([True], list(self.get_all(shutdown)))
                finally:
                    service.shutdown()