Beispiel #1
0
def test_estimate_gradient(fir_filter, art_warning, expected_values):

    try:
        # Load data for testing
        expected_data = expected_values()

        x1 = expected_data[0]
        x2 = expected_data[1]
        x3 = expected_data[2]

        grad0 = expected_data[3]
        grad1 = expected_data[4]
        grad2 = expected_data[5]

        result0 = expected_data[6]
        result1 = expected_data[7]
        result2 = expected_data[8]

        # Create signal data
        x = np.array([np.array(x1 * 2),
                      np.array(x2 * 2),
                      np.array(x3 * 2)],
                     dtype=object)

        # Create input gradient
        grad = np.array([np.array(grad0),
                         np.array(grad1),
                         np.array(grad2)],
                        dtype=object)

        # Filter params
        numerator_coef = np.array([0.1, 0.2, -0.1, -0.2])

        if fir_filter:
            denominator_coef = np.array([1.0])
        else:
            denominator_coef = np.array([1.0, 0.1, 0.3, 0.4])

        # Create filter
        audio_filter = LFilter(numerator_coef=numerator_coef,
                               denominator_coef=denominator_coef)

        # Estimate gradient
        estimated_grad = audio_filter.estimate_gradient(x, grad=grad)

        # Test
        assert estimated_grad.shape == x.shape
        np.testing.assert_array_almost_equal(result0,
                                             estimated_grad[0],
                                             decimal=0)
        np.testing.assert_array_almost_equal(result1,
                                             estimated_grad[1],
                                             decimal=0)
        np.testing.assert_array_almost_equal(result2,
                                             estimated_grad[2],
                                             decimal=0)

    except ARTTestException as e:
        art_warning(e)
Beispiel #2
0
def test_relation_clip_values_error(art_warning):
    try:
        exc_msg = "Invalid `clip_values`: min >= max."
        with pytest.raises(ValueError, match=exc_msg):
            LFilter(numerator_coef=np.array([0.1, 0.2, 0.3]),
                    denominator_coef=np.array([0.1, 0.2, 0.3]),
                    clip_values=(1, 0))

    except ARTTestException as e:
        art_warning(e)
Beispiel #3
0
def test_triple_clip_values_error(art_warning):
    try:
        exc_msg = "`clip_values` should be a tuple of 2 floats containing the allowed data range."
        with pytest.raises(ValueError, match=exc_msg):
            LFilter(
                numerator_coef=np.array([0.1, 0.2, 0.3]),
                denominator_coef=np.array([0.1, 0.2, 0.3]),
                clip_values=(0, 1, 2),
            )

    except ARTTestException as e:
        art_warning(e)
Beispiel #4
0
def test_default(art_warning):
    try:
        # Small data for testing
        x = np.array([[0.37, 0.68, 0.63, 0.48, 0.48, 0.18, 0.19]])

        # Create filter
        audio_filter = LFilter()

        # Apply filter
        result = audio_filter(x)

        # Test
        assert result[1] is None
        np.testing.assert_array_almost_equal(x, result[0], decimal=0)

    except ARTTestException as e:
        art_warning(e)
Beispiel #5
0
def load_audio_channel(delay, attenuation, pytorch=True):
    """
    Return an art LFilter object for a simple delay (multipath) channel

    If attenuation == 0 or delay == 0, return an identity channel
        Otherwise, return a channel with length equal to delay + 1

    NOTE: lfilter truncates the end of the echo, so output length equals input length
    """
    delay = int(delay)
    attenuation = float(attenuation)
    if delay < 0:
        raise ValueError(
            f"delay {delay} must be a nonnegative number (of samples)")
    if delay == 0 or attenuation == 0:
        logger.warning("Using an identity channel")
        numerator_coef = np.array([1.0])
        denominator_coef = np.array([1.0])
    else:
        if not (-1 <= attenuation <= 1):
            logger.warning(f"filter attenuation {attenuation} not in [-1, 1]")

        # Simple FIR filter with a single multipath delay
        numerator_coef = np.zeros(delay + 1)
        numerator_coef[0] = 1.0
        numerator_coef[delay] = attenuation

        denominator_coef = np.zeros_like(numerator_coef)
        denominator_coef[0] = 1.0

    if pytorch:
        try:
            return LFilterPyTorch(numerator_coef=numerator_coef,
                                  denominator_coef=denominator_coef)
        except ImportError:
            logger.exception(
                "PyTorch not available. Resorting to scipy filter")

    logger.warning(
        "Scipy LFilter does not currently implement proper gradients")
    return LFilter(numerator_coef=numerator_coef,
                   denominator_coef=denominator_coef)
Beispiel #6
0
def test_audio_filter(fir_filter, art_warning, expected_values):
    try:
        # Load data for testing
        expected_data = expected_values()

        x1 = expected_data[0]
        x2 = expected_data[1]
        x3 = expected_data[2]
        result_0 = expected_data[3]
        result_1 = expected_data[4]
        result_2 = expected_data[5]

        # Create signal data
        x = np.array([np.array(x1 * 2),
                      np.array(x2 * 2),
                      np.array(x3 * 2)],
                     dtype=object)

        # Filter params
        numerator_coef = np.array([0.1, 0.2, -0.1, -0.2])

        if fir_filter:
            denominator_coef = np.array([1.0])
        else:
            denominator_coef = np.array([1.0, 0.1, 0.3, 0.4])

        # Create filter
        audio_filter = LFilter(numerator_coef=numerator_coef,
                               denominator_coef=denominator_coef)

        # Apply filter
        result = audio_filter(x)

        # Test
        assert result[1] is None
        np.testing.assert_array_almost_equal(result_0, result[0][0], decimal=0)
        np.testing.assert_array_almost_equal(result_1, result[0][1], decimal=0)
        np.testing.assert_array_almost_equal(result_2, result[0][2], decimal=0)

    except ARTTestException as e:
        art_warning(e)
Beispiel #7
0
def test_check_params(art_warning):

    try:

        with pytest.raises(ValueError):
            _ = LFilter(numerator_coef=np.array([0.1, 0.2, -0.1, -0.2]),
                        denominator_coef=[0.0, 0.1, 0.3, 0.4])

        with pytest.raises(ValueError):
            _ = LFilter(numerator_coef=np.array([0.1, 0.2, -0.1, -0.2]),
                        denominator_coef=np.array([0.0, 0.1, 0.3, 0.4]))

        with pytest.raises(ValueError):
            _ = LFilter(numerator_coef=[0.1, 0.2, -0.1, -0.2],
                        denominator_coef=np.array([1.0, 0.1, 0.3, 0.4]))

        with pytest.raises(ValueError):
            _ = LFilter(
                numerator_coef=np.array([0.1, 0.2, -0.1, -0.2]),
                denominator_coef=np.array([1.0, 0.1, 0.3, 0.4]),
                axis=1.0,
            )

        with pytest.raises(ValueError):
            _ = LFilter(
                numerator_coef=np.array([0.1, 0.2, -0.1, -0.2]),
                denominator_coef=np.array([1.0, 0.1, 0.3, 0.4]),
                initial_cond=1.0,
            )

        with pytest.raises(ValueError):
            _ = LFilter(
                numerator_coef=np.array([0.1, 0.2, -0.1, -0.2]),
                denominator_coef=np.array([1.0, 0.1, 0.3, 0.4]),
                verbose="True",
            )

    except ARTTestException as e:
        art_warning(e)