def optimize_operator(self, graph: Graph, op: Reshape): x = op.inputs["x"] y = op.outputs["y"] if x.order == y.order and x.shape == y.shape: _remove_unary_operator(graph, op) return True if x.shape == y.shape: op.remove_all() y_dummy, = ReinterpretAxis(None, in_order=x.order, out_order=y.order)(x) y_dummy.replace(y) return True if isinstance(x, ConstantVariable) and x.output_from is None: _remove_unary_operator(graph, op) x.change_order(y.order) return True if all([ y not in graph.outputs, all(x.stride_dict[axis] == y.stride_dict[axis] for axis in [axis for axis in x.order.axes if axis in y.order.axes]), all(isinstance(op2, Elementwise) for op2 in y.input_to) ]): _remove_unary_operator(graph, op) return True return False
def optimize_operator(self, graph: Graph, op: Reshape): x = op.inputs["x"] y = op.outputs["y"] if x.order == y.order and x.shape == y.shape: # no reshape is occurred _remove_unary_operator(graph, op) return True if x.shape == y.shape: # only reinterpret_axis is occurred op.remove_all() y_dummy, = ReinterpretAxis(None, in_order=x.order, out_order=y.order)(x) y_dummy.replace(y) return True return False