Exemplo n.º 1
0
def calculate_all_strides(var):
    return [
        calculate_stride(var, axis)
        for axis in [Axis.N, Axis.H, Axis.W, Axis.C]
    ]
Exemplo n.º 2
0
def linear(op: Linear, memory_layout: MemoryLayout) -> List[Kernel]:
    x = op.inputs["x"]
    w = op.inputs["w"]
    y = op.outputs["y"]

    assert y.order == OrderNC
    if x.order.ndim == 2:
        assert w.order.ndim == 2
        k = x.shape_dict[Axis.C]
        m = x.shape_dict[Axis.N]
        n = w.shape_dict[Axis.N]
        # 各行列操作方向でのstrideを求める
        # 操作軸の番号より右側にある(inner-loopの)次元の要素数の積
        x_k_stride = calculate_stride(x, Axis.C)
        x_m_stride = calculate_stride(x, Axis.N)
        w_k_stride = calculate_stride(w, Axis.C)
        w_n_stride = calculate_stride(w, Axis.N)
    elif x.order.ndim == 4:
        assert w.order.ndim == 4
        # CHWが、連続していてx,wで同順のみサポート(NCHW/NCHW, NHWC/HWCN, ...)
        x_order_wo_n = list(x.order.axes)
        x_order_wo_n.remove(Axis.N)  # [Axis.C, Axis.H, Axis.W]
        x_n_size = x.shape_dict[Axis.N]
        x_chw_size = x.size // x_n_size
        w_order_wo_n = list(w.order.axes)
        w_order_wo_n.remove(Axis.N)
        w_n_size = w.shape_dict[Axis.N]
        w_chw_size = w.size // w_n_size

        assert x_chw_size == w_chw_size
        assert x_order_wo_n == w_order_wo_n
        k = x_chw_size
        m = x_n_size
        n = w_n_size
        if x.order.axes[0] == Axis.N:
            # N***
            x_k_stride = 1
            x_m_stride = x_chw_size
        elif x.order.axes[3] == Axis.N:
            # ***N
            x_k_stride = x_n_size
            x_m_stride = 1
        else:
            # such as HWNC
            raise ValueError()
        if w.order.axes[0] == Axis.N:
            # N***
            w_k_stride = 1
            w_n_stride = w_chw_size
        elif w.order.axes[3] == Axis.N:
            # ***N
            w_k_stride = w_n_size
            w_n_stride = 1
        else:
            # such as HWNC
            raise ValueError()

    else:
        raise ValueError()

    kernel = Kernel(
        {"linear": source},
        "linear",
        inputs=[x, w],
        outputs=[y],
        call_option={"m": m,
                     "n": n,
                     "k": k,
                     "x_k_stride": x_k_stride,
                     "x_m_stride": x_m_stride,
                     "w_k_stride": w_k_stride,
                     "w_n_stride": w_n_stride}
    )

    return [kernel]