def test_orchestration_run_one_step(make_random_dataset): """Test if the orchestration works. In the base class it should raise an error as without prediction function we cannot do anything """ X, y = make_random_dataset # pylint:disable=invalid-name palinstance = PALBase(X, ["model"], 3, beta_scale=1) sample_idx = np.array([1, 2, 3, 4]) palinstance.update_train_set(sample_idx, y[sample_idx]) with pytest.raises(NotImplementedError): _ = palinstance.run_one_step()
def test__replace_by_measurements(make_random_dataset): """Test that replacing the mean/std by the measured ones works""" X, y = make_random_dataset # pylint:disable=invalid-name palinstance = PALBase(X, ["model"], 3, beta_scale=1) assert palinstance.measurement_uncertainty.sum() == 0 sample_idx = np.array([1, 2, 3, 4]) palinstance.update_train_set(sample_idx, y[sample_idx], y[sample_idx]) palinstance._means = palinstance.measurement_uncertainty palinstance.std = palinstance.measurement_uncertainty palinstance._replace_by_measurements() assert (palinstance.y == palinstance.std).all()
def test_update_train_set(make_random_dataset): """Check if the update of the training set works""" X, y = make_random_dataset # pylint:disable=invalid-name palinstance = PALBase(X, ["model"], 3) assert not palinstance._has_train_set assert palinstance.sampled.sum() == 0 palinstance.update_train_set(np.array([0]), y[0, :].reshape(-1, 3)) assert palinstance.sampled_indices == np.array([0]) assert palinstance.number_sampled_points == 1 assert (palinstance.y[0] == y[0, :]).all()
def test_turn_to_maximization(make_random_dataset): """Test that flipping the sign for minimization problems works""" X, y = make_random_dataset # pylint:disable=invalid-name palinstance = PALBase(X, ["model"], 3) palinstance.update_train_set(np.array([0]), y[0, :].reshape(-1, 3)) assert (palinstance.y[0] == y[0, :]).all() assert (palinstance._y[0] == y[0, :]).all() palinstance = PALBase(X, ["model"], 3, goals=[1, 1, -1]) palinstance.update_train_set(np.array([0]), y[0, :].reshape(-1, 3)) assert (palinstance.y[0] == y[0, :] * np.array([1, 1, -1])).all() assert (palinstance._y[0] == y[0, :]).all()