def test_mix_order(): vx1 = np.random.rand(2, 3, 4, 5) vx2 = np.random.rand(2, 3, 4, 5) vx3 = np.random.rand(2, 3, 4, 5) vx4 = np.random.rand(2, 3, 4, 5) vy = np.concatenate((vx1, vx2, vx3, vx4), 1) x1 = Variable(vx1.shape, order=OrderNHWC) x2 = Variable(vx2.shape, order=OrderNHWC) x3 = Variable(vx3.shape, order=OrderNHWC) x4 = Variable(vx4.shape, order=OrderNHWC) x2.change_order(OrderCNHW) vx2 = np.rollaxis(vx2, 3, 0) x3.change_order(OrderCHWN) vx3 = np.rollaxis(np.rollaxis(vx3, 3, 0), 1, 4) x4.change_order(OrderNCHW) vx4 = np.rollaxis(vx4, 3, 1) y, = Concat(None, axis=Axis.H)(x1, x2, x3, x4) y.change_order(OrderNHWC) generate_kernel_test_case(description=f"concat_mix_order", graph=Graph([x1, x2, x3, x4], [y]), inputs={ x1: vx1, x2: vx2, x3: vx3, x4: vx4 }, expected={y: vy})
def optimize_pair(self, graph: Graph, op1: Concat, op2: ElementwiseMul): x0, x1 = op1.inputs["x0"], op1.inputs["x1"] c, _ = _get_constant_and_variable(op2, "x0", "x1") if c is None: return False if c.order != Order([op1.axis]): return False y2 = op2.outputs["y"] c0 = ConstantVariable(c.data[:x0.shape_dict[op1.axis]], c.order) c1 = ConstantVariable(c.data[x0.shape_dict[op1.axis]:], c.order) op1.remove_all() op2.remove_all() y, = Concat(None, axis=op1.axis)((x0 * c0), (x1 * c1)) OptimizeRule.replace_variable(graph, y2, y.change_order(y2.order)) return True
def _split_concat(graph: Graph, op: Concat, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] xs = [ op.inputs[key] for key in sorted( [key for key in op.inputs.keys() if key.startswith("x")]) ] y = op.outputs["y"] op.remove_all() if v in xs: x_0, x_1 = v_pair if axis == op.axis: """ before) x1 -+ | x2 -+-{concat}- y | x3 -+ after) x1 ---+ | x2_0 -+ +-{concat}- y x2_1 -+ | x3 ---+ """ i = xs.index(v) xs.pop(i) xs.insert(i + 0, x_0) xs.insert(i + 1, x_1) y_new, = Concat(None, axis=axis)(*xs) OptimizeRule.replace_variable(graph, y, y_new) else: """ before) x1 -+ | x2 -+-{concat[op.axis]}- y | x3 -+ after) +- x1_0 -+ x1 -{split[axis]}-+ | +- x1_1 -|-+ | | x2_0 ----------------------+---{concat[op.axis]}- y_0 -+ | | +-{concat[axis]}- y x2_1 ----------------------|-+-{concat[op.axis]}- y_1 -+ | | +- x3_0 -+ | x3 -{split[axis]}-+ | +- x3_1 ---+ """ xs_0, xs_1 = zip(*[ v_pair if x == v else SplitAxis(None, axis=axis, sections=[s1]) (x) for x in xs ]) y_0, = Concat(None, axis=op.axis)(*xs_0) y_1, = Concat(None, axis=op.axis)(*xs_1) y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new, y) elif v == y: y_0, y_1 = v_pair if axis == op.axis: """ before) x1 -+ | x2 -+-{concat[axis=op.axis]}- y | x3 -+ after) x1 ------------------------------+ +-{concat[axis=axis]}- y_0 +- y_0_1 -+ x2 -{split[axis=axis]}-+ +- y_1_0 -+ +-{concat[axis=axis]}- y_1 x3 ------------------------------+ """ # find input variable which should be split total_size = 0 xs_0 = [] # type: List[Variable] xs_1 = list(xs) # type: List[Variable] for x in xs: xs_1.remove(x) xs_0.append(x) total_size += x.shape_dict[axis] if total_size == s1: # splitting is not needed # # x0, x1, ..., xn, | xn+1, ..., xs[-1] # <--------------> | <---------------> # y_0 | y_1 break elif total_size > s1: # this `x` must be split # # x0, x1, ..., xn, ..., xs[-1] # <-------------><-------------> # y_0 y_1 xn_0, xn_1 = SplitAxis( None, axis=axis, sections=[s1 - (total_size - x.shape_dict[axis])])(x) xs_0.remove(x) xs_0.append(xn_0) xs_1.insert(0, xn_1) break if len(xs_0) > 1: y_0, = Concat(None, axis=axis)(*xs_0) y_0.change_order(v_pair[0].order) elif len(xs_0) == 1: y_0 = xs_0[0] else: raise UnexpectedAndPleaseReportError if len(xs_1) > 1: y_1, = Concat(None, axis=axis)(*xs_1) y_1.change_order(v_pair[1].order) elif len(xs_1) == 1: y_1 = xs_1[0] else: raise UnexpectedAndPleaseReportError OptimizeRule.replace_variable(graph, y_0, v_pair[0]) OptimizeRule.replace_variable(graph, y_1, v_pair[1]) else: """ before) x1 -+ | x2 -+-{concat[op.axis]}- y | x3 -+ after) +- x1_0 -+ x1 -{split[axis]}-+ | +- x1_1 ---+ | | +- x2_0 -+-|-{concat[op.axis]}- y_0 x2 -{split[axis]}-+ | | +- x2_1 ---+-{concat[op.axis]}- y_1 | | +- x3_0 -+ | x3 -{split[axis]}-+ | +- x3_1 ---+ """ xs_0, xs_1 = zip( *[SplitAxis(None, axis=axis, sections=[s1])(x) for x in xs]) y_new_0, = Concat(None, axis=op.axis)(*xs_0) y_new_1, = Concat(None, axis=op.axis)(*xs_1) OptimizeRule.replace_variable(graph, y_new_0, y_0) OptimizeRule.replace_variable(graph, y_new_1, y_1) else: raise UnexpectedAndPleaseReportError
def _split_splitaxis(graph: Graph, op: SplitAxis, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] x = op.inputs["x"] ys = [op.outputs[f"y{i}"] for i in range(len(op.outputs))] sections = op.parameters["sections"] op.remove_all() if v == x: x_0, x_1 = v_pair if axis == op.axis: """ before) +-y1 | x -{split[axis=axis]}-+-y2 | +-y3 after) +- h1 ------------------------ y1 x_0 -{split[axis=axis]}-+ +- h2_0 -+ +-{concat[axis=axis]}- y2 +- h2_1 -+ x_1 -{split[axis=axis]}-+ +- h3 ------------------------- y3 """ # find output variable which should be split ("y2" in above figure) total_size = 0 ys_0 = [] # type: List[Variable] ys_1 = list(ys) # type: List[Variable] for y in ys: ys_1.remove(y) ys_0.append(y) total_size += y.shape_dict[axis] if total_size == s1: # splitting is not needed # # x_0 | x_1 # <--------------> | <---------------> # h0, h1, ..., hn, | hn+1, ..., hs[-1] # y0, y1, ..., yn, | yn+1, ..., ys[-1] break elif total_size > s1: # this `y` must be split # # x_0 | x_1 # <-------------------> | <-----------------> # h0, h1, ..., | hn_0, | hn_1, | ..., hs[-1] # | <-----------> | # y0, y1, ..., | yn , | ..., ys[-1] hn_0 = Variable([x_0.shape_dict[axis] - (total_size - s1) if a == axis else y.shape_dict[a] for a in y.order.axes], y.order) hn_1 = Variable([total_size - s1 if a == axis else y.shape_dict[a] for a in y.order.axes], y.order) yn_new, = Concat(None, axis=axis)(hn_0, hn_1) yn_new.change_order(y.order) OptimizeRule.replace_variable(graph, yn_new, y) ys_0.remove(y) ys_0.append(hn_0) ys_1.insert(0, hn_1) break if len(ys_0) > 1: sections_0 = [0] for h in ys_0: sections_0.append(sections_0[-1] + h.shape_dict[axis]) sections_0.pop(0) sections_0.pop() for y_new, y in zip(SplitAxis(None, axis=axis, sections=sections_0)(x_0), ys_0): y_new.change_order(y.order) OptimizeRule.replace_variable(graph, y_new, y) elif len(ys_0) == 1: OptimizeRule.replace_variable(graph, ys_0[0], x_0) else: raise UnexpectedAndPleaseReportError if len(ys_1) > 1: sections_1 = [0] for h in ys_1: sections_1.append(sections_1[-1] + h.shape_dict[axis]) sections_1.pop(0) sections_1.pop() for y_new, y in zip(SplitAxis(None, axis=axis, sections=sections_1)(x_1), ys_1): y_new.change_order(y.order) OptimizeRule.replace_variable(graph, y_new, y) elif len(ys_1) == 1: OptimizeRule.replace_variable(graph, ys_1[0], x_1) else: raise UnexpectedAndPleaseReportError else: """ before) +- y1 | x -{split[axis=op.axis]}-+- y2 | +- y3 after) +--- y1_0 -+ | +-{concat[axis=axis]}- y1 | +- y1_1 -+ | | x_0 -{split[axis=op.axis]}-+-|- y2_0 -+ | | +-{concat[axis=axis]}- y2 x_1 -{split[axis=op.axis]}---+- y2_1 -+ | | +-|- y3_0 -+ | +-{concat[axis=axis]}- y3 +- y3_1 -+ """ ys_0 = SplitAxis(None, axis=op.axis, sections=op.sections)(x_0) ys_1 = SplitAxis(None, axis=op.axis, sections=op.sections)(x_1) for y, y_0, y_1 in zip(ys, ys_0, ys_1): y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new, y) elif v in ys: op.remove_all() if axis == op.axis: """ before) +- y1 | x -{split}-+- y2 | +- y3 after) +- y1 | +- y2_0 x -{split}-+ +- y2_1 | +- y3 """ target_i = ys.index(v) s_insert = (0 if target_i == 0 else sections[target_i - 1]) + s1 new_sections = list(sections) new_sections.insert(target_i, s_insert) new_ys = SplitAxis(None, axis=axis, sections=new_sections)(x) for i, new_y in enumerate(new_ys): if i == target_i: ys.pop(0) y = v_pair[0] new_y.change_order(y.order) OptimizeRule.replace_variable(graph, y, new_y) elif i == target_i + 1: y = v_pair[1] new_y.change_order(y.order) OptimizeRule.replace_variable(graph, y, new_y) else: y = ys.pop(0) new_y.change_order(y.order) OptimizeRule.replace_variable(graph, y, new_y) else: """ before) y1 y2 y3 y1 y2 y3 +--+--+--+ +--+ +--+ +--+ | : : | | | | | | | | : : | => | | | | | | | : : | | | | | | | +--+--+--+ +--+ +--+ +--+ +- y1 | x -{split[op.axis]}-+- y2 | +- y3 after) split y2 into y2_0 and y2_1 y1_0 y2_0 y3_0 y2_0 +--+--+--+ +--+ +--+ +--+ y1 +--+ y3 0 +--+--+--+ x_0 | : : | | | | | | | +--+ | | +--+ | : : | +--+--+--+ +--+ +--+ +--+ | | +--+ | | s1 + + + + => => => + + + + | : : | +--+--+--+ +--+ +--+ +--+ | | +--+ | | +--+--+--+ x_1 | : : | | | | | | | +--+ | | +--+ +--+--+--+ +--+ +--+ +--+ +--+ x y1_1 y2_1 y3_1 y2_1 +--- y1_0 -+ | +-{concat[axis]}- y1 | +- y1_1 -+ | | +- x_0 -{split[op.axis]}-+-|------------------------- y2_0 x -{split[axis]}-+ | | +- x_1 -{split[op.axis]}---+------------------------- y2_1 | | +-|- y3_0 -+ | +-{concat[axis]}- y3 +- y3_1 -+ """ x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x) ys_0, = SplitAxis(None, axis=op.axis, sections=op.sections)(x_0) ys_1, = SplitAxis(None, axis=op.axis, sections=op.sections)(x_1) for y, y_0, y_1 in zip(ys, ys_0, ys_1): if y == v: OptimizeRule.replace_variable(graph, y_0, v_pair[0]) OptimizeRule.replace_variable(graph, y_1, v_pair[1]) else: y_new, = Concat(None, axis=axis)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new, y) else: raise UnexpectedAndPleaseReportError