def test_wide_stride_CNHW(): v_im, v_col = generate_data_212() col_dummy = ConstantVariable(v_col, order=OrderNHWC) col_dummy.change_order(OrderCNHW) im = Variable(v_im.shape, order=OrderNHWC) col_wasm, = WasmIm2Col(None, ksize=2, padding=1, stride=2)(im) col_wasm.change_order(OrderCNHW) col_webgpu, = WebGPUIm2Col(None, ksize=2, padding=1, stride=2)(im) col_webgpu.change_order(OrderCNHW) generate_kernel_test_case(description=f"Im2Col output=CNHW stride=2", backend=["webassembly"], graph=Graph([im], [col_wasm]), inputs={im: v_im}, expected={col_wasm: col_dummy.data}, raise_skip=False) generate_kernel_test_case(description=f"Im2Col output=CNHW stride=2", backend=["webgpu"], graph=Graph([im], [col_webgpu]), inputs={im: v_im}, expected={col_webgpu: col_dummy.data})
def test_NHWC(): v_im, v_col = generate_data_311() im = Variable(v_im.shape, order=OrderNHWC) col_wasm, = WasmIm2Col(None, ksize=3, padding=1, stride=1, dilation_rate=1)(im) col_wasm.change_order(OrderNHWC) col_webgpu, = WebGPUIm2Col(None, ksize=3, padding=1, stride=1, dilation_rate=1)(im) col_webgpu.change_order(OrderNHWC) generate_kernel_test_case(description=f"Im2Col output=NHWC", backend=["webassembly"], graph=Graph([im], [col_wasm]), inputs={im: v_im}, expected={col_wasm: v_col}, raise_skip=False) generate_kernel_test_case(description=f"Im2Col output=NHWC", backend=["webgpu"], graph=Graph([im], [col_webgpu]), inputs={im: v_im}, expected={col_webgpu: v_col})