async def test_worker(): with Pair1(listen=f"{SERVER_URL}:{SERVER_PORT}", polyamorous=True, recv_timeout=500) as worker_socket: worker_task = asyncio.create_task( worker(SERVER_URL, SERVER_PORT, num_workers=1)) message = await worker_socket.arecv() assert message == b"Ready" dummy_work = { "@module": "tests.cli.test_distributed", "@class": "DummyBuilder", "@version": None, "dummy_prechunk": False, "val": 0, } for i in range(2): await worker_socket.asend(json.dumps(dummy_work).encode("utf-8")) await asyncio.sleep(1) message = await worker_socket.arecv() assert message == b"Ready" await worker_socket.asend(json.dumps({}).encode("utf-8")) with pytest.raises(Timeout): await worker_socket.arecv() assert len(worker_socket.pipes) == 0 worker_task.cancel()
def run(builders, verbosity, reporting_store, num_workers, url, port, num_chunks, no_bars): # Set Logging levels = [logging.WARNING, logging.INFO, logging.DEBUG] level = levels[min(len(levels) - 1, verbosity)] # capped to number of levels root = logging.getLogger() root.setLevel(level) ch = TqdmLoggingHandler() formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch.setFormatter(formatter) root.addHandler(ch) builder_objects = [] for b in builders: if str(b).endswith(".py") or str(b).endswith(".ipynb"): builder_objects.append(load_builder_from_source(b)) else: builder_objects.append(loadfn(b)) builder_objects = [ b if isinstance(b, list) else [b] for b in builder_objects ] builder_objects = list(chain.from_iterable(builder_objects)) if reporting_store: reporting_store = loadfn(reporting_store) root.addHandler(ReportingHandler(reporting_store)) if url: loop = asyncio.get_event_loop() if num_chunks > 0: # Manager if port is None: port = find_port() root.critical(f"Using random port for mrun manager: {port}") loop.run_until_complete( manager(url=url, port=port, builders=builder_objects, num_chunks=num_chunks)) else: # worker loop.run_until_complete( worker(url=url, port=port, num_workers=num_workers)) else: if num_workers == 1: for builder in builder_objects: serial(builder, no_bars) else: loop = asyncio.get_event_loop() for builder in builder_objects: loop.run_until_complete( multi(builder=builder, num_workers=num_workers, no_bars=no_bars))
def run(builders, verbosity, reporting_store, num_workers, url, num_chunks): # Set Logging levels = [logging.WARNING, logging.INFO, logging.DEBUG] level = levels[min(len(levels) - 1, verbosity)] # capped to number of levels root = logging.getLogger() root.setLevel(level) ch = TqdmLoggingHandler() formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) ch.setFormatter(formatter) root.addHandler(ch) builders = [loadfn(b) for b in builders] builders = [b if isinstance(b, list) else [b] for b in builders] builders = list(chain.from_iterable(builders)) if reporting_store: reporting_store = loadfn(reporting_store) root.addHandler(ReportingHandler(reporting_store)) if url: if num_chunks > 0: # Master asyncio.run(master(url, builders, num_chunks)) else: # worker asyncio.run(worker(url, num_workers)) else: if num_workers == 1: for builder in builders: serial(builder) else: for builder in builders: asyncio.run(multi(builder, num_workers))