コード例 #1
0
ファイル: test_scipy.py プロジェクト: tdglobalnews/autograd
def test_convolve_generalization():
    ag_convolve = autograd.scipy.signal.convolve
    A_35 = R(3, 5)
    A_34 = R(3, 4)
    A_342 = R(3, 4, 2)
    A_2543 = R(2, 5, 4, 3)
    A_24232 = R(2, 4, 2, 3, 2)

    for mode in ["valid", "full"]:
        assert npo.allclose(
            ag_convolve(A_35, A_34, axes=([1], [0]), mode=mode)[1, 2], sp_convolve(A_35[1, :], A_34[:, 2], mode)
        )
        assert npo.allclose(
            ag_convolve(A_35, A_34, axes=([], []), dot_axes=([0], [0]), mode=mode),
            npo.tensordot(A_35, A_34, axes=([0], [0])),
        )
        assert npo.allclose(
            ag_convolve(A_35, A_342, axes=([1], [2]), dot_axes=([0], [0]), mode=mode)[2],
            sum([sp_convolve(A_35[i, :], A_342[i, 2, :], mode) for i in range(3)]),
        )
        assert npo.allclose(
            ag_convolve(A_2543, A_24232, axes=([1, 2], [2, 4]), dot_axes=([0, 3], [0, 3]), mode=mode)[2],
            sum(
                [
                    sum([sp_convolve(A_2543[i, :, :, j], A_24232[i, 2, :, j, :], mode) for i in range(2)])
                    for j in range(3)
                ]
            ),
        )
コード例 #2
0
    def test_convolve_generalization():
        from scipy.signal import convolve as sp_convolve

        ag_convolve = autograd.scipy.signal.convolve
        A_35 = R(3, 5)
        A_34 = R(3, 4)
        A_342 = R(3, 4, 2)
        A_2543 = R(2, 5, 4, 3)
        A_24232 = R(2, 4, 2, 3, 2)

        for mode in ['valid', 'full']:
            assert npo.allclose(
                ag_convolve(A_35, A_34, axes=([1], [0]), mode=mode)[1, 2],
                sp_convolve(A_35[1, :], A_34[:, 2], mode))
            assert npo.allclose(
                ag_convolve(A_35,
                            A_34,
                            axes=([], []),
                            dot_axes=([0], [0]),
                            mode=mode),
                npo.tensordot(A_35, A_34, axes=([0], [0])))
            assert npo.allclose(
                ag_convolve(A_35,
                            A_342,
                            axes=([1], [2]),
                            dot_axes=([0], [0]),
                            mode=mode)[2],
                sum([
                    sp_convolve(A_35[i, :], A_342[i, 2, :], mode)
                    for i in range(3)
                ]))
            assert npo.allclose(
                ag_convolve(A_2543,
                            A_24232,
                            axes=([1, 2], [2, 4]),
                            dot_axes=([0, 3], [0, 3]),
                            mode=mode)[2],
                sum([
                    sum([
                        sp_convolve(A_2543[i, :, :, j], A_24232[i, 2, :, j, :],
                                    mode) for i in range(2)
                    ]) for j in range(3)
                ]))