Ejemplo n.º 1
0
def template(x_order=OrderNHWC, y_order=OrderNHW, axis=Axis.C, description: str = ""):
    vx = np.arange(120).reshape(2, 3, 4, 5)
    vy = np.max(vx, axis=OrderNHWC.axes_dict[axis])

    x = Variable(vx.shape, order=OrderNHWC)
    y, = Max(None, axis=axis)(x)

    x.change_order(x_order)
    y.change_order(y_order)

    generate_kernel_test_case(
        description=f"Max {description}",
        graph=Graph([x], [y]),
        backend=["webgpu", "webgl", "webassembly"],
        inputs={x: np.transpose(vx, [OrderNHWC.axes_dict[a] for a in x.order.axes])},
        expected={y: np.transpose(vy, [OrderNHW.axes_dict[a] for a in y.order.axes])},
    )
Ejemplo n.º 2
0
def template(x_shape=[2, 3, 4, 5],
             x_order=OrderNHWC,
             y_order=OrderNHWC,
             axis=Axis.C,
             description: str = ""):
    vx = np.random.rand(*x_shape)
    vy = np.max(vx, axis=x_order.axes_dict[axis], keepdims=True)

    x = Variable(vx.shape, order=x_order)
    y, = Max(None, axis=axis)(x)

    y.change_order(y_order)

    generate_kernel_test_case(
        description=f"Max {description}",
        graph=Graph([x], [y]),
        backend=["webgpu", "webgl", "webassembly"],
        inputs={x: vx},
        expected={
            y: np.transpose(vy, [x.order.axes_dict[a] for a in y.order.axes])
        },
    )