def _convert_transpose(converter: ONNXConverter, onnx_op: INodeProto): x = converter.get_variable(onnx_op.input[0]) attrs = attribute_dict(onnx_op) y, = Transpose(None)(x) perm = list(attrs["perm"].ints if "perm" in attrs else reversed(range(x.ndim))) y.change_order(Order([x.order.axes[i] for i in perm])) converter.set_variable(onnx_op.output[0], y)
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): orig_var = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if orig_var.order in target_orders: return False trans, = Transpose(None)(orig_var) trans.change_order(target_orders[0]) op.remove_input(orig_var) op.append_input(var_name, trans) return True
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return False v_new, = Transpose(None)(v) op.replace_input(v, v_new, with_assert=False) v_new.change_order(target_orders[0]) return True
def transpose_handler(converter: TensorFlowConverter, tf_op: "tf.Operation"): x = converter.get_variable(tf_op.inputs[0]) indices = converter.get_variable(tf_op.inputs[1]) if not isinstance(indices, ConstantVariable): raise NotImplementedError( "[TensorFlowConverter] Operator 'Transpose' with dynamic indices is not supported yet." ) indices = indices.data.astype(int).flatten().tolist() # type: List[int] y, = Transpose(None)(x) y.change_order(Order([x.order.axes[i] for i in indices])) converter.set_variable(tf_op.outputs[0], y)
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes(traverse.listup_operators(graph), Linear): x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] flag_changed = True op.remove_all() a_k = Axis.C a_n = w.order.axes[0] if w.order.axes[1] == a_k else w.order.axes[1] axes_m = [a for a in x.order.axes if a != a_k] K = x.shape_dict[a_k] M = x.size // K N = w.shape_dict[a_n] x, = Transpose(None)(x) x.change_order(Order([a_k] + axes_m)) w, = Transpose(None)(w) w.change_order(Order([a_k, a_n])) new_y, = Sgemm(None, M=M, N=N, K=K, out_shape=[x.shape_dict[a] for a in axes_m] + [N], out_order=Order(axes_m + [a_n]), transpose_A=False, transpose_B=True)(x, w) new_y, = Transpose(None)(new_y) OptimizeRule.replace_variable(graph, new_y, y) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.listup_operators(graph): if isinstance(op, Reshape): flag_changed |= _replace_input(op, "x", op.parameters["in_order"]) flag_changed |= _replace_output(op, "y", op.parameters["out_order"]) continue elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D, Deconvolution2D, Space2Depth, Depth2Space)): flag_changed |= _replace_input(op, "x", OrderNHWC) flag_changed |= _replace_output(op, "y", OrderNHWC) continue elif isinstance(op, Softmax): x = op.inputs["x"] y = op.outputs["y"] target_axis = op.parameters["axis"] if not (x.ndim == 2 and x.order.axes_dict[target_axis] == x.ndim - 1): """ Before) | x | | y | |-----| -{softmax}-> |-----| | XYZ | axis=Y | XYZ | After) | x | | hx1 | | hx2 | | hy1 | | hy2 | | y | |-----| -{transpose}-> |-----| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |-----| -{transpose}-> |-----| | XYZ | | XZY | | NC | axis=C | NC | | XZY | | XYZ | : : order_nd = XZY order_2d = NC """ op.remove_all() axes_nd = list(x.order.axes) axes_nd.remove(target_axis) axes_nd.append(target_axis) order_nd = Order(axes_nd) shape_nd = tuple([x.shape_dict[axis] for axis in axes_nd]) order_2d = OrderNC shape_2d = tuple([ x.size // x.shape_dict[target_axis], x.shape_dict[target_axis] ]) if x.order == order_nd: hx1 = x else: hx1, = Transpose(None)(x) hx1.change_order(order_nd) flag_changed = True if hx1.order == order_2d and hx1.shape == shape_2d: hx2 = hx1 else: hx2, = Reshape(None, in_order=hx1.order, out_order=order_2d, out_shape=shape_2d)(hx1) flag_changed = True hy1, = Softmax(None, axis=Axis.C)(hx2) if hy1.order == order_nd and hy1.shape == shape_nd: hy2 = hy1 else: hy2, = Reshape(None, in_order=hy1.order, out_order=order_nd, out_shape=shape_nd)(hy1) flag_changed = True if hy2.order == y.order: y_dummy = hy2 else: y_dummy, = Transpose(None)(hy2) y_dummy.change_order(y.order) flag_changed = True y_dummy.replace(y) continue return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False matches = traverse.search_sub_structure( graph, [Sgemm, Variable, ElementwiseMul]) while len(matches) > 0: match = matches.pop() sgemm = match[0] # type: Sgemm elementwise_mul = match[2] # type: ElementwiseMul out_order = sgemm.parameters["out_order"] out_shape = sgemm.parameters["out_shape"] axis_k = Axis('AxisK') if not isinstance(sgemm.inputs["A"], ConstantVariable) and not isinstance( sgemm.inputs["B"], ConstantVariable): # neither x nor w1 is constant continue elif isinstance(sgemm.inputs["A"], ConstantVariable): w1 = sgemm.inputs["A"] # type: ConstantVariable if sgemm.transpose_A: # w1.shape = (M, K) shape = [] axes = [] for axis, size in zip(out_order.axes, out_shape): shape.append(size) axes.append(axis) if mul(shape) >= sgemm.M: break if mul(shape) != sgemm.M: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes + [axis_k]) w1_virtual_shape = shape + [sgemm.K] else: # w1.shape = (K, M) shape = [sgemm.K] axes = [axis_k] for axis, size in zip(out_order.axes, out_shape): shape.append(size) axes.append(axis) if mul(shape) >= w1.size: break if mul(shape) != w1.size: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes) w1_virtual_shape = shape else: w1 = sgemm.inputs["B"] # type: ConstantVariable if sgemm.transpose_B: # w1.shape = (K, N) shape = [] axes = [] for axis, size in reversed( list(zip(out_order.axes, out_shape))): shape.insert(0, size) axes.insert(0, axis) if mul(shape) >= sgemm.N: break if mul(shape) != sgemm.N: # output axes are derived from both w1 and x continue w1_virtual_order = Order([axis_k] + axes) w1_virtual_shape = [sgemm.K] + shape else: # w1.shape = (N, K) shape = [sgemm.K] axes = [axis_k] for axis, size in reversed( list(zip(out_order.axes, out_shape))): shape.insert(0, size) axes.insert(0, axis) if mul(shape) >= w1.size: break if mul(shape) != w1.size: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes) w1_virtual_shape = shape h = sgemm.outputs["C"] # type: Variable x0 = elementwise_mul.inputs["x0"] x1 = elementwise_mul.inputs["x1"] if h == x1: if not isinstance(x0, ConstantVariable): # w2 is not constant continue w2 = x0 # type: ConstantVariable else: if not isinstance(x1, ConstantVariable): # w2 is not constant continue w2 = x1 # type: ConstantVariable y = elementwise_mul.outputs["y"] # type: Variable if not all(axis in w1_virtual_order.axes for axis in w2.order.axes): # w2's axes are derived from both w1 and x continue elementwise_mul.remove_all() y_dummy, = Transpose(None)(h) y_dummy.change_order(y.order) y_dummy.replace(y) w2.change_order(w1_virtual_order) w_new = ConstantVariable( w1.data.reshape(w1_virtual_shape), w1_virtual_order) * w2 # type: ConstantVariable w1.data = w_new.data.reshape(w1.shape) flag_changed = True matches = traverse.search_sub_structure( graph, [Sgemm, Variable, ElementwiseMul]) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.listup_operators(graph): if isinstance(op, Transpose): x = op.inputs["x0"] y = op.outputs["y"] if x.order == y.order: op.remove_all() x.replace(y) flag_changed = True if all(isinstance(op2, (Elementwise, SplitAxis)) for op2 in y.input_to): op.remove_all() for op2 in list(y.input_to): name = op2._get_input_name(y) op2.remove_input(y) op2.append_input(name, x) elif isinstance(op, Reshape): flag_changed |= _replace_input(op, "x", op.parameters["in_order"]) flag_changed |= _replace_output(op, "y", op.parameters["out_order"]) elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D, Deconvolution2D)): flag_changed |= _replace_input(op, "x", OrderNHWC) flag_changed |= _replace_output(op, "y", OrderNHWC) elif isinstance(op, Softmax): x = op.inputs["x"] y = op.outputs["y"] if x.ndim > 2: """ Before) | x | | y | |------| -{softmax}-> |------| | NCHW | | NCHW | After) | x | | hx1 | | hx2 | | hy1 | | hy2 | | y | |------| -{transpose}-> |------| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |------| -{transpose}-> |------| | NCHW | | NHWC | | NC | | NC | | NHWC | | NCHW | """ op.remove_all() target_axis = op.parameters["axis"] axes_nd = list(x.order.axes) axes_nd.remove(target_axis) axes_nd.append(target_axis) order_nd = Order(axes_nd) shape_nd = [x.shape_dict[axis] for axis in axes_nd] order_2d = OrderNC shape_2d = [x.size // x.shape_dict[target_axis], x.shape_dict[target_axis]] hx1, = Transpose(None)(x) hx1.change_order(order_nd) hx2, = Reshape(None, in_order=hx1.order, out_order=order_2d, out_shape=shape_2d)(hx1) hy1, = Softmax(None, axis=Axis.C)(hx2) hy2, = Reshape(None, in_order=hy1.order, out_order=order_nd, out_shape=shape_nd)(hy1) y_dummy, = Transpose(None)(hy2) y_dummy.change_order(y.order) y_dummy.replace(y) flag_changed = True else: flag_changed |= _replace_input(op, "x", OrderNC) flag_changed |= _replace_output(op, "y", OrderNC) return graph, flag_changed