def test_unify_chain(): a1 = Axis() a2 = Axis() a3 = Axis() a1.unify(a2) a1.unify(a3) assert a2 == a3
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes(traverse.listup_operators(graph), Convolution2D): # type: Convolution2D x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] flag_changed = True op.remove_all() a_filter, a_kh, a_kw = Axis(), Axis(), Axis() w, = ReinterpretAxis(None, in_order=OrderNHWC, out_order=Order( [Axis.C, a_kh, a_kw, a_filter]))(w) if op.WH == 1 and op.WW == 1 and op.stride == ( 1, 1) and op.padding == (0, 0): # Projection col, = ReinterpretAxis( None, in_order=OrderNHWC, out_order=Order([Axis.N, Axis.H, Axis.W, a_filter]))(x) new_y, = Tensordot(None, [[a_filter], [a_kh, a_kw, a_filter]])(col, w) elif op.WH == x.shape_dict[Axis.H] and op.WW == x.shape_dict[ Axis.W] and op.padding == (0, 0): # Global convolution col, = ReinterpretAxis(None, in_order=OrderNHWC, out_order=Order( [Axis.N, a_kh, a_kw, a_filter]))(x) new_y, = Tensordot( None, [[[a_kh, a_kw, a_filter], [a_kh, a_kw, a_filter]], [a_kh, a_kw, a_filter]])(col, w) else: # General convolution col, = Im2Col(None, ksize=op.ksize, stride=op.stride, padding=op.padding, dilation_rate=op.dilation_rate)(x) col, = ReinterpretAxis( None, in_order=OrderNHWC, out_order=Order([Axis.N, Axis.H, Axis.W, a_filter]))(col) new_y, = Tensordot(None, [[a_filter], [a_kh, a_kw, a_filter]])(col, w) new_y = new_y.transpose(y.order) OptimizeRule.replace_variable(graph, new_y, y) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes( traverse.listup_operators(graph), Deconvolution2D): # type: Deconvolution2D x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] flag_changed = True op.remove_all() a_filter, a_kh, a_kw = Axis(), Axis(), Axis() w, = ReinterpretAxis(None, in_order=OrderNHWC, out_order=Order( [Axis.C, a_kh, a_kw, a_filter]))(w) x, = ReinterpretAxis(None, in_order=OrderNHWC, out_order=Order( [Axis.N, Axis.H, Axis.W, a_filter]))(x) col, = Tensordot(None, axes=a_filter)(x, w) col = col.transpose( Order([Axis.N, Axis.H, Axis.W, a_kh, a_kw, Axis.C])) col = col.reshape(shape=[*col.shape[0:3], mul(col.shape[3:6])], order=OrderNHWC) new_y, = Col2Im(None, ksize=op.ksize, stride=op.stride, padding=op.padding)(col) OptimizeRule.replace_variable(graph, new_y.transpose_like(y), y) return graph, flag_changed
def test_unify_deep_chain_different_name_axes(): a1 = Axis(name="A") a2 = Axis() a3 = Axis(name="B") a4 = Axis() a1.unify(a2) a3.unify(a4) a2.unify(a4)
def test_unify_deep_chain(): a1 = Axis() a2 = Axis() a3 = Axis() a4 = Axis() a1.unify(a2) a3.unify(a4) a1.unify(a3) assert a2 == a4
def __getitem__(self, slices) -> "Variable": slices = list(slices) if isinstance(slices, Sequence) else [slices] if Ellipsis in slices: ellipsis_position = slices.index(Ellipsis) slices.remove(Ellipsis) else: ellipsis_position = len(slices) num_new_axis = slices.count(None) while len(slices) - num_new_axis < self.ndim: slices.insert(ellipsis_position, slice(None)) x_axis_index = 0 indices = AxisKeyDict() for index in slices: if isinstance(index, (slice, int)): indices[self.order.axes[x_axis_index]] = index x_axis_index += 1 elif index is None: indices[Axis()] = None else: raise TypeError( "Variable indices must be sequence of integers, slices, ellipsis, or None" ) return webdnn.graph.operators.slice.Slice(None, indices=indices)(self)[0]
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes(traverse.listup_operators(graph), Linear): x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] flag_changed = True op.remove_all() a_filter = Axis() if x.ndim == 2: w, = ReinterpretAxis(None, in_order=OrderNC, out_order=Order([Axis.C, a_filter]))(w) new_y, = Tensordot(None, axes=[Axis.C, a_filter])(x, w) elif x.ndim == 4: w, = ReinterpretAxis( None, in_order=OrderNHWC, out_order=Order([Axis.C, Axis.H, Axis.W, a_filter]))(w) new_y, = Tensordot(None, axes=[[Axis.H, Axis.W, Axis.C], [Axis.H, Axis.W, a_filter]])(x, w) else: raise NotImplementedError OptimizeRule.replace_variable(graph, new_y.transpose_like(y), y) return graph, flag_changed
def __init__(self, axes: Sequence[Union[Axis, None]]): axes = tuple(Axis() if a is None else a for a in axes) for a1, a2 in itertools.permutations(axes, 2): assert a1 != a2, f""" [Order] Axes are duplicated: (axes) = {axes} """ self._axes = axes
def _convert_repeat_vector(converter: KerasConverter, k_op: "keras.layers.RepeatVector"): x = converter.get_variable(converter.get_input_tensor(k_op)[0]) new_axis = Axis() multiplier = AxisKeyDict(x.order.axes, [1, 1]) multiplier[new_axis] = k_op.n x = x.reshape(shape=(x.shape[0], 1, x.shape[1]), order=Order([x.order.axes[0], new_axis, x.order.axes[1]])) y, = Tile(None, multiplier=multiplier)(x) converter.set_variable(converter.get_output_tensor(k_op)[0], y)
def _convert_unsqueeze(converter: ONNXConverter, onnx_op: INodeProto): x = converter.get_variable(onnx_op.input[0]) if isinstance(x, ConstantVariable): data = np.expand_dims(x.data, 0) y = ConstantVariable(data, Order([None] * len(data.shape))) else: y = x.expand_dims(Axis()) converter.set_variable(onnx_op.output[0], y)
def _convert_squeeze(converter: ONNXConverter, onnx_op: INodeProto): x = converter.get_variable(onnx_op.input[0]) attrs = attribute_dict(onnx_op) if isinstance(x, ConstantVariable): # generate actual data as constant new_axes = list(x.order.axes) new_data = x.data.copy() for i in attrs["axes"].ints: new_axes.insert(i, Axis()) new_data = np.expand_dims(new_data, axis=i) y = ConstantVariable(new_data, Order(new_axes)) else: y = x for i in attrs["axes"].ints: y = y.expand_dims(Axis(), i) converter.set_variable(onnx_op.output[0], y)
def test_combine_axes_create_new_axis(): new_axis = Axis() v1 = Variable([2, 3, 4, 5], OrderNHWC) v2 = v1.combine_axes([Axis.W, Axis.H], new_axis) assert v2.order == Order([Axis.N, new_axis, Axis.C]) assert v2.shape_dict[Axis.N] == 2 assert v2.shape_dict[new_axis] == 12 assert v2.shape_dict[Axis.C] == 5 assert isinstance(v2.output_from, Reshape) assert v2.output_from.in_order == Order([Axis.N, Axis.W, Axis.H, Axis.C]) assert v2.output_from.out_order == Order([Axis.N, new_axis, Axis.C]) assert v2.output_from.inputs["x"] == v1
def _convert_tile(converter: ChainerConverter, c_op: "chainer.functions.Tile"): x = converter.get_variable(c_op.inputs[0]) reps = c_op.reps if x.ndim > len(reps): reps = (1, ) * (x.ndim - len(reps)) + reps else: while x.ndim < len(c_op.reps): x = x.expand_dims(Axis(), 0) y, = Tile(None, AxisKeyDict(x.order.axes, reps))(x) converter.set_variable(c_op.outputs[0](), y)
def matmul_handler(converter: TensorFlowConverter, tf_op: "tf.Operation"): a = converter.get_variable(tf_op.inputs[0]) b = converter.get_variable(tf_op.inputs[1]) transposed_a = tf_op.get_attr("transpose_a") transposed_b = tf_op.get_attr("transpose_b") if a.ndim > 2 or b.ndim > 2: raise NotImplementedError( "[TensorFlowConverter] Currently, MatMul is supported only 2D * 2D case." ) c_axes = [] if transposed_a: c_axes.append(a.order.axes[-1]) if a.order != OrderCN: a = a.reinterpret_axes(OrderCN) else: c_axes.append(a.order.axes[-2]) if a.order != OrderNC: a = a.reinterpret_axes(OrderNC) if transposed_b: c_axes.append(Axis()) if b.order != OrderNC: b = b.reinterpret_axes(OrderNC) else: c_axes.append(Axis()) if b.order != OrderCN: b = b.reinterpret_axes(OrderCN) c_normalized, = Linear(None)(a, b) c = c_normalized.reinterpret_axes(Order(c_axes)) converter.set_variable(tf_op.outputs[0], c)
def _convert_global_max_pool(converter: ONNXConverter, onnx_op: INodeProto): x = converter.get_variable(onnx_op.input[0]) if x.ndim == 4: x.order.unify(OrderNCHW) reduction_size = mul(x.shape[2:]) reduction_axis = Axis() x = x.reshape([x.shape[0], x.shape[1], reduction_size], Order([x.order.axes[0], x.order.axes[1], reduction_axis])) y, = Max(None, axis=reduction_axis)(x) converter.set_variable(onnx_op.output[0], y)
def expand_dims_handler(converter: TensorFlowConverter, tf_op: "tf.Operation"): x = converter.get_variable(tf_op.inputs[0]) dim = converter.get_variable(tf_op.inputs[1]) if not isinstance(dim, ConstantVariable): raise NotImplementedError( "[TensorFlowConverter] Operator 'ExpandDims' with dynamic dimension is not supported." ) dim = dim.data.astype(np.int32).flatten()[0] new_shape = list(x.shape) new_shape.insert(dim, 1) new_axes = list(x.order.axes) new_axes.insert(dim, Axis()) converter.set_variable(tf_op.outputs[0], x.reshape(order=Order(new_axes), shape=new_shape))
def __getitem__(self, slices) -> "Variable": slices = list(slices) if isinstance(slices, Sequence) else [slices] if Ellipsis in slices: ellipsis_position = slices.index(Ellipsis) slices.remove(Ellipsis) else: ellipsis_position = len(slices) while len(slices) < self.ndim: slices.insert(ellipsis_position, slice(None)) x_axis_index = 0 indices = AxisKeyDict() for index in slices: if isinstance(index, (slice, int)): indices[self.order.axes[x_axis_index]] = index x_axis_index += 1 elif index is None: indices[Axis()] = None return webdnn.graph.operators.slice.Slice(None, indices=indices)(self)[0]
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.filter_nodes( traverse.listup_operators(graph), Deconvolution2D): # type: Deconvolution2D x = op.inputs["x"] w = op.inputs["w"] y = op.outputs["y"] flag_changed = True op.remove_all() a_filter = Axis() w, = ReinterpretAxis( None, in_order=Order([Axis.N, Axis.KH, Axis.KW, Axis.C]), out_order=Order([Axis.C, Axis.KH, Axis.KW, a_filter]))(w) if op.KH == 1 and op.KW == 1 and op.stride == ( 1, 1) and op.padding == (0, 0): # Projection w = w.transpose(Order([Axis.C, Axis.KH, Axis.KW, a_filter])) w = w.reshape([w.shape_dict[Axis.C], w.shape_dict[a_filter]], Order([Axis.C, a_filter])) new_y, = Tensordot(None, [Axis.C, a_filter])(x, w) else: # General deconvolution w = w.transpose(Order([a_filter, Axis.KH, Axis.KW, Axis.C])) col, = Tensordot(None, axes=[Axis.C, a_filter])(x, w) new_y, = Col2Im(None, ksize=op.ksize, stride=op.stride, padding=op.padding)(col) OptimizeRule.replace_variable(graph, new_y.transpose_like(y), y) return graph, flag_changed
def _split_tensordot(graph: Graph, op: Tensordot, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] A = op.inputs["A"] B = op.inputs["B"] C = op.outputs["C"] axes_M = tuple(filter(lambda a: a not in op.axes[0], A.order.axes)) axes_N = tuple(filter(lambda a: a not in op.axes[1], B.order.axes)) axes_K_A, axes_K_B = op.axes K = mul(A.shape_dict[a] for a in axes_K_A) M = A.size // K N = B.size // K shape_M = [A.shape_dict[a] for a in axes_M] shape_N = [B.shape_dict[a] for a in axes_N] op.remove_all() if v == A: A1, A2 = v_pair if axis in axes_K_A: split_axis_A = axis if (B.shape_dict[axes_K_B[0]] * s1) % (s1 + s2) == 0: split_axis_B = axes_K_B[0] else: # Factorize B's axes consisting to K into A's corresponding axes B = B.transpose(Order(axes_N + axes_K_B)) B = B.reshape(order=Order((Axis(), ) + axes_K_A), shape=[N] + [A.shape_dict[a] for a in axes_K_A]) split_axis_B = split_axis_A axes_K_B = axes_K_A B1, B2 = SplitAxis(None, axis=split_axis_B, sections=[(B.shape_dict[split_axis_B] * s1) // (s1 + s2)])(B) C1, = Tensordot(None, [axes_K_A, axes_K_B])(A1, B1) C2, = Tensordot(None, [axes_K_A, axes_K_B])(A2, B2) OptimizeRule.replace_variable(graph, (C1 + C2).reshape( shape_M + shape_N, Order(axes_M + axes_N)).transpose_like(C), C) else: C1, = Tensordot(None, op.axes)(A1, B) C2, = Tensordot(None, op.axes)(A2, B) for a1, a2 in zip(C1.order.axes, C2.order.axes): if a1 == a2 == axis: continue a1.unify(a2) C_new, = Concat(None, axis=axis)(C1, C2) OptimizeRule.replace_variable(graph, C_new, C) elif v == B: B1, B2 = v_pair if axis in axes_K_B: split_axis_B = axis if (A.shape_dict[axes_K_A[0]] * (s1 + s2)) % s1 == 0: split_axis_A = axes_K_A[0] else: # Factorize A's axes consisting to K into B's corresponding axes A = A.transpose(Order(axes_M + axes_K_A)) A = A.reshape(order=Order((Axis(), ) + axes_K_B), shape=[M] + [B.shape_dict[a] for a in axes_K_B]) split_axis_A = split_axis_B axes_K_A = axes_K_B A1, A2 = SplitAxis(None, axis=split_axis_A, sections=[(A.shape_dict[split_axis_A] * s1) // (s1 + s2)])(A) C1, = Tensordot(None, [axes_K_A, axes_K_B])(A1, B1) C2, = Tensordot(None, [axes_K_A, axes_K_B])(A2, B2) OptimizeRule.replace_variable(graph, (C1 + C2).reshape( shape_M + shape_N, Order(axes_M + axes_N)).transpose_like(C), C) else: C1, = Tensordot(None, op.axes)(A, B1) C2, = Tensordot(None, op.axes)(A, B2) for a1, a2 in zip(C1.order.axes, C2.order.axes): if a1 == a2 == axis: continue a1.unify(a2) C_new, = Concat(None, axis=axis)(C1, C2) OptimizeRule.replace_variable(graph, C_new, C) elif v == C: """ before) C[M, N] = A[M, K] @ B[K, N] after) In case `axis` is in `N`, C[M, N1] = Concat(A[M, K] @ B1[K, N1]) C[M, N2] = Concat(A[M, K] @ B2[K, N2]) """ raise NotImplementedError( f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError
def test_unify_same_axis(): a1 = Axis() a1.unify(a1) assert a1 == a1
def _convert_separable_conv2d(converter: KerasConverter, k_op: "keras.layers.SeparableConv2D"): x = converter.get_variable(converter.get_input_tensor(k_op)[0]) check_data_format(x, k_op.data_format) axis_c_in = Axis.C axis_c_out = Axis() axis_depth_multiplier = Axis() w_depthwise = converter.convert_to_constant_variable( k_op.depthwise_kernel, Order([Axis.KH, Axis.KW, axis_c_in, axis_depth_multiplier])) w_pointwise = converter.convert_to_constant_variable( k_op.pointwise_kernel, Order([Axis.KH, Axis.KW, axis_c_in, axis_c_out])) w_pointwise = w_pointwise.reshape( shape=[ x.shape_dict[axis_c_in], k_op.depth_multiplier, w_pointwise.shape_dict[axis_c_out] ], order=Order([axis_c_in, axis_depth_multiplier, axis_c_out])) ksize = tuple(k_op.kernel_size) stride = tuple(k_op.strides) dilation_rate = tuple(k_op.dilation_rate) padding = (parse_padding(k_op.padding, ksize[0], dilation_rate[0]), parse_padding(k_op.padding, ksize[1], dilation_rate[1])) if any(p[0] != p[1] for p in padding): raise NotImplementedError( "[KerasConverter] \"Different size padding\" is not supported yet") padding = tuple(p[0] for p in padding) h, = Im2Col(None, ksize=ksize, stride=stride, padding=padding, dilation_rate=dilation_rate)(x) # TODO: Support depth-wise convolution natively # Currently, depth-wise convolution is not supported natively, and emulated by composition of small convolution operations. ys = [] for i in range(h.shape_dict[axis_c_in]): # 1. Depthwise convolution # # Ideal | Current implementation # ----------------------------------+---------------------------------------------------- # h.axes=[N, H, W, KH, KW, C_in] | g_sub.axes=[N, H, W, KH, KW] # w.axes=[KH, KW, C_in, DM] | w_sub.axes=[KH, KW, DM] # g.axes=[N, H, W, C_in, DM] | g_sub.axes=[N, H, W, DM] h_sub, = Slice( None, indices=AxisKeyDict( h.order.axes, [i if a == axis_c_in else slice(None) for a in h.order.axes]))(h) w_depthwise_sub = w_depthwise[:, :, i, :] g_sub, = Tensordot(None, axes=((Axis.KH, Axis.KW), (Axis.KH, Axis.KW)))(h_sub, w_depthwise_sub) # 2. Pointwise (projection) convolution # # Ideal | Current implementation # ----------------------------------+---------------------------------------------------- # g.axes=[N, H, W, C_in, DM] | g_sub.axes=[N, H, W, DM] # w.axes=[DM, Cin, C_out] | w_sub.axes=[DM, C_out] # y.axes=[N, H, W, C_out] | y_sub.axes=[N, H, W, C_out] w_pointwise_sub = w_pointwise[i, :, :] y_sub, = Tensordot(None, axes=((axis_depth_multiplier, ), (axis_depth_multiplier, )))(g_sub, w_pointwise_sub) ys.append(y_sub) # Sum up all sub convolution results to one while len(ys) > 1: ys.append(ys.pop(0) + ys.pop(0)) y = ys[0] # reinterpret axis "C_out" as C axes = list(y.order.axes) i = axes.index(axis_c_out) axes.pop(i) axes.insert(i, Axis.C) y = y.reinterpret_axes(Order(axes)) if k_op.data_format == "channels_last": y = y.transpose(OrderNHWC) elif k_op.data_format == "channels_first": y = y.transpose(OrderNCHW) else: raise NotImplementedError( f"[KerasConverter] Unknown data format: {k_op.data_format}") if k_op.use_bias: b = converter.convert_to_constant_variable(k_op.bias, OrderC) y = y + b y = do_activation(k_op.activation, y) converter.set_variable(converter.get_output_tensor(k_op)[0], y)
def test_unify_resolved_axes(): a1 = Axis() a2 = Axis() a1.unify(a2) a1.unify(a2) assert a1 == a2
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False matches = traverse.search_sub_structure( graph, [Sgemm, Variable, ElementwiseMul]) while len(matches) > 0: match = matches.pop() sgemm = match[0] # type: Sgemm elementwise_mul = match[2] # type: ElementwiseMul out_order = sgemm.parameters["out_order"] out_shape = sgemm.parameters["out_shape"] axis_k = Axis('AxisK') if not isinstance(sgemm.inputs["A"], ConstantVariable) and not isinstance( sgemm.inputs["B"], ConstantVariable): # neither x nor w1 is constant continue elif isinstance(sgemm.inputs["A"], ConstantVariable): w1 = sgemm.inputs["A"] # type: ConstantVariable if sgemm.transpose_A: # w1.shape = (M, K) shape = [] axes = [] for axis, size in zip(out_order.axes, out_shape): shape.append(size) axes.append(axis) if mul(shape) >= sgemm.M: break if mul(shape) != sgemm.M: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes + [axis_k]) w1_virtual_shape = shape + [sgemm.K] else: # w1.shape = (K, M) shape = [sgemm.K] axes = [axis_k] for axis, size in zip(out_order.axes, out_shape): shape.append(size) axes.append(axis) if mul(shape) >= w1.size: break if mul(shape) != w1.size: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes) w1_virtual_shape = shape else: w1 = sgemm.inputs["B"] # type: ConstantVariable if sgemm.transpose_B: # w1.shape = (K, N) shape = [] axes = [] for axis, size in reversed( list(zip(out_order.axes, out_shape))): shape.insert(0, size) axes.insert(0, axis) if mul(shape) >= sgemm.N: break if mul(shape) != sgemm.N: # output axes are derived from both w1 and x continue w1_virtual_order = Order([axis_k] + axes) w1_virtual_shape = [sgemm.K] + shape else: # w1.shape = (N, K) shape = [sgemm.K] axes = [axis_k] for axis, size in reversed( list(zip(out_order.axes, out_shape))): shape.insert(0, size) axes.insert(0, axis) if mul(shape) >= w1.size: break if mul(shape) != w1.size: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes) w1_virtual_shape = shape h = sgemm.outputs["C"] # type: Variable x0 = elementwise_mul.inputs["x0"] x1 = elementwise_mul.inputs["x1"] if h == x1: if not isinstance(x0, ConstantVariable): # w2 is not constant continue w2 = x0 # type: ConstantVariable else: if not isinstance(x1, ConstantVariable): # w2 is not constant continue w2 = x1 # type: ConstantVariable y = elementwise_mul.outputs["y"] # type: Variable if not all(axis in w1_virtual_order.axes for axis in w2.order.axes): # w2's axes are derived from both w1 and x continue elementwise_mul.remove_all() y_dummy, = Transpose(None)(h) y_dummy.change_order(y.order) y_dummy.replace(y) w2.change_order(w1_virtual_order) w_new = ConstantVariable( w1.data.reshape(w1_virtual_shape), w1_virtual_order) * w2 # type: ConstantVariable w1.data = w_new.data.reshape(w1.shape) flag_changed = True matches = traverse.search_sub_structure( graph, [Sgemm, Variable, ElementwiseMul]) return graph, flag_changed
def test_unify_same_name_axes(): a1 = Axis(name="A") a2 = Axis(name="A") a1.unify(a2) assert a1 == a2
def test_equal(): a1 = Axis() assert a1 == a1
def test_unify_different_name_axes(): a1 = Axis(name="A") a2 = Axis(name="B") a1.unify(a2)
def _simplify_orders( variables: List[Variable] ) -> Tuple[Dict[Variable, Order], Dict[Variable, AxisKeyDict[int]]]: """ Simplify variable orders based on follow rules - Axis whose size is :code:`1` will be removed. - If axis :code:`A` and :code:`B` is adjacent in all variables which has axis :code:`A` and axis :code:`B`, :code:`A` and :code:`B` will be merged. - For example, :code:`OrderABC` and :code:`OrderCAB` can be simplified as :code:`OrderXC` and :code:`OrderCX` - In this case, the size of axis :code:`X` is calculated as :code:`(size of axis A) * (size of axis B)` ...code-block::text ex) x0.order=NHWC, simplify x0.order=X y.order=NHWC ------------> y.order=X ex) x0.order=C, simplify x0.order=C x1.order=NHWC ------------> x1.order=XC y.order=NHWC y.order=XC ex) x0.order=C, simplify x0.order=C x1.order=HW ------------> x1.order=X y.order=NHWC y.order=NXC Returns: (tuple of dicts) simplified orders and shape """ orders = {} # type: Dict[Variable, Order] shape_dicts = {} # type: Dict[Variable, AxisKeyDict[int]] axis_scalar = Axis("Scalar") # remove all axes whose size is `1`. for v in variables: new_axes = [a for a in v.order.axes if v.shape_dict[a] != 1] orders[v] = Order(new_axes) shape_dicts[v] = AxisKeyDict(new_axes, [v.shape_dict[a] for a in new_axes]) if len(new_axes) == 0 and v.size == 1: orders[v] = Order([axis_scalar]) shape_dicts[v] = AxisKeyDict([axis_scalar], [1]) # list up all axes and variables which have the axis var_dict = AxisKeyDict[Set[Variable]]() for v in variables: for axis in orders[v].axes: if axis in var_dict: var_dict[axis].add(v) else: var_dict[axis] = {v} # find pair of axes which can be merged counter = 0 flag_continue = True while flag_continue: flag_continue = False for axis1, vars1 in list(var_dict.items()): for axis2, vars2 in list(var_dict.items()): if axis1 == axis2: continue if vars1 != vars2 or any(orders[v].axes_dict[axis1] + 1 != orders[v].axes_dict[axis2] for v in vars1): # `axis1` and `axis2` must be adjacent. continue # merge `axis1` and `axis2` into `axis_new` axis_new = Axis(f"X{counter}") counter += 1 for v in vars1: shape_dict = shape_dicts[v] shape_dict[ axis_new] = shape_dict[axis1] * shape_dict[axis2] del shape_dict[axis1] del shape_dict[axis2] order = orders[v] orders[v] = Order(order.axes[:order.axes_dict[axis1]] + (axis_new, ) + order.axes[order.axes_dict[axis2] + 1:]) var_dict[axis_new] = vars1 del var_dict[axis1] del var_dict[axis2] flag_continue = True break if flag_continue: break return orders, shape_dicts
def _convert_expand_dims(converter: ChainerConverter, c_op: "chainer.functions.ExpandDims"): x = converter.get_variable(c_op.inputs[0]) y = x.expand_dims(Axis(), c_op.axis) converter.set_variable(c_op.outputs[0](), y)
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for sgemm in traverse.filter_nodes(traverse.listup_operators(graph), Sgemm): # type: Sgemm A = sgemm.inputs["A"] B = sgemm.inputs["B"] M = sgemm.M N = sgemm.N K = sgemm.K transpose_A = sgemm.transpose_A transpose_B = sgemm.transpose_B if all([ self.optimize_channel_mode, K % 4 == 0, isinstance(A, ConstantVariable) or transpose_A == True, isinstance(B, ConstantVariable) or transpose_B == False ]): if transpose_A != True: assert isinstance(A, ConstantVariable) flag_changed = True old_A = A A = ConstantVariable( A.data.reshape([K, M]).transpose(), Order([Axis(None), Axis(None)])) ChannelMode.set(A, ChannelMode.get(old_A)) sgemm.replace_input(old_A, A, with_assert=False) sgemm.parameters["transpose_A"] = transpose_A = True if transpose_B != False: assert isinstance(B, ConstantVariable) flag_changed = True old_B = B B = ConstantVariable( B.data.reshape([K, N]).transpose(), Order([Axis(None), Axis(None)])) ChannelMode.set(B, ChannelMode.get(old_B)) sgemm.replace_input(old_B, B, with_assert=False) sgemm.parameters["transpose_B"] = transpose_B = False if ChannelMode.get(A) != ChannelModeEnum.RGBA: flag_changed = True ChannelMode.set(A, ChannelModeEnum.RGBA) if ChannelMode.get(B) != ChannelModeEnum.RGBA: flag_changed = True ChannelMode.set(B, ChannelModeEnum.RGBA) texture_shape_A = [M, K // 4] if transpose_A else [K // 4, M] texture_shape_B = [K // 4, N] if transpose_B else [N, K // 4] else: if ChannelMode.get(A) != ChannelModeEnum.R: flag_changed = True ChannelMode.set(A, ChannelModeEnum.R) if ChannelMode.get(B) != ChannelModeEnum.R: flag_changed = True ChannelMode.set(B, ChannelModeEnum.R) texture_shape_A = [M, K] if transpose_A else [K, M] texture_shape_B = [K, N] if transpose_B else [N, K] if TextureShape.get(A) != texture_shape_A: flag_changed = True TextureShape.set(A, height=texture_shape_A[0], width=texture_shape_A[1]) if TextureShape.get(B) != texture_shape_B: flag_changed = True TextureShape.set(B, height=texture_shape_B[0], width=texture_shape_B[1]) if flag_changed: graph, _ = ConstantFolding().optimize(graph) return graph, flag_changed
def check_broadcast_constraints(a: Variable, b: Variable, axis: Optional[int] = None): """check_broadcast_constraints(a, b, axis=None) Check follow constraints corresponding to broadcasting: - each axes pair must be same axis. - shape must be valid for broadcasting. Args: a: Variable b: Variable axis: broadcast start position a.shape=(2, 3, 4, 5) b.shape=(5), --> If axis=3 (or None), broadcasting is possible. a.shape=(2, 3, 4, 5) b.shape=(3, 4), --> If axis=1, broadcasting is possible. Note that `axis=None` (or `a.ndim - b.ndim`) is same as numpy-style broadcasting. Returns: """ a_shape = list(a.shape) b_shape = list(b.shape) a_axes = list(a.order.axes) b_axes = list(b.order.axes) a_ndim = a.ndim b_ndim = b.ndim if axis is None: axis = a_ndim - b_ndim for _ in range(axis): b_shape = [1] + b_shape b_axes = [Axis()] + b_axes b_ndim += 1 while a_ndim < b_ndim: a_shape = a_shape + [1] a_axes = a_axes + [Axis()] a_ndim += 1 while b_ndim < a_ndim: b_shape = b_shape + [1] b_axes = b_axes + [Axis()] b_ndim += 1 for i in range(a_ndim): if a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1: a_axes[i].unify(b_axes[i]) if (a_shape[i] == 1 and b_shape[i] == 1) or (a_shape[i] != 1 and b_shape[i] != 1): # If broadcast is not occurred, size must be same assert a_shape[i] == b_shape[i], f""" Shape mismatch: a.shape[{i}] != b.shape[{i}] (a.shape) = {a_shape} (b.shape) = {b_shape} """ else: raise ValueError(f"""Broadcast is failed: \n (a.shape)={a_shape} (b.shape)={b_shape} (axis)={axis}""")