Exemple #1
0
def test_orchestration_run_one_step(make_random_dataset):
    """Test if the orchestration works.
    In the base class it should raise an error as without
    prediction function we cannot do anything
    """
    X, y = make_random_dataset  # pylint:disable=invalid-name
    palinstance = PALBase(X, ["model"], 3, beta_scale=1)
    sample_idx = np.array([1, 2, 3, 4])
    palinstance.update_train_set(sample_idx, y[sample_idx])
    with pytest.raises(NotImplementedError):
        _ = palinstance.run_one_step()
Exemple #2
0
def test__replace_by_measurements(make_random_dataset):
    """Test that replacing the mean/std by the measured ones works"""
    X, y = make_random_dataset  # pylint:disable=invalid-name
    palinstance = PALBase(X, ["model"], 3, beta_scale=1)
    assert palinstance.measurement_uncertainty.sum() == 0
    sample_idx = np.array([1, 2, 3, 4])
    palinstance.update_train_set(sample_idx, y[sample_idx], y[sample_idx])
    palinstance._means = palinstance.measurement_uncertainty
    palinstance.std = palinstance.measurement_uncertainty
    palinstance._replace_by_measurements()
    assert (palinstance.y == palinstance.std).all()
Exemple #3
0
def test_update_train_set(make_random_dataset):
    """Check if the update of the training set works"""
    X, y = make_random_dataset  # pylint:disable=invalid-name

    palinstance = PALBase(X, ["model"], 3)
    assert not palinstance._has_train_set
    assert palinstance.sampled.sum() == 0

    palinstance.update_train_set(np.array([0]), y[0, :].reshape(-1, 3))
    assert palinstance.sampled_indices == np.array([0])
    assert palinstance.number_sampled_points == 1
    assert (palinstance.y[0] == y[0, :]).all()
Exemple #4
0
def test_turn_to_maximization(make_random_dataset):
    """Test that flipping the sign for minimization problems works"""
    X, y = make_random_dataset  # pylint:disable=invalid-name
    palinstance = PALBase(X, ["model"], 3)

    palinstance.update_train_set(np.array([0]), y[0, :].reshape(-1, 3))
    assert (palinstance.y[0] == y[0, :]).all()
    assert (palinstance._y[0] == y[0, :]).all()

    palinstance = PALBase(X, ["model"], 3, goals=[1, 1, -1])

    palinstance.update_train_set(np.array([0]), y[0, :].reshape(-1, 3))
    assert (palinstance.y[0] == y[0, :] * np.array([1, 1, -1])).all()
    assert (palinstance._y[0] == y[0, :]).all()