Beispiel #1
0
def test_wide_stride_CNHW():
    v_im, v_col = generate_data_212()

    col_dummy = ConstantVariable(v_col, order=OrderNHWC)
    col_dummy.change_order(OrderCNHW)

    im = Variable(v_im.shape, order=OrderNHWC)

    col_wasm, = WasmIm2Col(None, ksize=2, padding=1, stride=2)(im)
    col_wasm.change_order(OrderCNHW)

    col_webgpu, = WebGPUIm2Col(None, ksize=2, padding=1, stride=2)(im)
    col_webgpu.change_order(OrderCNHW)

    generate_kernel_test_case(description=f"Im2Col output=CNHW stride=2",
                              backend=["webassembly"],
                              graph=Graph([im], [col_wasm]),
                              inputs={im: v_im},
                              expected={col_wasm: col_dummy.data},
                              raise_skip=False)

    generate_kernel_test_case(description=f"Im2Col output=CNHW stride=2",
                              backend=["webgpu"],
                              graph=Graph([im], [col_webgpu]),
                              inputs={im: v_im},
                              expected={col_webgpu: col_dummy.data})
Beispiel #2
0
def test_NHWC():
    v_im, v_col = generate_data_311()

    im = Variable(v_im.shape, order=OrderNHWC)

    col_wasm, = WasmIm2Col(None, ksize=3, padding=1, stride=1,
                           dilation_rate=1)(im)
    col_wasm.change_order(OrderNHWC)

    col_webgpu, = WebGPUIm2Col(None,
                               ksize=3,
                               padding=1,
                               stride=1,
                               dilation_rate=1)(im)
    col_webgpu.change_order(OrderNHWC)

    generate_kernel_test_case(description=f"Im2Col output=NHWC",
                              backend=["webassembly"],
                              graph=Graph([im], [col_wasm]),
                              inputs={im: v_im},
                              expected={col_wasm: v_col},
                              raise_skip=False)

    generate_kernel_test_case(description=f"Im2Col output=NHWC",
                              backend=["webgpu"],
                              graph=Graph([im], [col_webgpu]),
                              inputs={im: v_im},
                              expected={col_webgpu: v_col})