def test_possibly_fix_worker_map(capsys, client): client.wait_for_workers(2) worker_addresses = list(client.scheduler_info()["workers"].keys()) retry_msg = 'Searching for a LightGBM training port for worker' # should handle worker maps without any duplicates map_without_duplicates = { worker_address: 12400 + i for i, worker_address in enumerate(worker_addresses) } patched_map = lgb.dask._possibly_fix_worker_map_duplicates( client=client, worker_map=map_without_duplicates ) assert patched_map == map_without_duplicates assert retry_msg not in capsys.readouterr().out # should handle worker maps with duplicates map_with_duplicates = { worker_address: 12400 for i, worker_address in enumerate(worker_addresses) } patched_map = lgb.dask._possibly_fix_worker_map_duplicates( client=client, worker_map=map_with_duplicates ) assert retry_msg in capsys.readouterr().out assert len(set(patched_map.values())) == len(worker_addresses)
def test_network_params_not_required_but_respected_if_given(client, task, listen_port): client.wait_for_workers(2) _, _, _, _, dX, dy, _, dg = _create_data( objective=task, output='array', chunk_size=10, group=None ) dask_model_factory = task_to_dask_factory[task] # rebalance data to be sure that each worker has a piece of the data client.rebalance() # model 1 - no network parameters given dask_model1 = dask_model_factory( n_estimators=5, num_leaves=5, ) dask_model1.fit(dX, dy, group=dg) assert dask_model1.fitted_ params = dask_model1.get_params() assert 'local_listen_port' not in params assert 'machines' not in params # model 2 - machines given n_workers = len(client.scheduler_info()['workers']) open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)] dask_model2 = dask_model_factory( n_estimators=5, num_leaves=5, machines=",".join([ "127.0.0.1:" + str(port) for port in open_ports ]), ) dask_model2.fit(dX, dy, group=dg) assert dask_model2.fitted_ params = dask_model2.get_params() assert 'local_listen_port' not in params assert 'machines' in params # model 3 - local_listen_port given # training should fail because LightGBM will try to use the same # port for multiple worker processes on the same machine dask_model3 = dask_model_factory( n_estimators=5, num_leaves=5, local_listen_port=listen_port ) error_msg = "has multiple Dask worker processes running on it" with pytest.raises(lgb.basic.LightGBMError, match=error_msg): dask_model3.fit(dX, dy, group=dg) client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, task, output): if task == 'ranking' and output == 'scipy_csr_matrix': pytest.skip('LGBMRanker is not currently tested on sparse matrices') def collection_to_single_partition(collection): """Merge the parts of a Dask collection into a single partition.""" if collection is None: return if isinstance(collection, da.Array): return collection.rechunk(*collection.shape) return collection.repartition(npartitions=1) if task == 'ranking': X, y, w, g, dX, dy, dw, dg = _create_ranking_data( output=output, group=None ) else: X, y, w, dX, dy, dw = _create_data( objective=task, output=output ) g = None dg = None dask_model_factory = task_to_dask_factory[task] local_model_factory = task_to_local_factory[task] dX = collection_to_single_partition(dX) dy = collection_to_single_partition(dy) dw = collection_to_single_partition(dw) dg = collection_to_single_partition(dg) n_workers = len(client.scheduler_info()['workers']) assert n_workers > 1 assert dX.npartitions == 1 params = { 'time_out': 5, 'random_state': 42, 'num_leaves': 10 } dask_model = dask_model_factory(tree='data', client=client, **params) dask_model.fit(dX, dy, group=dg, sample_weight=dw) dask_preds = dask_model.predict(dX).compute() local_model = local_model_factory(**params) if task == 'ranking': local_model.fit(X, y, group=g, sample_weight=w) else: local_model.fit(X, y, sample_weight=w) local_preds = local_model.predict(X) assert assert_eq(dask_preds, local_preds) client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def test_machines_should_be_used_if_provided(task, output): if task == 'ranking' and output == 'scipy_csr_matrix': pytest.skip('LGBMRanker is not currently tested on sparse matrices') with LocalCluster(n_workers=2) as cluster, Client(cluster) as client: if task == 'ranking': _, _, _, _, dX, dy, _, dg = _create_ranking_data( output=output, group=None, chunk_size=10, ) else: _, _, _, dX, dy, _ = _create_data( objective=task, output=output, chunk_size=10, ) dg = None dask_model_factory = task_to_dask_factory[task] # rebalance data to be sure that each worker has a piece of the data if output == 'array': client.rebalance() n_workers = len(client.scheduler_info()['workers']) assert n_workers > 1 open_ports = [ lgb.dask._find_random_open_port() for _ in range(n_workers) ] dask_model = dask_model_factory( n_estimators=5, num_leaves=5, machines=",".join( ["127.0.0.1:" + str(port) for port in open_ports]), ) # test that "machines" is actually respected by creating a socket that uses # one of the ports mentioned in "machines" error_msg = "Binding port %s failed" % open_ports[0] with pytest.raises(lgb.basic.LightGBMError, match=error_msg): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('127.0.0.1', open_ports[0])) dask_model.fit(dX, dy, group=dg) # an informative error should be raised if "machines" has duplicates one_open_port = lgb.dask._find_random_open_port() dask_model.set_params(machines=",".join( ["127.0.0.1:" + str(one_open_port) for _ in range(n_workers)])) with pytest.raises(ValueError, match="Found duplicates in 'machines'"): dask_model.fit(dX, dy, group=dg)
def _get_client_workers(client: "Client") -> Dict[str, Dict]: workers = client.scheduler_info()['workers'] return workers
def test_network_params_not_required_but_respected_if_given( client, task, output, listen_port): if task == 'ranking' and output == 'scipy_csr_matrix': pytest.skip('LGBMRanker is not currently tested on sparse matrices') if task == 'ranking': _, _, _, _, dX, dy, _, dg = _create_ranking_data( output=output, group=None, chunk_size=10, ) dask_model_factory = lgb.DaskLGBMRanker else: _, _, _, dX, dy, _ = _create_data( objective=task, output=output, chunk_size=10, ) dg = None if task == 'classification': dask_model_factory = lgb.DaskLGBMClassifier elif task == 'regression': dask_model_factory = lgb.DaskLGBMRegressor # rebalance data to be sure that each worker has a piece of the data if output == 'array': client.rebalance() # model 1 - no network parameters given dask_model1 = dask_model_factory( n_estimators=5, num_leaves=5, ) if task == 'ranking': dask_model1.fit(dX, dy, group=dg) else: dask_model1.fit(dX, dy) assert dask_model1.fitted_ params = dask_model1.get_params() assert 'local_listen_port' not in params assert 'machines' not in params # model 2 - machines given n_workers = len(client.scheduler_info()['workers']) open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)] dask_model2 = dask_model_factory( n_estimators=5, num_leaves=5, machines=",".join(["127.0.0.1:" + str(port) for port in open_ports]), ) if task == 'ranking': dask_model2.fit(dX, dy, group=dg) else: dask_model2.fit(dX, dy) assert dask_model2.fitted_ params = dask_model2.get_params() assert 'local_listen_port' not in params assert 'machines' in params # model 3 - local_listen_port given # training should fail because LightGBM will try to use the same # port for multiple worker processes on the same machine dask_model3 = dask_model_factory(n_estimators=5, num_leaves=5, local_listen_port=listen_port) error_msg = "has multiple Dask worker processes running on it" with pytest.raises(lgb.basic.LightGBMError, match=error_msg): if task == 'ranking': dask_model3.fit(dX, dy, group=dg) else: dask_model3.fit(dX, dy) client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def _get_client_workers(client): workers = client.scheduler_info()['workers'] return workers