def test_band_info_init_error():
    """Test init error of BandInfo class"""

    num_ch = 3
    centers = [100, 200]
    centers_true = [100, 200, 300]
    unit = "nm"
    bandwidths = [1, 2, 3, 5]
    centers_std = [0.1, 0.1, 0.2, 0.6, 0.2]
    bandwidths_std = [0.2, 0.3, 0.4, 1]
    type = "reflectance"

    # Test error handling
    with raises(ValueError) as cm:
        tmp = BandInfo(num_channels=num_ch, centers=centers, unit='nm')
    assert "Length of list has to match number of channels" == str(cm.value)

    with raises(ValueError) as cm:
        tmp = BandInfo(num_channels=num_ch, centers=centers_true,
                          bandwidths=bandwidths, unit='nm')
    assert "Length of list has to match number of channels" == str(cm.value)

    with raises(ValueError) as cm:
        tmp = BandInfo(num_channels=num_ch, centers=centers_true,
                          bandwidths_std=bandwidths_std, unit='nm')
    assert "Length of list has to match number of channels" == str(cm.value)

    with raises(ValueError) as cm:
        tmp = BandInfo(num_channels=num_ch, centers=centers_true,
                          centers_std=centers_std, unit='nm')
    assert "Length of list has to match number of channels" == str(cm.value)

    return
def test_band_info_subband():
    """Test get_subband() of BandInfo class"""

    num_ch = 13
    centers = np.arange(400, 400 + num_ch)
    unit = "nm"
    bandwidths = np.arange(num_ch)
    centers_std = 0.5*np.arange(num_ch)
    bandwidths_std = 0.25*np.arange(num_ch)
    type = "reflectance"

    tmp = BandInfo(num_channels=num_ch, centers=centers, unit=unit,
                   bandwidths=bandwidths, bandwidths_std=bandwidths_std,
                   centers_std=centers_std, type=type)

    tmp_sub = tmp.get_subband(start=5, stop=11)
    centers_exp = np.asarray([405, 406, 407, 408, 409, 410])
    centers_std_exp = np.asarray([2.5, 3.0, 3.5, 4.0, 4.5, 5.0])
    bandwidths_exp = np.asarray([5, 6, 7, 8, 9, 10])
    bandwidths_std_exp = np.asarray([1.25, 1.5, 1.75, 2.0, 2.25, 2.5])

    assert tmp.type == tmp_sub.type
    assert 6 == tmp_sub.num_channels
    assert np.array_equal(centers_exp, tmp_sub.centers)
    assert np.array_equal(centers_std_exp, tmp_sub.centers_std)
    assert np.array_equal(bandwidths_exp, tmp_sub.bandwidths)
    assert np.array_equal(bandwidths_std_exp, tmp_sub.bandwidths_std)

    return
def test_band_info_downsampled():
    """Test get_subband() of BandInfo class"""

    num_ch = 13
    centers = np.linspace(400, 600, num_ch)
    unit = "nm"
    bandwidths = np.arange(num_ch)
    centers_std = 0.5*np.arange(num_ch)
    bandwidths_std = 0.25*np.arange(num_ch)
    type = "reflectance"

    tmp = BandInfo(num_channels=num_ch, centers=centers, unit=unit,
                   bandwidths=bandwidths, bandwidths_std=bandwidths_std,
                   centers_std=centers_std, type=type)

    tmp_down = tmp.get_downsampled(num=7)
    centers_exp = np.linspace(400, 600, 7)

    assert tmp.type == tmp_down.type
    assert 7 == tmp_down.num_channels
    assert np.array_equal(centers_exp, tmp_down.centers)
    assert tmp_down.centers_std is None
    assert tmp_down.bandwidths is None
    assert tmp_down.bandwidths_std is None

    return
def test_band_info_copy():
    """Test copy() of BandInfo class"""

    num_ch = 3
    centers = [100, 200, 300]
    unit = "nm"
    bandwidths = [1, 2, 3]
    centers_std = [0.1, 0.1, 0.2]
    bandwidths_std = [0.2, 0.3, 0.4]
    type = "reflectance"

    tmp = BandInfo(num_channels=num_ch, centers=centers, unit=unit,
                   bandwidths=bandwidths, bandwidths_std=bandwidths_std,
                   centers_std=centers_std, type=type)

    tmp_cp = tmp.copy()

    assert tmp == tmp_cp

    return
def test_band_info_init():
    """Test init of BandInfo class"""

    num_ch = 3
    centers = [100, 200, 300]
    unit = "nm"
    bandwidths = [1, 2, 3]
    centers_std = [0.1, 0.1, 0.2]
    bandwidths_std = [0.2, 0.3, 0.4]
    type = "reflectance"

    tmp = BandInfo(num_channels=num_ch, centers=centers, unit=unit,
                   bandwidths=bandwidths, bandwidths_std=bandwidths_std,
                   centers_std=centers_std, type=type)

    assert tmp.num_channels == num_ch
    assert tmp.centers == centers
    assert tmp.unit == unit
    assert tmp.bandwidths == bandwidths
    assert tmp.centers_std == centers_std
    assert tmp.bandwidths_std == bandwidths_std
    assert tmp.type == type

    return