コード例 #1
0
    def test_cdgmm_exceptions(self, backend):
        with pytest.raises(TypeError) as record:
            backend.cdgmm(np.empty((3, 4, 5)).astype(np.float64),
                          np.empty((4, 5)).astype(np.complex128))
        assert 'first input must be complex' in record.value.args[0]

        with pytest.raises(TypeError) as record:
            backend.cdgmm(np.empty((3, 4, 5)).astype(np.complex128),
                          np.empty((4, 5)).astype(np.int64))
        assert 'second input must be complex or real' in record.value.args[0]

        with pytest.raises(RuntimeError) as record:
            backend.cdgmm(np.empty((3, 4, 5)).astype(np.complex128),
                          np.empty((4, 6)).astype(np.complex128))
        assert 'not compatible for multiplication' in record.value.args[0]
コード例 #2
0
    def test_cdgmm_forward(self, data, backend, inplace):
        x, filt, y = data

        z = backend.cdgmm(x, filt, inplace=inplace)

        assert np.allclose(y, z)