def test_apply(hist_match: HistogramMatching) -> None:
    # pylint: disable=protected-access (W0212)
    hist_match.channels = (0,)
    source = TEST_SRC_IMAGE[:, :, np.newaxis].astype(float)
    reference = TEST_REF_IMAGE[:, :, np.newaxis].astype(float)
    result = hist_match(source, reference)
    expected_result = hist_match._apply(source, reference)
    np.testing.assert_array_equal(result, expected_result)
def test_match_channel_prop(match_prop: float,
                            expected_result: np.ndarray) -> None:
    # pylint: disable=protected-access (W0212)
    hist_match = HistogramMatching(CHANNELS_DEFAULT, check_input=True,
                                   match_prop=match_prop)
    result = hist_match._match_channel(TEST_SRC_IMAGE, TEST_REF_IMAGE)

    assert result.shape == TEST_SRC_IMAGE.shape
    np.testing.assert_array_almost_equal(result, expected_result)
def test_match_channel(hist_match: HistogramMatching) -> None:
    # pylint: disable=protected-access (W0212)
    result = hist_match._match_channel(TEST_SRC_IMAGE, TEST_REF_IMAGE)

    # we test against scikit image histogram matching
    assert result.shape == TEST_SRC_IMAGE.shape
    np.testing.assert_array_equal(result, TEST_RES_IMAGE)
def test_apply_channels(channels: ChannelsType) -> None:
    source = cv2.imread(MUNICH_1_PATH)
    reference = cv2.imread(MUNICH_2_PATH)

    original_source = np.copy(source)
    original_reference = np.copy(reference)

    hist_match = HistogramMatching(channels, check_input=True)
    result = hist_match(source.astype(float), reference.astype(float))

    # check channels to be matched
    for channel in channels:
        with np.testing.assert_raises(AssertionError):
            np.testing.assert_array_equal(source[:, :, channel],
                                          result[:, :, channel])

    # check skipped channels
    skipped_channels = tuple(set(CHANNELS_DEFAULT) - set(channels))
    for channel in skipped_channels:
        np.testing.assert_array_equal(source[:, :, channel],
                                      result[:, :, channel])

    assert result.shape == source.shape
    assert result.dtype == np.float32

    np.testing.assert_array_equal(source, original_source)
    np.testing.assert_array_equal(reference, original_reference)
def test_make_plot(color_space: str, image_channels: str, source_path: str,
                   reference_path: str) -> None:
    params = Params(
        {
            'color_space': color_space,
            'channels': image_channels,
            'match_proportion': 0.8,
            'src_path': source_path,
            'ref_path': reference_path
        }
    )
    converter = build_cs_converter(params.color_space)

    source = read_image(params.src_path)
    reference = read_image(params.ref_path)

    source = converter.convert(source)
    reference = converter.convert(reference)

    channels = tuple(int(c) for c in params.channels.split(','))
    hist_match = HistogramMatching(channels, params.match_proportion)
    result = hist_match(source, reference)

    file_name = os.path.join(TEST_DIR, 'histogram_matching_plot.png')
    images = hm_plot.Images(source, reference, result)
    hm_plot.make_plot(file_name, images, converter, params.color_space,
                      params.channels)

    assert os.path.exists(file_name) is True

    os.remove(file_name)
def test_match_channel_images(source_path: str, reference_path: str,
                              hist_match: HistogramMatching) -> None:
    # pylint: disable=protected-access (W0212)
    source = cv2.imread(source_path)
    reference = cv2.imread(reference_path)

    for channel in range(source.shape[-1]):
        source_c = source[:, :, channel]
        reference_c = reference[:, :, channel]
        result = hist_match._match_channel(source_c, reference_c)

        # we test against scikit image histogram matching
        expected_result = match_histograms(source_c, reference_c)

        assert result.shape == source_c.shape
        np.testing.assert_array_equal(result, expected_result)
def test_apply_2d_image(hist_match: HistogramMatching) -> None:
    hist_match.channels = (0,)
    source = TEST_SRC_IMAGE[:, :, np.newaxis]
    reference = TEST_REF_IMAGE[:, :, np.newaxis]
    original_source = np.copy(source)
    original_reference = np.copy(reference)

    result = hist_match(source.astype(float), reference.astype(float))

    assert result.shape == original_source.shape
    assert result.dtype == np.float32
    assert result.ndim == DIM_3

    np.testing.assert_array_equal(source, original_source)
    np.testing.assert_array_equal(reference, original_reference)

    # we test against scikit image histogram matching
    np.testing.assert_array_equal(result, TEST_RES_IMAGE[:, :, np.newaxis])
def test_hm_match_prop_invalid_value(match_prop: float) -> None:
    with pytest.raises(ValueError):
        HistogramMatching(CHANNELS_DEFAULT, check_input=True,
                          match_prop=match_prop)
def test_hm_match_prop_valid_value(match_prop: float) -> None:
    hist_match = HistogramMatching(CHANNELS_DEFAULT, check_input=True,
                                   match_prop=match_prop)
    assert hist_match.match_prop == float(match_prop)
def test_hm_match_prop_type_error(match_prop: float) -> None:
    with pytest.raises(TypeError):
        HistogramMatching(CHANNELS_DEFAULT, check_input=True,
                          match_prop=match_prop)
def fixture_histogram_matching() -> HistogramMatching:
    return HistogramMatching(CHANNELS_DEFAULT, check_input=True)