def test_mix_order(): vx1 = np.random.rand(2, 3, 4, 5) vx2 = np.random.rand(2, 3, 4, 5) vx3 = np.random.rand(2, 3, 4, 5) vx4 = np.random.rand(2, 3, 4, 5) vy = np.concatenate((vx1, vx2, vx3, vx4), 1) y = Variable(vy.shape, order=OrderNHWC) x1, x2, x3, x4, = SplitAxis(None, axis=Axis.H, sections=[3, 6, 9])(y) x2.change_order(OrderCNHW) vx2 = np.rollaxis(vx2, 3, 0) x3.change_order(OrderCHWN) vx3 = np.rollaxis(np.rollaxis(vx3, 3, 0), 1, 4) x4.change_order(OrderNCHW) vx4 = np.rollaxis(vx4, 3, 1) generate_kernel_test_case(description=f"SplitAxis with mix order", graph=Graph([y], [x1, x2, x3, x4]), inputs={y: vy}, expected={ x1: vx1, x2: vx2, x3: vx3, x4: vx4 })
def _convert_split_axis(converter: ChainerConverter, c_op: "chainer.functions.SplitAxis"): x = converter.get_variable(c_op.inputs[0]) if isinstance(c_op.indices_or_sections, int): raise NotImplementedError("[ChainerConverter] SplitAxis with indices are not supported.") ys = SplitAxis(None, sections=c_op.indices_or_sections, axis=x.order.axes[c_op.axis])(x) for i, y in enumerate(ys): converter.set_variable(c_op.outputs[i](), y)
def _split_tensorwise(graph: Graph, op: Operator, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] xs = dict(op.inputs) ys = dict(op.outputs) op.remove_all() op_0 = op.copy() op_1 = op.copy() for key, x in xs.items(): if x == v: x_0, x_1 = v_pair else: if axis in x.order.axes: x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x) else: # splitting is not occurred x_0 = x_1 = x op_0.append_input(key, x_0) op_1.append_input(key, x_1) for key, y in ys.items(): if y == v: y_0, y_1 = v_pair else: if axis in y.order.axes: # TODO (Kiikurage) # Attribute attached to "y" is not copied to neither "y_0" or "y_1" y_0 = Variable([ s1 if a == axis else y.shape_dict[a] for a in y.order.axes ], y.order) y_1 = Variable([ s2 if a == axis else y.shape_dict[a] for a in y.order.axes ], y.order) y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y, y_new) else: raise UnexpectedAndPleaseReportError op_0.append_output(key, y_0) op_1.append_output(key, y_1)
def optimize(self, graph): flag_changed = False matches = traverse.search_sub_structure( graph, [SplitAxis, Variable, SplitAxis]) while len(matches) > 0: op1, h, op2 = matches.pop() # type: SplitAxis, Variable, SplitAxis if len(h.input_to) > 1: # `h` will be removed by this optimization continue if op1.axis != op2.axis: # These operations cannot be merged. continue flag_changed = True x = op1.inputs["x"] hs = [op1.outputs[f"y{i}"] for i in range(len(op1.outputs))] i_h = hs.index(h) original_ys = list(hs) new_sections = op1.sections original_ys.remove(h) section_offset = ([0] + op1.sections)[i_h] op2_sections = [0] + op2.sections for i in range(len(op2.outputs)): original_ys.insert(i_h + i, op2.outputs[f"y{i}"]) new_sections.insert(i_h + i, section_offset + op2_sections[i]) new_sections.remove(section_offset) op1.remove_all() op2.remove_all() new_ys = SplitAxis(None, axis=op1.axis, sections=new_sections)(x) for original_y, new_y in zip(original_ys, new_ys): new_y.change_order(original_y.order) new_y.replace(original_y) matches = traverse.search_sub_structure( graph, [SplitAxis, Variable, SplitAxis]) return graph, flag_changed
def _convert_split(converter: ONNXConverter, onnx_op: INodeProto): x = converter.get_variable(onnx_op.input[0]) attrs = attribute_dict(onnx_op) axis = x.order.axes[attrs["axis"].i] if "split" not in attrs: raise NotImplementedError( "[ONNXConverter] Operator \"Split\" without \"split\" parameter is not supported yet." ) split = attrs["split"].ints sections = np.cumsum(split).tolist()[:-1] ys = SplitAxis(None, axis=axis, sections=sections)(x) for i, y in enumerate(ys): converter.set_variable(onnx_op.output[i], y)
def test_middle_axis(): vx1 = np.random.rand(2, 3, 4, 5) vx2 = np.random.rand(2, 3, 4, 5) vx3 = np.random.rand(2, 3, 4, 5) vx4 = np.random.rand(2, 3, 4, 5) vy = np.concatenate((vx1, vx2, vx3, vx4), 1) y = Variable(vy.shape, order=OrderNHWC) x1, x2, x3, x4, = SplitAxis(None, axis=Axis.H, sections=[3, 6, 9])(y) generate_kernel_test_case(description=f"SplitAxis in middle axis", graph=Graph([y], [x1, x2, x3, x4]), inputs={y: vy}, expected={ x1: vx1, x2: vx2, x3: vx3, x4: vx4 })
def test_2d(): vx1 = np.random.rand(2, 3) vx2 = np.random.rand(2, 3) vx3 = np.random.rand(2, 3) vx4 = np.random.rand(2, 3) vy = np.concatenate((vx1, vx2, vx3, vx4), 0) y = Variable(vy.shape, order=OrderNC) x1, x2, x3, x4, = SplitAxis(None, axis=Axis.N, sections=[2, 4, 6])(y) generate_kernel_test_case(description=f"SplitAxis 2D", graph=Graph([y], [x1, x2, x3, x4]), inputs={y: vy}, expected={ x1: vx1, x2: vx2, x3: vx3, x4: vx4 })
def _split_tensorwise(graph: Graph, op: Operator, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] xs = dict(op.inputs) ys = dict(op.outputs) op.remove_all() op_0 = op.copy() op_1 = op.copy() for key in xs.keys(): x = xs[key] if x == v: x_0, x_1 = v_pair else: if axis not in x.order.axes or x.shape_dict[axis] == 1: # broadcasting x_0 = x_1 = x else: x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x) op_0.append_input(key, x_0) op_1.append_input(key, x_1) op_0.exec() op_1.exec() for key in ys.keys(): y = ys[key] if y == v: OptimizeRule.replace_variable( graph, op_0.outputs[key].transpose_like(v_pair[0]), v_pair[0]) OptimizeRule.replace_variable( graph, op_1.outputs[key].transpose_like(v_pair[1]), v_pair[1]) else: y_0 = op_0.outputs[key] y_1 = op_1.outputs[key] y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new.transpose_like(y), y)
def test_minor_axis(): vx1 = np.random.rand(2, 3, 4, 5) vx2 = np.random.rand(2, 3, 4, 5) vx3 = np.random.rand(2, 3, 4, 5) vx4 = np.random.rand(2, 3, 4, 5) vy = np.concatenate((vx1, vx2, vx3, vx4), 3) y = Variable(vy.shape, order=OrderNHWC) x1, x2, x3, x4, = SplitAxis(None, axis=Axis.C, sections=[5, 10, 15])(y) generate_kernel_test_case( description=f"SplitAxis in minor axis", backend=["webgpu", "webassembly", "fallback"], graph=Graph([y], [x1, x2, x3, x4]), inputs={y: vy}, expected={ x1: vx1, x2: vx2, x3: vx3, x4: vx4 } )
def _convert_split_axis(converter: ChainerConverter, c_op: "chainer.functions.SplitAxis"): x = converter.get_variable(c_op.inputs[0]) VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(chainer.__version__) if VERSION_MAJOR >= 4: # Internal data structure changed # https://github.com/chainer/chainer/commit/906a8e9b0837cd9a4e5ee6f1dbda26431a1e12d1#diff-9e610d281c820d44c4a0cbf0ca6263fd if c_op.indices is None: raise NotImplementedError( "[ChainerConverter] SplitAxis with sections are not supported." ) indices = c_op.indices else: if isinstance(c_op.indices_or_sections, int): raise NotImplementedError( "[ChainerConverter] SplitAxis with sections are not supported." ) indices = c_op.indices_or_sections ys = SplitAxis(None, sections=indices, axis=x.order.axes[c_op.axis])(x) for i, y in enumerate(ys): converter.set_variable(c_op.outputs[i](), y)
def _split_tensordot(graph: Graph, op: Tensordot, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] A = op.inputs["A"] B = op.inputs["B"] C = op.outputs["C"] axes_M = tuple(filter(lambda a: a not in op.axes[0], A.order.axes)) axes_N = tuple(filter(lambda a: a not in op.axes[1], B.order.axes)) axes_K_A, axes_K_B = op.axes K = mul(A.shape_dict[a] for a in axes_K_A) M = A.size // K N = B.size // K shape_M = [A.shape_dict[a] for a in axes_M] shape_N = [B.shape_dict[a] for a in axes_N] op.remove_all() if v == A: A1, A2 = v_pair if axis in axes_K_A: split_axis_A = axis if (B.shape_dict[axes_K_B[0]] * s1) % (s1 + s2) == 0: split_axis_B = axes_K_B[0] else: # Factorize B's axes consisting to K into A's corresponding axes B = B.transpose(Order(axes_N + axes_K_B)) B = B.reshape(order=Order((Axis(), ) + axes_K_A), shape=[N] + [A.shape_dict[a] for a in axes_K_A]) split_axis_B = split_axis_A axes_K_B = axes_K_A B1, B2 = SplitAxis(None, axis=split_axis_B, sections=[(B.shape_dict[split_axis_B] * s1) // (s1 + s2)])(B) C1, = Tensordot(None, [axes_K_A, axes_K_B])(A1, B1) C2, = Tensordot(None, [axes_K_A, axes_K_B])(A2, B2) OptimizeRule.replace_variable(graph, (C1 + C2).reshape( shape_M + shape_N, Order(axes_M + axes_N)).transpose_like(C), C) else: C1, = Tensordot(None, op.axes)(A1, B) C2, = Tensordot(None, op.axes)(A2, B) for a1, a2 in zip(C1.order.axes, C2.order.axes): if a1 == a2 == axis: continue a1.unify(a2) C_new, = Concat(None, axis=axis)(C1, C2) OptimizeRule.replace_variable(graph, C_new, C) elif v == B: B1, B2 = v_pair if axis in axes_K_B: split_axis_B = axis if (A.shape_dict[axes_K_A[0]] * (s1 + s2)) % s1 == 0: split_axis_A = axes_K_A[0] else: # Factorize A's axes consisting to K into B's corresponding axes A = A.transpose(Order(axes_M + axes_K_A)) A = A.reshape(order=Order((Axis(), ) + axes_K_B), shape=[M] + [B.shape_dict[a] for a in axes_K_B]) split_axis_A = split_axis_B axes_K_A = axes_K_B A1, A2 = SplitAxis(None, axis=split_axis_A, sections=[(A.shape_dict[split_axis_A] * s1) // (s1 + s2)])(A) C1, = Tensordot(None, [axes_K_A, axes_K_B])(A1, B1) C2, = Tensordot(None, [axes_K_A, axes_K_B])(A2, B2) OptimizeRule.replace_variable(graph, (C1 + C2).reshape( shape_M + shape_N, Order(axes_M + axes_N)).transpose_like(C), C) else: C1, = Tensordot(None, op.axes)(A, B1) C2, = Tensordot(None, op.axes)(A, B2) for a1, a2 in zip(C1.order.axes, C2.order.axes): if a1 == a2 == axis: continue a1.unify(a2) C_new, = Concat(None, axis=axis)(C1, C2) OptimizeRule.replace_variable(graph, C_new, C) elif v == C: """ before) C[M, N] = A[M, K] @ B[K, N] after) In case `axis` is in `N`, C[M, N1] = Concat(A[M, K] @ B1[K, N1]) C[M, N2] = Concat(A[M, K] @ B2[K, N2]) """ raise NotImplementedError( f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError
def _split_reshape(graph: Graph, op: Reshape, v: Variable, v_pair: Sequence[Variable], axis: Axis): x = op.inputs["x"] y = op.outputs["y"] s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] op.remove_all() in_order = op.in_order out_order = op.out_order x_shape = [x.shape_dict[a] for a in in_order.axes] y_shape = [y.shape_dict[a] for a in out_order.axes] if v == x: """ before) x -{reshape}- y after) x_0 -{reshape}- t_0 -+ +-{concat[axis_k]}- t -{reshape}- y x_1 -{reshape}- t_1 -+ shape and order is changed as follows: x.shape = [dx_0, dx_1, ..., dx_m, ..., dx_M-1] x_0.shape = [dx_0, dx_1, ..., dx_m/2, ..., dx_M-1] --------------------------------------------------------------------------------- t_0.shape = [dy_0, dy_1, ..., dy_n, ..., dy_k/2, ..., dy_N-1] t.shape = [dy_0, dy_1, ..., dy_n*2, ..., dy_k/2, ..., dy_N-1] y.shape = [dy_0, dy_1, ..., dy_n, ..., dy_k, ..., dy_N-1] m: split target axis find axis_k and axis_n, which satisfies follow conditions dy_n * dy_n+1 * ... * dy_N-1 == dx_m * dx_m+1 * ... * dx_M-1 dy_k % 2 == 0 n <= k """ x_0, x_1 = v_pair dx_prod = mul(x_shape[in_order.axes_dict[axis]:]) dy_prod = 1 axis_k_candidate = [] for axis_n in reversed(out_order.axes): dy_prod *= y.shape_dict[axis_n] if y.shape_dict[axis_n] % 2 == 0: axis_k_candidate.append(axis_n) if dx_prod == dy_prod: # Split most large axis axis_k = (sorted(axis_k_candidate, key=lambda a: y.shape_dict[a], reverse=True))[0] t_0_shape = [y.shape_dict[a] for a in out_order.axes] t_0_shape[out_order.axes_dict[axis_k]] = t_0_shape[ out_order.axes_dict[axis_k]] // 2 # TODO t_0, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=t_0_shape)(x_0) t_1_shape = [y.shape_dict[a] for a in out_order.axes] t_1_shape[out_order.axes_dict[axis_k]] = t_1_shape[ out_order.axes_dict[axis_k]] // 2 # TODO t_1, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=t_1_shape)(x_1) t, = Concat(None, axis=axis_n)(t_0, t_1) y_new, = Reshape(None, in_order=out_order, out_order=out_order, out_shape=y_shape)(t) OptimizeRule.replace_variable(graph, y_new.transpose_like(y), y) break elif dy_prod > (s1 + s2) * dx_prod: raise NotImplementedError( f"Variable is too large to handle in WebGL backend: {v}") elif v == y: """ algorithm is almost same as the case `v == x` (above). before) x -{reshape}- y after) +- t_0 -{reshape}- y_0 x -{reshape}- t-{split}-+ +- t_1 -{reshape}- y_1 shape and order is changed as follows: x.shape = [dx_0, dx_1, ..., dx_m, ..., dx_k, ..., dx_M-1] t.shape = [dx_0, dx_1, ..., dx_m*2, ..., dx_k/2, ..., dx_M-1] t_0.shape = [dx_0, dx_1, ..., dx_m, ..., dx_k/2, ..., dx_M-1] --------------------------------------------------------------------------------- y_0.shape = [dy_0, dy_1, ..., dy_n/2, ..., dy_N-1] y.shape = [dy_0, dy_1, ..., dy_n, ..., dy_N-1] m: split target axis find axis_k and axis_m, which satisfies follow conditions dx_m * dx_m+1 * ... * dx_M-1 == dy_n * dy_n+1 * ... * dy_N-1 dx_k % 2 == 0 m <= k """ y_0, y_1 = v_pair dx_prod = 1 dy_prod = mul(x_shape[out_order.axes_dict[axis]:]) axis_k_candidate = [] for axis_m in reversed(in_order.axes): dx_prod *= x.shape_dict[axis_m] if x.shape_dict[axis_m] % 2 == 0: axis_k_candidate.append(axis_m) if dx_prod == dy_prod: # Split most large axis axis_k = (sorted(axis_k_candidate, key=lambda a: x.shape_dict[a], reverse=True))[0] t_shape = [x.shape_dict[a] for a in in_order.axes] t_shape[in_order.axes_dict[axis_m]] = 2 * t_shape[ in_order.axes_dict[axis_m]] # TODO t_shape[in_order.axes_dict[axis_k]] = t_shape[ in_order.axes_dict[axis_k]] // 2 # TODO t, = Reshape(None, in_order=in_order, out_order=in_order, out_shape=t_shape)(x) t_0, t_1 = SplitAxis(None, axis=axis_m, sections=[t.shape_dict[axis_m] // 2 ])(t) # TODO y_0_new, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=y_0.shape)(t_0) y_1_new, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=y_1.shape)(t_1) OptimizeRule.replace_variable(graph, y_0_new.reshape_like(y_0), y_0) OptimizeRule.replace_variable(graph, y_1_new.reshape_like(y_1), y_1) break elif dx_prod > dy_prod: raise NotImplementedError( f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError
def _split_splitaxis(graph: Graph, op: SplitAxis, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] x = op.inputs["x"] ys = [op.outputs[f"y{i}"] for i in range(len(op.outputs))] sections = op.parameters["sections"] op.remove_all() if v == x: x_0, x_1 = v_pair if axis == op.axis: """ before) +-y1 | x -{split[axis=axis]}-+-y2 | +-y3 after) +- y1 x_0 -{split[axis=axis]}-+ +- y2_0 -+ +-{concat[axis=axis]}- y2 +- y2_1 -+ x_1 -{split[axis=axis]}-+ +- y3 """ # find output variable which should be split ("y2" in above figure) total_size = 0 ys_0 = [] # type: List[Variable] ys_1 = list(ys) # type: List[Variable] for y in ys: ys_1.remove(y) if total_size + y.shape_dict[axis] == s1: # splitting is not needed # # x_0 | x_1 # <--------------> | <---------------> # y0, y1, ..., yn, | yn+1, ..., ys[-1] ys_0.append(y) break elif total_size + y.shape_dict[axis] > s1: # this `y` must be split # # x_0 | x_1 # <-----------------> | <-----------------> # y0, y1, ..., yn_0, | yn_1, ..., ys[-1] # <-----------> # yn yn_0 = Variable([ s1 - total_size if a == axis else y.shape_dict[a] for a in y.order.axes ], y.order) yn_1 = Variable([ y.shape_dict[axis] - (s1 - total_size) if a == axis else y.shape_dict[a] for a in y.order.axes ], y.order) OptimizeRule.replace_variable( graph, Concat(None, axis=axis)(yn_0, yn_1)[0].change_order(y.order), y) ys_0.append(yn_0) ys_1.insert(0, yn_1) break else: ys_0.append(y) total_size += y.shape_dict[axis] if len(ys_0) > 1: sections_0 = [0] for y in ys_0: sections_0.append(sections_0[-1] + y.shape_dict[axis]) sections_0.pop(0) sections_0.pop() for y_new, y in zip( SplitAxis(None, axis=axis, sections=sections_0)(x_0), ys_0): y_new.change_order(y.order) OptimizeRule.replace_variable(graph, y_new, y) elif len(ys_0) == 1: OptimizeRule.replace_variable(graph, ys_0[0], x_0) else: raise UnexpectedAndPleaseReportError if len(ys_1) > 1: sections_1 = [0] for y in ys_1: sections_1.append(sections_1[-1] + y.shape_dict[axis]) sections_1.pop(0) sections_1.pop() for y_new, y in zip( SplitAxis(None, axis=axis, sections=sections_1)(x_1), ys_1): y_new.change_order(y.order) OptimizeRule.replace_variable(graph, y_new, y) elif len(ys_1) == 1: OptimizeRule.replace_variable(graph, ys_1[0], x_1) else: raise UnexpectedAndPleaseReportError else: """ before) +- y1 | x -{split[axis=op.axis]}-+- y2 | +- y3 after) +--- y1_0 -+ | +-{concat[axis=axis]}- y1 | +- y1_1 -+ | | x_0 -{split[axis=op.axis]}-+-|- y2_0 -+ | | +-{concat[axis=axis]}- y2 x_1 -{split[axis=op.axis]}---+- y2_1 -+ | | +-|- y3_0 -+ | +-{concat[axis=axis]}- y3 +- y3_1 -+ """ ys_0 = SplitAxis(None, axis=op.axis, sections=op.sections)(x_0) ys_1 = SplitAxis(None, axis=op.axis, sections=op.sections)(x_1) for y, y_0, y_1 in zip(ys, ys_0, ys_1): y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new, y) elif v in ys: op.remove_all() if axis == op.axis: """ before) +- y1 | x -{split}-+- y2 | +- y3 after) +- y1 | +- y2_0 x -{split}-+ +- y2_1 | +- y3 """ target_i = ys.index(v) s_insert = (0 if target_i == 0 else sections[target_i - 1]) + s1 new_sections = list(sections) new_sections.insert(target_i, s_insert) new_ys = SplitAxis(None, axis=axis, sections=new_sections)(x) for i, new_y in enumerate(new_ys): if i == target_i: ys.pop(0) y = v_pair[0] new_y.change_order(y.order) OptimizeRule.replace_variable(graph, y, new_y) elif i == target_i + 1: y = v_pair[1] new_y.change_order(y.order) OptimizeRule.replace_variable(graph, y, new_y) else: y = ys.pop(0) new_y.change_order(y.order) OptimizeRule.replace_variable(graph, y, new_y) else: """ before) y1 y2 y3 y1 y2 y3 +--+--+--+ +--+ +--+ +--+ | : : | | | | | | | | : : | => | | | | | | | : : | | | | | | | +--+--+--+ +--+ +--+ +--+ +- y1 | x -{split[op.axis]}-+- y2 | +- y3 after) split y2 into y2_0 and y2_1 y1_0 y2_0 y3_0 y2_0 +--+--+--+ +--+ +--+ +--+ y1 +--+ y3 0 +--+--+--+ x_0 | : : | | | | | | | +--+ | | +--+ | : : | +--+--+--+ +--+ +--+ +--+ | | +--+ | | s1 + + + + => => => + + + + | : : | +--+--+--+ +--+ +--+ +--+ | | +--+ | | +--+--+--+ x_1 | : : | | | | | | | +--+ | | +--+ +--+--+--+ +--+ +--+ +--+ +--+ x y1_1 y2_1 y3_1 y2_1 +--- y1_0 -+ | +-{concat[axis]}- y1 | +- y1_1 -+ | | +- x_0 -{split[op.axis]}-+-|------------------------- y2_0 x -{split[axis]}-+ | | +- x_1 -{split[op.axis]}---+------------------------- y2_1 | | +-|- y3_0 -+ | +-{concat[axis]}- y3 +- y3_1 -+ """ x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x) ys_0, = SplitAxis(None, axis=op.axis, sections=op.sections)(x_0) ys_1, = SplitAxis(None, axis=op.axis, sections=op.sections)(x_1) for y, y_0, y_1 in zip(ys, ys_0, ys_1): if y == v: OptimizeRule.replace_variable(graph, y_0, v_pair[0]) OptimizeRule.replace_variable(graph, y_1, v_pair[1]) else: y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new, y) else: raise UnexpectedAndPleaseReportError
def _split_concat(graph: Graph, op: Concat, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] xs = [ op.inputs[key] for key in sorted( [key for key in op.inputs.keys() if key.startswith("x")]) ] y = op.outputs["y"] op.remove_all() if v in xs: x_0, x_1 = v_pair if axis == op.axis: """ before) x1 -+ | x2 -+-{concat}- y | x3 -+ after) x1 ---+ | x2_0 -+ +-{concat}- y x2_1 -+ | x3 ---+ """ i = xs.index(v) xs.pop(i) xs.insert(i + 0, x_0) xs.insert(i + 1, x_1) y_new, = Concat(None, axis=axis)(*xs) OptimizeRule.replace_variable(graph, y, y_new) else: """ before) x1 -+ | x2 -+-{concat[op.axis]}- y | x3 -+ after) +- x1_0 -+ x1 -{split[axis]}-+ | +- x1_1 -|-+ | | x2_0 ----------------------+---{concat[op.axis]}- y_0 -+ | | +-{concat[axis]}- y x2_1 ----------------------|-+-{concat[op.axis]}- y_1 -+ | | +- x3_0 -+ | x3 -{split[axis]}-+ | +- x3_1 ---+ """ xs_0, xs_1 = zip(*[ v_pair if x == v else SplitAxis(None, axis=axis, sections=[s1]) (x) for x in xs ]) y_0, = Concat(None, axis=op.axis)(*xs_0) y_1, = Concat(None, axis=op.axis)(*xs_1) y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new, y) elif v == y: y_0, y_1 = v_pair if axis == op.axis: """ before) x1 -+ | x2 -+-{concat[axis=op.axis]}- y | x3 -+ after) x1 ------------------------------+ +-{concat[axis=axis]}- y_0 +- y_0_1 -+ x2 -{split[axis=axis]}-+ +- y_1_0 -+ +-{concat[axis=axis]}- y_1 x3 ------------------------------+ """ # find input variable which should be split total_size = 0 xs_0 = [] # type: List[Variable] xs_1 = list(xs) # type: List[Variable] for x in xs: xs_1.remove(x) xs_0.append(x) total_size += x.shape_dict[axis] if total_size == s1: # splitting is not needed # # x0, x1, ..., xn, | xn+1, ..., xs[-1] # <--------------> | <---------------> # y_0 | y_1 break elif total_size > s1: # this `x` must be split # # x0, x1, ..., xn, ..., xs[-1] # <-------------><-------------> # y_0 y_1 xn_0, xn_1 = SplitAxis( None, axis=axis, sections=[s1 - (total_size - x.shape_dict[axis])])(x) xs_0.remove(x) xs_0.append(xn_0) xs_1.insert(0, xn_1) break if len(xs_0) > 1: y_0, = Concat(None, axis=axis)(*xs_0) y_0.change_order(v_pair[0].order) elif len(xs_0) == 1: y_0 = xs_0[0] else: raise UnexpectedAndPleaseReportError if len(xs_1) > 1: y_1, = Concat(None, axis=axis)(*xs_1) y_1.change_order(v_pair[1].order) elif len(xs_1) == 1: y_1 = xs_1[0] else: raise UnexpectedAndPleaseReportError OptimizeRule.replace_variable(graph, y_0, v_pair[0]) OptimizeRule.replace_variable(graph, y_1, v_pair[1]) else: """ before) x1 -+ | x2 -+-{concat[op.axis]}- y | x3 -+ after) +- x1_0 -+ x1 -{split[axis]}-+ | +- x1_1 ---+ | | +- x2_0 -+-|-{concat[op.axis]}- y_0 x2 -{split[axis]}-+ | | +- x2_1 ---+-{concat[op.axis]}- y_1 | | +- x3_0 -+ | x3 -{split[axis]}-+ | +- x3_1 ---+ """ xs_0, xs_1 = zip( *[SplitAxis(None, axis=axis, sections=[s1])(x) for x in xs]) y_new_0, = Concat(None, axis=op.axis)(*xs_0) y_new_1, = Concat(None, axis=op.axis)(*xs_1) OptimizeRule.replace_variable(graph, y_new_0, y_0) OptimizeRule.replace_variable(graph, y_new_1, y_1) else: raise UnexpectedAndPleaseReportError
def _split_reshape(graph: Graph, op: Reshape, v: Variable, v_pair: Sequence[Variable], axis: Axis): x = op.inputs["x"] y = op.outputs["y"] s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] op.remove_all() if v == x: """ Regard x's order as `[D1, D2]`, shape as `[d1x, d2x]`, where the most outside axis in D2 is the split target axis. If y's shape can be converted as `[d1x, d2x]` by merging some adjacent axes in y, split can be performed. before) x -{reshape}- y after) x_0 -{reshape}- y_0 -+ +-{concat[axis]}- y x_1 -{reshape}- y_1 -+ """ x_0, x_1 = v_pair d2x = mul(x.shape[x.order.axes_dict[axis]:]) d2y = 1 for axis_y in reversed(y.order.axes): d2y *= y.shape_dict[axis_y] if d2y == d2x: y_0_shape = [y.shape_dict[axis_y] * s1 // (s1 + s2) if a == axis_y else y.shape_dict[a] for a in y.order.axes] y_1_shape = [y.shape_dict[axis_y] * s2 // (s1 + s2) if a == axis_y else y.shape_dict[a] for a in y.order.axes] y_0 = x_0.reshape(y_0_shape, y.order) y_1 = x_1.reshape(y_1_shape, y.order) y_new, = Concat(None, axis=axis_y)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new, y) break elif d2y > (s1 + s2) * d2x: raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}") elif v == y: """ Same algorithm in case `v == y` (above). before) x -{reshape}- y after) +- x_0 -{reshape}- y_0 x -{split}-+ +- x_1 -{reshape}- y_1 """ y_0, y_1 = v_pair d2y = mul(y.shape[y.order.axes_dict[axis]:]) d2x = 1 for axis_x in reversed(x.order.axes): d2x *= x.shape_dict[axis_x] if d2x == d2y: x_0, x_1 = SplitAxis(None, axis=axis_x, sections=[x.shape_dict[axis_x] * s1 // (s1 + s2)])(x) OptimizeRule.replace_variable(graph, x_0.reshape_like(y_0), y_0) OptimizeRule.replace_variable(graph, x_1.reshape_like(y_1), y_1) break elif d2y > (s1 + s2) * d2x: raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError
def _split_sgemm(graph: Graph, op: Sgemm, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] A = op.inputs["A"] B = op.inputs["B"] C = op.outputs["C"] transpose_A, transpose_B = op.transpose_A, op.transpose_B M, K, N = op.M, op.K, op.N axis_M, axis_K, axis_N = Axis(None), Axis(None), Axis(None) op.remove_all() def decompose_logical_axes(logical_shape: Tuple[int, int], v: Variable): """ Decompose logical axes into real axes Examples:: A.order, A.shape >>> "NCHW", (1, 128, 8, 8) M = 128 K = 64 decompose_logical_axes([M, K], A) >>> ["<Axis N>", "<Axis C>"], ["<Axis H>", "<Axis W>"] """ total_size = 1 axes1 = [] # type: List[Axis] axes2 = list(v.order.axes) # type: List[Axis] for size, a in zip(v.shape, v.order.axes): if total_size == logical_shape[0]: return axes1, axes2 elif total_size > logical_shape[0]: raise ValueError axes1.append(a) axes2.remove(a) total_size *= size if v == A: A1, A2 = v_pair if transpose_A: # A.shape = [M, K] axes_M, axes_K = decompose_logical_axes((M, K), A) else: # A.shape = [K, M] axes_K, axes_M = decompose_logical_axes((K, M), A) if axis in axes_K: """ before) A -{sgemm}- C after) In case `axis` is in `K`, A_0 -{sgemm}- C_0 -+ +-{Add}- C A_1 -{sgemm}- C_1 -+ """ K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2) # Factorize B's axes included in K into A's corresponding axes if transpose_B: # B: [k_b1, k_b2, ..., N] -{reshape}-> [k_a1, k_a2, ..., N] B, = Reshape(None, in_order=B.order, out_order=Order(axes_K + [axis_N]), out_shape=[A.shape_dict[a] for a in axes_K] + [N])(B) else: # B: [N, k_b1, k_b2, ...] -{reshape}-> [N, k_a1, k_a2, ...] B, = Reshape(None, in_order=B.order, out_order=Order([axis_N] + axes_K), out_shape=[N] + [A.shape_dict[a] for a in axes_K])(B) B1, B2 = SplitAxis(None, axis=axis, sections=[s1])(B) C1, = Sgemm(None, M=M, K=K1, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A1, B1) C2, = Sgemm(None, M=M, K=K2, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A2, B2) OptimizeRule.replace_variable(graph, C1 + C2, C) else: assert axis in axes_M """ before) A -{sgemm}- C after) In case `axis` is in `M`, A_0 -{sgemm}- C_0 -+ +-{Concat}- C A_1 -{sgemm}- C_1 -+ """ M1, M2 = M * s1 // (s1 + s2), M * s2 // (s1 + s2) c_tmp_order = Order(axes_M + [axis_N]) c1_shape = [A1.shape_dict[a] for a in axes_M] + [N] c2_shape = [A2.shape_dict[a] for a in axes_M] + [N] C1, = Sgemm(None, M=M1, K=K, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c1_shape, out_order=c_tmp_order)(A1, B) C2, = Sgemm(None, M=M2, K=K, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c2_shape, out_order=c_tmp_order)(A2, B) C_new, = Concat(None, axis=axis)(C1, C2) C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new) OptimizeRule.replace_variable(graph, C_new, C) elif v == B: B1, B2 = v_pair if transpose_B: # B.shape = [K, N] axes_K, axes_N = decompose_logical_axes((K, N), B) else: # B.shape = [N, K] axes_N, axes_K = decompose_logical_axes((N, K), B) if axis in axes_K: """ before) B -{sgemm}- C after) In case `axis` is in `K`, B_0 -{sgemm}- C_0 -+ +-{Add}- C B_1 -{sgemm}- C_1 -+ """ K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2) # Factorize A's axes included in K into B's corresponding axes if transpose_A: # A: [M, k_a1, k_a2, k_a3, ...] -{reshape}-> [M, k_b1, k_b2, ...] A, = Reshape(None, in_order=A.order, out_order=Order([axis_M] + axes_K), out_shape=[M] + [B.shape_dict[a] for a in axes_K])(A) else: # A: [k_a1, k_a2, k_a3, ..., M] -{reshape}-> [k_b1, k_b2, ..., M] A, = Reshape(None, in_order=A.order, out_order=Order(axes_K + [axis_M]), out_shape=[B.shape_dict[a] for a in axes_K] + [M])(A) A1, A2 = SplitAxis(None, axis=axis, sections=[s1])(A) C1, = Sgemm(None, M=M, K=K1, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A1, B1) C2, = Sgemm(None, M=M, K=K2, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A2, B2) OptimizeRule.replace_variable(graph, C1 + C2, C) else: assert axis in axes_N """ before) C[M, N] = A[M, K] @ B[K, N] after) In case `axis` is in `N`, C[M, N] = Concat(C1[M, N1], C2[M, N2]) = Concat(A[M, K] @ B1[K, N1], A[M, K] @ B2[K, N2]) """ N1, N2 = N * s1 // (s1 + s2), N * s2 // (s1 + s2) c_tmp_order = Order([axis_M] + axes_N) c1_shape = [M] + [B1.shape_dict[a] for a in axes_N] c2_shape = [M] + [B2.shape_dict[a] for a in axes_N] C1, = Sgemm(None, M=M, K=K, N=N1, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c1_shape, out_order=c_tmp_order)(A, B1) # C1.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis], ...] # C1.order = [axis_M, n1, n2, ..., axis, ...] C2, = Sgemm(None, M=M, K=K, N=N2, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c2_shape, out_order=c_tmp_order)(A, B2) C_new, = Concat(None, axis=axis)(C1, C2) # C_new.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis]+B2.shape_dict[axis], ...] # C_new.order = [axis_M, n1, n2, ..., axis, ...] C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new) OptimizeRule.replace_variable(graph, C_new, C) elif v == C: """ before) C[M, N] = A[M, K] @ B[K, N] after) In case `axis` is in `N`, C[M, N1] = Concat(A[M, K] @ B1[K, N1]) C[M, N2] = Concat(A[M, K] @ B2[K, N2]) """ raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError