def _convert_transpose(converter: ChainerConverter, c_op: "chainer.functions.Transpose"): x = converter.get_variable(c_op.inputs[0]) y, = Transpose(None)(x) y.change_order(Order([x.order.axes[axis] for axis in c_op.axes])) converter.set_variable(c_op.outputs[0](), y)
def _convert_transpose(converter: ONNXConverter, onnx_op: INodeProto): x = converter.get_variable(onnx_op.input[0]) attrs = attribute_dict(onnx_op) y, = Transpose(None)(x) perm = list(attrs["perm"].ints if "perm" in attrs else reversed(range(x.ndim))) y.change_order(Order([x.order.axes[i] for i in perm])) converter.set_variable(onnx_op.output[0], y)
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False variables = traverse.listup_variables(graph) while len(variables) > 0: x = variables.pop() for op1, op2 in itertools.permutations(x.input_to, 2): if op2 is op1: continue if op2.__class__ != op1.__class__: # class is not same continue if any((x_name not in op2.inputs) or ( op2.inputs[x_name] != op1.inputs[x_name]) for x_name in op1.inputs.keys()): # input is not same continue if any((key not in op2.parameters) or ( op2.parameters[key] != op1.parameters[key]) for key in op1.parameters.keys()): # parameter is not same continue flag_changed = True vs_1 = dict(op1.outputs) vs_2 = dict(op2.outputs) op2.remove_all() for v_name, v1 in vs_1.items(): v2 = vs_2[v_name] if v1.order == v2.order: """ +-{op3}- -{op1}- v1 -+ +-{op4}- """ OptimizeRule.replace_variable(graph, v2, v1) else: """ +-{op3}- -{op1}- v1 -+ +-{Transpose}- v2 -{op4}- """ v2_dummy, = Transpose(None)(v1) v2_dummy.change_order(v2.order) OptimizeRule.replace_variable(graph, v2_dummy, v2) variables = traverse.listup_variables(graph) break return graph, flag_changed
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): orig_var = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if orig_var.order in target_orders: return False trans, = Transpose(None)(orig_var) trans.change_order(target_orders[0]) op.remove_input(orig_var) op.append_input(var_name, trans) return True
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.inputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return False v_new, = Transpose(None)(v) op.replace_input(v, v_new, with_assert=False) v_new.change_order(target_orders[0]) return True
def transpose_handler(converter: TensorFlowConverter, tf_op: "tf.Operation"): x = converter.get_variable(tf_op.inputs[0]) indices = converter.get_variable(tf_op.inputs[1]) if not isinstance(indices, ConstantVariable): raise NotImplementedError( "[TensorFlowConverter] Operator 'Transpose' with dynamic indices is not supported yet." ) indices = indices.data.astype(int).flatten().tolist() # type: List[int] y, = Transpose(None)(x) y.change_order(Order([x.order.axes[i] for i in indices])) converter.set_variable(tf_op.outputs[0], y)
def _convert_linear_function( converter: ChainerConverter, c_op: "chainer.functions.connection.linear.LinearFunction"): x = converter.get_variable(c_op.inputs[0]) w = converter.get_variable(c_op.inputs[1]) # type: ConstantVariable x2, = Reshape(None, in_order=x.order, out_order=OrderNC, out_shape=[x.shape[0], mul(x.shape[1:])])(x) w2, = ReinterpretAxis(None, in_order=w.order, out_order=OrderNC)(w) w2, = Transpose(None)(w2) w2.change_order(OrderCN) y, = Linear(None)(x2, w2) y, = ReinterpretAxis(None, in_order=y.order, out_order=Order([x.order.axes[0], w.order.axes[0]]))(y) if len(c_op.inputs) == 3: # with bias b = converter.get_variable(c_op.inputs[2]) check_broadcast_constraints(y, b) y = y + b converter.set_variable(c_op.outputs[0](), y)
def _replace_output(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): v = op.outputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if v.order in target_orders: return False v_new = Variable(v.shape, v.order).change_order(target_orders[0]) op.replace_output(v, v_new, with_assert=False) Transpose(None)(v_new)[0].replace(v, with_assert=False) return True
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes(traverse.listup_operators(graph), Linear): x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] flag_changed = True op.remove_all() a_k = Axis.C a_n = w.order.axes[0] if w.order.axes[1] == a_k else w.order.axes[1] axes_m = [a for a in x.order.axes if a != a_k] K = x.shape_dict[a_k] M = x.size // K N = w.shape_dict[a_n] x, = Transpose(None)(x) x.change_order(Order([a_k] + axes_m)) w, = Transpose(None)(w) w.change_order(Order([a_k, a_n])) new_y, = Sgemm(None, M=M, N=N, K=K, out_shape=[x.shape_dict[a] for a in axes_m] + [N], out_order=Order(axes_m + [a_n]), transpose_A=False, transpose_B=True)(x, w) new_y, = Transpose(None)(new_y) OptimizeRule.replace_variable(graph, new_y, y) return graph, flag_changed
def _replace_output(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]): orig_var = op.outputs[var_name] if isinstance(target_orders, Order): target_orders = [target_orders] if orig_var.order in target_orders: return False trans = Variable(orig_var.shape, orig_var.order) trans.change_order(target_orders[0]) op.remove_output(orig_var) op.append_output(var_name, trans) transpose_op = Transpose(None) dummy_out, = transpose_op(trans) transpose_op.remove_output(dummy_out) transpose_op.append_output("y", orig_var) return True
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.listup_operators(graph): if isinstance(op, Reshape): flag_changed |= _replace_input(op, "x", op.parameters["in_order"]) flag_changed |= _replace_output(op, "y", op.parameters["out_order"]) continue elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D, Deconvolution2D, Space2Depth, Depth2Space)): flag_changed |= _replace_input(op, "x", OrderNHWC) flag_changed |= _replace_output(op, "y", OrderNHWC) continue elif isinstance(op, Softmax): x = op.inputs["x"] y = op.outputs["y"] target_axis = op.parameters["axis"] if not (x.ndim == 2 and x.order.axes_dict[target_axis] == x.ndim - 1): """ Before) | x | | y | |-----| -{softmax}-> |-----| | XYZ | axis=Y | XYZ | After) | x | | hx1 | | hx2 | | hy1 | | hy2 | | y | |-----| -{transpose}-> |-----| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |-----| -{transpose}-> |-----| | XYZ | | XZY | | NC | axis=C | NC | | XZY | | XYZ | : : order_nd = XZY order_2d = NC """ op.remove_all() axes_nd = list(x.order.axes) axes_nd.remove(target_axis) axes_nd.append(target_axis) order_nd = Order(axes_nd) shape_nd = tuple([x.shape_dict[axis] for axis in axes_nd]) order_2d = OrderNC shape_2d = tuple([ x.size // x.shape_dict[target_axis], x.shape_dict[target_axis] ]) if x.order == order_nd: hx1 = x else: hx1, = Transpose(None)(x) hx1.change_order(order_nd) flag_changed = True if hx1.order == order_2d and hx1.shape == shape_2d: hx2 = hx1 else: hx2, = Reshape(None, in_order=hx1.order, out_order=order_2d, out_shape=shape_2d)(hx1) flag_changed = True hy1, = Softmax(None, axis=Axis.C)(hx2) if hy1.order == order_nd and hy1.shape == shape_nd: hy2 = hy1 else: hy2, = Reshape(None, in_order=hy1.order, out_order=order_nd, out_shape=shape_nd)(hy1) flag_changed = True if hy2.order == y.order: y_dummy = hy2 else: y_dummy, = Transpose(None)(hy2) y_dummy.change_order(y.order) flag_changed = True y_dummy.replace(y) continue return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False matches = traverse.search_sub_structure( graph, [Sgemm, Variable, ElementwiseMul]) while len(matches) > 0: match = matches.pop() sgemm = match[0] # type: Sgemm elementwise_mul = match[2] # type: ElementwiseMul out_order = sgemm.parameters["out_order"] out_shape = sgemm.parameters["out_shape"] axis_k = Axis('AxisK') if not isinstance(sgemm.inputs["A"], ConstantVariable) and not isinstance( sgemm.inputs["B"], ConstantVariable): # neither x nor w1 is constant continue elif isinstance(sgemm.inputs["A"], ConstantVariable): w1 = sgemm.inputs["A"] # type: ConstantVariable if sgemm.transpose_A: # w1.shape = (M, K) shape = [] axes = [] for axis, size in zip(out_order.axes, out_shape): shape.append(size) axes.append(axis) if mul(shape) >= sgemm.M: break if mul(shape) != sgemm.M: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes + [axis_k]) w1_virtual_shape = shape + [sgemm.K] else: # w1.shape = (K, M) shape = [sgemm.K] axes = [axis_k] for axis, size in zip(out_order.axes, out_shape): shape.append(size) axes.append(axis) if mul(shape) >= w1.size: break if mul(shape) != w1.size: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes) w1_virtual_shape = shape else: w1 = sgemm.inputs["B"] # type: ConstantVariable if sgemm.transpose_B: # w1.shape = (K, N) shape = [] axes = [] for axis, size in reversed( list(zip(out_order.axes, out_shape))): shape.insert(0, size) axes.insert(0, axis) if mul(shape) >= sgemm.N: break if mul(shape) != sgemm.N: # output axes are derived from both w1 and x continue w1_virtual_order = Order([axis_k] + axes) w1_virtual_shape = [sgemm.K] + shape else: # w1.shape = (N, K) shape = [sgemm.K] axes = [axis_k] for axis, size in reversed( list(zip(out_order.axes, out_shape))): shape.insert(0, size) axes.insert(0, axis) if mul(shape) >= w1.size: break if mul(shape) != w1.size: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes) w1_virtual_shape = shape h = sgemm.outputs["C"] # type: Variable x0 = elementwise_mul.inputs["x0"] x1 = elementwise_mul.inputs["x1"] if h == x1: if not isinstance(x0, ConstantVariable): # w2 is not constant continue w2 = x0 # type: ConstantVariable else: if not isinstance(x1, ConstantVariable): # w2 is not constant continue w2 = x1 # type: ConstantVariable y = elementwise_mul.outputs["y"] # type: Variable if not all(axis in w1_virtual_order.axes for axis in w2.order.axes): # w2's axes are derived from both w1 and x continue elementwise_mul.remove_all() y_dummy, = Transpose(None)(h) y_dummy.change_order(y.order) y_dummy.replace(y) w2.change_order(w1_virtual_order) w_new = ConstantVariable( w1.data.reshape(w1_virtual_shape), w1_virtual_order) * w2 # type: ConstantVariable w1.data = w_new.data.reshape(w1.shape) flag_changed = True matches = traverse.search_sub_structure( graph, [Sgemm, Variable, ElementwiseMul]) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.listup_operators(graph): if isinstance(op, Transpose): x = op.inputs["x0"] y = op.outputs["y"] if x.order == y.order: op.remove_all() x.replace(y) flag_changed = True if all(isinstance(op2, (Elementwise, SplitAxis)) for op2 in y.input_to): op.remove_all() for op2 in list(y.input_to): name = op2._get_input_name(y) op2.remove_input(y) op2.append_input(name, x) elif isinstance(op, Reshape): flag_changed |= _replace_input(op, "x", op.parameters["in_order"]) flag_changed |= _replace_output(op, "y", op.parameters["out_order"]) elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D, Deconvolution2D)): flag_changed |= _replace_input(op, "x", OrderNHWC) flag_changed |= _replace_output(op, "y", OrderNHWC) elif isinstance(op, Softmax): x = op.inputs["x"] y = op.outputs["y"] if x.ndim > 2: """ Before) | x | | y | |------| -{softmax}-> |------| | NCHW | | NCHW | After) | x | | hx1 | | hx2 | | hy1 | | hy2 | | y | |------| -{transpose}-> |------| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |------| -{transpose}-> |------| | NCHW | | NHWC | | NC | | NC | | NHWC | | NCHW | """ op.remove_all() target_axis = op.parameters["axis"] axes_nd = list(x.order.axes) axes_nd.remove(target_axis) axes_nd.append(target_axis) order_nd = Order(axes_nd) shape_nd = [x.shape_dict[axis] for axis in axes_nd] order_2d = OrderNC shape_2d = [x.size // x.shape_dict[target_axis], x.shape_dict[target_axis]] hx1, = Transpose(None)(x) hx1.change_order(order_nd) hx2, = Reshape(None, in_order=hx1.order, out_order=order_2d, out_shape=shape_2d)(hx1) hy1, = Softmax(None, axis=Axis.C)(hx2) hy2, = Reshape(None, in_order=hy1.order, out_order=order_nd, out_shape=shape_nd)(hy1) y_dummy, = Transpose(None)(hy2) y_dummy.change_order(y.order) y_dummy.replace(y) flag_changed = True else: flag_changed |= _replace_input(op, "x", OrderNC) flag_changed |= _replace_output(op, "y", OrderNC) return graph, flag_changed