def template(x_order=OrderNHWC, y_order=OrderNHW, axis=Axis.C, description: str = ""): vx = np.arange(120).reshape(2, 3, 4, 5) vy = np.max(vx, axis=OrderNHWC.axes_dict[axis]) x = Variable(vx.shape, order=OrderNHWC) y, = Max(None, axis=axis)(x) x.change_order(x_order) y.change_order(y_order) generate_kernel_test_case( description=f"Max {description}", graph=Graph([x], [y]), backend=["webgpu", "webgl", "webassembly"], inputs={x: np.transpose(vx, [OrderNHWC.axes_dict[a] for a in x.order.axes])}, expected={y: np.transpose(vy, [OrderNHW.axes_dict[a] for a in y.order.axes])}, )
def template(x_shape=[2, 3, 4, 5], x_order=OrderNHWC, y_order=OrderNHWC, axis=Axis.C, description: str = ""): vx = np.random.rand(*x_shape) vy = np.max(vx, axis=x_order.axes_dict[axis], keepdims=True) x = Variable(vx.shape, order=x_order) y, = Max(None, axis=axis)(x) y.change_order(y_order) generate_kernel_test_case( description=f"Max {description}", graph=Graph([x], [y]), backend=["webgpu", "webgl", "webassembly"], inputs={x: vx}, expected={ y: np.transpose(vy, [x.order.axes_dict[a] for a in y.order.axes]) }, )