def _convert_average_pooling1d(converter: KerasConverter, k_op: "keras.layers.AveragePooling1D"): x = converter.get_variable(converter.get_input_tensor(k_op)[0]) # FIXME: More effective implementation y, = Reshape(None, in_order=x.order, out_order=OrderNHWC, out_shape=[x.shape[0], x.shape[1], 1, x.shape[2]])(x) if k_op.padding == "valid": padding = (0, 0) elif k_op.padding == "same": padding = (k_op.pool_size[0] // 2, 0) else: raise NotImplementedError(f"Unknown padding: {k_op.padding}") y, = AveragePooling2D(None, ksize=(k_op.pool_size[0], 1), stride=(1, 1), padding=padding)(y) z, = Reshape(None, in_order=y.order, out_order=OrderNTC, out_shape=[y.shape[0], y.shape[1], y.shape[3]])(y) converter.set_variable(converter.get_output_tensor(k_op)[0], z)
def optimize_operator(self, graph: Graph, op: Reshape): x = op.inputs["x"] y = op.outputs["y"] if x.order == y.order and x.shape == y.shape: _remove_unary_operator(graph, op) return True if x.shape == y.shape: op.remove_all() y_dummy, = ReinterpretAxis(None, in_order=x.order, out_order=y.order)(x) y_dummy.replace(y) return True if isinstance(x, ConstantVariable) and x.output_from is None: _remove_unary_operator(graph, op) x.change_order(y.order) return True if all([ y not in graph.outputs, all(x.stride_dict[axis] == y.stride_dict[axis] for axis in [axis for axis in x.order.axes if axis in y.order.axes]), all(isinstance(op2, Elementwise) for op2 in y.input_to) ]): _remove_unary_operator(graph, op) return True return False
def template(x_order=OrderNHWC, x_shape=(2, 3, 4, 5), y_order=OrderNHWC, y_shape=(1, 12, 2, 5), description: str = ""): vx = np.random.rand(*x_shape) - 0.5 x = Variable(vx.shape, order=OrderNHWC) y, = Reshape(None, in_order=x_order, out_order=y_order, out_shape=y_shape)(x) x.change_order(x_order) y.change_order(y_order) generate_kernel_test_case( description=f"Reshape {description}", graph=Graph([x], [y]), inputs={ x: np.transpose(vx, [OrderNHWC.axes_dict[a] for a in x.order.axes]).flatten() }, expected={ y: np.transpose(vx, [OrderNHWC.axes_dict[a] for a in y.order.axes]).flatten() }, )
def _convert_global_average_pooling1d(converter: KerasConverter, k_op: keras.layers.GlobalAveragePooling1D): x = converter.get_variable(converter.get_input_tensor(k_op)[0]) # FIXME: More effective implementation y, = Reshape(None, in_order=OrderNTC, out_order=OrderNHWC, out_shape=[x.shape[0], x.shape[1], 1, x.shape[2]])(x) y, = AveragePooling2D(None, ksize=(x.shape[1], 1), stride=(1, 1), padding=(0, 0))(y) # flatten without changing memory layout z, = Reshape(None, in_order=y.order, out_order=OrderNC, out_shape=[y.shape[0], mul(y.shape[1:])])(y) converter.set_variable(converter.get_output_tensor(k_op)[0], z)
def _convert_reshape(converter: KerasConverter, k_op: "keras.layers.Reshape"): x = converter.get_variable(converter.get_input_tensor(k_op)[0]) target_shape = [x.shape[0]] + list(k_op.target_shape) if len(target_shape) == 2: target_order = OrderNC elif len(target_shape) == 3: target_order = OrderNTC elif len(target_shape) == 4: target_order = OrderNHWC else: raise NotImplementedError( f"[KerasConverter] Unknown default order: shape={target_shape}") console.warning( "[KerasConverter] keras.layers.Reshape is parsed new data order as default order (OrderNC in 2D, " "OrderNTC in 3D, OrderNHWC in 4D). To handle this, please overwrite keras.layers.Reshape converter " "handler.") y, = Reshape(None, in_order=x.order, out_order=target_order, out_shape=target_shape)(x) converter.set_variable(converter.get_output_tensor(k_op)[0], y)
def convert_layer_global_average_pooling2d( converter: KerasConverter, k_op: "keras.layers.GlobalAveragePooling2D"): x = converter.get_variable(converter.get_input_tensor(k_op)[0]) if k_op.data_format == "channels_first": assert x.order == OrderNCHW elif k_op.data_format == "channels_last": assert x.order == OrderNHWC else: raise ValueError( f"[KerasConverter] Unknown data format: {k_op.data_format}") y, = AveragePooling2D(None, ksize=(x.shape_dict[Axis.H], x.shape_dict[Axis.W]), stride=(1, 1), padding=(0, 0))(x) # flatten without changing memory layout z, = Reshape(None, in_order=y.order, out_order=OrderNC, out_shape=[y.shape[0], mul(y.shape[1:])])(y) converter.set_variable(converter.get_output_tensor(k_op)[0], z)
def _convert_linear_function( converter: ChainerConverter, c_op: "chainer.functions.connection.linear.LinearFunction"): x = converter.get_variable(c_op.inputs[0]) w = converter.get_variable(c_op.inputs[1]) # type: ConstantVariable x2, = Reshape(None, in_order=x.order, out_order=OrderNC, out_shape=[x.shape[0], mul(x.shape[1:])])(x) w2, = ReinterpretAxis(None, in_order=w.order, out_order=OrderNC)(w) w2, = Transpose(None)(w2) w2.change_order(OrderCN) y, = Linear(None)(x2, w2) y, = ReinterpretAxis(None, in_order=y.order, out_order=Order([x.order.axes[0], w.order.axes[0]]))(y) if len(c_op.inputs) == 3: # with bias b = converter.get_variable(c_op.inputs[2]) check_broadcast_constraints(y, b) y = y + b converter.set_variable(c_op.outputs[0](), y)
def _convert_reshape(converter: ChainerConverter, c_op: "chainer.functions.Reshape"): assert len(c_op.inputs) == 1, \ f"For 'Reshape' operator in chainer, expected number of inputs is 1, but actual is {len(c_op.inputs)}" x = converter.get_variable(c_op.inputs[0]) out_shape = list(c_op.shape) # c_op.shape is tuple if len(out_shape) == 1: out_order = OrderC elif len(out_shape) == 2: out_order = OrderNC elif len(out_shape) == 4: out_order = OrderNCHW else: raise NotImplementedError( "Reshaping into dimensions none of 1, 2, 4 is not supported.") assert mul(out_shape) == x.size y, = Reshape(None, in_order=x.order, out_order=out_order, out_shape=out_shape)(x) converter.set_variable(c_op.outputs[0](), y)
def _convert_reshape(converter: ONNXConverter, onnx_op: INodeProto): x = converter.get_variable(onnx_op.input[0]) if converter.opset_version >= 5: # output shape is specified by onnx_op.input[1] # It have to be ConstantVariable. # TODO: test for different operator set version shape_var = converter.get_variable(onnx_op.input[1]) assert isinstance( shape_var, ConstantVariable ), "Shape specifier of Reshape operator have to be constant." out_shape = [int(d) for d in shape_var.data] else: # Reshape-1 attrs = attribute_dict(onnx_op) out_shape = [ r if s == 0 else s for r, s in zip(x.shape, attrs["shape"].ints) ] if -1 in out_shape: i = out_shape.index(-1) out_shape.remove(-1) out_shape.insert(i, x.size // mul(out_shape)) out_order = Order([None] * len(out_shape)) y, = Reshape(None, in_order=x.order, out_order=out_order, out_shape=out_shape)(x) converter.set_variable(onnx_op.output[0], y)
def _convert_repeat_vector(converter: KerasConverter, k_op: "keras.layers.RepeatVector"): x = converter.get_variable(converter.get_input_tensor(k_op)[0]) assert x.order == OrderNC, f"[KerasConverter] Currently only OrderNC is supported for input variable order of " \ f"keras.layers.RepeatVector: x.order={x.order}" N = x.shape_dict[Axis.N] n = k_op.n C = x.shape_dict[Axis.C] # TODO: Implement more efficient version # ex) x.shape=(N=2, C=3), n=2 # # x(N, C) * w(C, n*C) = y(N, n*C) = y(N, n, C) # ----------------------------------------------------------------------------- # [1, 2, 3] [1, 0, 0, 1, 0, 0] [1, 2, 3, 1, 2, 3] [[1, 2, 3], [1, 2, 3]] # [4, 5, 6] * [0, 1, 0, 0, 1, 0] = [4, 5, 6, 4, 5, 6] = [[4, 5, 6], [4, 5, 6]] # [0, 0, 1, 0, 0, 1] # w = ConstantVariable(np.tile(np.eye(C), (1, n)), OrderCN) y, = Linear(None)(x, w) y, = Reshape(None, in_order=OrderNC, out_order=OrderNTC, out_shape=[N, n, C])(y) converter.set_variable(converter.get_output_tensor(k_op)[0], y)
def template(in_order, in_shape, out_order, out_shape): op = Reshape(None, in_order=in_order, out_order=out_order, out_shape=[out_shape[a] for a in out_order.axes]) x = Variable([in_shape[a] for a in in_order.axes], in_order) y, = op(x) assert_shape(y, out_shape)
def optimize_operator(self, graph: Graph, op: Reshape): x = op.inputs["x"] y = op.outputs["y"] if x.order == y.order and x.shape == y.shape: # no reshape is required _remove_unary_operator(graph, op) return True if x.shape == y.shape: # only reinterpret_axis is required op.remove_all() y_dummy = x.reinterpret_axes(y.order) OptimizeRule.replace_variable(graph, y_dummy, y) return True return False
def _convert_flatten(converter: KerasConverter, k_op: "keras.layers.Flatten"): x = converter.get_variable(converter.get_input_tensor(k_op)[0]) # flatten without changing memory layout y, = Reshape(None, in_order=x.order, out_order=OrderNC, out_shape=[x.shape[0], mul(x.shape[1:])])(x) converter.set_variable(converter.get_output_tensor(k_op)[0], y)
def optimize_operator(self, graph: Graph, op: Reshape): x = op.inputs["x"] y = op.outputs["y"] if x.order == y.order and x.shape == y.shape: # no reshape is occurred _remove_unary_operator(graph, op) return True if x.shape == y.shape: # only reinterpret_axis is occurred op.remove_all() y_dummy, = ReinterpretAxis(None, in_order=x.order, out_order=y.order)(x) y_dummy.replace(y) return True return False
def _convert_flatten(converter: ChainerConverter, c_op: "chainer.functions.Flatten"): x = converter.get_variable(c_op.inputs[0]) y, = Reshape(None, in_order=x.order, out_shape=[x.size], out_order=OrderC) converter.set_variable(c_op.outputs[0](), y) console.warning( "[ChainerConverter] In chainer.functions.Flatten, output data order is parsed as OrderC. To " "customize this, please overwrite chainer.functions.Flatten converter handler." )
def _convert_reshape(converter: ONNXConverter, onnx_op: INodeProto): x = converter.get_variable(onnx_op.input[0]) attrs = attribute_dict(onnx_op) out_shape = [r if s == 0 else s for r, s in zip(x.shape, attrs["shape"].ints)] if -1 in out_shape: i = out_shape.index(-1) out_shape.remove(-1) out_shape.insert(i, x.size // mul(out_shape)) out_order = Order([None] * len(out_shape)) y, = Reshape(None, in_order=x.order, out_order=out_order, out_shape=out_shape)(x) converter.set_variable(onnx_op.output[0], y)
def _convert_reshape(converter: ChainerConverter, c_op: "chainer.functions.Reshape"): x = converter.get_variable(c_op.inputs[0]) out_shape = c_op.shape # noinspection PyTypeChecker out_order = Order([AxisVar() for _ in out_shape]) assert mul( out_shape ) == x.size, f"[ChainerConverter] Shape mismatch: mul(out_shape)={mul(out_shape)}, x.size={x.size}" y, = Reshape(None, in_order=x.order, out_order=out_order, out_shape=out_shape)(x) converter.set_variable(c_op.outputs[0](), y)
def _split_reshape(graph: Graph, op: Reshape, v: Variable, v_pair: Sequence[Variable], axis: Axis): x = op.inputs["x"] y = op.outputs["y"] s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] op.remove_all() in_order = op.in_order out_order = op.out_order x_shape = [x.shape_dict[a] for a in in_order.axes] y_shape = [y.shape_dict[a] for a in out_order.axes] if v == x: """ before) x -{reshape}- y after) x_0 -{reshape}- t_0 -+ +-{concat[axis_k]}- t -{reshape}- y x_1 -{reshape}- t_1 -+ shape and order is changed as follows: x.shape = [dx_0, dx_1, ..., dx_m, ..., dx_M-1] x_0.shape = [dx_0, dx_1, ..., dx_m/2, ..., dx_M-1] --------------------------------------------------------------------------------- t_0.shape = [dy_0, dy_1, ..., dy_n, ..., dy_k/2, ..., dy_N-1] t.shape = [dy_0, dy_1, ..., dy_n*2, ..., dy_k/2, ..., dy_N-1] y.shape = [dy_0, dy_1, ..., dy_n, ..., dy_k, ..., dy_N-1] m: split target axis find axis_k and axis_n, which satisfies follow conditions dy_n * dy_n+1 * ... * dy_N-1 == dx_m * dx_m+1 * ... * dx_M-1 dy_k % 2 == 0 n <= k """ x_0, x_1 = v_pair dx_prod = mul(x_shape[in_order.axes_dict[axis]:]) dy_prod = 1 axis_k_candidate = [] for axis_n in reversed(out_order.axes): dy_prod *= y.shape_dict[axis_n] if y.shape_dict[axis_n] % 2 == 0: axis_k_candidate.append(axis_n) if dx_prod == dy_prod: # Split most large axis axis_k = (sorted(axis_k_candidate, key=lambda a: y.shape_dict[a], reverse=True))[0] t_0_shape = [y.shape_dict[a] for a in out_order.axes] t_0_shape[out_order.axes_dict[axis_k]] = t_0_shape[ out_order.axes_dict[axis_k]] // 2 # TODO t_0, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=t_0_shape)(x_0) t_1_shape = [y.shape_dict[a] for a in out_order.axes] t_1_shape[out_order.axes_dict[axis_k]] = t_1_shape[ out_order.axes_dict[axis_k]] // 2 # TODO t_1, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=t_1_shape)(x_1) t, = Concat(None, axis=axis_n)(t_0, t_1) y_new, = Reshape(None, in_order=out_order, out_order=out_order, out_shape=y_shape)(t) OptimizeRule.replace_variable(graph, y_new.transpose_like(y), y) break elif dy_prod > (s1 + s2) * dx_prod: raise NotImplementedError( f"Variable is too large to handle in WebGL backend: {v}") elif v == y: """ algorithm is almost same as the case `v == x` (above). before) x -{reshape}- y after) +- t_0 -{reshape}- y_0 x -{reshape}- t-{split}-+ +- t_1 -{reshape}- y_1 shape and order is changed as follows: x.shape = [dx_0, dx_1, ..., dx_m, ..., dx_k, ..., dx_M-1] t.shape = [dx_0, dx_1, ..., dx_m*2, ..., dx_k/2, ..., dx_M-1] t_0.shape = [dx_0, dx_1, ..., dx_m, ..., dx_k/2, ..., dx_M-1] --------------------------------------------------------------------------------- y_0.shape = [dy_0, dy_1, ..., dy_n/2, ..., dy_N-1] y.shape = [dy_0, dy_1, ..., dy_n, ..., dy_N-1] m: split target axis find axis_k and axis_m, which satisfies follow conditions dx_m * dx_m+1 * ... * dx_M-1 == dy_n * dy_n+1 * ... * dy_N-1 dx_k % 2 == 0 m <= k """ y_0, y_1 = v_pair dx_prod = 1 dy_prod = mul(x_shape[out_order.axes_dict[axis]:]) axis_k_candidate = [] for axis_m in reversed(in_order.axes): dx_prod *= x.shape_dict[axis_m] if x.shape_dict[axis_m] % 2 == 0: axis_k_candidate.append(axis_m) if dx_prod == dy_prod: # Split most large axis axis_k = (sorted(axis_k_candidate, key=lambda a: x.shape_dict[a], reverse=True))[0] t_shape = [x.shape_dict[a] for a in in_order.axes] t_shape[in_order.axes_dict[axis_m]] = 2 * t_shape[ in_order.axes_dict[axis_m]] # TODO t_shape[in_order.axes_dict[axis_k]] = t_shape[ in_order.axes_dict[axis_k]] // 2 # TODO t, = Reshape(None, in_order=in_order, out_order=in_order, out_shape=t_shape)(x) t_0, t_1 = SplitAxis(None, axis=axis_m, sections=[t.shape_dict[axis_m] // 2 ])(t) # TODO y_0_new, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=y_0.shape)(t_0) y_1_new, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=y_1.shape)(t_1) OptimizeRule.replace_variable(graph, y_0_new.reshape_like(y_0), y_0) OptimizeRule.replace_variable(graph, y_1_new.reshape_like(y_1), y_1) break elif dx_prod > dy_prod: raise NotImplementedError( f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.listup_operators(graph): if isinstance(op, Transpose): x = op.inputs["x0"] y = op.outputs["y"] if x.order == y.order: op.remove_all() x.replace(y) flag_changed = True if all(isinstance(op2, (Elementwise, SplitAxis)) for op2 in y.input_to): op.remove_all() for op2 in list(y.input_to): name = op2._get_input_name(y) op2.remove_input(y) op2.append_input(name, x) elif isinstance(op, Reshape): flag_changed |= _replace_input(op, "x", op.parameters["in_order"]) flag_changed |= _replace_output(op, "y", op.parameters["out_order"]) elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D, Deconvolution2D)): flag_changed |= _replace_input(op, "x", OrderNHWC) flag_changed |= _replace_output(op, "y", OrderNHWC) elif isinstance(op, Softmax): x = op.inputs["x"] y = op.outputs["y"] if x.ndim > 2: """ Before) | x | | y | |------| -{softmax}-> |------| | NCHW | | NCHW | After) | x | | hx1 | | hx2 | | hy1 | | hy2 | | y | |------| -{transpose}-> |------| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |------| -{transpose}-> |------| | NCHW | | NHWC | | NC | | NC | | NHWC | | NCHW | """ op.remove_all() target_axis = op.parameters["axis"] axes_nd = list(x.order.axes) axes_nd.remove(target_axis) axes_nd.append(target_axis) order_nd = Order(axes_nd) shape_nd = [x.shape_dict[axis] for axis in axes_nd] order_2d = OrderNC shape_2d = [x.size // x.shape_dict[target_axis], x.shape_dict[target_axis]] hx1, = Transpose(None)(x) hx1.change_order(order_nd) hx2, = Reshape(None, in_order=hx1.order, out_order=order_2d, out_shape=shape_2d)(hx1) hy1, = Softmax(None, axis=Axis.C)(hx2) hy2, = Reshape(None, in_order=hy1.order, out_order=order_nd, out_shape=shape_nd)(hy1) y_dummy, = Transpose(None)(hy2) y_dummy.change_order(y.order) y_dummy.replace(y) flag_changed = True else: flag_changed |= _replace_input(op, "x", OrderNC) flag_changed |= _replace_output(op, "y", OrderNC) return graph, flag_changed
def _split_sgemm(graph: Graph, op: Sgemm, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] A = op.inputs["A"] B = op.inputs["B"] C = op.outputs["C"] transpose_A, transpose_B = op.transpose_A, op.transpose_B M, K, N = op.M, op.K, op.N axis_M, axis_K, axis_N = Axis(None), Axis(None), Axis(None) op.remove_all() def decompose_logical_axes(logical_shape: Tuple[int, int], v: Variable): """ Decompose logical axes into real axes Examples:: A.order, A.shape >>> "NCHW", (1, 128, 8, 8) M = 128 K = 64 decompose_logical_axes([M, K], A) >>> ["<Axis N>", "<Axis C>"], ["<Axis H>", "<Axis W>"] """ total_size = 1 axes1 = [] # type: List[Axis] axes2 = list(v.order.axes) # type: List[Axis] for size, a in zip(v.shape, v.order.axes): if total_size == logical_shape[0]: return axes1, axes2 elif total_size > logical_shape[0]: raise ValueError axes1.append(a) axes2.remove(a) total_size *= size if v == A: A1, A2 = v_pair if transpose_A: # A.shape = [M, K] axes_M, axes_K = decompose_logical_axes((M, K), A) else: # A.shape = [K, M] axes_K, axes_M = decompose_logical_axes((K, M), A) if axis in axes_K: """ before) A -{sgemm}- C after) In case `axis` is in `K`, A_0 -{sgemm}- C_0 -+ +-{Add}- C A_1 -{sgemm}- C_1 -+ """ K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2) # Factorize B's axes included in K into A's corresponding axes if transpose_B: # B: [k_b1, k_b2, ..., N] -{reshape}-> [k_a1, k_a2, ..., N] B, = Reshape(None, in_order=B.order, out_order=Order(axes_K + [axis_N]), out_shape=[A.shape_dict[a] for a in axes_K] + [N])(B) else: # B: [N, k_b1, k_b2, ...] -{reshape}-> [N, k_a1, k_a2, ...] B, = Reshape(None, in_order=B.order, out_order=Order([axis_N] + axes_K), out_shape=[N] + [A.shape_dict[a] for a in axes_K])(B) B1, B2 = SplitAxis(None, axis=axis, sections=[s1])(B) C1, = Sgemm(None, M=M, K=K1, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A1, B1) C2, = Sgemm(None, M=M, K=K2, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A2, B2) OptimizeRule.replace_variable(graph, C1 + C2, C) else: assert axis in axes_M """ before) A -{sgemm}- C after) In case `axis` is in `M`, A_0 -{sgemm}- C_0 -+ +-{Concat}- C A_1 -{sgemm}- C_1 -+ """ M1, M2 = M * s1 // (s1 + s2), M * s2 // (s1 + s2) c_tmp_order = Order(axes_M + [axis_N]) c1_shape = [A1.shape_dict[a] for a in axes_M] + [N] c2_shape = [A2.shape_dict[a] for a in axes_M] + [N] C1, = Sgemm(None, M=M1, K=K, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c1_shape, out_order=c_tmp_order)(A1, B) C2, = Sgemm(None, M=M2, K=K, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c2_shape, out_order=c_tmp_order)(A2, B) C_new, = Concat(None, axis=axis)(C1, C2) C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new) OptimizeRule.replace_variable(graph, C_new, C) elif v == B: B1, B2 = v_pair if transpose_B: # B.shape = [K, N] axes_K, axes_N = decompose_logical_axes((K, N), B) else: # B.shape = [N, K] axes_N, axes_K = decompose_logical_axes((N, K), B) if axis in axes_K: """ before) B -{sgemm}- C after) In case `axis` is in `K`, B_0 -{sgemm}- C_0 -+ +-{Add}- C B_1 -{sgemm}- C_1 -+ """ K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2) # Factorize A's axes included in K into B's corresponding axes if transpose_A: # A: [M, k_a1, k_a2, k_a3, ...] -{reshape}-> [M, k_b1, k_b2, ...] A, = Reshape(None, in_order=A.order, out_order=Order([axis_M] + axes_K), out_shape=[M] + [B.shape_dict[a] for a in axes_K])(A) else: # A: [k_a1, k_a2, k_a3, ..., M] -{reshape}-> [k_b1, k_b2, ..., M] A, = Reshape(None, in_order=A.order, out_order=Order(axes_K + [axis_M]), out_shape=[B.shape_dict[a] for a in axes_K] + [M])(A) A1, A2 = SplitAxis(None, axis=axis, sections=[s1])(A) C1, = Sgemm(None, M=M, K=K1, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A1, B1) C2, = Sgemm(None, M=M, K=K2, N=N, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=op.parameters["out_shape"], out_order=op.parameters["out_order"])(A2, B2) OptimizeRule.replace_variable(graph, C1 + C2, C) else: assert axis in axes_N """ before) C[M, N] = A[M, K] @ B[K, N] after) In case `axis` is in `N`, C[M, N] = Concat(C1[M, N1], C2[M, N2]) = Concat(A[M, K] @ B1[K, N1], A[M, K] @ B2[K, N2]) """ N1, N2 = N * s1 // (s1 + s2), N * s2 // (s1 + s2) c_tmp_order = Order([axis_M] + axes_N) c1_shape = [M] + [B1.shape_dict[a] for a in axes_N] c2_shape = [M] + [B2.shape_dict[a] for a in axes_N] C1, = Sgemm(None, M=M, K=K, N=N1, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c1_shape, out_order=c_tmp_order)(A, B1) # C1.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis], ...] # C1.order = [axis_M, n1, n2, ..., axis, ...] C2, = Sgemm(None, M=M, K=K, N=N2, transpose_A=transpose_A, transpose_B=transpose_B, out_shape=c2_shape, out_order=c_tmp_order)(A, B2) C_new, = Concat(None, axis=axis)(C1, C2) # C_new.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis]+B2.shape_dict[axis], ...] # C_new.order = [axis_M, n1, n2, ..., axis, ...] C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new) OptimizeRule.replace_variable(graph, C_new, C) elif v == C: """ before) C[M, N] = A[M, K] @ B[K, N] after) In case `axis` is in `N`, C[M, N1] = Concat(A[M, K] @ B1[K, N1]) C[M, N2] = Concat(A[M, K] @ B2[K, N2]) """ raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.listup_operators(graph): if isinstance(op, Reshape): flag_changed |= _replace_input(op, "x", op.parameters["in_order"]) flag_changed |= _replace_output(op, "y", op.parameters["out_order"]) continue elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D, Deconvolution2D, Space2Depth, Depth2Space)): flag_changed |= _replace_input(op, "x", OrderNHWC) flag_changed |= _replace_output(op, "y", OrderNHWC) continue elif isinstance(op, Softmax): x = op.inputs["x"] y = op.outputs["y"] target_axis = op.parameters["axis"] if not (x.ndim == 2 and x.order.axes_dict[target_axis] == x.ndim - 1): """ Before) | x | | y | |-----| -{softmax}-> |-----| | XYZ | axis=Y | XYZ | After) | x | | hx1 | | hx2 | | hy1 | | hy2 | | y | |-----| -{transpose}-> |-----| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |-----| -{transpose}-> |-----| | XYZ | | XZY | | NC | axis=C | NC | | XZY | | XYZ | : : order_nd = XZY order_2d = NC """ op.remove_all() axes_nd = list(x.order.axes) axes_nd.remove(target_axis) axes_nd.append(target_axis) order_nd = Order(axes_nd) shape_nd = tuple([x.shape_dict[axis] for axis in axes_nd]) order_2d = OrderNC shape_2d = tuple([ x.size // x.shape_dict[target_axis], x.shape_dict[target_axis] ]) if x.order == order_nd: hx1 = x else: hx1, = Transpose(None)(x) hx1.change_order(order_nd) flag_changed = True if hx1.order == order_2d and hx1.shape == shape_2d: hx2 = hx1 else: hx2, = Reshape(None, in_order=hx1.order, out_order=order_2d, out_shape=shape_2d)(hx1) flag_changed = True hy1, = Softmax(None, axis=Axis.C)(hx2) if hy1.order == order_nd and hy1.shape == shape_nd: hy2 = hy1 else: hy2, = Reshape(None, in_order=hy1.order, out_order=order_nd, out_shape=shape_nd)(hy1) flag_changed = True if hy2.order == y.order: y_dummy = hy2 else: y_dummy, = Transpose(None)(hy2) y_dummy.change_order(y.order) flag_changed = True y_dummy.replace(y) continue return graph, flag_changed
def _split_reshape(graph: Graph, op: Reshape, v: Variable, v_pair: Sequence[Variable], axis: Axis): x = op.inputs["x"] y = op.outputs["y"] s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] op.remove_all() if v == x: """ Regard x's order as `[D1, D2]`, shape as `[d1x, d2x]`, where the most outside axis in D2 is the split target axis. If y's shape can be converted as `[d1x, d2x]` by merging some adjacent axes in y, split can be performed. before) x -{reshape}- y after) x_0 -{reshape}- y_0 -+ +-{concat[axis]}- y x_1 -{reshape}- y_1 -+ """ x_0, x_1 = v_pair d2x = mul(x.shape[x.order.axes_dict[axis]:]) d2y = 1 for axis_y in reversed(y.order.axes): d2y *= y.shape_dict[axis_y] if d2y == d2x: y_0_shape = [y.shape_dict[axis_y] * s1 // (s1 + s2) if a == axis_y else y.shape_dict[a] for a in y.order.axes] y_1_shape = [y.shape_dict[axis_y] * s2 // (s1 + s2) if a == axis_y else y.shape_dict[a] for a in y.order.axes] y_0 = x_0.reshape(y_0_shape, y.order) y_1 = x_1.reshape(y_1_shape, y.order) y_new, = Concat(None, axis=axis_y)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new, y) break elif d2y > (s1 + s2) * d2x: raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}") elif v == y: """ Same algorithm in case `v == y` (above). before) x -{reshape}- y after) +- x_0 -{reshape}- y_0 x -{split}-+ +- x_1 -{reshape}- y_1 """ y_0, y_1 = v_pair d2y = mul(y.shape[y.order.axes_dict[axis]:]) d2x = 1 for axis_x in reversed(x.order.axes): d2x *= x.shape_dict[axis_x] if d2x == d2y: x_0, x_1 = SplitAxis(None, axis=axis_x, sections=[x.shape_dict[axis_x] * s1 // (s1 + s2)])(x) OptimizeRule.replace_variable(graph, x_0.reshape_like(y_0), y_0) OptimizeRule.replace_variable(graph, x_1.reshape_like(y_1), y_1) break elif d2y > (s1 + s2) * d2x: raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError