Ejemplo n.º 1
0
def test_add_centered_entries():
    history = LeastSquaresHistory()
    history.add_entries(np.ones((2, 2)), np.ones((2, 4)))
    center_info = {
        "x": history.get_xs(index=-1),
        "residuals": history.get_residuals(index=-1),
        "radius": 0.5,
    }
    history.add_centered_entries(
        xs=np.ones(2), residuals=np.ones(4) * 2, center_info=center_info
    )

    xs, residuals, critvals = history.get_entries(index=-1)

    aaae(xs, np.array([1.5, 1.5]))
    aaae(residuals, np.array([3, 3, 3, 3]))
    assert critvals == 36
    assert history.get_n_fun() == 3
Ejemplo n.º 2
0
def test_add_entries_not_initialized(entries, is_center):
    history = LeastSquaresHistory()

    if is_center:
        c_info = {"x": np.zeros(3), "residuals": np.zeros(5), "radius": 1}
        history.add_centered_entries(*entries, c_info)
    else:
        history.add_entries(*entries)

    xs, residuals, critvals = history.get_entries()
    xs_sinlge = history.get_xs()
    residuals_sinlge = history.get_residuals()
    critvals_sinlge = history.get_critvals()

    for entry in xs, residuals, critvals:
        assert isinstance(entry, np.ndarray)

    aaae(xs, np.arange(3).reshape(1, 3))
    aaae(xs_sinlge, np.arange(3).reshape(1, 3))
    aaae(residuals, np.arange(5).reshape(1, 5))
    aaae(residuals_sinlge, np.arange(5).reshape(1, 5))
    aaae(critvals, np.array([30.0]))
    aaae(critvals_sinlge, np.array([30.0]))
Ejemplo n.º 3
0
def test_add_entries_initialized_with_space(entries, is_center):
    history = LeastSquaresHistory()
    history.add_entries(np.ones((4, 3)), np.zeros((4, 5)))

    if is_center:
        c_info = {"x": np.zeros(3), "residuals": np.zeros(5), "radius": 1}
        history.add_centered_entries(*entries, c_info)
    else:
        history.add_entries(*entries)

    xs, residuals, critvals = history.get_entries(index=-1)
    xs_sinlge = history.get_xs(index=-1)
    residuals_sinlge = history.get_residuals(index=-1)
    critvals_sinlge = history.get_critvals(index=-1)

    for entry in xs, residuals:
        assert isinstance(entry, np.ndarray)

    aaae(xs, np.arange(3))
    aaae(xs_sinlge, np.arange(3))
    aaae(residuals, np.arange(5))
    aaae(residuals_sinlge, np.arange(5))
    assert critvals == 30
    assert critvals_sinlge == 30