def test_W_NC(): template(w_order=Order([Axis.N, Axis.C]))
def rank_handler(converter: TensorFlowConverter, tf_op: "tf.Operation"): x = converter.get_variable(tf_op.inputs[0]) y = ConstantVariable(np.array([x.ndim]), Order([None])) converter.set_variable(tf_op.outputs[0], y)
def _convert_separable_conv2d(converter: KerasConverter, k_op: "keras.layers.SeparableConv2D"): x = converter.get_variable(converter.get_input_tensor(k_op)[0]) check_data_format(x, k_op.data_format) axis_c_in = Axis.C axis_c_out = Axis() axis_depth_multiplier = Axis() w_depthwise = converter.convert_to_constant_variable( k_op.depthwise_kernel, Order([Axis.KH, Axis.KW, axis_c_in, axis_depth_multiplier])) w_pointwise = converter.convert_to_constant_variable( k_op.pointwise_kernel, Order([Axis.KH, Axis.KW, axis_c_in, axis_c_out])) w_pointwise = w_pointwise.reshape( shape=[ x.shape_dict[axis_c_in], k_op.depth_multiplier, w_pointwise.shape_dict[axis_c_out] ], order=Order([axis_c_in, axis_depth_multiplier, axis_c_out])) ksize = tuple(k_op.kernel_size) stride = tuple(k_op.strides) dilation_rate = tuple(k_op.dilation_rate) padding = (parse_padding(k_op.padding, ksize[0], dilation_rate[0]), parse_padding(k_op.padding, ksize[1], dilation_rate[1])) if any(p[0] != p[1] for p in padding): raise NotImplementedError( "[KerasConverter] \"Different size padding\" is not supported yet") padding = tuple(p[0] for p in padding) h, = Im2Col(None, ksize=ksize, stride=stride, padding=padding, dilation_rate=dilation_rate)(x) # TODO: Support depth-wise convolution natively # Currently, depth-wise convolution is not supported natively, and emulated by composition of small convolution operations. ys = [] for i in range(h.shape_dict[axis_c_in]): # 1. Depthwise convolution # # Ideal | Current implementation # ----------------------------------+---------------------------------------------------- # h.axes=[N, H, W, KH, KW, C_in] | g_sub.axes=[N, H, W, KH, KW] # w.axes=[KH, KW, C_in, DM] | w_sub.axes=[KH, KW, DM] # g.axes=[N, H, W, C_in, DM] | g_sub.axes=[N, H, W, DM] h_sub, = Slice( None, indices=AxisKeyDict( h.order.axes, [i if a == axis_c_in else slice(None) for a in h.order.axes]))(h) w_depthwise_sub = w_depthwise[:, :, i, :] g_sub, = Tensordot(None, axes=((Axis.KH, Axis.KW), (Axis.KH, Axis.KW)))(h_sub, w_depthwise_sub) # 2. Pointwise (projection) convolution # # Ideal | Current implementation # ----------------------------------+---------------------------------------------------- # g.axes=[N, H, W, C_in, DM] | g_sub.axes=[N, H, W, DM] # w.axes=[DM, Cin, C_out] | w_sub.axes=[DM, C_out] # y.axes=[N, H, W, C_out] | y_sub.axes=[N, H, W, C_out] w_pointwise_sub = w_pointwise[i, :, :] y_sub, = Tensordot(None, axes=((axis_depth_multiplier, ), (axis_depth_multiplier, )))(g_sub, w_pointwise_sub) ys.append(y_sub) # Sum up all sub convolution results to one while len(ys) > 1: ys.append(ys.pop(0) + ys.pop(0)) y = ys[0] # reinterpret axis "C_out" as C axes = list(y.order.axes) i = axes.index(axis_c_out) axes.pop(i) axes.insert(i, Axis.C) y = y.reinterpret_axes(Order(axes)) if k_op.data_format == "channels_last": y = y.transpose(OrderNHWC) elif k_op.data_format == "channels_first": y = y.transpose(OrderNCHW) else: raise NotImplementedError( f"[KerasConverter] Unknown data format: {k_op.data_format}") if k_op.use_bias: b = converter.convert_to_constant_variable(k_op.bias, OrderC) y = y + b y = do_activation(k_op.activation, y) converter.set_variable(converter.get_output_tensor(k_op)[0], y)
def test(): x = Variable([2, 3, 4, 5], OrderNHWC) y, = Min(None, axis=Axis.C)(x) assert y.order == Order([Axis.N, Axis.H, Axis.W]) assert y.shape == [2, 3, 4]
def convert(self, inputs: List["chainer.Variable"], outputs: List["chainer.Variable"]) -> Graph: """convert(inputs, outputs) Convert chainer computational graph into WebDNN IR. Args: inputs(list of chainer.Variable): input chainer variables outputs(list of chainer.Variable): output chainer variables .. admonition:: example Convert pre-trained ResNet model .. code:: model = chainer.links.model.vision.resnet.ResNet50Layers() # Forward propagation with dummy input to build computational graph x = chainer.Variable(np.empty((1, 3, 224, 224), dtype=np.float32)) y = model(x, layers=["fc6"])["fc6"] graph = ChainerConverter().convert([x], [y]) Returns: (:class:`~webdnn.Graph`): WebDNN Graph """ for v in inputs: if isinstance(v, PlaceholderVariable): n_var = Variable(v.actual_shape, Order([None] * v.ndim)) self.set_variable(to_variable_node(v), n_var) inputs = [to_variable_node(v) for v in inputs] outputs = [to_variable_node(v) for v in outputs] # Convert parameters into constant variable input_set = set(inputs) for node in chainer.computational_graph.build_computational_graph( outputs).nodes: if isinstance(node, T_VARIABLE) and not self.has_variable( node) and node.creator is None: # If "c_var.creator" is None, it's input variable or parameters. # NOTE(Kiikurage): # In chainer v1, "Variable" doesn't support "__eq__" method, so "list.__contains__" cannot be used for list of variables. # However, "Variable.__hash__" is implemented and "set.__contains__" is available. self._convert_var(node, constant=node not in input_set) # Convert each Chainer function into WebDNN operators for c_opr in _listup_functions(inputs, outputs): self._convert_operator(c_opr) # Build graph graph = Graph([self.get_variable(c_var) for c_var in inputs], [self.get_variable(c_var) for c_var in outputs]) for v in graph.inputs: v.attributes.add(Input()) for v in graph.outputs: v.attributes.add(Output()) return graph
def conv2_d_backprop_input_handler(converter: TensorFlowConverter, tf_op: "tf.Operation"): input_sizes = converter.get_variable(tf_op.inputs[0]) if not isinstance(input_sizes, ConstantVariable): raise NotImplementedError( "[TensorFlowConverter] Conv2DBackpropInput with dynamic shape of output (input of convolution) variable is not supported." ) input_sizes = tuple(input_sizes.data.astype(np.int32).tolist()) w = converter.get_variable(tf_op.inputs[1]) # HWNC w.order.unify(Order([Axis.KH, Axis.KW, Axis.N, Axis.C])) gy = converter.get_variable(tf_op.inputs[2]) # NHWC data_format = tf_op.get_attr("data_format") check_data_format(gy, data_format) input_size = np.array([ input_sizes[gy.order.axes_dict[Axis.H]], input_sizes[gy.order.axes_dict[Axis.W]] ]) ksize = np.array([w.shape_dict[Axis.KH], w.shape_dict[Axis.KW]]) stride = np.array(tf_op.get_attr("strides")) assert stride[gy.order.axes_dict[Axis.N]] == 1 assert stride[gy.order.axes_dict[Axis.C]] == 1 stride = stride[[gy.order.axes_dict[Axis.H], gy.order.axes_dict[Axis.W]]] padding = np.array([ parse_padding(tf_op.get_attr("padding"), ksize[0], 1), parse_padding(tf_op.get_attr("padding"), ksize[1], 1) ]) x, = Deconvolution2D(None, ksize=ksize.tolist(), stride=stride.tolist(), padding=0)(gy, w) # Actual padding size is depend on 2 factors # 1. padding mode # 2. extra apron size (= (input size of convolution) - (size of the tensor expanded by deconvolution)) expanded_size = np.array([x.shape_dict[Axis.H], x.shape_dict[Axis.W]]) apron_size = input_size - (expanded_size - padding.sum(axis=1)) # cancel padding by apron if possible for i in (0, 1): if padding[i, 0] > apron_size[i]: padding[i, 0] -= apron_size[i] apron_size[i] = 0 else: apron_size[i] -= padding[i, 0] padding[i, 0] = 0 if padding[i, 1] > apron_size[i]: padding[i, 1] -= apron_size[i] apron_size[i] = 0 else: apron_size[i] -= padding[i, 1] padding[i, 1] = 0 # append extra apron for i, axis in enumerate((Axis.H, Axis.W)): if apron_size[i] == 0: continue data = np.zeros([ apron_size[i] if a == axis else x.shape_dict[a] for a in x.order.axes ]) x, = Concat(None, axis=axis)(x, ConstantVariable(data, x.order)) # crop without padding padding = padding.tolist() # type: List[List[int]] slice_h = slice(None) if padding[0] == [0, 0] else slice( padding[0][0], -padding[0][1]) slice_w = slice(None) if padding[1] == [0, 0] else slice( padding[1][0], -padding[1][1]) if data_format == b"NCHW": x = x[:, :, slice_h, slice_w] elif data_format == b"NHWC": x = x[:, slice_h, slice_w, :] else: raise NotImplementedError(f"Unknown data format: {data_format}") converter.set_variable(tf_op.outputs[0], x)
def _convert_transpose(converter: ChainerConverter, c_op: "chainer.functions.Transpose"): x = converter.get_variable(c_op.inputs[0]) y = x.transpose(Order([x.order.axes[axis] for axis in c_op.axes])) converter.set_variable(c_op.outputs[0](), y)
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.listup_operators(graph): if isinstance(op, (Reshape, ReinterpretAxis)): flag_changed |= _replace_input(graph, op, "x", op.parameters["in_order"]) flag_changed |= _replace_output(graph, op, "y", op.parameters["out_order"]) continue elif isinstance(op, LSTM): flag_changed |= _replace_input(graph, op, "x", OrderNTC) flag_changed |= _replace_input(graph, op, "w_input", OrderCN) flag_changed |= _replace_input(graph, op, "w_hidden", OrderCN) flag_changed |= _replace_output( graph, op, "y", OrderNTC if op.parameters["return_sequences"] else OrderNC) flag_changed |= _replace_output(graph, op, "final_c", OrderNC) continue elif isinstance(op, Embedding): flag_changed |= _replace_input(graph, op, "x", OrderNT) flag_changed |= _replace_input(graph, op, "w", OrderCN) flag_changed |= _replace_output(graph, op, "y", OrderNTC) continue elif isinstance(op, Im2Col): flag_changed |= _replace_input(graph, op, "im", OrderNHWC) flag_changed |= _replace_output(graph, op, "col", [ Order([Axis.N, Axis.H, Axis.W, Axis.KH, Axis.KW, Axis.C]), Order([Axis.KH, Axis.KW, Axis.C, Axis.N, Axis.H, Axis.W]) ]) continue elif isinstance(op, Col2Im): flag_changed |= _replace_input(graph, op, "col", [ Order([Axis.N, Axis.H, Axis.W, Axis.KH, Axis.KW, Axis.C]) ]) flag_changed |= _replace_output(graph, op, "im", OrderNHWC) continue elif isinstance(op, (Tensordot, )): op = op # type: Tensordot A = op.inputs["A"] B = op.inputs["B"] C = op.outputs["C"] # Reduced axes must be located in inner side. a_axes = list(A.order.axes) for axis in op.axes[0]: a_axes.remove(axis) a_axes.append(axis) b_axes = list(B.order.axes) for axis in op.axes[1]: b_axes.remove(axis) b_axes.append(axis) # Remained axes must be located in same order as A and B's axes order. if all(axis in op.axes[0] for axis in C.order.axes[:A.ndim - len(op.axes[0])]): # C's order is as [*a_remained_axes, *b_remained_axes], so it's not need to transpose C. for i, axis in enumerate(C.order.axes[:A.ndim - len(op.axes[0])]): a_axes.remove(axis) a_axes.insert(i, axis) for i, axis in enumerate(C.order.axes[A.ndim - len(op.axes[0]):]): b_axes.remove(axis) b_axes.insert(i, axis) else: c_axes = a_axes[:(A.ndim - len(op.axes[0]))] + b_axes[:( B.ndim - len(op.axes[1]))] flag_changed |= _replace_output(graph, op, "C", Order(c_axes)) flag_changed |= _replace_input(graph, op, "A", Order(a_axes)) flag_changed |= _replace_input(graph, op, "B", Order(b_axes)) continue elif isinstance(op, (Convolution2D, Deconvolution2D, MaxPooling2D, AveragePooling2D, Space2Depth, Depth2Space, LocalResponseNormalization, Unpooling2D)): flag_changed |= _replace_input(graph, op, "x", OrderNHWC) flag_changed |= _replace_output(graph, op, "y", OrderNHWC) continue elif isinstance(op, Softmax): x = op.inputs["x"] y = op.outputs["y"] target_axis = op.parameters["axis"] if not (x.ndim == 2 and x.order.axes_dict[target_axis] == x.ndim - 1): """ Before) | x | | y | |-----| -{softmax}-> |-----| | XYZ | axis=Y | XYZ | After) | x | | hx1 | | hx2 | | hy1 | | hy2 | | y | |-----| -{transpose}-> |-----| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |-----| -{transpose}-> |-----| | XYZ | | XZY | | NC | axis=C | NC | | XZY | | XYZ | : : order_nd = XZY order_2d = NC """ op.remove_all() axes_nd = list(x.order.axes) axes_nd.remove(target_axis) axes_nd.append(target_axis) order_nd = Order(axes_nd) shape_nd = tuple([x.shape_dict[axis] for axis in axes_nd]) order_2d = OrderNC shape_2d = tuple([ x.size // x.shape_dict[target_axis], x.shape_dict[target_axis] ]) if x.order == order_nd: hx1 = x else: hx1 = x.transpose(order_nd) flag_changed = True if hx1.order == order_2d and hx1.shape == shape_2d: hx2 = hx1 else: hx2 = hx1.reshape(shape_2d, order_2d) flag_changed = True hy1, = Softmax(None, axis=Axis.C)(hx2) if hy1.order == order_nd and hy1.shape == shape_nd: hy2 = hy1 else: hy2 = hy1.reshape(shape_nd, order_nd) flag_changed = True if hy2.order == y.order: y_dummy = hy2 else: y_dummy = hy2.transpose(y.order) flag_changed = True OptimizeRule.replace_variable(graph, y_dummy, y) continue else: # "op" accepts any order. Remove redundant transpose operations if exist. for key in op.inputs: flag_changed |= _optimize_redundant_transposed_input( graph, op, key, None) for key in op.outputs: flag_changed |= _optimize_redundant_transposed_output( graph, op, key, None) continue return graph, flag_changed
def optimize_loop_structure(variables: List[Variable], key_variable: Variable): """ Optimize loop structure to iterate each element in variables Returns: (tuple): two elements are returned - First one is shape dictionary of all variables. - Second one is stride dictionary of all variables. """ orders, shape_dicts = _simplify_orders( variables ) # type: Dict[Variable, Order], Dict[Variable, AxisKeyDict[List[int]]] shapes = { v: [shape_dicts[v][a] for a in orders[v].axes] for v in variables } strides = { v: [mul(shapes[v][orders[v].axes_dict[a] + 1:]) for a in orders[v].axes] for v in variables } stride_dicts = { v: AxisKeyDict(orders[v].axes, strides[v]) for v in variables } # re-ordering axes = [] for v in sorted(variables, key=lambda v: orders[v].ndim): axes += [axis for axis in orders[v].axes if axis not in axes] orders = { v: Order(list(filter(lambda x: x in orders[v].axes, axes))) for v in variables } key_order = orders[key_variable] if key_order.ndim > 4: raise NotImplementedError( 'Currently, loop nest depth larger than 4 is not supported') shapes = { v: [ shape_dicts[v][a] if a in orders[v].axes else 1 for a in key_order.axes ] for v in variables } strides = { v: [ stride_dicts[v][a] if a in orders[v].axes else 1 for a in key_order.axes ] for v in variables } for v in variables: shape = shapes[v] stride = strides[v] while len(shape) < 4: stride.append(1) shape.append(1) return shapes, strides
def test_compare_custom_order(): order1 = Order([Axis.N, Axis.C]) assert OrderNC == order1
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.listup_operators(graph): if isinstance(op, Transpose): x = op.inputs["x0"] y = op.outputs["y"] if x.order == y.order: op.remove_all() x.replace(y) flag_changed = True if all(isinstance(op2, (Elementwise, SplitAxis)) for op2 in y.input_to): op.remove_all() for op2 in list(y.input_to): name = op2._get_input_name(y) op2.remove_input(y) op2.append_input(name, x) elif isinstance(op, Reshape): flag_changed |= _replace_input(op, "x", op.parameters["in_order"]) flag_changed |= _replace_output(op, "y", op.parameters["out_order"]) elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D, Deconvolution2D)): flag_changed |= _replace_input(op, "x", OrderNHWC) flag_changed |= _replace_output(op, "y", OrderNHWC) elif isinstance(op, Softmax): x = op.inputs["x"] y = op.outputs["y"] if x.ndim > 2: """ Before) | x | | y | |------| -{softmax}-> |------| | NCHW | | NCHW | After) | x | | hx1 | | hx2 | | hy1 | | hy2 | | y | |------| -{transpose}-> |------| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |------| -{transpose}-> |------| | NCHW | | NHWC | | NC | | NC | | NHWC | | NCHW | """ op.remove_all() target_axis = op.parameters["axis"] axes_nd = list(x.order.axes) axes_nd.remove(target_axis) axes_nd.append(target_axis) order_nd = Order(axes_nd) shape_nd = [x.shape_dict[axis] for axis in axes_nd] order_2d = OrderNC shape_2d = [x.size // x.shape_dict[target_axis], x.shape_dict[target_axis]] hx1, = Transpose(None)(x) hx1.change_order(order_nd) hx2, = Reshape(None, in_order=hx1.order, out_order=order_2d, out_shape=shape_2d)(hx1) hy1, = Softmax(None, axis=Axis.C)(hx2) hy2, = Reshape(None, in_order=hy1.order, out_order=order_nd, out_shape=shape_nd)(hy1) y_dummy, = Transpose(None)(hy2) y_dummy.change_order(y.order) y_dummy.replace(y) flag_changed = True else: flag_changed |= _replace_input(op, "x", OrderNC) flag_changed |= _replace_output(op, "y", OrderNC) return graph, flag_changed
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op1 in traverse.filter_nodes(traverse.listup_operators(graph), Concat): # type: Concat if len(op1.inputs) != 2: continue x0 = op1.inputs["x0"] x1 = op1.inputs["x1"] y = op1.outputs["y"] op2 = x0.output_from op3 = x1.output_from if isinstance(op2, ElementwiseAdd) and isinstance( op3, ElementwiseAdd) and len(x0.input_to) == 1 and len( x1.input_to) == 1: """ before) v1 -+ +-[op2: ElementwiseAdd]-> x0 -+ c1 -+ | +-[op1: Concat]-> y v2 -+ | +-[op3: ElementwiseAdd]-> x1 -+ c2 -+ after) v1 -+ +-[Concat]-> x6 -+ v2 -+ | +-[ElementwiseAdd]-> y | c3 -+ """ x2 = op2.inputs["x0"] x3 = op2.inputs["x1"] x4 = op3.inputs["x0"] x5 = op3.inputs["x1"] if isinstance(x2, ConstantVariable): c1 = x2 v1 = x3 elif isinstance(x3, ConstantVariable): c1 = x3 v1 = x2 else: continue if isinstance(x4, ConstantVariable): c2 = x4 v2 = x5 elif isinstance(x5, ConstantVariable): c2 = x5 v2 = x4 else: continue if not (c1.order == c2.order == Order([op1.axis])): continue op1.remove_all() op2.remove_all() op3.remove_all() c3 = ConstantVariable(np.hstack([c1.data, c2.data]), c1.order) x6, = Concat(None, axis=op1.axis)(v1, v2) y_dummy, = ElementwiseAdd(None)(x6, c3) y_dummy.replace(y) flag_changed = True return graph, flag_changed
def expand_dims(self, axis: Axis, index: int) -> "Variable": """expand_dims(shape, axis, index) Insert new axis whose size is 1. This is alias of follow codes. new_axes = list(v.order.axes) new_axes.insert(index, axis) Reshape(None, in_order=v.order, out_order=Order(new_axes), out_shape=[1 if a == axis else self.shape_dict[a] for a in new_axes])(v)[0] Args: axis (:class:`~Axis`): inserted axis index (int): insert position Returns: (:class:`~Variable`) expanded variable """ if index < 0: index += 1 new_axes = list(self.order.axes) new_axes.insert(index, axis) return self.reshape(shape=[1 if a == axis else self.shape_dict[a] for a in new_axes], order=Order(new_axes))
def test_Y_CTN(): template(w_order=Order([Axis.C, Axis.T, Axis.N]))
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.listup_operators(graph): if isinstance(op, (Tensordot, )): op = op # type: Tensordot A = op.inputs["A"] B = op.inputs["B"] C = op.outputs["C"] # Reduced axes must be located in inner side. a_axes = list(A.order.axes) for axis in op.axes[0]: a_axes.remove(axis) a_axes.append(axis) b_axes = list(B.order.axes) for axis in op.axes[1]: b_axes.remove(axis) b_axes.append(axis) # Remained axes must be located in same order as A and B's axes order. if all(axis in op.axes[0] for axis in C.order.axes[:A.ndim - len(op.axes[0])]): # C's order is as [*a_remained_axes, *b_remained_axes], so it's not need to transpose C. for i, axis in enumerate(C.order.axes[:A.ndim - len(op.axes[0])]): a_axes.remove(axis) a_axes.insert(i, axis) for i, axis in enumerate(C.order.axes[A.ndim - len(op.axes[0]):]): b_axes.remove(axis) b_axes.insert(i, axis) else: c_axes = a_axes[:(A.ndim - len(op.axes[0]))] + b_axes[:( B.ndim - len(op.axes[1]))] flag_changed |= _replace_output(graph, op, "C", Order(c_axes)) flag_changed |= _replace_input(graph, op, "A", Order(a_axes)) flag_changed |= _replace_input(graph, op, "B", Order(b_axes)) continue elif isinstance(op, (Im2Col, )): op = op # type: Im2Col col = op.outputs["col"] # In variable "col", Axis.KH, Axis.KW, and Axis.C must be placed in this order. col_axes = list(col.order.axes) for axis in (Axis.KH, Axis.KW, Axis.C): col_axes.remove(axis) col_axes.append(axis) flag_changed |= _replace_output(graph, op, "col", Order(col_axes)) continue elif isinstance(op, (Col2Im, )): op = op # type: Col2Im col = op.inputs["col"] # In variable "col", Axis.KH, Axis.KW, and Axis.C must be placed in this order. col_axes = list(col.order.axes) for axis in (Axis.KH, Axis.KW, Axis.C): col_axes.remove(axis) col_axes.append(axis) flag_changed |= _replace_input(graph, op, "col", Order(col_axes)) continue elif isinstance(op, (ConvertRGBAtoR, ConvertRtoRGBA)): flag_changed |= _replace_input(graph, op, "x0", op.outputs["y"].order) continue else: # "op" accepts any order. Remove redundant transpose operations if exist. for key in op.inputs: flag_changed |= _optimize_redundant_transposed_input( graph, op, key, None) for key in op.outputs: flag_changed |= _optimize_redundant_transposed_output( graph, op, key, None) continue return graph, flag_changed
def __call__(self, x: Variable): # assert index is valid for axis, index in self.indices.items(): if axis in x.order.axes: if isinstance(index, slice): index = normalize_slice(index, x.shape_dict[axis]) valid_start = -x.shape_dict[ axis] <= index.start <= x.shape_dict[axis] valid_stop = -x.shape_dict[ axis] <= index.stop <= x.shape_dict[axis] if not valid_start or not valid_stop: raise ValueError(f""" [Slice] Index {index} in {axis} is out of range: (x.order) = {x.order} (x.shape) = {x.shape} (indices) = {self.indices} (indices[{axis.name}]) = {index} """) if ((abs(index.stop - index.start) - 1) // abs(index.step)) + 1 < 0: raise ValueError(f""" [Slice] Slice operator doesn't support 0-size output: (x.order) = {x.order} (x.shape) = {x.shape} (indices) = {self.indices} (indices[{axis.name}]) = {index} """) elif isinstance(index, int): if not -x.shape_dict[axis] <= index < x.shape_dict[axis]: raise ValueError(f""" [Slice] Index {index} in {axis} is out of range: (x.order) = {x.order} (x.shape) = {x.shape} (indices) = {self.indices} (indices[{axis.name}]) = {index} (valid range) = [{-x.shape_dict[axis]}, {x.shape_dict[axis]}) """) elif index is None: raise ValueError(f""" [Slice] Axis {axis} is already exist: (x.order) = {x.order} (x.shape) = {x.shape} (indices) = {self.indices} (indices[{axis.name}]) = {index} """) else: if index is not None: raise ValueError(f""" [Slice] Axis {axis} is not exist in input variable. In this case, index must be "None" (=insert new axis): (x.order) = {x.order} (x.shape) = {x.shape} (indices) = {self.indices} (indices[{axis.name}]) = {index} """) if all(isinstance(index, int) for index in self.indices.values()): raise NotImplementedError(f""" [Slice] Accessing to one element is not supported: (indices) = {self.indices} """) y_shape_dict = AxisKeyDict() for axis, index in self.indices.items(): if isinstance(index, slice): index = normalize_slice(index, x.shape_dict[axis]) y_shape_dict[axis] = ( (abs(index.stop - index.start) - 1) // abs(index.step)) + 1 elif isinstance(index, int): pass # Remove axis elif index is None: y_shape_dict[axis] = 1 # Insert axis y = Variable(list(y_shape_dict.values()), Order(list(y_shape_dict.keys()))) for axis in x.order.axes: if axis in self.indices: index = self.indices[axis] if isinstance( index, slice ) and index.start is None and index.stop is None and index.step is None: # This axis is not sliced. self.attributes.add(Tensorwise(axis)) else: # This axis is not sliced. self.attributes.add(Tensorwise(axis)) self.append_input("x", x) self.append_output("y", y) return y,
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for sgemm in traverse.filter_nodes(traverse.listup_operators(graph), Sgemm): # type: Sgemm A = sgemm.inputs["A"] B = sgemm.inputs["B"] M = sgemm.M N = sgemm.N K = sgemm.K transpose_A = sgemm.transpose_A transpose_B = sgemm.transpose_B if all([ self.optimize_channel_mode, K % 4 == 0, isinstance(A, ConstantVariable) or transpose_A == True, isinstance(B, ConstantVariable) or transpose_B == False ]): if transpose_A != True: assert isinstance(A, ConstantVariable) flag_changed = True old_A = A A = ConstantVariable( A.data.reshape([K, M]).transpose(), Order([Axis(None), Axis(None)])) ChannelMode.set(A, ChannelMode.get(old_A)) sgemm.replace_input(old_A, A, with_assert=False) sgemm.parameters["transpose_A"] = transpose_A = True if transpose_B != False: assert isinstance(B, ConstantVariable) flag_changed = True old_B = B B = ConstantVariable( B.data.reshape([K, N]).transpose(), Order([Axis(None), Axis(None)])) ChannelMode.set(B, ChannelMode.get(old_B)) sgemm.replace_input(old_B, B, with_assert=False) sgemm.parameters["transpose_B"] = transpose_B = False if ChannelMode.get(A) != ChannelModeEnum.RGBA: flag_changed = True ChannelMode.set(A, ChannelModeEnum.RGBA) if ChannelMode.get(B) != ChannelModeEnum.RGBA: flag_changed = True ChannelMode.set(B, ChannelModeEnum.RGBA) texture_shape_A = [M, K // 4] if transpose_A else [K // 4, M] texture_shape_B = [K // 4, N] if transpose_B else [N, K // 4] else: if ChannelMode.get(A) != ChannelModeEnum.R: flag_changed = True ChannelMode.set(A, ChannelModeEnum.R) if ChannelMode.get(B) != ChannelModeEnum.R: flag_changed = True ChannelMode.set(B, ChannelModeEnum.R) texture_shape_A = [M, K] if transpose_A else [K, M] texture_shape_B = [K, N] if transpose_B else [N, K] if TextureShape.get(A) != texture_shape_A: flag_changed = True TextureShape.set(A, height=texture_shape_A[0], width=texture_shape_A[1]) if TextureShape.get(B) != texture_shape_B: flag_changed = True TextureShape.set(B, height=texture_shape_B[0], width=texture_shape_B[1]) if flag_changed: graph, _ = ConstantFolding().optimize(graph) return graph, flag_changed
def im2col(op: Im2Col) -> List[Kernel]: im = op.inputs["im"] col = op.outputs["col"] H1 = im.shape_dict[Axis.H] W1 = im.shape_dict[Axis.W] C1 = im.shape_dict[Axis.C] assert col.order.check_same_axes( Order([Axis.N, Axis.H, Axis.W, Axis.KH, Axis.KW, Axis.C])) assert col.order.axes_dict[Axis.KH] + 2 == col.order.axes_dict[ Axis.KW] + 1 == col.order.axes_dict[Axis.C] == 5 assert im.order.check_same_axes(OrderNHWC) assert ChannelMode.get(im) == ChannelModeEnum.R col_shape = col.shape[0:3] + (mul(col.shape[3:6]), ) col_stride = [mul(col_shape[i + 1:]) for i in range(len(col_shape))] col_order = Order(col.order.axes[0:3] + (Axis.C, )) if ChannelMode.get(col) == ChannelModeEnum.R: code = KernelCode([ """ void main() { ivec4 variable_position_col = """, change_order( convert_position("gl_FragCoord.yx", texture_shape(col)[:2], texture_stride(col)[:2], col_shape, col_stride), col_order, OrderNHWC), f"""; int n = variable_position_col.x; int h2 = variable_position_col.y; int w2 = variable_position_col.z; int khkwc1 = variable_position_col.w; int kh = khkwc1 / {C1} / {op.KW}; int kw = khkwc1 / {C1} - kh * {op.KW}; int c1 = khkwc1 - (kh * {op.KW} + kw) * {C1}; int h1 = h2 * {op.SH} - {op.PH} + kh * {op.DH}; int w1 = w2 * {op.SW} - {op.PW} + kw * {op.DW}; if (h1 < 0 || h1 >= {H1} || w1 < 0 || w1 >= {W1}) {{ gl_FragColor.r = 0.0; }} else {{ gl_FragColor.r = """, texel_fetch( im, change_order("vec4(n, h1, w1, c1)", OrderNHWC, im.order)), f""".r; }} }} """ ], name="Im2Col_R") elif ChannelMode.get(col) == ChannelModeEnum.RGBA: code = KernelCode([ """ void main() { ivec4 variable_position_col = """, change_order( convert_position("gl_FragCoord.yx", texture_shape(col)[:2], texture_stride(col)[:2], col_shape, col_stride), col_order, OrderNHWC), f"""; int n = variable_position_col.x; int h2 = variable_position_col.y; int w2 = variable_position_col.z; int khkwc1 = variable_position_col.w; int kh = khkwc1 / {C1} / {op.KW}; int kw = khkwc1 / {C1} - kh * {op.KW}; int c1 = khkwc1 - (kh * {op.KW} + kw) * {C1}; int h1 = h2 * {op.SH} - {op.PH} + kh * {op.DH}; int w1 = w2 * {op.SW} - {op.PW} + kw * {op.DW}; if (h1 < 0 || h1 >= {H1} || w1 < 0 || w1 >= {W1}) {{ gl_FragColor = vec4(0.0, 0.0, 0.0, 0.0); }} else {{ gl_FragColor.r = """, texel_fetch( im, change_order("vec4(n, h1, w1, c1 + 0)", OrderNHWC, im.order)), f""".r; gl_FragColor.g = """, texel_fetch( im, change_order("vec4(n, h1, w1, c1 + 1)", OrderNHWC, im.order)), f""".r; gl_FragColor.b = """, texel_fetch( im, change_order("vec4(n, h1, w1, c1 + 2)", OrderNHWC, im.order)), f""".r; gl_FragColor.a = """, texel_fetch( im, change_order("vec4(n, h1, w1, c1 + 3)", OrderNHWC, im.order)), f""".r; }} }} """ ], name="Im2Col_RGBA") else: raise NotImplementedError source = code.generate() return [Kernel(source, code.name, code.samplers, code.uniforms, col)]
def test_slice_invalid_type(): v1 = Variable([2, 3, 4, 5, 6], Order([None, None, None, None, None])) v1[:, 2, 3, :, None, "hoge"]
def _optimize_loop_structure(variables: List[Variable], key_variable: Variable, keep_axes: List[Axis] = None): """ Optimize loop structure to iterate each element in variables Returns: (tuple): two elements are returned - First one is shape dictionary of all variables. - Second one is stride dictionary of all variables. """ orders, shape_dicts = simplify_orders( variables, keep_axes=keep_axes ) # type: Dict[Variable, Order], Dict[Variable, AxisKeyDict[List[int]]] shapes = { v: [shape_dicts[v][a] for a in orders[v].axes] for v in variables } strides = { v: [mul(shapes[v][orders[v].axes_dict[a] + 1:]) for a in orders[v].axes] for v in variables } stride_dicts = { v: AxisKeyDict(orders[v].axes, strides[v]) for v in variables } # Re-ordering shapes and strides along to key variable's order axes = [] axes += [axis for axis in orders[key_variable].axes if axis not in axes] for v in sorted(variables, key=lambda v: orders[v].ndim): axes += [axis for axis in orders[v].axes if axis not in axes] orders = { v: Order(list(filter(lambda x: x in orders[v].axes, axes))) for v in variables } key_order = orders[key_variable] shapes = { v: [ shape_dicts[v][a] if a in orders[v].axes else 1 for a in key_order.axes ] for v in variables } strides = { v: [ stride_dicts[v][a] if a in orders[v].axes else 1 for a in key_order.axes ] for v in variables } # Padding shapes and strides to 4D if key_order.ndim > 4: raise NotImplementedError(f"Too large number of dimension: {v}") for v in variables: shape = shapes[v] stride = strides[v] while len(shape) < 4: stride.append(1) shape.append(1) return shapes, strides
def _convert_flatten(converter: ChainerConverter, c_op: "chainer.functions.Flatten"): x = converter.get_variable(c_op.inputs[0]) y = x.reshape([x.size], Order([None])) converter.set_variable(c_op.outputs[0](), y)
import numpy as np from test.util import generate_kernel_test_case, wrap_template from webdnn.graph.axis import Axis from webdnn.graph.graph import Graph from webdnn.graph.operators.max import Max from webdnn.graph.operators.sum import Sum from webdnn.graph.order import OrderNHWC, OrderNCHW, Order from webdnn.graph.variable import Variable OrderNHW = Order([Axis.N, Axis.H, Axis.W]) @wrap_template def template(x_order=OrderNHWC, y_order=OrderNHW, axis=Axis.C, description: str = ""): vx = np.arange(120).reshape(2, 3, 4, 5) vy = np.sum(vx, axis=OrderNHWC.axes_dict[axis]) x = Variable(vx.shape, order=OrderNHWC) y, = Sum(None, axis=axis)(x) x.change_order(x_order) y.change_order(y_order) generate_kernel_test_case( description=f"Sum {description}", graph=Graph([x], [y]), backend=["webgl"],
from webdnn.graph.axis import Axis, AxisKeyDict from webdnn.graph.operators.im2col import Im2Col from webdnn.graph.order import Order, OrderNHWC from webdnn.graph.variable import Variable OrderNHWKKC = Order([Axis.N, Axis.H, Axis.W, Axis.KH, Axis.KW, Axis.C]) def main(im_shape=[1, 5, 5, 6], im_order=OrderNHWC, ksize=3, stride=1, padding=1, dilation_rate=1, expected_shape_dict: AxisKeyDict[int] = AxisKeyDict( OrderNHWKKC.axes, [1, 5, 5, 3, 3, 6])): op = Im2Col(None, ksize=ksize, stride=stride, padding=padding, dilation_rate=dilation_rate) x = Variable(im_shape, im_order) y, = op(x) for axis in y.order.axes: assert y.shape_dict[axis] == expected_shape_dict[axis] def test_normal(): main()
def _split_tensordot(graph: Graph, op: Tensordot, v: Variable, v_pair: Sequence[Variable], axis: Axis): s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] A = op.inputs["A"] B = op.inputs["B"] C = op.outputs["C"] axes_M = tuple(filter(lambda a: a not in op.axes[0], A.order.axes)) axes_N = tuple(filter(lambda a: a not in op.axes[1], B.order.axes)) axes_K_A, axes_K_B = op.axes K = mul(A.shape_dict[a] for a in axes_K_A) M = A.size // K N = B.size // K shape_M = [A.shape_dict[a] for a in axes_M] shape_N = [B.shape_dict[a] for a in axes_N] op.remove_all() if v == A: A1, A2 = v_pair if axis in axes_K_A: split_axis_A = axis if (B.shape_dict[axes_K_B[0]] * s1) % (s1 + s2) == 0: split_axis_B = axes_K_B[0] else: # Factorize B's axes consisting to K into A's corresponding axes B = B.transpose(Order(axes_N + axes_K_B)) B = B.reshape(order=Order((Axis(), ) + axes_K_A), shape=[N] + [A.shape_dict[a] for a in axes_K_A]) split_axis_B = split_axis_A axes_K_B = axes_K_A B1, B2 = SplitAxis(None, axis=split_axis_B, sections=[(B.shape_dict[split_axis_B] * s1) // (s1 + s2)])(B) C1, = Tensordot(None, [axes_K_A, axes_K_B])(A1, B1) C2, = Tensordot(None, [axes_K_A, axes_K_B])(A2, B2) OptimizeRule.replace_variable(graph, (C1 + C2).reshape( shape_M + shape_N, Order(axes_M + axes_N)).transpose_like(C), C) else: C1, = Tensordot(None, op.axes)(A1, B) C2, = Tensordot(None, op.axes)(A2, B) for a1, a2 in zip(C1.order.axes, C2.order.axes): if a1 == a2 == axis: continue a1.unify(a2) C_new, = Concat(None, axis=axis)(C1, C2) OptimizeRule.replace_variable(graph, C_new, C) elif v == B: B1, B2 = v_pair if axis in axes_K_B: split_axis_B = axis if (A.shape_dict[axes_K_A[0]] * (s1 + s2)) % s1 == 0: split_axis_A = axes_K_A[0] else: # Factorize A's axes consisting to K into B's corresponding axes A = A.transpose(Order(axes_M + axes_K_A)) A = A.reshape(order=Order((Axis(), ) + axes_K_B), shape=[M] + [B.shape_dict[a] for a in axes_K_B]) split_axis_A = split_axis_B axes_K_A = axes_K_B A1, A2 = SplitAxis(None, axis=split_axis_A, sections=[(A.shape_dict[split_axis_A] * s1) // (s1 + s2)])(A) C1, = Tensordot(None, [axes_K_A, axes_K_B])(A1, B1) C2, = Tensordot(None, [axes_K_A, axes_K_B])(A2, B2) OptimizeRule.replace_variable(graph, (C1 + C2).reshape( shape_M + shape_N, Order(axes_M + axes_N)).transpose_like(C), C) else: C1, = Tensordot(None, op.axes)(A, B1) C2, = Tensordot(None, op.axes)(A, B2) for a1, a2 in zip(C1.order.axes, C2.order.axes): if a1 == a2 == axis: continue a1.unify(a2) C_new, = Concat(None, axis=axis)(C1, C2) OptimizeRule.replace_variable(graph, C_new, C) elif v == C: """ before) C[M, N] = A[M, K] @ B[K, N] after) In case `axis` is in `N`, C[M, N1] = Concat(A[M, K] @ B1[K, N1]) C[M, N2] = Concat(A[M, K] @ B2[K, N2]) """ raise NotImplementedError( f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError
def _simplify_orders( variables: List[Variable] ) -> Tuple[Dict[Variable, Order], Dict[Variable, AxisKeyDict[int]]]: """ Simplify variable orders based on follow rules - Axis whose size is :code:`1` will be removed. - If axis :code:`A` and :code:`B` is adjacent in all variables which has axis :code:`A` and axis :code:`B`, :code:`A` and :code:`B` will be merged. - For example, :code:`OrderABC` and :code:`OrderCAB` can be simplified as :code:`OrderXC` and :code:`OrderCX` - In this case, the size of axis :code:`X` is calculated as :code:`(size of axis A) * (size of axis B)` ...code-block::text ex) x0.order=NHWC, simplify x0.order=X y.order=NHWC ------------> y.order=X ex) x0.order=C, simplify x0.order=C x1.order=NHWC ------------> x1.order=XC y.order=NHWC y.order=XC ex) x0.order=C, simplify x0.order=C x1.order=HW ------------> x1.order=X y.order=NHWC y.order=NXC Returns: (tuple of dicts) simplified orders and shape """ orders = {} # type: Dict[Variable, Order] shape_dicts = {} # type: Dict[Variable, AxisKeyDict[int]] axis_scalar = Axis("Scalar") # remove all axes whose size is `1`. for v in variables: new_axes = [a for a in v.order.axes if v.shape_dict[a] != 1] orders[v] = Order(new_axes) shape_dicts[v] = AxisKeyDict(new_axes, [v.shape_dict[a] for a in new_axes]) if len(new_axes) == 0 and v.size == 1: orders[v] = Order([axis_scalar]) shape_dicts[v] = AxisKeyDict([axis_scalar], [1]) # list up all axes and variables which have the axis var_dict = AxisKeyDict[Set[Variable]]() for v in variables: for axis in orders[v].axes: if axis in var_dict: var_dict[axis].add(v) else: var_dict[axis] = {v} # find pair of axes which can be merged counter = 0 flag_continue = True while flag_continue: flag_continue = False for axis1, vars1 in list(var_dict.items()): for axis2, vars2 in list(var_dict.items()): if axis1 == axis2: continue if vars1 != vars2 or any(orders[v].axes_dict[axis1] + 1 != orders[v].axes_dict[axis2] for v in vars1): # `axis1` and `axis2` must be adjacent. continue # merge `axis1` and `axis2` into `axis_new` axis_new = Axis(f"X{counter}") counter += 1 for v in vars1: shape_dict = shape_dicts[v] shape_dict[ axis_new] = shape_dict[axis1] * shape_dict[axis2] del shape_dict[axis1] del shape_dict[axis2] order = orders[v] orders[v] = Order(order.axes[:order.axes_dict[axis1]] + (axis_new, ) + order.axes[order.axes_dict[axis2] + 1:]) var_dict[axis_new] = vars1 del var_dict[axis1] del var_dict[axis2] flag_continue = True break if flag_continue: break return orders, shape_dicts
def convert( self, inputs: List["tf.Tensor"], outputs: List["tf.Tensor"], order_hints: Optional[Dict[Union["tf.Tensor", "tf.Variable"], Order]] = None ) -> Graph: """convert(model, input_orders=None) Args: inputs (list of `tf.Tensor`): tensorflow input tensors outputs (list of `tf.Tensor`): tensorflow output tensors order_hints: Order annotations which helps webdnn's optimizer. .. admonition:: Example .. code:: # y = x @ W + b x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b) webdnn_graph = TensorFlowConverter().convert([x], [y]) Returns: (:class:`~webdnn.graph.graph.Graph`): WebDNN IR Graph """ for tensor in inputs: shape = [ Placeholder() if dim.value is None else dim.value for dim in tensor.shape.dims ] if isinstance(shape[0], Placeholder): shape[0] = self._batch_size self.set_variable(tensor, Variable(shape, Order([None] * len(shape)))) ops = _listup_operations(inputs, outputs) for op in ops: self._convert_operator(op) if order_hints: for tensor, order in order_hints.items(): if isinstance(tensor, tf.Variable): tensor = tensor.value() variable = self.get_variable(tensor) for axis1, axis2 in zip(variable.order.axes, order.axes): axis1.unify(axis2) # Remove redundant ReinterpretAxis operators graph = Graph([self.get_variable(tensor) for tensor in inputs], [self.get_variable(tensor) for tensor in outputs]) graph, _ = TensorFlowFrontendOptimizeRule().optimize(graph) for v in graph.inputs: v.attributes.add(Input(v)) for v in graph.outputs: v.attributes.add(Output(v)) return graph
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False matches = traverse.search_sub_structure( graph, [Sgemm, Variable, ElementwiseMul]) while len(matches) > 0: match = matches.pop() sgemm = match[0] # type: Sgemm elementwise_mul = match[2] # type: ElementwiseMul out_order = sgemm.parameters["out_order"] out_shape = sgemm.parameters["out_shape"] axis_k = Axis('AxisK') if not isinstance(sgemm.inputs["A"], ConstantVariable) and not isinstance( sgemm.inputs["B"], ConstantVariable): # neither x nor w1 is constant continue elif isinstance(sgemm.inputs["A"], ConstantVariable): w1 = sgemm.inputs["A"] # type: ConstantVariable if sgemm.transpose_A: # w1.shape = (M, K) shape = [] axes = [] for axis, size in zip(out_order.axes, out_shape): shape.append(size) axes.append(axis) if mul(shape) >= sgemm.M: break if mul(shape) != sgemm.M: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes + [axis_k]) w1_virtual_shape = shape + [sgemm.K] else: # w1.shape = (K, M) shape = [sgemm.K] axes = [axis_k] for axis, size in zip(out_order.axes, out_shape): shape.append(size) axes.append(axis) if mul(shape) >= w1.size: break if mul(shape) != w1.size: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes) w1_virtual_shape = shape else: w1 = sgemm.inputs["B"] # type: ConstantVariable if sgemm.transpose_B: # w1.shape = (K, N) shape = [] axes = [] for axis, size in reversed( list(zip(out_order.axes, out_shape))): shape.insert(0, size) axes.insert(0, axis) if mul(shape) >= sgemm.N: break if mul(shape) != sgemm.N: # output axes are derived from both w1 and x continue w1_virtual_order = Order([axis_k] + axes) w1_virtual_shape = [sgemm.K] + shape else: # w1.shape = (N, K) shape = [sgemm.K] axes = [axis_k] for axis, size in reversed( list(zip(out_order.axes, out_shape))): shape.insert(0, size) axes.insert(0, axis) if mul(shape) >= w1.size: break if mul(shape) != w1.size: # output axes are derived from both w1 and x continue w1_virtual_order = Order(axes) w1_virtual_shape = shape h = sgemm.outputs["C"] # type: Variable x0 = elementwise_mul.inputs["x0"] x1 = elementwise_mul.inputs["x1"] if h == x1: if not isinstance(x0, ConstantVariable): # w2 is not constant continue w2 = x0 # type: ConstantVariable else: if not isinstance(x1, ConstantVariable): # w2 is not constant continue w2 = x1 # type: ConstantVariable y = elementwise_mul.outputs["y"] # type: Variable if not all(axis in w1_virtual_order.axes for axis in w2.order.axes): # w2's axes are derived from both w1 and x continue elementwise_mul.remove_all() y_dummy, = Transpose(None)(h) y_dummy.change_order(y.order) y_dummy.replace(y) w2.change_order(w1_virtual_order) w_new = ConstantVariable( w1.data.reshape(w1_virtual_shape), w1_virtual_order) * w2 # type: ConstantVariable w1.data = w_new.data.reshape(w1.shape) flag_changed = True matches = traverse.search_sub_structure( graph, [Sgemm, Variable, ElementwiseMul]) return graph, flag_changed
def _broadcasted_order(order1: Order, order2: Order): axes = list(order1.axes) axes.extend([a for a in order2.axes if a not in order1.axes]) return Order(axes)
def optimize(self, graph: Graph) -> Tuple[Graph, bool]: flag_changed = False for op in traverse.listup_operators(graph): if isinstance(op, Reshape): flag_changed |= _replace_input(op, "x", op.parameters["in_order"]) flag_changed |= _replace_output(op, "y", op.parameters["out_order"]) continue elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D, Deconvolution2D, Space2Depth, Depth2Space)): flag_changed |= _replace_input(op, "x", OrderNHWC) flag_changed |= _replace_output(op, "y", OrderNHWC) continue elif isinstance(op, Softmax): x = op.inputs["x"] y = op.outputs["y"] target_axis = op.parameters["axis"] if not (x.ndim == 2 and x.order.axes_dict[target_axis] == x.ndim - 1): """ Before) | x | | y | |-----| -{softmax}-> |-----| | XYZ | axis=Y | XYZ | After) | x | | hx1 | | hx2 | | hy1 | | hy2 | | y | |-----| -{transpose}-> |-----| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |-----| -{transpose}-> |-----| | XYZ | | XZY | | NC | axis=C | NC | | XZY | | XYZ | : : order_nd = XZY order_2d = NC """ op.remove_all() axes_nd = list(x.order.axes) axes_nd.remove(target_axis) axes_nd.append(target_axis) order_nd = Order(axes_nd) shape_nd = tuple([x.shape_dict[axis] for axis in axes_nd]) order_2d = OrderNC shape_2d = tuple([ x.size // x.shape_dict[target_axis], x.shape_dict[target_axis] ]) if x.order == order_nd: hx1 = x else: hx1, = Transpose(None)(x) hx1.change_order(order_nd) flag_changed = True if hx1.order == order_2d and hx1.shape == shape_2d: hx2 = hx1 else: hx2, = Reshape(None, in_order=hx1.order, out_order=order_2d, out_shape=shape_2d)(hx1) flag_changed = True hy1, = Softmax(None, axis=Axis.C)(hx2) if hy1.order == order_nd and hy1.shape == shape_nd: hy2 = hy1 else: hy2, = Reshape(None, in_order=hy1.order, out_order=order_nd, out_shape=shape_nd)(hy1) flag_changed = True if hy2.order == y.order: y_dummy = hy2 else: y_dummy, = Transpose(None)(hy2) y_dummy.change_order(y.order) flag_changed = True y_dummy.replace(y) continue return graph, flag_changed
def test_X_TN(): template(x_order=Order([Axis.T, Axis.N]))