示例#1
0
def test_NodeBipartite_observation(solving_model):
    """Observation of NodeBipartite is a type with array attributes."""
    obs = make_obs(O.NodeBipartite(), solving_model)
    assert isinstance(obs, O.NodeBipartiteObs)
    assert_array(obs.column_features, ndim=2)
    assert_array(obs.row_features, ndim=2)
    assert_array(obs.edge_features.values)
    assert_array(obs.edge_features.indices, ndim=2, dtype=np.uint64)
示例#2
0
def pytest_generate_tests(metafunc):
    """Parametrize the `observation_function` fixture.

    Add observation functions here to have them automatically run all the tests that take
    `observation_function` as input.
    """
    if "observation_function" in metafunc.fixturenames:
        all_observation_functions = (
            O.Nothing(),
            O.NodeBipartite(),
            O.StrongBranchingScores(True),
            O.StrongBranchingScores(False),
        )
        metafunc.parametrize("observation_function", all_observation_functions)
示例#3
0
def test_NodeBipartite(solving_model):
    obs = O.NodeBipartite().obtain_observation(solving_model)
    assert isinstance(obs, O.NodeBipartiteObs)
    assert isinstance(obs.column_features, np.ndarray)

    assert obs.column_features.size > 0
    assert len(obs.column_features.shape) == 2
    assert obs.row_features.size > 0
    assert len(obs.row_features.shape) == 2
    assert obs.edge_features.shape == (
        obs.row_features.shape[0],
        obs.column_features.shape[0],
    )
    assert obs.edge_features.indices.shape == (2, obs.edge_features.nnz)

    val = np.random.rand()
    obs.column_features[:] = val
    assert np.all(obs.column_features == val)
    obs.row_features[:] = val
    assert np.all(obs.row_features == val)
    obs.edge_features.values[:] = val
    assert np.all(obs.edge_features.values == val)