def test_new_med_ndwi():
    medndwi = NormalisedDifferenceStats('green',
                                        'nir',
                                        'ndwi',
                                        stats=['median'])

    arr = np.random.uniform(low=-1, high=1, size=(5, 100, 100))
    data_array_1 = xr.DataArray(arr,
                                dims=('time', 'y', 'x'),
                                coords={'time': list(range(5))},
                                attrs={'crs': 'Fake CRS'})
    arr = np.random.uniform(low=-1, high=1, size=(5, 100, 100))
    data_array_2 = xr.DataArray(arr,
                                dims=('time', 'y', 'x'),
                                coords={'time': list(range(5))},
                                attrs={'crs': 'Fake CRS'})
    dataset = xr.Dataset(data_vars={
        'green': data_array_1,
        'nir': data_array_2
    },
                         attrs={'crs': 'Fake CRS'})
    result = medndwi.compute(dataset)
    assert isinstance(result, xr.Dataset)
    assert 'crs' in result.attrs
    assert 'ndwi_median' in result.data_vars
예제 #2
0
def test_normalised_difference_stats(dataset, output_name):
    var1, var2 = list(dataset.data_vars)
    ndstat = NormalisedDifferenceStats(var1, var2, output_name)
    result = ndstat.compute(dataset)

    assert isinstance(result, xr.Dataset)
    assert 'time' not in result.dims
    assert dataset.crs == result.crs

    expected_output_varnames = set(f'{output_name}_{stat_name}' for stat_name in ndstat.stats)
    assert set(result.data_vars) == expected_output_varnames

    # Check the measurements() function raises an error on bad input_measurements
    with pytest.raises(StatsConfigurationError):
        invalid_names = [Measurement(name='foo', **FAKE_MEASUREMENT_INFO)]
        ndstat.measurements(invalid_names)

    # Check the measurements() function returns something reasonable
    input_measurements = [Measurement(name=name, **FAKE_MEASUREMENT_INFO) for name in (var1, var2)]
    output_measurements = ndstat.measurements(input_measurements)
    measurement_names = set(m.name for m in output_measurements)
    assert expected_output_varnames == measurement_names