コード例 #1
0
ファイル: select_test.py プロジェクト: zhangaz1/webdnn
def template(shape=(2, 3, 4, 5), x0_order=OrderNHWC, x1_order=OrderNHWC, x2_order=OrderNHWC, y_order=OrderNHWC, description: str = ""):
    vx0 = np.where(np.random.rand(*shape).astype(np.float32) > 0.5, 0, 1)
    vx1 = np.random.rand(*shape).astype(np.float32) - 0.5
    vx2 = np.random.rand(*shape).astype(np.float32) - 0.5
    vy = np.where(vx0 == 1, vx1, vx2)

    x0 = Variable(vx0.shape, order=OrderNHWC)
    x1 = Variable(vx1.shape, order=OrderNHWC)
    x2 = Variable(vx2.shape, order=OrderNHWC)
    y, = Select(None)(x0, x1, x2)
    x0.change_order(x0_order)
    x1.change_order(x1_order)
    x2.change_order(x2_order)
    y.change_order(y_order)

    generate_kernel_test_case(
        description=f"Select {description}",
        graph=Graph([x0, x1, x2], [y]),
        inputs={
            x0: np.transpose(vx0, [OrderNHWC.axes_dict[a] for a in x0.order.axes]),
            x1: np.transpose(vx1, [OrderNHWC.axes_dict[a] for a in x1.order.axes]),
            x2: np.transpose(vx2, [OrderNHWC.axes_dict[a] for a in x2.order.axes])
        },
        expected={y: np.transpose(vy, [OrderNHWC.axes_dict[a] for a in y.order.axes])},
    )
コード例 #2
0
ファイル: select_test.py プロジェクト: zhangaz1/webdnn
def test_broadcast():
    vx0 = np.where(np.random.rand(*[2, 5]).astype(np.float32) > 0.5, 0, 1)
    vx1 = np.random.rand(*[3, 4]).astype(np.float32) - 0.5
    vx2 = np.random.rand(*[2, 3, 4, 5]).astype(np.float32) - 0.5
    vy = np.where(vx0[:, None, None, :] == 1, vx1[None, :, :, None], vx2)

    x0 = Variable(vx0.shape, order=OrderNC)
    x1 = Variable(vx1.shape, order=Order([Axis.H, Axis.W]))
    x2 = Variable(vx2.shape, order=OrderNHWC)
    y, = Select(None)(x0, x1, x2)
    y.change_order(OrderNHWC)

    generate_kernel_test_case(
        description=f"Select broadcast",
        graph=Graph([x0, x1, x2], [y]),
        inputs={x0: vx0, x1: vx1, x2: vx2},
        expected={y: vy},
    )