def convolution_2d(op: Convolution2D) -> List[Kernel]: x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] kernel = Kernel({"convolution_2d": source}, "convolution_2d", inputs=[x.parameters["name"]], outputs=[y.parameters["name"]], weights=[w.parameters["name"]], call_option={ "in_spatial": [x.shape_dict[Axis.H], x.shape_dict[Axis.W]], "n": x.shape_dict[Axis.N], "out_size": y.shape_dict[Axis.C], "in_size": x.shape_dict[Axis.C], "out_spatial": [y.shape_dict[Axis.H], y.shape_dict[Axis.W]], "strides_x": calculate_all_strides(x), "strides_w": calculate_all_strides(w), "strides_y": calculate_all_strides(y), "padding": op.padding, "stride": op.stride, "ksize": op.ksize }) return [kernel]
def axiswise_bias(op: AxiswiseBias, memory_layout: MemoryLayout) -> List[Kernel]: # 該当軸のsize, strideを与える x = op.inputs["x"] b = op.inputs["b"] y = op.outputs["y"] assert b.ndim == 1 axis_pos = x.order.axes_dict[op.parameters["axis"]] # NCHWでaxis=Cなら、1 axis_size = x.shape[axis_pos] assert axis_size == b.size axis_stride = mul(x.shape[axis_pos + 1:]) kernel = Kernel({"axiswise_bias": source}, "axiswise_bias", inputs=[x, b], outputs=[y], call_option={ "n": x.size, "axis_stride": axis_stride, "axis_size": axis_size }) return [kernel]
def convolution_2d(op: Convolution2D, memory_layout: MemoryLayout) -> List[Kernel]: x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] kernel = Kernel( {"convolution_2d": source}, "convolution_2d", inputs=[x, w], outputs=[y], call_option={"in_spatial": [x.shape_dict[Axis.H], x.shape_dict[Axis.W]], "n": x.shape_dict[Axis.N], "out_size": y.shape_dict[Axis.C], "in_size": x.shape_dict[Axis.C], "out_spatial": [y.shape_dict[Axis.H], y.shape_dict[Axis.W]], "strides_x": calculate_all_strides(x), "strides_w": calculate_all_strides(w), "strides_y": calculate_all_strides(y), "padding": op.padding, "stride": op.stride, "ksize": op.ksize, "dilation_rate": op.dilation_rate} ) return [kernel]
def axiswise_scale(op: AxiswiseScale) -> List[Kernel]: # 該当軸のsize, strideを与える x = op.inputs["x"] b = op.inputs["s"] y = op.outputs["y"] assert b.ndim == 1 axis_pos = x.order.axes_dict[op.parameters["axis"]] # NCHWでaxis=Cなら、1 axis_size = x.shape[axis_pos] assert axis_size == b.size axis_stride = np.prod( x.shape[axis_pos + 1:]) # NCHWでaxis=Cなら、size(H)*size(W), np.prod([])==1.0 kernel = Kernel({"axiswise_scale": source}, "axiswise_scale", inputs=[x.parameters["name"]], outputs=[y.parameters["name"]], weights=[b.parameters["name"]], call_option={ "n": x.size, "axis_stride": axis_stride, "axis_size": axis_size }) return [kernel]
def split_axis(op: SplitAxis, memory_layout: MemoryLayout) -> List[Kernel]: x = op.inputs["x"] ys = [op.outputs[f"y{i}"] for i in range(len(op.outputs))] target_axis = op.parameters["axis"] y_shapes = [y.shape for y in ys] # y_strides[i][j] is stride size of ys[i].order.axes[j] in x y_strides = [[] for _ in ys] for y, strides in zip(ys, y_strides): for axis in y.order.axes: strides.append(x.stride[x.order.axes_dict[axis]]) # y_offsets[i] is memory offset of ys[i]'s data in x. y_offsets = [] target_axis_offset = 0 for y in ys: y_offsets.append(target_axis_offset * x.stride[x.order.axes_dict[target_axis]]) target_axis_offset += y.shape_dict[target_axis] # (destination address of ys[i][d_0, ..., d_n]) = y_offsets[i] + y_strides[i][0] * d_0 + ... + y_strides[i][n] * d_n kernel = Kernel({"concat": source}, "concat", inputs=[x], outputs=ys, call_option={ "y_shapes": y_shapes, "y_strides": y_strides, "y_offsets": y_offsets }) return [kernel]
def tensordot(op: Tensordot, memory_layout: MemoryLayout) -> List[Kernel]: A = op.inputs["A"] B = op.inputs["B"] C = op.outputs["C"] shape_A_reduced_axes = [A.shape_dict[a] for a in op.axes[0]] shape_B_reduced_axes = [B.shape_dict[a] for a in op.axes[1]] kernel = Kernel({"tensordot": source}, "tensordot", inputs=[memory_layout[A], memory_layout[B]], outputs=[memory_layout[C]], call_option={ "reduction_size": mul(A.shape_dict[a] for a in op.axes[0]), "stride_A": A.stride, "stride_B": B.stride, "stride_C": C.stride, "shape_C": C.shape, "stride_A_for_C_axes": [ 0 if a not in A.order.axes or a in op.axes[0] else A.stride_dict[a] for a in C.order.axes ], "stride_B_for_C_axes": [ 0 if a not in B.order.axes or a in op.axes[1] else B.stride_dict[a] for a in C.order.axes ], "shape_A_reduced_axes": shape_A_reduced_axes, "stride_A_reduced_axes": [ mul(shape_A_reduced_axes[i + 1:]) for i in range(len(shape_A_reduced_axes)) ], "stride_A_reduced_axes_for_whole": [A.stride_dict[a] for a in op.axes[0]], "shape_B_reduced_axes": shape_B_reduced_axes, "stride_B_reduced_axes": [ mul(shape_B_reduced_axes[i + 1:]) for i in range(len(shape_B_reduced_axes)) ], "stride_B_reduced_axes_for_whole": [B.stride_dict[a] for a in op.axes[1]] }) return [kernel]
def elementwise_kernel_base(op: Elementwise, command_buffer: CommandBuffer, buffer_injector: BufferInjector): name_injector = KernelNameInjector(op) source, inputs, outputs, call_option = encode_command(command_buffer) source = buffer_injector.inject(source) source = name_injector.inject(source) kernel = Kernel({name_injector.name: source}, name_injector.name, inputs, outputs, call_option=call_option) return [kernel]
def reshape(op: Reshape, memory_layout: MemoryLayout) -> List[Kernel]: # Operation without need for transposition is currently supported x = op.inputs["x"] y = op.outputs["y"] assert x.order == op.parameters["in_order"] assert y.order == op.parameters["out_order"] assert y.size == mul(op.parameters["out_shape"]) kernel = Kernel({"reshape": source}, "reshape", inputs=[x], outputs=[y], call_option={"length": x.size}) return [kernel]
def reinterpret_axis(op: ReinterpretAxis, memory_layout: MemoryLayout) -> List[Kernel]: # Operation without need for transposition is currently supported x = op.inputs["x"] y = op.outputs["y"] assert x.order == op.parameters["in_order"] assert y.order == op.parameters["out_order"] kernel = Kernel({"reinterpret_axis": source}, "reinterpret_axis", inputs=[memory_layout[x]], outputs=[memory_layout[y]], call_option={"length": x.size}) return [kernel]
def elementwise_sum(op: ElementwiseSum) -> List[Kernel]: assert len(op.inputs) == 2 x0 = op.inputs["x0"] x1 = op.inputs["x1"] y = op.outputs["y"] assert x0.shape == x1.shape assert x0.shape == y.shape kernel = Kernel( {"elementwise_sum": source}, "elementwise_sum", inputs=[x0.parameters["name"], x1.parameters["name"]], outputs=[y.parameters["name"]], weights=[], call_option={"length": x0.size} ) return [kernel]
def flatten(op: Flatten, memory_layout: MemoryLayout) -> List[Kernel]: # データ変換がない場合のみ現状サポート # 該当軸のsize, strideを与える x = op.inputs["x"] y = op.outputs["y"] if x.order == OrderNCHW: assert y.order == OrderNC elif x.order == OrderNHWC: assert y.order == OrderNC else: raise AssertionError("Unsupported order") kernel = Kernel({"flatten": source}, "flatten", inputs=[x], outputs=[y], call_option={"length": x.size}) return [kernel]
def average_pooling_2d(op: AveragePooling2D, memory_layout: MemoryLayout) -> List[Kernel]: x = op.inputs["x"] y = op.outputs["y"] kernel = Kernel( {"average_pooling_2d": source}, "average_pooling_2d", inputs=[x], outputs=[y], call_option={"in_spatial": [x.shape_dict[Axis.H], x.shape_dict[Axis.W]], "n": x.shape_dict[Axis.N], "out_size": y.shape_dict[Axis.C], "out_spatial": [y.shape_dict[Axis.H], y.shape_dict[Axis.W]], "strides_x": calculate_all_strides(x), "strides_y": calculate_all_strides(y), "padding": op.parameters["padding"], "stride": op.parameters["stride"], "ksize": op.parameters["ksize"]} ) return [kernel]
def softmax(op: Softmax, memory_layout: MemoryLayout) -> List[Kernel]: x = op.inputs["x"] y = op.outputs["y"] assert y.order == x.order assert y.shape == x.shape axis = op.parameters["axis"] assert axis == x.order.axes[ -1], "[Fallback] Softmax supports only for aggregating last axis." kernel = Kernel({"softmax": source}, "softmax", inputs=[x], outputs=[y], call_option={ "N": y.size // y.shape_dict[axis], "C": y.shape_dict[axis] }) return [kernel]
def local_response_normalization(op: LocalResponseNormalization, memory_layout: MemoryLayout) -> List[Kernel]: x = op.inputs["x"] y = op.outputs["y"] kernel = Kernel( {"local_response_normalization": source}, "local_response_normalization", inputs=[memory_layout[x]], outputs=[memory_layout[y]], call_option={"out_spatial": [y.shape_dict[Axis.H], y.shape_dict[Axis.W]], "n": x.shape_dict[Axis.N], "out_size": y.shape_dict[Axis.C], "strides_x": calculate_all_strides(x), "strides_y": calculate_all_strides(y), "p_half_n": int(op.parameters["n"] // 2), "p_k": float(op.parameters["k"]), "p_alpha": float(op.parameters["alpha"]), "p_minus_beta": float(-op.parameters["beta"])} ) return [kernel]
def elementwise_kernel(op: Elementwise, memory_layout: MemoryLayout) -> List[Kernel]: xs = [op.inputs[f"x{str(i)}"] for i in range(len(op.inputs))] y = op.outputs["y"] item = _registered_items[op.__class__] parameters = {key: fn(op) for key, fn in item.parameters.items()} x_shapes = [x.shape for x in xs] y_strides = [] stride = 1 for s in reversed(y.shape): y_strides.insert(0, stride) stride *= s # x_strides[i][j] is stride size of xs[i].order.axes[j] in y x_strides_in_y = [[] for _ in xs] for x, strides in zip(xs, x_strides_in_y): for axis in x.order.axes: strides.append(y_strides[y.order.axes_dict[axis]]) call_options = {"x_shapes": x_shapes, "x_strides_in_y": x_strides_in_y} call_options.update({ f"elementwise_parameters_{key}": val for key, val in parameters.items() }) name_injector = KernelNameInjector(op) source = _generate_source(xs, item.code, parameters) source = name_injector.inject(source) kernel = Kernel({name_injector.name: source}, name_injector.name, inputs=xs, outputs=[y], call_option=call_options) return [kernel]
def normalize(op: Normalize, memory_layout: MemoryLayout) -> List[Kernel]: x = op.inputs["x"] y = op.outputs["y"] assert y.order == x.order assert y.shape == x.shape axis = op.parameters["axis"] assert axis == x.order.axes[ -1], "[Fallback] Normalize supports only for aggregating last axis." kernel = Kernel({"normalize": source}, "normalize", inputs=[memory_layout[x]], outputs=[memory_layout[y]], call_option={ "N": y.size // y.shape_dict[axis], "C": y.shape_dict[axis], "eps": op.parameters["eps"] }) return [kernel]
def concat(op: Concat) -> List[Kernel]: xs = [op.inputs[f"x{i}"] for i in range(len(op.inputs))] y = op.outputs["y"] target_axis = op.axis x_shapes = [x.shape for x in xs] y_strides = [] stride = 1 for s in reversed(y.shape): y_strides.insert(0, stride) stride *= s # x_strides[i][j] is stride size of xs[i].order.axes[j] in y x_strides = [[] for _ in xs] for x, strides in zip(xs, x_strides): for axis in x.order.axes: strides.append(y_strides[y.order.axes_dict[axis]]) # x_offsets[i] is memory offset of xs[i]'s data in y. x_offsets = [] target_axis_offset = 0 for x in xs: x_offsets.append(target_axis_offset * y_strides[y.order.axes_dict[target_axis]]) target_axis_offset += x.shape_dict[target_axis] # (destination address of xs[i][d_0, ..., d_n]) = x_offsets[i] + x_strides[i][0] * d_0 + ... + x_strides[i][n] * d_n kernel = Kernel({"concat": source}, "concat", inputs=[x.parameters["name"] for x in xs], outputs=[y.parameters["name"]], weights=[], call_option={ "x_shapes": x_shapes, "x_strides": x_strides, "x_offsets": x_offsets }) return [kernel]
def convolution_2d(op: Convolution2D, memory_layout: MemoryLayout) -> List[Kernel]: x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] kernel = Kernel( {"convolution_2d": source}, "convolution_2d", inputs=[memory_layout[x], memory_layout[w]], outputs=[memory_layout[y]], call_option={ "in_spatial": [x.shape_dict[Axis.H], x.shape_dict[Axis.W]], "n": x.shape_dict[Axis.N], "out_size": y.shape_dict[Axis.C], "in_size": x.shape_dict[Axis.C], "out_spatial": [y.shape_dict[Axis.H], y.shape_dict[Axis.W]], "strides_x": [x.stride_dict[a] for a in [Axis.N, Axis.H, Axis.W, Axis.C]], "strides_w": [w.stride_dict[a] for a in [Axis.N, Axis.KH, Axis.KW, Axis.C]], "strides_y": [y.stride_dict[a] for a in [Axis.N, Axis.H, Axis.W, Axis.C]], "padding": op.padding, "stride": op.stride, "ksize": op.ksize, "dilation_rate": op.dilation_rate }) return [kernel]
def max_pooling_2d(op: MaxPooling2D) -> List[Kernel]: x = op.inputs["x"] y = op.outputs["y"] kernel = Kernel({"max_pooling_2d": source}, "max_pooling_2d", inputs=[x.parameters["name"]], outputs=[y.parameters["name"]], weights=[], call_option={ "in_spatial": [x.shape_dict[Axis.H], x.shape_dict[Axis.W]], "n": x.shape_dict[Axis.N], "out_size": y.shape_dict[Axis.C], "out_spatial": [y.shape_dict[Axis.H], y.shape_dict[Axis.W]], "strides_x": calculate_all_strides(x), "strides_y": calculate_all_strides(y), "padding": op.parameters["padding"], "stride": op.parameters["stride"], "ksize": op.parameters["ksize"] }) return [kernel]
def linear(op: Linear, memory_layout: MemoryLayout) -> List[Kernel]: x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] assert y.order == OrderNC if x.order.ndim == 2: assert w.order.ndim == 2 k = x.shape_dict[Axis.C] m = x.shape_dict[Axis.N] n = w.shape_dict[Axis.N] # 各行列操作方向でのstrideを求める # 操作軸の番号より右側にある(inner-loopの)次元の要素数の積 x_k_stride = calculate_stride(x, Axis.C) x_m_stride = calculate_stride(x, Axis.N) w_k_stride = calculate_stride(w, Axis.C) w_n_stride = calculate_stride(w, Axis.N) elif x.order.ndim == 4: assert w.order.ndim == 4 # CHWが、連続していてx,wで同順のみサポート(NCHW/NCHW, NHWC/HWCN, ...) x_order_wo_n = list(x.order.axes) x_order_wo_n.remove(Axis.N) # [Axis.C, Axis.H, Axis.W] x_n_size = x.shape_dict[Axis.N] x_chw_size = x.size // x_n_size w_order_wo_n = list(w.order.axes) w_order_wo_n.remove(Axis.N) w_n_size = w.shape_dict[Axis.N] w_chw_size = w.size // w_n_size assert x_chw_size == w_chw_size assert x_order_wo_n == w_order_wo_n k = x_chw_size m = x_n_size n = w_n_size if x.order.axes[0] == Axis.N: # N*** x_k_stride = 1 x_m_stride = x_chw_size elif x.order.axes[3] == Axis.N: # ***N x_k_stride = x_n_size x_m_stride = 1 else: # such as HWNC raise ValueError() if w.order.axes[0] == Axis.N: # N*** w_k_stride = 1 w_n_stride = w_chw_size elif w.order.axes[3] == Axis.N: # ***N w_k_stride = w_n_size w_n_stride = 1 else: # such as HWNC raise ValueError() else: raise ValueError() kernel = Kernel( {"linear": source}, "linear", inputs=[x, w], outputs=[y], call_option={"m": m, "n": n, "k": k, "x_k_stride": x_k_stride, "x_m_stride": x_m_stride, "w_k_stride": w_k_stride, "w_n_stride": w_n_stride} ) return [kernel]