예제 #1
0
    def optimize_operator(self, graph: Graph, op: Reshape):
        x = op.inputs["x"]
        y = op.outputs["y"]

        if x.order == y.order and x.shape == y.shape:
            _remove_unary_operator(graph, op)
            return True

        if x.shape == y.shape:
            op.remove_all()
            y_dummy, = ReinterpretAxis(None,
                                       in_order=x.order,
                                       out_order=y.order)(x)
            y_dummy.replace(y)
            return True

        if isinstance(x, ConstantVariable) and x.output_from is None:
            _remove_unary_operator(graph, op)
            x.change_order(y.order)
            return True

        if all([
                y not in graph.outputs,
                all(x.stride_dict[axis] == y.stride_dict[axis] for axis in
                    [axis for axis in x.order.axes if axis in y.order.axes]),
                all(isinstance(op2, Elementwise) for op2 in y.input_to)
        ]):
            _remove_unary_operator(graph, op)
            return True

        return False
예제 #2
0
    def optimize_operator(self, graph: Graph, op: Reshape):
        x = op.inputs["x"]
        y = op.outputs["y"]

        if x.order == y.order and x.shape == y.shape:
            # no reshape is occurred
            _remove_unary_operator(graph, op)
            return True

        if x.shape == y.shape:
            # only reinterpret_axis is occurred
            op.remove_all()
            y_dummy, = ReinterpretAxis(None,
                                       in_order=x.order,
                                       out_order=y.order)(x)
            y_dummy.replace(y)
            return True

        return False