def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes(traverse.listup_operators(graph), Convolution2D): # type: Convolution2D x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] flag_changed = True op.remove_all() a_filter, a_kh, a_kw = Axis(), Axis(), Axis() w, = ReinterpretAxis(None, in_order=OrderNHWC, out_order=Order( [Axis.C, a_kh, a_kw, a_filter]))(w) if op.WH == 1 and op.WW == 1 and op.stride == ( 1, 1) and op.padding == (0, 0): # Projection col, = ReinterpretAxis( None, in_order=OrderNHWC, out_order=Order([Axis.N, Axis.H, Axis.W, a_filter]))(x) new_y, = Tensordot(None, [[a_filter], [a_kh, a_kw, a_filter]])(col, w) elif op.WH == x.shape_dict[Axis.H] and op.WW == x.shape_dict[ Axis.W] and op.padding == (0, 0): # Global convolution col, = ReinterpretAxis(None, in_order=OrderNHWC, out_order=Order( [Axis.N, a_kh, a_kw, a_filter]))(x) new_y, = Tensordot( None, [[[a_kh, a_kw, a_filter], [a_kh, a_kw, a_filter]], [a_kh, a_kw, a_filter]])(col, w) else: # General convolution col, = Im2Col(None, ksize=op.ksize, stride=op.stride, padding=op.padding, dilation_rate=op.dilation_rate)(x) col, = ReinterpretAxis( None, in_order=OrderNHWC, out_order=Order([Axis.N, Axis.H, Axis.W, a_filter]))(col) new_y, = Tensordot(None, [[a_filter], [a_kh, a_kw, a_filter]])(col, w) new_y = new_y.transpose(y.order) OptimizeRule.replace_variable(graph, new_y, y) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes( traverse.listup_operators(graph), Deconvolution2D): # type: Deconvolution2D x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] flag_changed = True op.remove_all() a_filter, a_kh, a_kw = Axis(), Axis(), Axis() w, = ReinterpretAxis(None, in_order=OrderNHWC, out_order=Order( [Axis.C, a_kh, a_kw, a_filter]))(w) x, = ReinterpretAxis(None, in_order=OrderNHWC, out_order=Order( [Axis.N, Axis.H, Axis.W, a_filter]))(x) col, = Tensordot(None, axes=a_filter)(x, w) col = col.transpose( Order([Axis.N, Axis.H, Axis.W, a_kh, a_kw, Axis.C])) col = col.reshape(shape=[*col.shape[0:3], mul(col.shape[3:6])], order=OrderNHWC) new_y, = Col2Im(None, ksize=op.ksize, stride=op.stride, padding=op.padding)(col) OptimizeRule.replace_variable(graph, new_y.transpose_like(y), y) return graph, flag_changed