Ejemplo n.º 1
0
def test_calculate_or_validate_base_steps_wrong_shape():
    base_steps = np.array([0.01, 0.01, 0.01])
    min_steps = np.full(3, 1e-8)
    with pytest.raises(ValueError):
        _calculate_or_validate_base_steps(
            base_steps, np.ones(2), "first_derivative", min_steps, scaling_factor=1
        )
Ejemplo n.º 2
0
def test_calculate_or_validate_base_steps_invalid_too_small():
    base_steps = np.array([1e-10, 0.01, 0.01])
    min_steps = np.full(3, 1e-8)
    with pytest.raises(ValueError):
        _calculate_or_validate_base_steps(
            base_steps, np.ones(3), "first_derivative", min_steps, scaling_factor=1
        )
Ejemplo n.º 3
0
def test_scalars_as_base_steps():
    steps_scalar = _calculate_or_validate_base_steps(
        0.1, np.ones(3), "first_derivative", None, scaling_factor=1
    )

    steps_array = _calculate_or_validate_base_steps(
        np.full(3, 0.1), np.ones(3), "first_derivative", None, scaling_factor=1
    )

    aaae(steps_scalar, steps_array)
Ejemplo n.º 4
0
def test_calculate_or_validate_base_steps_hessian():
    x = np.array([0.05, 1, -5])
    expected = np.array([0.1, 1, 5]) * np.finfo(float).eps ** (1 / 3)
    calculated = _calculate_or_validate_base_steps(
        None, x, "second_derivative", 0, scaling_factor=1.0
    )
    aaae(calculated, expected, decimal=12)
Ejemplo n.º 5
0
def test_calculate_or_validate_base_steps_jacobian_with_scaling_factor():
    x = np.array([0.05, 1, -5])
    expected = np.array([0.1, 1, 5]) * np.sqrt(np.finfo(float).eps) * 2
    calculated = _calculate_or_validate_base_steps(
        None, x, "first_derivative", 0, scaling_factor=2.0
    )
    aaae(calculated, expected, decimal=12)
Ejemplo n.º 6
0
def test_calculate_or_validate_base_steps_binding_min_step():
    x = np.array([0.05, 1, -5])
    expected = np.array([0.1, 1, 5]) * np.sqrt(np.finfo(float).eps)
    expected[0] = 1e-8
    calculated = _calculate_or_validate_base_steps(
        None, x, "first_derivative", 1e-8, scaling_factor=1.0
    )
    aaae(calculated, expected, decimal=12)