def calculate_all_strides(var): return [ calculate_stride(var, axis) for axis in [Axis.N, Axis.H, Axis.W, Axis.C] ]
def linear(op: Linear, memory_layout: MemoryLayout) -> List[Kernel]: x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] assert y.order == OrderNC if x.order.ndim == 2: assert w.order.ndim == 2 k = x.shape_dict[Axis.C] m = x.shape_dict[Axis.N] n = w.shape_dict[Axis.N] # 各行列操作方向でのstrideを求める # 操作軸の番号より右側にある(inner-loopの)次元の要素数の積 x_k_stride = calculate_stride(x, Axis.C) x_m_stride = calculate_stride(x, Axis.N) w_k_stride = calculate_stride(w, Axis.C) w_n_stride = calculate_stride(w, Axis.N) elif x.order.ndim == 4: assert w.order.ndim == 4 # CHWが、連続していてx,wで同順のみサポート(NCHW/NCHW, NHWC/HWCN, ...) x_order_wo_n = list(x.order.axes) x_order_wo_n.remove(Axis.N) # [Axis.C, Axis.H, Axis.W] x_n_size = x.shape_dict[Axis.N] x_chw_size = x.size // x_n_size w_order_wo_n = list(w.order.axes) w_order_wo_n.remove(Axis.N) w_n_size = w.shape_dict[Axis.N] w_chw_size = w.size // w_n_size assert x_chw_size == w_chw_size assert x_order_wo_n == w_order_wo_n k = x_chw_size m = x_n_size n = w_n_size if x.order.axes[0] == Axis.N: # N*** x_k_stride = 1 x_m_stride = x_chw_size elif x.order.axes[3] == Axis.N: # ***N x_k_stride = x_n_size x_m_stride = 1 else: # such as HWNC raise ValueError() if w.order.axes[0] == Axis.N: # N*** w_k_stride = 1 w_n_stride = w_chw_size elif w.order.axes[3] == Axis.N: # ***N w_k_stride = w_n_size w_n_stride = 1 else: # such as HWNC raise ValueError() else: raise ValueError() kernel = Kernel( {"linear": source}, "linear", inputs=[x, w], outputs=[y], call_option={"m": m, "n": n, "k": k, "x_k_stride": x_k_stride, "x_m_stride": x_m_stride, "w_k_stride": w_k_stride, "w_n_stride": w_n_stride} ) return [kernel]