def optimize(self, graph: Graph) -> Tuple[Graph, bool]: if not (flags.optimize.OPTIMIZE and flags.optimize.CONCAT_SCALAR_AFFINE): return graph, False flag_changed = False matches = search_sub_structure(graph, [ScalarAffine, Variable, ScalarAffine]) while len(matches) > 0: match = matches[0] a1: ScalarAffine = match[0] a2: ScalarAffine = match[2] y1 = a1.outputs["y"] y2 = a2.outputs["y"] a1.scale = a1.scale * a2.scale a1.bias = a1.bias * a2.scale + a2.bias a2.remove_all() a1.replace_output(y1, y2) flag_changed = True matches = search_sub_structure(graph, [ScalarAffine, Variable, ScalarAffine]) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False matches = search_sub_structure(graph, [Elementwise, Variable, Elementwise]) while len(matches) > 0: match = matches.pop() op1 = match[0] # type: Elementwise h = match[1] # type: Variable op2 = match[2] # type: Elementwise y = op2.outputs["y"] if h.input_to != {op2}: continue if len(op2.inputs) > 1: continue op = Elementwise(None) op1.replace(op) op1.replace_output(h, y) matches = search_sub_structure( graph, [Elementwise, Variable, Elementwise]) return graph, flag_changed
def optimize(self, graph: Graph): flag_changed = False """ before) v0[RGBA] -{ConvertRtoRGBA}- v1[RGBA] after) v0[RGBA] -{ConvertRGBAtoR}- v2[Order=v0.order][R] -{Transpose}- v3[Order=v1.order][R]-{ConvertRtoRGBA}- v1[RGBA] """ matches = traverse.search_sub_structure( graph, [Variable, ConvertRtoRGBA, Variable]) while len(matches) > 0: v0, r2rgba, v1 = matches.pop( ) # type: Variable, ConvertRtoRGBA, Variable if not (ChannelMode.get(v0) == ChannelMode.get(v1) == ChannelModeEnum.RGBA): continue flag_changed = True r2rgba.remove_all() v2 = convert_rgba_to_r(v0) v2.change_order(v0.order) v3 = v2.transpose(v1.order) v1_new = convert_r_to_rgba(v3) v1_new.change_order(v1.order) OptimizeRule.replace_variable(graph, v1_new, v1) """ before) v0[R] -{ConvertRGBAtoR}- v1[R] after) v0[R] -{Transpose}- v1[R] """ matches = traverse.search_sub_structure( graph, [Variable, ConvertRGBAtoR, Variable]) while len(matches) > 0: v0, rgba2r, v1 = matches.pop( ) # type: Variable, ConvertRGBAtoR, Variable if not (ChannelMode.get(v0) == ChannelMode.get(v1) == ChannelModeEnum.R): continue flag_changed = True rgba2r.remove_all() OptimizeRule.replace_variable(graph, v0.transpose(v1.order), v1) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False matches = search_sub_structure(graph, [self.pattern[0], Variable, self.pattern[1]]) while len(matches) > 0: match = matches.pop() if self.optimize_pair(match[0], match[2]): flag_changed = True matches = search_sub_structure(graph, [self.pattern[0], Variable, self.pattern[1]]) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False matches = search_sub_structure(graph, [self.pattern[0], Variable, self.pattern[1]]) while len(matches) > 0: op1, v1, op2 = matches.pop() # type: Operator, Variable, Operator if len(v1.input_to) > 1: continue if self.optimize_pair(graph, op1, op2): flag_changed = True matches = search_sub_structure(graph, [self.pattern[0], Variable, self.pattern[1]]) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: if not (flags.optimize.OPTIMIZE and flags.optimize.INJECT_INLINE_INPLACE): return graph, False matches = traverse.search_sub_structure( graph, [PostInlineInplace, Variable, InlineInplace]) flag_changed = False for match in matches: op1 = match[0] op2 = match[2] post_inline_inplace = op1.get_attribute(PostInlineInplace)[ 0] # type: PostInlineInplace inline_inplace = op2.get_attribute(InlineInplace)[ 0] # type: InlineInplace x = inline_inplace.get_input() y = inline_inplace.get_output() if len(x.input_to) > 1: continue op2.remove_all() op1.replace_output(x, y) post_inline_inplace.register_injected(inline_inplace) flag_changed = True return graph, flag_changed
def test_search_sub_structure_full(): global graph, op1, op2, op3, v1, v2 matches = search_sub_structure( graph, [Operator, Variable, Operator, Variable, Operator]) assert len(matches) == 1 assert tuple(matches[0]) == (op1, v1, op2, v2, op3)
def optimize(self, graph): flag_changed = False matches = traverse.search_sub_structure( graph, [SplitAxis, Variable, SplitAxis]) while len(matches) > 0: op1, h, op2 = matches.pop() # type: SplitAxis, Variable, SplitAxis if len(h.input_to) > 1: # `h` will be removed by this optimization continue if op1.axis != op2.axis: # These operations cannot be merged. continue flag_changed = True x = op1.inputs["x"] hs = [op1.outputs[f"y{i}"] for i in range(len(op1.outputs))] i_h = hs.index(h) original_ys = list(hs) new_sections = op1.sections original_ys.remove(h) section_offset = ([0] + op1.sections)[i_h] op2_sections = [0] + op2.sections for i in range(len(op2.outputs)): original_ys.insert(i_h + i, op2.outputs[f"y{i}"]) new_sections.insert(i_h + i, section_offset + op2_sections[i]) new_sections.remove(section_offset) op1.remove_all() op2.remove_all() new_ys = SplitAxis(None, axis=op1.axis, sections=new_sections)(x) for original_y, new_y in zip(original_ys, new_ys): new_y.change_order(original_y.order) new_y.replace(original_y) matches = traverse.search_sub_structure( graph, [SplitAxis, Variable, SplitAxis]) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False matches = search_sub_structure(graph, [Elementwise, Variable, Elementwise]) while len(matches) > 0: match = matches.pop() op1 = match[0] # type: Operator op2 = match[2] # type: Operator if _optimizeScalarAdd(op1, op2) or \ _optimizeScalarMul(op1, op2) or \ _optimizeElementwiseAdd(op1, op2) or \ _optimizeElementwiseMul(op1, op2): flag_changed = True matches = search_sub_structure( graph, [Elementwise, Variable, Elementwise]) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: """ Merges padding of ZeroPadding2D and Convolution2D | MaxPooling2D | AveragePooling2D layer Args: graph: Returns: """ # this optimization is always applied (since backends do not implement padding) flag_changed = False for tail_layer in [Convolution2D, MaxPooling2D, AveragePooling2D]: matches = search_sub_structure( graph, [ZeroPadding2D, Variable, tail_layer]) while len(matches) > 0: match = matches[0] a1: ZeroPadding2D = match[0] a2: Union[Convolution2D, MaxPooling2D, AveragePooling2D] = match[2] zero_pad = a1.parameters["padding"] conv_pad = a2.parameters["padding"] a2.parameters["padding"] = (zero_pad[0] + conv_pad[0], zero_pad[1] + conv_pad[1]) x1 = a1.inputs["x"] x2 = a2.inputs["x"] a1.remove_all() # replace_input checks if the shape of x1 and x2 are same, but this restriction does not hold. a2.remove_input(x2) a2.append_input("x", x1) flag_changed = True matches = search_sub_structure( graph, [ZeroPadding2D, Variable, tail_layer]) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False matches = traverse.search_sub_structure( graph, [Tensordot, Variable, ElementwiseMul, Variable]) while len(matches) > 0: tensordot, h, elementwise_mul, y = matches.pop( ) # type: Tensordot, Variable, ElementwiseMul, Variable if len(h.input_to) != 1: # h will be removed by this optimization rule continue if isinstance(tensordot.inputs["A"], ConstantVariable): w1 = tensordot.inputs["A"] reduced_axes = tensordot.axes[0] elif isinstance(tensordot.inputs["B"], ConstantVariable): w1 = tensordot.inputs["B"] reduced_axes = tensordot.axes[1] else: continue if isinstance( elementwise_mul.inputs["x0"], ConstantVariable) and elementwise_mul.inputs["x1"] == h: w2 = elementwise_mul.inputs["x0"] elif isinstance( elementwise_mul.inputs["x1"], ConstantVariable) and elementwise_mul.inputs["x0"] == h: w2 = elementwise_mul.inputs["x1"] else: continue if any(axis not in w1.order.axes for axis in w2.order.axes): continue if any(axis in reduced_axes for axis in w2.order.axes): continue flag_changed = True elementwise_mul.remove_all() OptimizeRule.replace_variable(graph, w1, w1.copy() * w2, with_assert=False) OptimizeRule.replace_variable(graph, h, y, with_assert=False) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for match in traverse.search_sub_structure(graph, [LSTM]): lstm = match[0] # type: LSTM if lstm.has_attribute(LSTMOptimized): continue x = lstm.inputs["x"] w_input = lstm.inputs["w_input"] w_hidden = lstm.inputs["w_hidden"] if isinstance(w_input, ConstantVariable) and isinstance(w_hidden, ConstantVariable): w_input.change_order(OrderCN) w_hidden.change_order(OrderCN) w_all = ConstantVariable(np.vstack([w_input.data, w_hidden.data]), OrderCN) else: w_all, = Concat(None, axis=Axis.C)(w_input, w_hidden) # type: Variable w_all.change_order(OrderCN) attr = LSTMOptimized(lstm) N = x.shape_dict[Axis.N] C1 = attr.C1 C2 = attr.C2 x_and_h = Variable([C1 + C2, N], OrderCN) workspace = Variable([N, 4 * C2], OrderNC) lstm.remove_input(w_input) lstm.remove_input(w_hidden) lstm.append_input("x_and_h", x_and_h) lstm.append_input("workspace", workspace) lstm.append_input("w_all", w_all) lstm.attributes.add(attr) flag_changed = True return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: if not (flags.optimize.OPTIMIZE and flags.optimize.CONCAT_SGEMM_BIAS): return graph, False matches = traverse.search_sub_structure( graph, [Sgemm, Variable, AxiswiseBias]) if len(matches) == 0: return graph, False flag_changed = False for match in matches: sgemm: Sgemm = match[0] axiswise_bias: AxiswiseBias = match[2] if axiswise_bias.parameters["axis"] != Axis.C: continue h = sgemm.outputs["C"] b = axiswise_bias.inputs["b"] y = axiswise_bias.outputs["y"] if not isinstance(b, ConstantVariable): continue if len(h.input_to) != 1: continue if "b" in sgemm.inputs: sgemm.inputs["b"].data += b.data else: sgemm.append_input("b", b) axiswise_bias.remove_all() sgemm.replace_output(h, y) flag_changed = True return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False matches = traverse.search_sub_structure( graph, [Sgemm, Variable, ElementwiseMul]) while len(matches) > 0: match = matches.pop() sgemm = match[0] # type: Sgemm elementwise_mul = match[2] # type: ElementwiseMul out_order = sgemm.parameters["out_order"] out_shape = sgemm.parameters["out_shape"] axis_k = Axis('AxisK') if not isinstance(sgemm.inputs["A"], ConstantVariable) and not isinstance( sgemm.inputs["B"], ConstantVariable): # neither x nor w1 is constant continue elif isinstance(sgemm.inputs["A"], ConstantVariable): w1 = sgemm.inputs["A"] # type: ConstantVariable if sgemm.transpose_A: # w1.shape = (M, K) shape = [] axes = [] for axis, size in zip(out_order.axes, out_shape): shape.append(size) axes.append(axis) if mul(shape) >= sgemm.M: break if mul(shape) != sgemm.M: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes + [axis_k]) w1_virtual_shape = shape + [sgemm.K] else: # w1.shape = (K, M) shape = [sgemm.K] axes = [axis_k] for axis, size in zip(out_order.axes, out_shape): shape.append(size) axes.append(axis) if mul(shape) >= w1.size: break if mul(shape) != w1.size: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes) w1_virtual_shape = shape else: w1 = sgemm.inputs["B"] # type: ConstantVariable if sgemm.transpose_B: # w1.shape = (K, N) shape = [] axes = [] for axis, size in reversed( list(zip(out_order.axes, out_shape))): shape.insert(0, size) axes.insert(0, axis) if mul(shape) >= sgemm.N: break if mul(shape) != sgemm.N: # output axes are derived from both w1 and x continue w1_virtual_order = Order([axis_k] + axes) w1_virtual_shape = [sgemm.K] + shape else: # w1.shape = (N, K) shape = [sgemm.K] axes = [axis_k] for axis, size in reversed( list(zip(out_order.axes, out_shape))): shape.insert(0, size) axes.insert(0, axis) if mul(shape) >= w1.size: break if mul(shape) != w1.size: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes) w1_virtual_shape = shape h = sgemm.outputs["C"] # type: Variable x0 = elementwise_mul.inputs["x0"] x1 = elementwise_mul.inputs["x1"] if h == x1: if not isinstance(x0, ConstantVariable): # w2 is not constant continue w2 = x0 # type: ConstantVariable else: if not isinstance(x1, ConstantVariable): # w2 is not constant continue w2 = x1 # type: ConstantVariable y = elementwise_mul.outputs["y"] # type: Variable if not all(axis in w1_virtual_order.axes for axis in w2.order.axes): # w2's axes are derived from both w1 and x continue elementwise_mul.remove_all() y_dummy, = Transpose(None)(h) y_dummy.change_order(y.order) y_dummy.replace(y) w2.change_order(w1_virtual_order) w_new = ConstantVariable( w1.data.reshape(w1_virtual_shape), w1_virtual_order) * w2 # type: ConstantVariable w1.data = w_new.data.reshape(w1.shape) flag_changed = True matches = traverse.search_sub_structure( graph, [Sgemm, Variable, ElementwiseMul]) return graph, flag_changed
def optimize(self, graph: Graph): flag_changed = False """ before) +-{r2rgba}- y - x -{rgba2r}- h + +- after) +-{rgba2r}- h - x -+ +-{Transpose}- y - """ matches = traverse.search_sub_structure( graph, [Variable, ConvertRGBAtoR, Variable, ConvertRtoRGBA, Variable]) while len(matches) > 0: x, rgba2r, h, r2rgba, y = matches.pop( ) # type: Variable, ConvertRGBAtoR, Variable, ConvertRtoRGBA, Variable flag_changed = True r2rgba.remove_all() if x.order == y.order: OptimizeRule.replace_variable(graph, x, y) else: OptimizeRule.replace_variable(graph, x.transpose(y.order), y) if len(h.input_to) == 0: rgba2r.remove_all() matches = traverse.search_sub_structure( graph, [Variable, ConvertRGBAtoR, Variable, ConvertRtoRGBA, Variable]) matches = traverse.search_sub_structure( graph, [Variable, ConvertRtoRGBA, Variable, ConvertRGBAtoR, Variable]) while len(matches) > 0: x, r2rgba, h, rgba2r, y = matches.pop( ) # type: Variable, ConvertRtoRGBA, Variable, ConvertRGBAtoR, Variable flag_changed = True rgba2r.remove_all() if x.order == y.order: OptimizeRule.replace_variable(graph, x, y) else: OptimizeRule.replace_variable(graph, x.transpose(y.order), y) if len(h.input_to) == 0: r2rgba.remove_all() matches = traverse.search_sub_structure( graph, [Variable, ConvertRtoRGBA, Variable, ConvertRGBAtoR, Variable]) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: if not (flags.optimize.OPTIMIZE and flags.optimize.CONCAT_SCALAR_OPERATION): return graph, False flag_changed = False matches = search_sub_structure(graph, [ScalarOperation, Variable, ScalarOperation]) while len(matches) > 0: match = matches[0] op1 = match[0] # type: Operator op2 = match[2] # type: Operator y1 = op1.outputs["y"] y2 = op2.outputs["y"] if isinstance(op1, ScalarAffine): if isinstance(op2, ScalarAffine): op1.scale = op1.scale * op2.scale op1.bias = op1.bias * op2.scale + op2.bias op2.remove_all() op1.replace_output(y1, y2) elif isinstance(op2, ScalarAdd): op1.bias += op2.value op2.remove_all() op1.replace_output(y1, y2) elif isinstance(op2, ScalarMul): op1.scale *= op2.value op1.bias *= op2.value op2.remove_all() op1.replace_output(y1, y2) else: console.debug(f"[ConcatScalarOperation] unhandled pair: {type(op1)} and {type(op2)}") elif isinstance(op1, ScalarAdd): if isinstance(op2, ScalarAffine): op2.bias += op1.value * op2.scale x = op1.inputs["x0"] op1.remove_all() x.replace(y1) elif isinstance(op2, ScalarAdd): op1.parameters["value"] += op2.value op2.remove_all() op1.replace_output(y1, y2) elif isinstance(op2, ScalarMul): x = op1.inputs["x0"] new_op = ScalarAffine(None, scale=op2.value, bias=op1.value * op2.value) new_y, = new_op(x) op1.remove_all() op2.remove_all() y2.replace(new_y) else: console.debug(f"[ConcatScalarOperation] unhandled pair: {type(op1)} and {type(op2)}") elif isinstance(op1, ScalarMul): if isinstance(op2, ScalarAffine): op2.scale *= op1.value x = op1.inputs["x0"] op1.remove_all() x.replace(y1) elif isinstance(op2, ScalarAdd): x = op1.inputs["x0"] new_op = ScalarAffine(None, scale=op1.value, bias=op2.value) new_y, = new_op(x) op1.remove_all() op2.remove_all() y2.replace(new_y) elif isinstance(op2, ScalarMul): op1.parameters["value"] *= op2.value op2.remove_all() op1.replace_output(y1, y2) else: console.debug(f"[ConcatScalarOperation] unhandled pair: {type(op1)} and {type(op2)}") flag_changed = True matches = search_sub_structure(graph, [ScalarAffine, Variable, ScalarAffine]) return graph, flag_changed