def optimize_operator(self, graph: Graph, op: ReinterpretAxis):
        x = op.inputs["x"]
        y = op.outputs["y"]

        if len(x.input_to) == 1 and x.output_from is None:
            op.remove_all()

            if isinstance(x, ConstantVariable):
                x = ConstantVariable(x.data, y.order)

                if y in graph.outputs:
                    index = graph.outputs.index(y)
                    graph.outputs.remove(y)
                    graph.outputs.insert(index, x)

                else:
                    y.replace(x)
            else:
                assert x in graph.inputs

                index = graph.inputs.index(x)
                graph.inputs.remove(x)
                graph.inputs.insert(index, y)

            return True

        if op.parameters["in_order"] == op.parameters["out_order"]:
            _remove_unary_operator(graph, op)
            return True

        flag_changed = False
        for axis1, axis2 in zip(op.parameters["in_order"].axes,
                                op.parameters["out_order"].axes):
            is_resolved1 = not (isinstance(axis1, AxisVar)
                                and axis1.value is None)
            is_resolved2 = not (isinstance(axis2, AxisVar)
                                and axis2.value is None)

            if is_resolved1 and not is_resolved2:
                axis2.unify(axis1)
                flag_changed = True

            elif not is_resolved1 and is_resolved2:
                axis1.unify(axis2)
                flag_changed = True

        if flag_changed:
            return True

        return False
Beispiel #2
0
    def optimize_operator(self, graph: Graph, op: ReinterpretAxis):
        x = op.inputs["x"]
        y = op.outputs["y"]

        if len(x.input_to) == 1 and x.output_from is None:
            if x in graph.inputs:
                op.remove_all()
                index = graph.inputs.index(x)
                graph.inputs.remove(x)
                graph.inputs.insert(index, y)
                return True

        if op.parameters["in_order"] == op.parameters["out_order"]:
            _remove_unary_operator(graph, op)
            return True

        return False
    def optimize_operator(self, graph: Graph, op: ReinterpretAxis):
        x = op.inputs["x"]
        y = op.outputs["y"]

        if op.parameters["in_order"] == op.parameters["out_order"]:
            _remove_unary_operator(graph, op)
            return True

        if x in graph.inputs and len(x.input_to) == 1:
            # before)
            #
            # x[Graph Input] -{ReinterpretAxis}- h -{op}->
            #
            # after)
            #
            # h[Graph Input] -{op}->

            op.remove_all()
            OptimizeRule.replace_variable(graph, x, y, with_assert=False)
            return True

        return False