예제 #1
0
def test_possibly_fix_worker_map(capsys, client):
    client.wait_for_workers(2)
    worker_addresses = list(client.scheduler_info()["workers"].keys())

    retry_msg = 'Searching for a LightGBM training port for worker'

    # should handle worker maps without any duplicates
    map_without_duplicates = {
        worker_address: 12400 + i
        for i, worker_address in enumerate(worker_addresses)
    }
    patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
        client=client,
        worker_map=map_without_duplicates
    )
    assert patched_map == map_without_duplicates
    assert retry_msg not in capsys.readouterr().out

    # should handle worker maps with duplicates
    map_with_duplicates = {
        worker_address: 12400
        for i, worker_address in enumerate(worker_addresses)
    }
    patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
        client=client,
        worker_map=map_with_duplicates
    )
    assert retry_msg in capsys.readouterr().out
    assert len(set(patched_map.values())) == len(worker_addresses)
예제 #2
0
def test_network_params_not_required_but_respected_if_given(client, task, listen_port):
    client.wait_for_workers(2)

    _, _, _, _, dX, dy, _, dg = _create_data(
        objective=task,
        output='array',
        chunk_size=10,
        group=None
    )

    dask_model_factory = task_to_dask_factory[task]

    # rebalance data to be sure that each worker has a piece of the data
    client.rebalance()

    # model 1 - no network parameters given
    dask_model1 = dask_model_factory(
        n_estimators=5,
        num_leaves=5,
    )
    dask_model1.fit(dX, dy, group=dg)
    assert dask_model1.fitted_
    params = dask_model1.get_params()
    assert 'local_listen_port' not in params
    assert 'machines' not in params

    # model 2 - machines given
    n_workers = len(client.scheduler_info()['workers'])
    open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
    dask_model2 = dask_model_factory(
        n_estimators=5,
        num_leaves=5,
        machines=",".join([
            "127.0.0.1:" + str(port)
            for port in open_ports
        ]),
    )

    dask_model2.fit(dX, dy, group=dg)
    assert dask_model2.fitted_
    params = dask_model2.get_params()
    assert 'local_listen_port' not in params
    assert 'machines' in params

    # model 3 - local_listen_port given
    # training should fail because LightGBM will try to use the same
    # port for multiple worker processes on the same machine
    dask_model3 = dask_model_factory(
        n_estimators=5,
        num_leaves=5,
        local_listen_port=listen_port
    )
    error_msg = "has multiple Dask worker processes running on it"
    with pytest.raises(lgb.basic.LightGBMError, match=error_msg):
        dask_model3.fit(dX, dy, group=dg)

    client.close(timeout=CLIENT_CLOSE_TIMEOUT)
예제 #3
0
def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, task, output):
    if task == 'ranking' and output == 'scipy_csr_matrix':
        pytest.skip('LGBMRanker is not currently tested on sparse matrices')

    def collection_to_single_partition(collection):
        """Merge the parts of a Dask collection into a single partition."""
        if collection is None:
            return
        if isinstance(collection, da.Array):
            return collection.rechunk(*collection.shape)
        return collection.repartition(npartitions=1)

    if task == 'ranking':
        X, y, w, g, dX, dy, dw, dg = _create_ranking_data(
            output=output,
            group=None
        )
    else:
        X, y, w, dX, dy, dw = _create_data(
            objective=task,
            output=output
        )
        g = None
        dg = None

    dask_model_factory = task_to_dask_factory[task]
    local_model_factory = task_to_local_factory[task]

    dX = collection_to_single_partition(dX)
    dy = collection_to_single_partition(dy)
    dw = collection_to_single_partition(dw)
    dg = collection_to_single_partition(dg)

    n_workers = len(client.scheduler_info()['workers'])
    assert n_workers > 1
    assert dX.npartitions == 1

    params = {
        'time_out': 5,
        'random_state': 42,
        'num_leaves': 10
    }

    dask_model = dask_model_factory(tree='data', client=client, **params)
    dask_model.fit(dX, dy, group=dg, sample_weight=dw)
    dask_preds = dask_model.predict(dX).compute()

    local_model = local_model_factory(**params)
    if task == 'ranking':
        local_model.fit(X, y, group=g, sample_weight=w)
    else:
        local_model.fit(X, y, sample_weight=w)
    local_preds = local_model.predict(X)

    assert assert_eq(dask_preds, local_preds)

    client.close(timeout=CLIENT_CLOSE_TIMEOUT)
예제 #4
0
def test_machines_should_be_used_if_provided(task, output):
    if task == 'ranking' and output == 'scipy_csr_matrix':
        pytest.skip('LGBMRanker is not currently tested on sparse matrices')

    with LocalCluster(n_workers=2) as cluster, Client(cluster) as client:
        if task == 'ranking':
            _, _, _, _, dX, dy, _, dg = _create_ranking_data(
                output=output,
                group=None,
                chunk_size=10,
            )
        else:
            _, _, _, dX, dy, _ = _create_data(
                objective=task,
                output=output,
                chunk_size=10,
            )
            dg = None

        dask_model_factory = task_to_dask_factory[task]

        # rebalance data to be sure that each worker has a piece of the data
        if output == 'array':
            client.rebalance()

        n_workers = len(client.scheduler_info()['workers'])
        assert n_workers > 1
        open_ports = [
            lgb.dask._find_random_open_port() for _ in range(n_workers)
        ]
        dask_model = dask_model_factory(
            n_estimators=5,
            num_leaves=5,
            machines=",".join(
                ["127.0.0.1:" + str(port) for port in open_ports]),
        )

        # test that "machines" is actually respected by creating a socket that uses
        # one of the ports mentioned in "machines"
        error_msg = "Binding port %s failed" % open_ports[0]
        with pytest.raises(lgb.basic.LightGBMError, match=error_msg):
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(('127.0.0.1', open_ports[0]))
                dask_model.fit(dX, dy, group=dg)

        # an informative error should be raised if "machines" has duplicates
        one_open_port = lgb.dask._find_random_open_port()
        dask_model.set_params(machines=",".join(
            ["127.0.0.1:" + str(one_open_port) for _ in range(n_workers)]))
        with pytest.raises(ValueError, match="Found duplicates in 'machines'"):
            dask_model.fit(dX, dy, group=dg)
예제 #5
0
def _get_client_workers(client: "Client") -> Dict[str, Dict]:
    workers = client.scheduler_info()['workers']
    return workers
예제 #6
0
def test_network_params_not_required_but_respected_if_given(
        client, task, output, listen_port):
    if task == 'ranking' and output == 'scipy_csr_matrix':
        pytest.skip('LGBMRanker is not currently tested on sparse matrices')

    if task == 'ranking':
        _, _, _, _, dX, dy, _, dg = _create_ranking_data(
            output=output,
            group=None,
            chunk_size=10,
        )
        dask_model_factory = lgb.DaskLGBMRanker
    else:
        _, _, _, dX, dy, _ = _create_data(
            objective=task,
            output=output,
            chunk_size=10,
        )
        dg = None
        if task == 'classification':
            dask_model_factory = lgb.DaskLGBMClassifier
        elif task == 'regression':
            dask_model_factory = lgb.DaskLGBMRegressor

    # rebalance data to be sure that each worker has a piece of the data
    if output == 'array':
        client.rebalance()

    # model 1 - no network parameters given
    dask_model1 = dask_model_factory(
        n_estimators=5,
        num_leaves=5,
    )
    if task == 'ranking':
        dask_model1.fit(dX, dy, group=dg)
    else:
        dask_model1.fit(dX, dy)
    assert dask_model1.fitted_
    params = dask_model1.get_params()
    assert 'local_listen_port' not in params
    assert 'machines' not in params

    # model 2 - machines given
    n_workers = len(client.scheduler_info()['workers'])
    open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
    dask_model2 = dask_model_factory(
        n_estimators=5,
        num_leaves=5,
        machines=",".join(["127.0.0.1:" + str(port) for port in open_ports]),
    )

    if task == 'ranking':
        dask_model2.fit(dX, dy, group=dg)
    else:
        dask_model2.fit(dX, dy)
    assert dask_model2.fitted_
    params = dask_model2.get_params()
    assert 'local_listen_port' not in params
    assert 'machines' in params

    # model 3 - local_listen_port given
    # training should fail because LightGBM will try to use the same
    # port for multiple worker processes on the same machine
    dask_model3 = dask_model_factory(n_estimators=5,
                                     num_leaves=5,
                                     local_listen_port=listen_port)
    error_msg = "has multiple Dask worker processes running on it"
    with pytest.raises(lgb.basic.LightGBMError, match=error_msg):
        if task == 'ranking':
            dask_model3.fit(dX, dy, group=dg)
        else:
            dask_model3.fit(dX, dy)

    client.close(timeout=CLIENT_CLOSE_TIMEOUT)
예제 #7
0
def _get_client_workers(client):
    workers = client.scheduler_info()['workers']
    return workers