def __call__(self, *xs: "variable.Variable"): y_axes = [] y_shape_dict = AxisKeyDict() # Check variable in descent order of the number of dimensions. # Without this procedure, in case that x0.order=C and x1.order=NC, the output order is CN. Expected result is NC. xs_order = [(i, x) for i, x in enumerate(xs)] xs_order.sort(key=lambda d: d[1].ndim, reverse=True) for i, x in xs_order: for axis in x.order.axes: if axis in y_axes: if y_shape_dict[axis] == 1: # broadcast y_shape_dict[axis] = x.shape_dict[axis] else: y_axes.append(axis) y_shape_dict[axis] = x.shape_dict[axis] if Placeholder.check_resolved(x.shape_dict[axis]): if Placeholder.check_resolved(y_shape_dict[axis]): assert y_shape_dict[axis] == x.shape_dict[axis] or x.shape_dict[axis] == 1, \ "All input variables of elementwise operator should be same shape: " \ f"y.shape_dict[{axis}]={y_shape_dict[axis]}, " \ f"x{i}.shape_dict[{axis}]={x.shape_dict[axis]}" else: y_shape_dict[axis] = x.shape_dict[axis] y = variable.Variable([y_shape_dict[axis] for axis in y_axes], Order(y_axes)) ChannelMode.set(y, ChannelModeEnum.R) for i, x in enumerate(xs): self.append_input(f"x{i}", x) self.append_output("y", y) return y,
def fold_constance(self): x = self.inputs["x0"] # type:ConstantVariable y = self.outputs["y"] # type:Variable self.remove_all() y.replace(x) ChannelMode.set(x, ChannelModeEnum.RGBA) x.change_order(y.order)
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes(traverse.listup_operators(graph), Convolution2D): # type: Convolution2D x = op.inputs["x"] w = op.inputs["w"] # type: ConstantVariable old_y = op.outputs["y"] flag_changed = True op.remove_all() assert x.order == OrderNHWC assert isinstance(w, ConstantVariable) assert old_y.order == OrderNHWC w.change_order(OrderNHWC) col, = Im2Col(None, ksize=op.ksize, stride=op.stride, padding=op.padding, dilation_rate=op.dilation_rate)(x) col.change_order(OrderNHWC) ChannelMode.set(col, ChannelModeEnum.R) M = col.shape_dict[Axis.N] * col.shape_dict[ Axis.H] * col.shape_dict[Axis.W] N = w.shape_dict[Axis.N] K = col.shape_dict[Axis.C] if K > (w.size // N): w2_data = np.hstack([ w.data.reshape(N, w.size // N), np.zeros([N, K - w.size // N]) ]) else: w2_data = w.data.reshape(N, w.size // N) w = ConstantVariable(w2_data, OrderNC) ChannelMode.set(w, ChannelModeEnum.R) sgemm = Sgemm(None, M=M, N=N, K=K, out_shape=[ col.shape_dict[Axis.N], col.shape_dict[Axis.H], col.shape_dict[Axis.W], w.shape_dict[Axis.N] ], out_order=OrderNHWC, transpose_A=True, transpose_B=False) new_y, = sgemm(col, w) sgemm.replace_output(new_y, old_y) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes(traverse.listup_operators(graph), Tensordot): # type: Tensordot A = op.inputs["A"] B = op.inputs["B"] axes = op.axes K = mul(A.shape_dict[a] for a in axes[0]) M = A.size // K N = B.size // K if all([self.optimize_channel_mode, K % 4 == 0]): if ChannelMode.get(A) != ChannelModeEnum.RGBA: flag_changed = True ChannelMode.set(A, ChannelModeEnum.RGBA) if ChannelMode.get(B) != ChannelModeEnum.RGBA: flag_changed = True ChannelMode.set(B, ChannelModeEnum.RGBA) texture_shape_A = [M, K // 4] texture_shape_B = [N, K // 4] else: if ChannelMode.get(A) != ChannelModeEnum.R: flag_changed = True ChannelMode.set(A, ChannelModeEnum.R) if ChannelMode.get(B) != ChannelModeEnum.R: flag_changed = True ChannelMode.set(B, ChannelModeEnum.R) texture_shape_A = [M, K] texture_shape_B = [N, K] if TextureShape.get(A) != texture_shape_A: flag_changed = True TextureShape.set(A, height=texture_shape_A[0], width=texture_shape_A[1]) if TextureShape.get(B) != texture_shape_B: flag_changed = True TextureShape.set(B, height=texture_shape_B[0], width=texture_shape_B[1]) return graph, flag_changed
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum): """ before) -{op}- v after) -{op}- v' -{conversion}- v """ v = op.outputs[var_name] if ChannelMode.get(v) == target: return False v_new = Variable(v.shape, v.order) ChannelMode.set(v_new, target) op.replace_output(v, v_new) if target == ChannelModeEnum.RGBA: ConvertRGBAtoR(None)(v_new)[0].replace(v) else: ConvertRtoRGBA(None)(v_new)[0].replace(v) return True
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum): """ before) -{op}- v after) -{op}- v' -{conversion}- v """ v = op.outputs[var_name] if ChannelMode.get(v) == target: return False v_new = Variable(v.shape, v.order) ChannelMode.set(v_new, target) op.replace_output(v, v_new) if target == ChannelModeEnum.RGBA: convert_rgba_to_r(v_new).change_order(v.order).replace(v) else: convert_r_to_rgba(v_new).change_order(v.order).replace(v) return True
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes(traverse.listup_operators(graph), Tensordot): A = op.inputs["A"] B = op.inputs["B"] axes = op.axes K = mul(A.shape_dict[a] for a in axes[0]) M = A.size // K N = B.size // K if K % 4 == 0: if ChannelMode.get(A) != ChannelModeEnum.RGBA: flag_changed = True ChannelMode.set(A, ChannelModeEnum.RGBA) if ChannelMode.get(B) != ChannelModeEnum.RGBA: flag_changed = True ChannelMode.set(B, ChannelModeEnum.RGBA) else: if ChannelMode.get(A) != ChannelModeEnum.R: flag_changed = True ChannelMode.set(A, ChannelModeEnum.R) if ChannelMode.get(B) != ChannelModeEnum.R: flag_changed = True ChannelMode.set(B, ChannelModeEnum.R) if TextureShape.get(A) != (M, K): flag_changed = True TextureShape.set(A, height=M, width=K) if TextureShape.get(B) != (N, K): flag_changed = True TextureShape.set(B, height=N, width=K) return graph, flag_changed
def fold_constance(self, graph: Graph): x = self.inputs["x0"] # type:ConstantVariable y = self.outputs["y"] # type:Variable self.remove_all() OptimizeRule.replace_variable(graph, y, x.change_order(y.order)) ChannelMode.set(x, ChannelModeEnum.R)
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for sgemm in traverse.filter_nodes(traverse.listup_operators(graph), Sgemm): # type: Sgemm A = sgemm.inputs["A"] B = sgemm.inputs["B"] M = sgemm.M N = sgemm.N K = sgemm.K transpose_A = sgemm.transpose_A transpose_B = sgemm.transpose_B if all([ self.optimize_channel_mode, K % 4 == 0, isinstance(A, ConstantVariable) or transpose_A == True, isinstance(B, ConstantVariable) or transpose_B == False ]): if transpose_A != True: assert isinstance(A, ConstantVariable) flag_changed = True old_A = A A = ConstantVariable( A.data.reshape([K, M]).transpose(), Order([Axis(None), Axis(None)])) ChannelMode.set(A, ChannelMode.get(old_A)) sgemm.replace_input(old_A, A, with_assert=False) sgemm.parameters["transpose_A"] = transpose_A = True if transpose_B != False: assert isinstance(B, ConstantVariable) flag_changed = True old_B = B B = ConstantVariable( B.data.reshape([K, N]).transpose(), Order([Axis(None), Axis(None)])) ChannelMode.set(B, ChannelMode.get(old_B)) sgemm.replace_input(old_B, B, with_assert=False) sgemm.parameters["transpose_B"] = transpose_B = False if ChannelMode.get(A) != ChannelModeEnum.RGBA: flag_changed = True ChannelMode.set(A, ChannelModeEnum.RGBA) if ChannelMode.get(B) != ChannelModeEnum.RGBA: flag_changed = True ChannelMode.set(B, ChannelModeEnum.RGBA) texture_shape_A = [M, K // 4] if transpose_A else [K // 4, M] texture_shape_B = [K // 4, N] if transpose_B else [N, K // 4] else: if ChannelMode.get(A) != ChannelModeEnum.R: flag_changed = True ChannelMode.set(A, ChannelModeEnum.R) if ChannelMode.get(B) != ChannelModeEnum.R: flag_changed = True ChannelMode.set(B, ChannelModeEnum.R) texture_shape_A = [M, K] if transpose_A else [K, M] texture_shape_B = [K, N] if transpose_B else [N, K] if TextureShape.get(A) != texture_shape_A: flag_changed = True TextureShape.set(A, height=texture_shape_A[0], width=texture_shape_A[1]) if TextureShape.get(B) != texture_shape_B: flag_changed = True TextureShape.set(B, height=texture_shape_B[0], width=texture_shape_B[1]) if flag_changed: graph, _ = ConstantFolding().optimize(graph) return graph, flag_changed
def __call__(self, x0: Variable): y, = super(ConvertRtoRGBA, self).__call__(x0) ChannelMode.set(y, ChannelModeEnum.RGBA) return y,