示例#1
0
def test_sqrt_sum_wis_with_mask_with_unit_fails(test_spec, wav_unit, flux_unit,
                                                trans_unit2):
    """Assert a transmission with a unit fails with type error."""
    wav = test_spec[0] * wav_unit
    flux = test_spec[1] * flux_unit
    transmission = np.random.rand(len(wav)) * trans_unit2

    with pytest.raises(TypeError):
        sqrt_sum_wis(wav, flux, mask=transmission**2)

    with pytest.raises(TypeError):
        rv_precision(wav, flux, mask=transmission**2)
示例#2
0
def test_sqrt_sum_wis(test_spec, wav_unit, flux_unit, trans_unit):
    """Test that sqrt_sum_wis can handle inputs as Quantities or unitless.

    Returns a dimensionless unscaled Quantity.
    """
    wav = test_spec[0] * wav_unit
    flux = test_spec[1] * flux_unit
    mask = test_spec[2]
    if test_spec[2] is not None:
        mask *= trans_unit

    sqrtsumwis = sqrt_sum_wis(wav, flux, mask)
    print("wav", wav, type(wav))
    print("flux", flux, type(flux))
    print("mask", mask, type(mask))
    print("sqrtsumwis", sqrtsumwis, type(sqrtsumwis))
    if ((isinstance(wav, Quantity)) or (isinstance(flux, Quantity))
            or (isinstance(mask, Quantity) and (test_spec[2] is not None))):
        assert isinstance(sqrtsumwis, u.Quantity)
        # Check is unscaled and dimensionless Quantity
        assert sqrtsumwis.unit == u.dimensionless_unscaled
        assert not hasattr(sqrtsumwis.value,
                           "__len__")  # assert value is a scalar
    else:
        assert not hasattr(sqrtsumwis, "__len__")  # assert value is a scalar
示例#3
0
def test_sqrt_sum_wis_with_no_units(test_spec):
    """Test that sqrt_sum_wis can handle inputs as Quantities or unitless.
    Returns a dimensionless unscaled Quantity.
    """
    sqrtsumwis = sqrt_sum_wis(test_spec[0], test_spec[1], test_spec[2])
    # Doesn't turn into quantity if does not have to.
    assert not isinstance(sqrtsumwis, u.Quantity)
    assert not hasattr(sqrtsumwis, "__len__")  # assert value is a scalar
示例#4
0
def test_sqrtsumwis_warns_nonfinite(grad_flag):
    """Some warning tests."""
    with pytest.warns(UserWarning, match="This will cause infinite errors."):
        sqrt_sum_wis(
            np.array([1, 2, 3]),
            np.array([1, 2, 3]),
            np.array([0, 0, 0]),
            grad=grad_flag,
        )  # All masked

    with pytest.warns(UserWarning, match="Weight sum is not finite"):
        sqrt_sum_wis(
            np.array([2, 2, 2]),
            np.array([1, 2, 3]),
            np.array([1, 1, 1]),
            grad=grad_flag,
        )  # infinite gradient
示例#5
0
def test_relation_of_rv_to_sqrtsumwis(test_spec, wav_unit, flux_unit, trans_unit):
    """Test relation of sqrtsumwis to rv_precision."""
    wav = test_spec[0] * wav_unit
    flux = test_spec[1] * flux_unit
    mask = test_spec[2]
    if mask is not None:
        mask *= trans_unit
        mask = mask ** 2
    assert np.all(
        rv_precision(wav, flux, mask=mask) == c / sqrt_sum_wis(wav, flux, mask=mask)
    )
示例#6
0
def test_sqrt_sum_wis_transmission_outofbounds(test_spec, wav_unit, flux_unit):
    """Transmission must be within 0-1."""
    wav = test_spec[0] * wav_unit
    flux = test_spec[1] * flux_unit
    mask_1 = np.random.randn(len(wav))
    mask_2 = np.random.rand(len(wav))

    mask_1[0] = 5  # Outside 0-1
    mask_2[-1] = -2  # Outside 0-1

    # Higher value
    with pytest.raises(ValueError):
        rv_precision(wav, flux, mask=mask_1)

    with pytest.raises(ValueError):
        sqrt_sum_wis(wav, flux, mask=mask_1)

        # Lower value
    with pytest.raises(ValueError):
        sqrt_sum_wis(wav, flux, mask=mask_2)

    with pytest.raises(ValueError):
        sqrt_sum_wis(wav, flux, mask=mask_2)