def optimize_operator(self, graph: Graph, op: Reshape): x = op.inputs["x"] y = op.outputs["y"] if x.order == y.order and x.shape == y.shape: _remove_unary_operator(graph, op) return True if x.shape == y.shape: op.remove_all() y_dummy, = ReinterpretAxis(None, in_order=x.order, out_order=y.order)(x) y_dummy.replace(y) return True if isinstance(x, ConstantVariable) and x.output_from is None: _remove_unary_operator(graph, op) x.change_order(y.order) return True if all([ y not in graph.outputs, all(x.stride_dict[axis] == y.stride_dict[axis] for axis in [axis for axis in x.order.axes if axis in y.order.axes]), all(isinstance(op2, Elementwise) for op2 in y.input_to) ]): _remove_unary_operator(graph, op) return True return False
def optimize_operator(self, graph: Graph, op: Reshape): x = op.inputs["x"] y = op.outputs["y"] if x.order == y.order and x.shape == y.shape: # no reshape is required _remove_unary_operator(graph, op) return True if x.shape == y.shape: # only reinterpret_axis is required op.remove_all() y_dummy = x.reinterpret_axes(y.order) OptimizeRule.replace_variable(graph, y_dummy, y) return True return False
def optimize_operator(self, graph: Graph, op: Reshape): x = op.inputs["x"] y = op.outputs["y"] if x.order == y.order and x.shape == y.shape: # no reshape is occurred _remove_unary_operator(graph, op) return True if x.shape == y.shape: # only reinterpret_axis is occurred op.remove_all() y_dummy, = ReinterpretAxis(None, in_order=x.order, out_order=y.order)(x) y_dummy.replace(y) return True return False
def _split_reshape(graph: Graph, op: Reshape, v: Variable, v_pair: Sequence[Variable], axis: Axis): x = op.inputs["x"] y = op.outputs["y"] s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] op.remove_all() in_order = op.in_order out_order = op.out_order x_shape = [x.shape_dict[a] for a in in_order.axes] y_shape = [y.shape_dict[a] for a in out_order.axes] if v == x: """ before) x -{reshape}- y after) x_0 -{reshape}- t_0 -+ +-{concat[axis_k]}- t -{reshape}- y x_1 -{reshape}- t_1 -+ shape and order is changed as follows: x.shape = [dx_0, dx_1, ..., dx_m, ..., dx_M-1] x_0.shape = [dx_0, dx_1, ..., dx_m/2, ..., dx_M-1] --------------------------------------------------------------------------------- t_0.shape = [dy_0, dy_1, ..., dy_n, ..., dy_k/2, ..., dy_N-1] t.shape = [dy_0, dy_1, ..., dy_n*2, ..., dy_k/2, ..., dy_N-1] y.shape = [dy_0, dy_1, ..., dy_n, ..., dy_k, ..., dy_N-1] m: split target axis find axis_k and axis_n, which satisfies follow conditions dy_n * dy_n+1 * ... * dy_N-1 == dx_m * dx_m+1 * ... * dx_M-1 dy_k % 2 == 0 n <= k """ x_0, x_1 = v_pair dx_prod = mul(x_shape[in_order.axes_dict[axis]:]) dy_prod = 1 axis_k_candidate = [] for axis_n in reversed(out_order.axes): dy_prod *= y.shape_dict[axis_n] if y.shape_dict[axis_n] % 2 == 0: axis_k_candidate.append(axis_n) if dx_prod == dy_prod: # Split most large axis axis_k = (sorted(axis_k_candidate, key=lambda a: y.shape_dict[a], reverse=True))[0] t_0_shape = [y.shape_dict[a] for a in out_order.axes] t_0_shape[out_order.axes_dict[axis_k]] = t_0_shape[ out_order.axes_dict[axis_k]] // 2 # TODO t_0, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=t_0_shape)(x_0) t_1_shape = [y.shape_dict[a] for a in out_order.axes] t_1_shape[out_order.axes_dict[axis_k]] = t_1_shape[ out_order.axes_dict[axis_k]] // 2 # TODO t_1, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=t_1_shape)(x_1) t, = Concat(None, axis=axis_n)(t_0, t_1) y_new, = Reshape(None, in_order=out_order, out_order=out_order, out_shape=y_shape)(t) OptimizeRule.replace_variable(graph, y_new.transpose_like(y), y) break elif dy_prod > (s1 + s2) * dx_prod: raise NotImplementedError( f"Variable is too large to handle in WebGL backend: {v}") elif v == y: """ algorithm is almost same as the case `v == x` (above). before) x -{reshape}- y after) +- t_0 -{reshape}- y_0 x -{reshape}- t-{split}-+ +- t_1 -{reshape}- y_1 shape and order is changed as follows: x.shape = [dx_0, dx_1, ..., dx_m, ..., dx_k, ..., dx_M-1] t.shape = [dx_0, dx_1, ..., dx_m*2, ..., dx_k/2, ..., dx_M-1] t_0.shape = [dx_0, dx_1, ..., dx_m, ..., dx_k/2, ..., dx_M-1] --------------------------------------------------------------------------------- y_0.shape = [dy_0, dy_1, ..., dy_n/2, ..., dy_N-1] y.shape = [dy_0, dy_1, ..., dy_n, ..., dy_N-1] m: split target axis find axis_k and axis_m, which satisfies follow conditions dx_m * dx_m+1 * ... * dx_M-1 == dy_n * dy_n+1 * ... * dy_N-1 dx_k % 2 == 0 m <= k """ y_0, y_1 = v_pair dx_prod = 1 dy_prod = mul(x_shape[out_order.axes_dict[axis]:]) axis_k_candidate = [] for axis_m in reversed(in_order.axes): dx_prod *= x.shape_dict[axis_m] if x.shape_dict[axis_m] % 2 == 0: axis_k_candidate.append(axis_m) if dx_prod == dy_prod: # Split most large axis axis_k = (sorted(axis_k_candidate, key=lambda a: x.shape_dict[a], reverse=True))[0] t_shape = [x.shape_dict[a] for a in in_order.axes] t_shape[in_order.axes_dict[axis_m]] = 2 * t_shape[ in_order.axes_dict[axis_m]] # TODO t_shape[in_order.axes_dict[axis_k]] = t_shape[ in_order.axes_dict[axis_k]] // 2 # TODO t, = Reshape(None, in_order=in_order, out_order=in_order, out_shape=t_shape)(x) t_0, t_1 = SplitAxis(None, axis=axis_m, sections=[t.shape_dict[axis_m] // 2 ])(t) # TODO y_0_new, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=y_0.shape)(t_0) y_1_new, = Reshape(None, in_order=in_order, out_order=out_order, out_shape=y_1.shape)(t_1) OptimizeRule.replace_variable(graph, y_0_new.reshape_like(y_0), y_0) OptimizeRule.replace_variable(graph, y_1_new.reshape_like(y_1), y_1) break elif dx_prod > dy_prod: raise NotImplementedError( f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError
def _split_reshape(graph: Graph, op: Reshape, v: Variable, v_pair: Sequence[Variable], axis: Axis): x = op.inputs["x"] y = op.outputs["y"] s1 = v_pair[0].shape_dict[axis] s2 = v_pair[1].shape_dict[axis] op.remove_all() if v == x: """ Regard x's order as `[D1, D2]`, shape as `[d1x, d2x]`, where the most outside axis in D2 is the split target axis. If y's shape can be converted as `[d1x, d2x]` by merging some adjacent axes in y, split can be performed. before) x -{reshape}- y after) x_0 -{reshape}- y_0 -+ +-{concat[axis]}- y x_1 -{reshape}- y_1 -+ """ x_0, x_1 = v_pair d2x = mul(x.shape[x.order.axes_dict[axis]:]) d2y = 1 for axis_y in reversed(y.order.axes): d2y *= y.shape_dict[axis_y] if d2y == d2x: y_0_shape = [y.shape_dict[axis_y] * s1 // (s1 + s2) if a == axis_y else y.shape_dict[a] for a in y.order.axes] y_1_shape = [y.shape_dict[axis_y] * s2 // (s1 + s2) if a == axis_y else y.shape_dict[a] for a in y.order.axes] y_0 = x_0.reshape(y_0_shape, y.order) y_1 = x_1.reshape(y_1_shape, y.order) y_new, = Concat(None, axis=axis_y)(y_0, y_1) OptimizeRule.replace_variable(graph, y_new, y) break elif d2y > (s1 + s2) * d2x: raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}") elif v == y: """ Same algorithm in case `v == y` (above). before) x -{reshape}- y after) +- x_0 -{reshape}- y_0 x -{split}-+ +- x_1 -{reshape}- y_1 """ y_0, y_1 = v_pair d2y = mul(y.shape[y.order.axes_dict[axis]:]) d2x = 1 for axis_x in reversed(x.order.axes): d2x *= x.shape_dict[axis_x] if d2x == d2y: x_0, x_1 = SplitAxis(None, axis=axis_x, sections=[x.shape_dict[axis_x] * s1 // (s1 + s2)])(x) OptimizeRule.replace_variable(graph, x_0.reshape_like(y_0), y_0) OptimizeRule.replace_variable(graph, x_1.reshape_like(y_1), y_1) break elif d2y > (s1 + s2) * d2x: raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}") else: raise UnexpectedAndPleaseReportError