コード例 #1
0
def test_evaluate_analysis_test_pattern_output():
    # In this test we check that the decoded values are plausible based on them
    # being close to the predicted signal range

    wavelet_index = WaveletFilters.haar_with_shift
    wavelet_index_ho = WaveletFilters.le_gall_5_3
    dwt_depth = 1
    dwt_depth_ho = 0

    picture_bit_width = 10

    h_filter_params = LIFTING_FILTERS[wavelet_index_ho]
    v_filter_params = LIFTING_FILTERS[wavelet_index]

    input_min, input_max = signed_integer_range(picture_bit_width)

    input_array = SymbolArray(2)
    _, intermediate_arrays = analysis_transform(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        input_array,
    )

    for (level, array_name), target_array in intermediate_arrays.items():
        for x in range(target_array.period[0]):
            for y in range(target_array.period[1]):
                # Compute the expected bounds for this value
                lower_bound, upper_bound = evaluate_analysis_filter_bounds(
                    *analysis_filter_bounds(target_array[x, y]),
                    num_bits=picture_bit_width)

                # Create a test pattern
                test_pattern = make_analysis_maximising_pattern(
                    input_array,
                    target_array,
                    x,
                    y,
                )

                # Find the actual values
                lower_value, upper_value = evaluate_analysis_test_pattern_output(
                    h_filter_params,
                    v_filter_params,
                    dwt_depth,
                    dwt_depth_ho,
                    level,
                    array_name,
                    test_pattern,
                    input_min,
                    input_max,
                )

                assert np.isclose(lower_value, lower_bound, rtol=0.01)
                assert np.isclose(upper_value, upper_bound, rtol=0.01)
コード例 #2
0
    def analysis_transform_output(self, input_array, wavelet_index,
                                  wavelet_index_ho, dwt_depth, dwt_depth_ho):
        h_filter_params = LIFTING_FILTERS[wavelet_index_ho]
        v_filter_params = LIFTING_FILTERS[wavelet_index]

        return analysis_transform(
            h_filter_params,
            v_filter_params,
            dwt_depth,
            dwt_depth_ho,
            input_array,
        )
コード例 #3
0
 def analysis_coeff_linexp_arrays(
     self,
     filter_params,
     dwt_depth,
     dwt_depth_ho,
     analysis_input_linexp_array,
 ):
     return analysis_transform(
         filter_params,
         filter_params,
         dwt_depth,
         dwt_depth_ho,
         analysis_input_linexp_array,
     )[0]
コード例 #4
0
    def test_analysis_intermediate_steps_as_expected(self, dwt_depth,
                                                     dwt_depth_ho):
        filter_params = tables.LIFTING_FILTERS[
            tables.WaveletFilters.haar_with_shift]

        input_picture = SymbolArray(2, "p")

        _, intermediate_values = analysis_transform(
            filter_params,
            filter_params,
            dwt_depth,
            dwt_depth_ho,
            input_picture,
        )

        # 2D stages have all expected values
        for level in range(dwt_depth_ho + 1, dwt_depth + dwt_depth_ho + 1):
            names = set(n for l, n in intermediate_values if l == level)
            assert names == set([
                "Input",
                "DC",
                "DC'",
                "DC''",
                "L",
                "L'",
                "L''",
                "H",
                "H'",
                "H''",
                "LL",
                "LH",
                "HL",
                "HH",
            ])

        # HO stages have all expected values
        for level in range(1, dwt_depth_ho + 1):
            names = set(n for l, n in intermediate_values if l == level)
            assert names == set([
                "Input",
                "DC",
                "DC'",
                "DC''",
                "L",
                "H",
            ])
コード例 #5
0
def test_add_missing_analysis_values(
    wavelet_index,
    wavelet_index_ho,
    dwt_depth,
    dwt_depth_ho,
):
    h_filter_params = tables.LIFTING_FILTERS[wavelet_index_ho]
    v_filter_params = tables.LIFTING_FILTERS[wavelet_index]

    _, intermediate_values = analysis_transform(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        SymbolArray(2),
    )

    all_expressions = {
        (level, array_name, x, y): array[x, y]
        for (level, array_name), array in intermediate_values.items()
        for x in range(array.period[0]) for y in range(array.period[1])
    }

    non_nop_expressions = {
        (level, array_name, x, y): array[x, y]
        for (level, array_name), array in intermediate_values.items()
        for x in range(array.period[0]) for y in range(array.period[1])
        if not array.nop
    }

    # Sanity check
    assert all_expressions != non_nop_expressions

    refilled_expressions = add_missing_analysis_values(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        non_nop_expressions,
    )

    assert set(refilled_expressions) == set(all_expressions)
    assert refilled_expressions == all_expressions
コード例 #6
0
    def test_integration(self):
        h_filter_params = tables.LIFTING_FILTERS[
            tables.WaveletFilters.le_gall_5_3]
        v_filter_params = tables.LIFTING_FILTERS[
            tables.WaveletFilters.haar_with_shift]

        # Run against a real-world, more complex set of expressions
        symbol_array = SymbolArray(2)
        coeff_arrays, intermediate_arrays = analysis_transform(
            h_filter_params=h_filter_params,
            v_filter_params=v_filter_params,
            dwt_depth=1,
            dwt_depth_ho=2,
            array=symbol_array,
        )

        for array in intermediate_arrays.values():
            cached_array = SymbolicPeriodicCachingArray(array, symbol_array)

            for x in range(array.period[0] * 2):
                for y in range(array.period[1] * 2):
                    assert (strip_affine_errors(
                        cached_array[x, y]) == strip_affine_errors(array[x,
                                                                         y]))
コード例 #7
0
    def test_filters_invert_eachother(self, wavelet_index, wavelet_index_ho,
                                      dwt_depth, dwt_depth_ho):
        # Test that the analysis and synthesis filters invert each-other as a
        # check of consistency (and, indirectly, the correctness of the
        # analysis implementation and convert_between_synthesis_and_analysis)

        h_filter_params = tables.LIFTING_FILTERS[wavelet_index_ho]
        v_filter_params = tables.LIFTING_FILTERS[wavelet_index]

        input_picture = SymbolArray(2, "p")

        transform_coeffs, _ = analysis_transform(
            h_filter_params,
            v_filter_params,
            dwt_depth,
            dwt_depth_ho,
            input_picture,
        )
        output_picture, _ = synthesis_transform(
            h_filter_params,
            v_filter_params,
            dwt_depth,
            dwt_depth_ho,
            transform_coeffs,
        )

        # In this example, no quantisation is applied between the two filters.
        # As a consequence the only error terms arise from rounding errors in
        # the analysis and synthesis filters. Since this implementation does
        # not account for divisions of the same numbers producing the same
        # rounding errors, these rounding errors do not cancel out here.
        # However, aside from these terms, the input and output of the filters
        # should be identical.
        rounding_errors = output_picture[0, 0] - input_picture[0, 0]
        assert all(
            isinstance(sym, AAError) for sym in rounding_errors.symbols())
コード例 #8
0
def test_evaluate_synthesis_test_pattern_output():
    # In this test we simply check that the decoded values match those
    # computed by the optimise_synthesis_maximising_test_pattern function

    wavelet_index = WaveletFilters.haar_with_shift
    wavelet_index_ho = WaveletFilters.le_gall_5_3
    dwt_depth = 1
    dwt_depth_ho = 0

    picture_bit_width = 10

    max_quantisation_index = 64

    quantisation_matrix = {
        0: {
            "LL": 0
        },
        1: {
            "LH": 1,
            "HL": 2,
            "HH": 3
        },
    }

    h_filter_params = LIFTING_FILTERS[wavelet_index_ho]
    v_filter_params = LIFTING_FILTERS[wavelet_index]

    input_min, input_max = signed_integer_range(picture_bit_width)

    input_array = SymbolArray(2)
    analysis_transform_coeff_arrays, _ = analysis_transform(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        input_array,
    )

    symbolic_coeff_arrays = make_symbol_coeff_arrays(dwt_depth, dwt_depth_ho)
    symbolic_output_array, symbolic_intermediate_arrays = synthesis_transform(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        symbolic_coeff_arrays,
    )

    pyexp_coeff_arrays = make_variable_coeff_arrays(dwt_depth, dwt_depth_ho)
    _, pyexp_intermediate_arrays = synthesis_transform(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        pyexp_coeff_arrays,
    )

    for (level,
         array_name), target_array in symbolic_intermediate_arrays.items():
        for x in range(target_array.period[0]):
            for y in range(target_array.period[1]):
                # Create a test pattern
                test_pattern = make_synthesis_maximising_pattern(
                    input_array,
                    analysis_transform_coeff_arrays,
                    target_array,
                    symbolic_output_array,
                    x,
                    y,
                )

                synthesis_pyexp = pyexp_intermediate_arrays[(level,
                                                             array_name)][x, y]
                # Run with no-optimisation iterations but, as a side effect,
                # compute the actual decoded value to compare with
                test_pattern = optimise_synthesis_maximising_test_pattern(
                    h_filter_params,
                    v_filter_params,
                    dwt_depth,
                    dwt_depth_ho,
                    quantisation_matrix,
                    synthesis_pyexp,
                    test_pattern,
                    input_min,
                    input_max,
                    max_quantisation_index,
                    None,
                    1,
                    None,
                    0.0,
                    0.0,
                    0,
                    0,
                )

                # Find the actual values
                lower_value, upper_value = evaluate_synthesis_test_pattern_output(
                    h_filter_params,
                    v_filter_params,
                    dwt_depth,
                    dwt_depth_ho,
                    quantisation_matrix,
                    synthesis_pyexp,
                    test_pattern,
                    input_min,
                    input_max,
                    max_quantisation_index,
                )

                assert upper_value[0] == test_pattern.decoded_value
                assert upper_value[1] == test_pattern.quantisation_index
コード例 #9
0
ファイル: helpers.py プロジェクト: bbc/vc2_bit_widths
def static_filter_analysis(
    wavelet_index,
    wavelet_index_ho,
    dwt_depth,
    dwt_depth_ho,
    num_batches=1,
    batch_num=0,
):
    r"""
    Performs a complete static analysis of a VC-2 filter configuration,
    computing theoretical upper- and lower-bounds for signal values (see
    :ref:`theory-affine-arithmetic`) and heuristic test patterns (see
    :ref:`theory-test-patterns`) for all intermediate and final analysis and
    synthesis filter values.
    
    Parameters
    ==========
    wavelet_index : :py:class:`vc2_data_tables.WaveletFilters` or int
    wavelet_index_ho : :py:class:`vc2_data_tables.WaveletFilters` or int
    dwt_depth : int
    dwt_depth_ho : int
        The filter parameters.
    
    num_batches : int
    batch_num : int
        Though for most filters this function runs either instantaneously or at
        worst in the space of a couple of hours, unusually large filters can
        take an extremely long time to run. For example, a 4-level Fidelity
        transform may take around a month to evaluate.
        
        These arguments may be used to split this job into separate batches
        which may be computed separately (and in parallel) and later combined.
        For example, setting ``num_batches`` to 3 results in only analysing
        every third filter phase. The ``batch_num`` parameter should then be
        set to either 0, 1 or 2 to specify which third.
        
        The skipped phases are simply omitted from the returned dictionaries.
        The dictionaries returned for each batch should be unified to produce
        the complete analysis.
    
    Returns
    =======
    analysis_signal_bounds : {(level, array_name, x, y): (lower_bound_exp, upper_bound_exp), ...}
    synthesis_signal_bounds : {(level, array_name, x, y): (lower_bound_exp, upper_bound_exp), ...}
        Expressions defining the upper and lower bounds for all intermediate
        and final analysis and synthesis filter values.
        
        The keys of the returned dictionaries give the level, array name and
        filter phase for which each pair of bounds corresponds (see
        :ref:`terminology`). The naming
        conventions used are those defined by
        :py:func:`vc2_bit_widths.vc2_filters.analysis_transform` and
        :py:func:`vc2_bit_widths.vc2_filters.synthesis_transform`. Arrays which
        are just interleavings, subsamplings or renamings of other arrays are
        omitted.
        
        The lower and upper bounds are given algebraically as
        :py:class:`~vc2_bit_widths.linexp.LinExp`\ s.
        
        For the analysis filter bounds, the expressions are defined in terms of
        the variables ``LinExp("signal_min")`` and ``LinExp("signal_max")``.
        These should be substituted for the minimum and maximum picture signal
        level to find the upper and lower bounds for a particular picture bit
        width.
        
        For the synthesis filter bounds, the expressions are defined in terms
        of variables of the form ``LinExp("coeff_LEVEL_ORIENT_min")`` and
        ``LinExp("coeff_LEVEL_ORIENT_max")`` which give lower and upper bounds
        for the transform coefficients with the named level and orientation.
        
        The :py:func:`~vc2_bit_widths.helpers.evaluate_filter_bounds` function
        may be used to substitute concrete values into these expressions for a
        particular picture bit width.
        
    analysis_test_patterns: {(level, array_name, x, y): :py:class:`~vc2_bit_widths.patterns.TestPatternSpecification`, ...}
    synthesis_test_patterns: {(level, array_name, x, y): :py:class:`~vc2_bit_widths.patterns.TestPatternSpecification`, ...}
        Heuristic test patterns which are designed to maximise a particular
        intermediate or final filter value. For a minimising test pattern,
        invert the polarities of the pixels.
        
        The keys of the returned dictionaries give the level, array name and
        filter phase for which each set of bounds corresponds (see
        :ref:`terminology`). Arrays which are just interleavings, subsamplings
        or renamings of other arrays are omitted.
    """
    v_filter_params = LIFTING_FILTERS[wavelet_index]
    h_filter_params = LIFTING_FILTERS[wavelet_index_ho]

    # Create the algebraic representation of the analysis transform
    picture_array = SymbolArray(2)
    analysis_coeff_arrays, intermediate_analysis_arrays = analysis_transform(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        picture_array,
    )

    # Count the total number of arrays for use in logging messages
    num_arrays = sum(array.period[0] * array.period[1]
                     for array in intermediate_analysis_arrays.values()
                     if not array.nop)
    array_num = 0

    # Compute bounds/test pattern for every intermediate/output analysis value
    analysis_signal_bounds = OrderedDict()
    analysis_test_patterns = OrderedDict()
    for (level,
         array_name), target_array in intermediate_analysis_arrays.items():
        # Skip arrays which are just views of other arrays
        if target_array.nop:
            continue

        for x in range(target_array.period[0]):
            for y in range(target_array.period[1]):
                array_num += 1
                if (array_num - 1) % num_batches != batch_num:
                    continue

                logger.info(
                    "Analysing analysis filter %d of %d (level %d, %s[%d, %d])",
                    array_num,
                    num_arrays,
                    level,
                    array_name,
                    x,
                    y,
                )

                # Compute signal bounds
                analysis_signal_bounds[(level, array_name, x,
                                        y)] = analysis_filter_bounds(
                                            target_array[x, y])

                # Generate test pattern
                analysis_test_patterns[(level, array_name, x,
                                        y)] = make_analysis_maximising_pattern(
                                            picture_array,
                                            target_array,
                                            x,
                                            y,
                                        )

    # Create the algebraic representation of the synthesis transform
    coeff_arrays = make_symbol_coeff_arrays(dwt_depth, dwt_depth_ho)
    synthesis_output_array, intermediate_synthesis_arrays = synthesis_transform(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        coeff_arrays,
    )

    # Create a view of the analysis coefficient arrays which avoids recomputing
    # already-known analysis filter phases
    cached_analysis_coeff_arrays = {
        level: {
            orient: SymbolicPeriodicCachingArray(array, picture_array)
            for orient, array in orients.items()
        }
        for level, orients in analysis_coeff_arrays.items()
    }

    # Count the total number of arrays for use in logging messages
    num_arrays = sum(array.period[0] * array.period[1]
                     for array in intermediate_synthesis_arrays.values()
                     if not array.nop)
    array_num = 0

    # Compute bounds/test pattern for every intermediate/output analysis value
    synthesis_signal_bounds = OrderedDict()
    synthesis_test_patterns = OrderedDict()
    for (level,
         array_name), target_array in intermediate_synthesis_arrays.items():
        # Skip arrays which are just views of other arrays
        if target_array.nop:
            continue

        for x in range(target_array.period[0]):
            for y in range(target_array.period[1]):
                array_num += 1
                if (array_num - 1) % num_batches != batch_num:
                    continue

                logger.info(
                    "Analysing synthesis filter %d of %d (level %d, %s[%d, %d])",
                    array_num,
                    num_arrays,
                    level,
                    array_name,
                    x,
                    y,
                )

                # Compute signal bounds
                synthesis_signal_bounds[(level, array_name, x,
                                         y)] = synthesis_filter_bounds(
                                             target_array[x, y])

                # Compute test pattern
                synthesis_test_patterns[(
                    level, array_name, x,
                    y)] = make_synthesis_maximising_pattern(
                        picture_array,
                        cached_analysis_coeff_arrays,
                        target_array,
                        synthesis_output_array,
                        x,
                        y,
                    )

                # For extremely large filters, a noteworthy amount of overall
                # RAM can be saved by not caching synthesis filters. These
                # filters generally don't benefit much in terms of runtime from
                # caching so this has essentially no impact on runtime.
                for a in intermediate_synthesis_arrays.values():
                    a.clear_cache()

    return (
        analysis_signal_bounds,
        synthesis_signal_bounds,
        analysis_test_patterns,
        synthesis_test_patterns,
    )
コード例 #10
0
def test_integration():
    # A simple integration test which computes signal bounds for a small
    # transform operation

    filter_params = LIFTING_FILTERS[WaveletFilters.haar_with_shift]
    dwt_depth = 1
    dwt_depth_ho = 1

    input_picture_array = SymbolArray(2)
    analysis_coeff_arrays, analysis_intermediate_values = analysis_transform(
        filter_params,
        filter_params,
        dwt_depth,
        dwt_depth_ho,
        input_picture_array,
    )

    input_coeff_arrays = make_symbol_coeff_arrays(dwt_depth, dwt_depth_ho)
    synthesis_output, synthesis_intermediate_values = synthesis_transform(
        filter_params,
        filter_params,
        dwt_depth,
        dwt_depth_ho,
        input_coeff_arrays,
    )

    signal_min = LinExp("signal_min")
    signal_max = LinExp("signal_max")

    example_range = {signal_min: -512, signal_max: 511}

    # Input signal bounds should be as specified
    assert analysis_filter_bounds(
        analysis_intermediate_values[(2, "Input")][0, 0], ) == (signal_min,
                                                                signal_max)

    # Output of final analysis filter should require a greater depth (NB: for
    # the Haar transform it is the high-pass bands which gain the largest
    # signal range)
    analysis_output_lower, analysis_output_upper = analysis_filter_bounds(
        analysis_intermediate_values[(1, "H")][0, 0], )
    assert analysis_output_lower.subs(example_range) < signal_min.subs(
        example_range)
    assert analysis_output_upper.subs(example_range) > signal_max.subs(
        example_range)

    example_coeff_range = {
        "coeff_{}_{}_{}".format(level, orient, minmax):
        maximum_dequantised_magnitude(
            int(round(value.subs(example_range).constant)))
        for level, orients in analysis_coeff_arrays.items()
        for orient, expr in orients.items()
        for minmax, value in zip(["min", "max"], analysis_filter_bounds(expr))
    }

    # Signal range should shrink down by end of synthesis process but should
    # still be larger than the original signal
    final_output_lower, final_output_upper = synthesis_filter_bounds(
        synthesis_output[0, 0])

    assert final_output_upper.subs(
        example_coeff_range) < analysis_output_upper.subs(example_range)
    assert final_output_lower.subs(
        example_coeff_range) > analysis_output_lower.subs(example_range)

    assert final_output_upper.subs(example_coeff_range) > signal_max.subs(
        example_range)
    assert final_output_lower.subs(example_coeff_range) < signal_min.subs(
        example_range)
コード例 #11
0
    def test_filters_match_pseudocode(self, wavelet_index, wavelet_index_ho):
        # This test checks that the filters implement the same behaviour as the
        # VC-2 pseudocode, including compatible operation ordering. This test is
        # carried out on a relatively small Haar transform because::
        #
        # * The Haar transform is free from edge effects making the
        #   InfiniteArray implementation straight-forwardly equivalent to the
        #   pseudocode behaviour in all cases (not just non-edge cases)
        # * The Haar transform is available in a form with and without the bit
        #   shift so we can check that the bit shift parameter is used
        #   correctly and taken from the correct wavelet index.
        # * Using large transform depths or filters produces very large
        #   functions for analysis transforms under PyExp which can crash
        #   Python interpreters. (In practice they'll only ever be generated
        #   for synthesis transforms which produce small code even for large
        #   transforms)

        width = 16
        height = 8

        dwt_depth = 1
        dwt_depth_ho = 2

        # Create a random picture to analyse
        rand = np.random.RandomState(1)
        random_input_picture = rand.randint(-512, 511, (height, width))

        # Analyse using pseudocode
        state = State(
            wavelet_index=wavelet_index,
            wavelet_index_ho=wavelet_index_ho,
            dwt_depth=dwt_depth,
            dwt_depth_ho=dwt_depth_ho,
        )
        pseudocode_coeffs = dwt(state, random_input_picture.tolist())

        # Analyse using InfiniteArrays
        h_filter_params = tables.LIFTING_FILTERS[wavelet_index_ho]
        v_filter_params = tables.LIFTING_FILTERS[wavelet_index]
        ia_coeffs, _ = analysis_transform(
            h_filter_params,
            v_filter_params,
            dwt_depth,
            dwt_depth_ho,
            VariableArray(2, Argument("picture")),
        )

        # Compare analysis results
        for level in pseudocode_coeffs:
            for orient in pseudocode_coeffs[level]:
                pseudocode_data = pseudocode_coeffs[level][orient]
                for row, row_data in enumerate(pseudocode_data):
                    for col, pseudocode_value in enumerate(row_data):
                        # Create and call a function to compute this value via
                        # InfiniteArrays/PyExp
                        expr = ia_coeffs[level][orient][col, row]
                        f = expr.make_function()
                        # NB: Array is transposed to support (x, y) indexing
                        ia_value = f(random_input_picture.T)

                        assert ia_value == pseudocode_value

        # Synthesise using pseudocode
        pseudocode_picture = idwt(state, pseudocode_coeffs)

        # Synthesise using InfiniteArrays
        ia_picture, _ = synthesis_transform(
            h_filter_params,
            v_filter_params,
            dwt_depth,
            dwt_depth_ho,
            make_variable_coeff_arrays(dwt_depth, dwt_depth_ho),
        )

        # Create numpy-array based coeff data for use by
        # InfiniteArray-generated functions (NB: arrays are transposed to
        # support (x, y) indexing.
        ia_coeffs_data = {
            level: {
                orient: np.array(array, dtype=np.int64).T
                for orient, array in orients.items()
            }
            for level, orients in pseudocode_coeffs.items()
        }

        # Compare synthesis results

        for row, row_data in enumerate(pseudocode_picture):
            for col, pseudocode_value in enumerate(row_data):
                # Create and call a function to compute this value via
                # InfiniteArrays/PyExp
                expr = ia_picture[col, row]
                f = expr.make_function()
                # NB: Arrays are transposed to support (x, y) indexing
                ia_value = f(ia_coeffs_data)

                assert ia_value == pseudocode_value
コード例 #12
0
def test_fast_partial_analysis_transform_no_target(wavelet_index,
                                                   wavelet_index_ho, dwt_depth,
                                                   dwt_depth_ho):
    # This test verifies the analysis transform produces identical results to
    # the pseudocode in the case where no edge effects are encountered.

    width = 32
    height = 8

    rand = np.random.RandomState(1)
    signal = rand.randint(-512, 511, (height, width))

    # Process using pseudocode
    state = State(
        wavelet_index=wavelet_index,
        wavelet_index_ho=wavelet_index_ho,
        dwt_depth=dwt_depth,
        dwt_depth_ho=dwt_depth_ho,
    )
    pseudocode_out = dwt(state, signal.tolist())

    # Process using matrix transform
    h_filter_params = tables.LIFTING_FILTERS[wavelet_index_ho]
    v_filter_params = tables.LIFTING_FILTERS[wavelet_index]
    matrix_out = fast_partial_analysis_transform(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        signal.copy(),
    )

    # Using a symbolic representation of the transform operation and use this
    # to create masks identifying the coefficients which are edge-effect free.
    symbolic_out, _ = analysis_transform(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        SymbolArray(2),
    )
    edge_effect_free_pixel_mask = {
        level: {
            orient: np.array([[
                all(0 <= sym[1] < width and 0 <= sym[2] < height
                    for sym in strip_affine_errors(symbolic_out[level][orient][
                        col, row]).symbols() if sym is not None)
                for col in range(matrix_out[level][orient].shape[1])
            ] for row in range(matrix_out[level][orient].shape[0])])
            for orient in matrix_out[level]
        }
        for level in matrix_out
    }

    # Sanity check: Ensure that in every transform subband there is at least
    # one edge-effect free value (otherwise the test needs to be modified to
    # use a larger input picture.
    assert all(
        np.any(mask) for level, orients in edge_effect_free_pixel_mask.items()
        for orient, mask in orients.items())

    # Compare the two outputs and ensure all edge-effect free pixels are
    # identical
    assert set(pseudocode_out) == set(matrix_out)
    for level in matrix_out:
        assert set(pseudocode_out[level]) == set(matrix_out[level])
        for orient in matrix_out[level]:
            pseudocode_array = np.array(pseudocode_out[level][orient])
            matrix_array = matrix_out[level][orient]
            mask = edge_effect_free_pixel_mask[level][orient]
            assert np.array_equal(pseudocode_array[mask], matrix_array[mask])
コード例 #13
0
def test_aggregation_flag(tmpdir, capsys, arg, exp_phases):
    # Check that aggregation of filter phases works

    f = str(tmpdir.join("file.json"))

    # vc2-static-filter-analysis
    assert sfa(shlex.split("-w haar_with_shift -d 1 -o") + [f]) == 0

    # vc2-bit-widths-table
    assert bwt([f] + shlex.split("-b 10 {}".format(arg))) == 0

    csv_rows = list(csv.reader(capsys.readouterr().out.splitlines()))

    columns = csv_rows[0][:-5]

    # Check all phase columns are present as expected
    if exp_phases:
        assert columns == ["type", "level", "array_name", "x", "y"]
    else:
        assert columns == ["type", "level", "array_name"]

    # Check the rows are as expected
    row_headers = [tuple(row[:-5]) for row in csv_rows[1:]]

    # ...by comparing with the intermediate arrays expected for this filter...
    h_filter_params = LIFTING_FILTERS[WaveletFilters.haar_with_shift]
    v_filter_params = LIFTING_FILTERS[WaveletFilters.haar_with_shift]
    dwt_depth = 1
    dwt_depth_ho = 0

    _, analysis_intermediate_arrays = analysis_transform(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        SymbolArray(2),
    )
    _, synthesis_intermediate_arrays = synthesis_transform(
        h_filter_params,
        v_filter_params,
        dwt_depth,
        dwt_depth_ho,
        make_symbol_coeff_arrays(dwt_depth, dwt_depth_ho),
    )

    if exp_phases:
        assert row_headers == [
            (type_name, str(level), array_name, str(x), str(y))
            for type_name, intermediate_arrays in [
                ("analysis", analysis_intermediate_arrays),
                ("synthesis", synthesis_intermediate_arrays),
            ] for (level, array_name), array in intermediate_arrays.items()
            for x in range(array.period[0]) for y in range(array.period[1])
        ]
    else:
        assert row_headers == [(type_name, str(level), array_name)
                               for type_name, intermediate_arrays in [
                                   ("analysis", analysis_intermediate_arrays),
                                   ("synthesis",
                                    synthesis_intermediate_arrays),
                               ] for level, array_name in intermediate_arrays]