def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes(traverse.listup_operators(graph), Linear): # type: Linear x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] assert x.order == OrderNC or x.order == OrderNHWC, f"(x.order) = {x.order}" assert w.order == OrderCN or w.order == OrderHWCN, f"(x.order) = {w.order}" assert y.order == OrderNC or y.order == OrderNHWC, f"(x.order) = {y.order}" assert w.ndim == x.ndim flag_changed = True op.remove_all() sgemm = Sgemm(None, M=y.shape_dict[Axis.N], N=y.size // y.shape_dict[Axis.N], K=x.size // x.shape_dict[Axis.N], out_shape=y.shape, out_order=y.order, transpose_A=True, transpose_B=True) new_y, = sgemm(x, w) sgemm.replace_output(new_y, y) return graph, flag_changed
def template(transpose_A=False, transpose_B=False, M=5, N=8, K=6, description: str = ""): va = np.random.rand(M, K).astype(np.float32) vb = np.random.rand(K, N).astype(np.float32) va[0, :] = 2 vb[:, 0] = 2 vc = np.dot(va, vb) a = Variable((va if transpose_A else va.transpose()).shape, order=OrderNC) b = ConstantVariable((vb if transpose_B else vb.transpose()), order=OrderNC) c, = Sgemm(None, M=M, N=N, K=K, out_shape=[M, N], out_order=OrderNC, transpose_A=transpose_A, transpose_B=transpose_B)(a, b) generate_kernel_test_case( description=f"Sgemm {description}", backend=["webgpu", "webassembly", "webgl"], graph=Graph([a], [c]), inputs={a: (va if transpose_A else va.transpose())}, expected={c: vc})
def test_sgemm_invalid_C_shape(): op = Sgemm(None, M=10, N=20, K=30, out_shape=[1, 2, 3, 4], out_order=OrderNHWC, transpose_A=True, transpose_B=True) x = Variable((10, 30), OrderNC) w = Variable((20, 30), OrderNC) op(x, w)
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes(traverse.listup_operators(graph), Convolution2D): # type: Convolution2D x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] assert x.order == OrderNHWC assert y.order == OrderNHWC assert isinstance(w, ConstantVariable) flag_changed = True op.remove_all() w.change_order(OrderHWCN) if op.WH != 1 or op.WW != 1 or op.stride != ( 1, 1) or op.padding != (0, 0): im2col = Im2Col(None, ksize=op.ksize, stride=op.stride, padding=op.padding, dilation_rate=op.dilation_rate) col, = im2col(x) col.change_order(OrderNHWC) else: col = x sgemm = Sgemm( None, M=col.shape_dict[Axis.N] * col.shape_dict[Axis.H] * col.shape_dict[Axis.W], N=w.shape_dict[Axis.N], K=col.shape_dict[Axis.C], out_shape=[ col.shape_dict[Axis.N], col.shape_dict[Axis.H], col.shape_dict[Axis.W], w.shape_dict[Axis.N] ], out_order=OrderNHWC, transpose_A=True if col.order == OrderNHWC else False, transpose_B=True) new_y, = sgemm(col, w) new_y.replace(y) return graph, flag_changed
def sgemm(op: Sgemm, memory_layout: MemoryLayout) -> List[Kernel]: A = op.inputs["A"] B = op.inputs["B"] C = op.outputs["C"] buffer_injector = BufferInjector() buffer_injector.register({ "sgemm_A": memory_layout[A], "sgemm_B": memory_layout[B], "sgemm_C": memory_layout[C], "sgemm_M": op.M, "sgemm_N": op.N, "sgemm_K": op.K }) if op.has_attribute(SgemmWithEigen): source = generate_template_eigen(op.transpose_A, op.transpose_B) buffer_injector.register({ "sgemm_A": memory_layout[A], "sgemm_B": memory_layout[B], "sgemm_C": memory_layout[C] }) else: source = generate_template(op.transpose_A, op.transpose_B) buffer_injector.register({ "sgemm_A": memory_layout[A], "sgemm_B": memory_layout[B], "sgemm_C": memory_layout[C], "sgemm_M": op.M, "sgemm_N": op.N, "sgemm_K": op.K }) name_injector = KernelNameInjector(op) source = buffer_injector.inject(source) source = name_injector.inject(source) kernel = Kernel({name_injector.name: source}, name_injector.name, buffer_injector.buffer, buffer_injector.unresolved_value_list) return [kernel]
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes( traverse.listup_operators(graph), Deconvolution2D): # type: Deconvolution2D x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] assert y.order == OrderNHWC assert y.order == OrderNHWC assert isinstance(w, ConstantVariable) flag_changed = True op.remove_all() w.change_order(OrderCHWN) sgemm = Sgemm(None, M=x.shape_dict[Axis.N] * x.shape_dict[Axis.H] * x.shape_dict[Axis.W], N=w.shape_dict[Axis.H] * w.shape_dict[Axis.W] * w.shape_dict[Axis.N], K=x.shape_dict[Axis.C], out_shape=[ x.shape_dict[Axis.N], x.shape_dict[Axis.H], x.shape_dict[Axis.W], w.shape_dict[Axis.H] * w.shape_dict[Axis.W] * w.shape_dict[Axis.N] ], out_order=OrderNHWC, transpose_A=True if x.order == OrderNHWC else False, transpose_B=True) col, = sgemm(x, w) col2im = Col2Im(None, ksize=op.ksize, stride=op.stride, padding=op.padding) new_y, = col2im(col) col2im.replace_output(new_y, y) return graph, flag_changed
def test_sgemm(): op = Sgemm(None, M=10, N=20, K=30, out_shape=[1, 10, 4, 5], out_order=OrderNHWC, transpose_A=True, transpose_B=True) x = Variable((10, 30), OrderNC) w = Variable((20, 30), OrderNC) y, = op(x, w) assert y.order == OrderNHWC assert y.shape_dict[Axis.N] == 1 assert y.shape_dict[Axis.H] == 10 assert y.shape_dict[Axis.W] == 4 assert y.shape_dict[Axis.C] == 5
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes(traverse.listup_operators(graph), Linear): x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] flag_changed = True op.remove_all() a_k = Axis.C a_n = w.order.axes[0] if w.order.axes[1] == a_k else w.order.axes[1] axes_m = [a for a in x.order.axes if a != a_k] K = x.shape_dict[a_k] M = x.size // K N = w.shape_dict[a_n] x, = Transpose(None)(x) x.change_order(Order([a_k] + axes_m)) w, = Transpose(None)(w) w.change_order(Order([a_k, a_n])) new_y, = Sgemm(None, M=M, N=N, K=K, out_shape=[x.shape_dict[a] for a in axes_m] + [N], out_order=Order(axes_m + [a_n]), transpose_A=False, transpose_B=True)(x, w) new_y, = Transpose(None)(new_y) OptimizeRule.replace_variable(graph, new_y, y) return graph, flag_changed
def _split_sgemm(graph: Graph, op: Sgemm, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] A = op.inputs["A"] B = op.inputs["B"] C = op.outputs["C"] transpose_A, transpose_B = op.transpose_A, op.transpose_B M, K, N = op.M, op.K, op.N axis_M, axis_K, axis_N = Axis(None), Axis(None), Axis(None) op.remove_all() def decompose_logical_axes(logical_shape: Tuple[int, int], v: Variable): """ Decompose logical axes into real axes Examples:: A.order, A.shape >>> "NCHW", (1, 128, 8, 8) M = 128 K = 64 decompose_logical_axes([M, K], A) >>> ["<Axis N>", "<Axis C>"], ["<Axis H>", "<Axis W>"] """ total_size = 1 axes1 = [] # type: List[Axis] axes2 = list(v.order.axes) # type: List[Axis] for size, a in zip(v.shape, v.order.axes): if total_size == logical_shape[0]: return axes1, axes2 elif total_size > logical_shape[0]: raise ValueError axes1.append(a) axes2.remove(a) total_size *= size if v == A: A1, A2 = v_pair if transpose_A: # A.shape = [M, K] axes_M, axes_K = decompose_logical_axes((M, K), A) else: # A.shape = [K, M] axes_K, axes_M = decompose_logical_axes((K, M), A) if axis in axes_K: """ before) A -{sgemm}- C after) In case `axis` is in `K`, A_0 -{sgemm}- C_0 -+ +-{Add}- C A_1 -{sgemm}- C_1 -+ """ K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2) # Factorize B's axes included in K into A's corresponding axes if transpose_B: # B: [k_b1, k_b2, ..., N] -{reshape}-> [k_a1, k_a2, ..., N] B, = Reshape(None, in_order=B.order, out_order=Order(axes_K + [axis_N]), out_shape=[A.shape_dict[a] for a in axes_K] + [N])(B) else: # B: [N, k_b1, k_b2, ...] -{reshape}-> [N, k_a1, k_a2, ...] B, = Reshape(None, in_order=B.order, out_order=Order([axis_N] + axes_K), out_shape=[N] + [A.shape_dict[a] for a in axes_K])(B) B1, B2 = SplitAxis(None, axis=axis, sections=[s1])(B) C1, = Sgemm(None, M=M, K=K1, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A1, B1) C2, = Sgemm(None, M=M, K=K2, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A2, B2) OptimizeRule.replace_variable(graph, C1 + C2, C) else: assert axis in axes_M """ before) A -{sgemm}- C after) In case `axis` is in `M`, A_0 -{sgemm}- C_0 -+ +-{Concat}- C A_1 -{sgemm}- C_1 -+ """ M1, M2 = M * s1 // (s1 + s2), M * s2 // (s1 + s2) c_tmp_order = Order(axes_M + [axis_N]) c1_shape = [A1.shape_dict[a] for a in axes_M] + [N] c2_shape = [A2.shape_dict[a] for a in axes_M] + [N] C1, = Sgemm(None, M=M1, K=K, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c1_shape, out_order=c_tmp_order)(A1, B) C2, = Sgemm(None, M=M2, K=K, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c2_shape, out_order=c_tmp_order)(A2, B) C_new, = Concat(None, axis=axis)(C1, C2) C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new) OptimizeRule.replace_variable(graph, C_new, C) elif v == B: B1, B2 = v_pair if transpose_B: # B.shape = [K, N] axes_K, axes_N = decompose_logical_axes((K, N), B) else: # B.shape = [N, K] axes_N, axes_K = decompose_logical_axes((N, K), B) if axis in axes_K: """ before) B -{sgemm}- C after) In case `axis` is in `K`, B_0 -{sgemm}- C_0 -+ +-{Add}- C B_1 -{sgemm}- C_1 -+ """ K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2) # Factorize A's axes included in K into B's corresponding axes if transpose_A: # A: [M, k_a1, k_a2, k_a3, ...] -{reshape}-> [M, k_b1, k_b2, ...] A, = Reshape(None, in_order=A.order, out_order=Order([axis_M] + axes_K), out_shape=[M] + [B.shape_dict[a] for a in axes_K])(A) else: # A: [k_a1, k_a2, k_a3, ..., M] -{reshape}-> [k_b1, k_b2, ..., M] A, = Reshape(None, in_order=A.order, out_order=Order(axes_K + [axis_M]), out_shape=[B.shape_dict[a] for a in axes_K] + [M])(A) A1, A2 = SplitAxis(None, axis=axis, sections=[s1])(A) C1, = Sgemm(None, M=M, K=K1, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A1, B1) C2, = Sgemm(None, M=M, K=K2, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A2, B2) OptimizeRule.replace_variable(graph, C1 + C2, C) else: assert axis in axes_N """ before) C[M, N] = A[M, K] @ B[K, N] after) In case `axis` is in `N`, C[M, N] = Concat(C1[M, N1], C2[M, N2]) = Concat(A[M, K] @ B1[K, N1], A[M, K] @ B2[K, N2]) """ N1, N2 = N * s1 // (s1 + s2), N * s2 // (s1 + s2) c_tmp_order = Order([axis_M] + axes_N) c1_shape = [M] + [B1.shape_dict[a] for a in axes_N] c2_shape = [M] + [B2.shape_dict[a] for a in axes_N] C1, = Sgemm(None, M=M, K=K, N=N1, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c1_shape, out_order=c_tmp_order)(A, B1) # C1.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis], ...] # C1.order = [axis_M, n1, n2, ..., axis, ...] C2, = Sgemm(None, M=M, K=K, N=N2, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c2_shape, out_order=c_tmp_order)(A, B2) C_new, = Concat(None, axis=axis)(C1, C2) # C_new.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis]+B2.shape_dict[axis], ...] # C_new.order = [axis_M, n1, n2, ..., axis, ...] C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new) OptimizeRule.replace_variable(graph, C_new, C) elif v == C: """ before) C[M, N] = A[M, K] @ B[K, N] after) In case `axis` is in `N`, C[M, N1] = Concat(A[M, K] @ B1[K, N1]) C[M, N2] = Concat(A[M, K] @ B2[K, N2]) """ raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError